Compare commits

..

17 Commits

Author SHA1 Message Date
Danny Avila
278590d0bb refactor: update processOpenIDAuth to add a flag for processing existing users only 2025-09-17 20:43:27 -04:00
Danny Avila
41a4674469 chore: update logger import to use data-schemas in login controller 2025-09-17 20:14:15 -04:00
Danny Avila
e7a9cf88ac chore: update middleware imports/exports 2025-09-17 20:14:14 -04:00
Danny Avila
f6925f906b WIP: first pass, OpenID Proxy Auth 2025-09-17 20:14:12 -04:00
Danny Avila
e90fd1df15 refactor: reorder middleware imports for clarity 2025-09-17 20:13:46 -04:00
Danny Avila
a1f9f3dd39 refactor: re-use logic for admin routes 2025-09-17 20:13:46 -04:00
Danny Avila
fbe0def2fa WIP: admin auth 2025-09-17 20:13:46 -04:00
Federico Ruggi
d04da60b3b 💫 feat: MCP OAuth Auto-Reconnect (#9646)
* add oauth reconnect tracker

* add connection tracker to mcp manager

* reconnect oauth mcp servers function

* call reconnection in auth controller

* make sure to check connection in panel

* wait for isConnected

* add const for poll interval

* add logging to tryReconnect

* check expiration

* check mcp manager is not null

* check mcp manager is not null

* add test for reconnecting mcp server

* unify logic inside OAuthReconnectionManager

* test reconnection manager, adjust

* chore: reorder import statements in index.js

* chore: imports

* chore: imports

* chore: imports

* chore: imports

* chore: imports

* chore: imports and use types explicitly

---------

Co-authored-by: Danny Avila <danny@librechat.ai>
2025-09-17 16:49:36 -04:00
keltschdt
0e94d97bfb fix: Disable TTL For Transient OIDC Users In Permission Service (#9643) 2025-09-17 14:21:36 -04:00
Danny Avila
45ab4d4503 🎋 refactor: Improve Message UI State Handling (#9678)
* refactor: `ExecuteCode` component with submission state handling and cancellation message

* fix: Remove unnecessary argument check for execute_code tool call

* refactor: streamlined messages context

* chore: remove unused Convo prop

* chore: remove unnecessary whitespace in Message component

* refactor: enhance message context with submission state and latest message tracking

* chore: import order
2025-09-17 13:07:56 -04:00
github-actions[bot]
0ceef12eea 🌍 i18n: Update translation.json with latest translations (#9648)
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2025-09-15 18:41:34 -04:00
Danny Avila
6738360051 📋 refactor: Agent Tool Permissions for File Upload Options (#9647)
- Added isEphemeralAgent function to streamline checks for ephemeral agents.
- Updated logic in useAgentToolPermissions to utilize the new function for determining tool access.
- Introduced comprehensive tests for useAgentToolPermissions covering various scenarios including ephemeral agents, regular agents with tools, and edge cases.
2025-09-15 12:57:40 -04:00
Dustin Healy
52b65492d5 👻 fix: Phantom MCP Tool Calls (#9634)
* fix: mcp tool calls no longer happening when unselected (without breaking new convo behavior)

* refactor: Improve ephemeral agent synchronization logic in useMCPSelect

---------

Co-authored-by: Danny Avila <danny@librechat.ai>
2025-09-15 10:35:15 -04:00
Danny Avila
7a9a99d2a0 🔗 refactor: URL sanitization for MCP logging (#9632) 2025-09-14 18:55:32 -04:00
Danny Avila
5bfb06b417 💻 feat: Add Proxy Config for Mistral OCR API (#9629)
* 💻 feat: Add proxy configuration support for Mistral OCR API requests

* refactor: Implement proxy support for Mistral API requests using HttpsProxyAgent
2025-09-14 18:50:41 -04:00
github-actions[bot]
2ce8f1f686 🌍 i18n: Update translation.json with latest translations (#9626)
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2025-09-14 18:48:17 -04:00
Danny Avila
1a47601533 🔃 fix: Refresh Token Edge Cases (#9625)
* 🔃 fix: Refresh Token Edge Cases

* chore: Update parameter type for setAuthTokens function
2025-09-13 21:36:45 -04:00
95 changed files with 3038 additions and 1619 deletions

View File

@@ -233,7 +233,6 @@ class BaseClient {
sender: 'User',
text,
isCreatedByUser: true,
targetModel: this.modelOptions?.model ?? this.model,
};
}

View File

@@ -1,6 +1,6 @@
const { MCPManager, FlowStateManager } = require('@librechat/api');
const { EventSource } = require('eventsource');
const { Time } = require('librechat-data-provider');
const { MCPManager, FlowStateManager, OAuthReconnectionManager } = require('@librechat/api');
const logger = require('./winston');
global.EventSource = EventSource;
@@ -26,4 +26,6 @@ module.exports = {
createMCPManager: MCPManager.createInstance,
getMCPManager: MCPManager.getInstance,
getFlowStateManager,
createOAuthReconnectionManager: OAuthReconnectionManager.createInstance,
getOAuthReconnectionManager: OAuthReconnectionManager.getInstance,
};

View File

@@ -112,17 +112,8 @@ module.exports = {
update.expiredAt = null;
}
/** @type {{ $set: Partial<TConversation>; $addToSet?: Record<string, any>; $unset?: Record<keyof TConversation, number> }} */
/** @type {{ $set: Partial<TConversation>; $unset?: Record<keyof TConversation, number> }} */
const updateOperation = { $set: update };
if (convo.model && convo.endpoint) {
updateOperation.$addToSet = {
modelHistory: {
model: convo.model,
endpoint: convo.endpoint,
},
};
}
if (metadata && metadata.unsetFields && Object.keys(metadata.unsetFields).length > 0) {
updateOperation.$unset = metadata.unsetFields;
}

View File

@@ -11,8 +11,9 @@ const {
registerUser,
} = require('~/server/services/AuthService');
const { findUser, getUserById, deleteAllUserSessions, findSession } = require('~/models');
const { getOpenIdConfig } = require('~/strategies');
const { getGraphApiToken } = require('~/server/services/GraphTokenService');
const { getOAuthReconnectionManager } = require('~/config');
const { getOpenIdConfig } = require('~/strategies');
const registrationController = async (req, res) => {
try {
@@ -96,14 +97,25 @@ const refreshController = async (req, res) => {
return res.status(200).send({ token, user });
}
// Find the session with the hashed refresh token
const session = await findSession({
userId: userId,
refreshToken: refreshToken,
});
/** Session with the hashed refresh token */
const session = await findSession(
{
userId: userId,
refreshToken: refreshToken,
},
{ lean: false },
);
if (session && session.expiration > new Date()) {
const token = await setAuthTokens(userId, res, session._id);
const token = await setAuthTokens(userId, res, session);
// trigger OAuth MCP server reconnection asynchronously (best effort)
void getOAuthReconnectionManager()
.reconnectServers(userId)
.catch((err) => {
logger.error('Error reconnecting OAuth MCP servers:', err);
});
res.status(200).send({ token, user });
} else if (req?.query?.retry) {
// Retrying from a refresh token request that failed (401)

View File

@@ -1,6 +1,6 @@
const { logger } = require('@librechat/data-schemas');
const { generate2FATempToken } = require('~/server/services/twoFactorService');
const { setAuthTokens } = require('~/server/services/AuthService');
const { logger } = require('~/config');
const loginController = async (req, res) => {
try {

View File

@@ -0,0 +1,50 @@
const { isEnabled } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { syncUserEntraGroupMemberships } = require('~/server/services/PermissionService');
const { setAuthTokens, setOpenIDAuthTokens } = require('~/server/services/AuthService');
const { checkBan } = require('~/server/middleware');
const domains = {
client: process.env.DOMAIN_CLIENT,
server: process.env.DOMAIN_SERVER,
};
function createOAuthHandler(redirectUri = domains.client) {
/**
* A handler to process OAuth authentication results.
* @type {Function}
* @param {ServerRequest} req - Express request object.
* @param {ServerResponse} res - Express response object.
* @param {NextFunction} next - Express next middleware function.
*/
return async (req, res, next) => {
try {
if (res.headersSent) {
return;
}
await checkBan(req, res);
if (req.banned) {
return;
}
if (
req.user &&
req.user.provider == 'openid' &&
isEnabled(process.env.OPENID_REUSE_TOKENS) === true
) {
await syncUserEntraGroupMemberships(req.user, req.user.tokenset.access_token);
setOpenIDAuthTokens(req.user.tokenset, res, req.user._id.toString());
} else {
await setAuthTokens(req.user._id, res);
}
res.redirect(redirectUri);
} catch (err) {
logger.error('Error in setting authentication tokens:', err);
next(err);
}
};
}
module.exports = {
createOAuthHandler,
};

View File

@@ -12,6 +12,7 @@ const { logger } = require('@librechat/data-schemas');
const mongoSanitize = require('express-mongo-sanitize');
const { isEnabled, ErrorController } = require('@librechat/api');
const { connectDb, indexSync } = require('~/db');
const initializeOAuthReconnectManager = require('./services/initializeOAuthReconnectManager');
const createValidateImageRequest = require('./middleware/validateImageRequest');
const { jwtLogin, ldapLogin, passportLogin } = require('~/strategies');
const { updateInterfacePermissions } = require('~/models/interface');
@@ -108,6 +109,7 @@ const startServer = async () => {
app.use('/oauth', routes.oauth);
/* API Endpoints */
app.use('/api/auth', routes.auth);
app.use('/api/admin', routes.adminAuth);
app.use('/api/actions', routes.actions);
app.use('/api/keys', routes.keys);
app.use('/api/user', routes.user);
@@ -154,7 +156,7 @@ const startServer = async () => {
res.send(updatedIndexHtml);
});
app.listen(port, host, () => {
app.listen(port, host, async () => {
if (host === '0.0.0.0') {
logger.info(
`Server listening on all interfaces at port ${port}. Use http://localhost:${port} to access it`,
@@ -163,7 +165,9 @@ const startServer = async () => {
logger.info(`Server listening at http://${host == '0.0.0.0' ? 'localhost' : host}:${port}`);
}
initializeMCPs().then(() => checkMigrations());
await initializeMCPs();
await initializeOAuthReconnectManager();
await checkMigrations();
});
};

View File

@@ -14,6 +14,7 @@ const checkInviteUser = require('./checkInviteUser');
const requireJwtAuth = require('./requireJwtAuth');
const configMiddleware = require('./config/app');
const validateModel = require('./validateModel');
const requireAdmin = require('./requireAdmin');
const moderateText = require('./moderateText');
const logHeaders = require('./logHeaders');
const setHeaders = require('./setHeaders');
@@ -36,6 +37,7 @@ module.exports = {
setHeaders,
logHeaders,
moderateText,
requireAdmin,
validateModel,
requireJwtAuth,
checkInviteUser,

View File

@@ -0,0 +1,22 @@
const { logger } = require('@librechat/data-schemas');
const { SystemRoles } = require('librechat-data-provider');
/**
* Middleware to check if authenticated user has admin role
* Should be used AFTER authentication middleware (requireJwtAuth, requireLocalAuth, etc.)
*/
const requireAdmin = (req, res, next) => {
if (!req.user) {
logger.warn('[requireAdmin] No user found in request');
return res.status(401).json({ message: 'Authentication required' });
}
if (!req.user.role || req.user.role !== SystemRoles.ADMIN) {
logger.debug('[requireAdmin] Access denied for non-admin user:', req.user.email);
return res.status(403).json({ message: 'Access denied: Admin privileges required' });
}
next();
};
module.exports = requireAdmin;

View File

@@ -1,6 +1,6 @@
const passport = require('passport');
const cookies = require('cookie');
const { isEnabled } = require('~/server/utils');
const passport = require('passport');
const { isEnabled } = require('@librechat/api');
/**
* Custom Middleware to handle JWT authentication, with support for OpenID token reuse

View File

@@ -1,680 +0,0 @@
const express = require('express');
const request = require('supertest');
const mongoose = require('mongoose');
const { MongoMemoryServer } = require('mongodb-memory-server');
jest.mock('@librechat/data-schemas', () => ({
logger: {
debug: jest.fn(),
info: jest.fn(),
warn: jest.fn(),
error: jest.fn(),
},
createMethods: jest.fn(() => ({})),
createModels: jest.fn(() => ({})),
}));
jest.mock('~/server/middleware', () => ({
requireJwtAuth: (req, res, next) => next(),
validateMessageReq: (req, res, next) => next(),
}));
jest.mock('~/models', () => ({
getConvo: jest.fn(),
saveConvo: jest.fn(),
saveMessage: jest.fn(),
getMessage: jest.fn(),
getMessages: jest.fn(),
updateMessage: jest.fn(),
deleteMessages: jest.fn(),
}));
jest.mock('~/db/models', () => {
let User, Message, Transaction, Conversation;
return {
get User() {
return User;
},
get Message() {
return Message;
},
get Transaction() {
return Transaction;
},
get Conversation() {
return Conversation;
},
setUser: (model) => {
User = model;
},
setMessage: (model) => {
Message = model;
},
setTransaction: (model) => {
Transaction = model;
},
setConversation: (model) => {
Conversation = model;
},
};
});
describe('Costs Endpoint', () => {
let app;
let mongoServer;
let messagesRouter;
let User, Message, Transaction, Conversation;
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
await mongoose.connect(mongoServer.getUri());
const userSchema = new mongoose.Schema({
_id: String,
name: String,
email: String,
});
const conversationSchema = new mongoose.Schema({
conversationId: String,
user: String,
title: String,
createdAt: Date,
});
const messageSchema = new mongoose.Schema({
messageId: String,
conversationId: String,
user: String,
isCreatedByUser: Boolean,
tokenCount: Number,
createdAt: Date,
});
const transactionSchema = new mongoose.Schema({
conversationId: String,
user: String,
tokenType: String,
tokenValue: Number,
createdAt: Date,
});
User = mongoose.model('User', userSchema);
Conversation = mongoose.model('Conversation', conversationSchema);
Message = mongoose.model('Message', messageSchema);
Transaction = mongoose.model('Transaction', transactionSchema);
const dbModels = require('~/db/models');
dbModels.setUser(User);
dbModels.setMessage(Message);
dbModels.setTransaction(Transaction);
dbModels.setConversation(Conversation);
require('~/db/models');
try {
messagesRouter = require('../messages');
} catch (error) {
console.error('Error loading messages router:', error);
throw error;
}
app = express();
app.use(express.json());
app.use((req, res, next) => {
req.user = { id: 'test-user-id' };
next();
});
app.use('/api/messages', messagesRouter);
});
afterAll(async () => {
await mongoose.disconnect();
await mongoServer.stop();
});
beforeEach(async () => {
await User.deleteMany({});
await Conversation.deleteMany({});
await Message.deleteMany({});
await Transaction.deleteMany({});
});
describe('GET /:conversationId/costs', () => {
const conversationId = 'test-conversation-123';
const userId = 'test-user-id';
it('should return cost data for valid conversation', async () => {
const { getConvo } = require('~/models');
getConvo.mockResolvedValue({
conversationId,
user: userId,
title: 'Test Conversation',
});
const conversation = new Conversation({
conversationId,
user: userId,
title: 'Test Conversation',
createdAt: new Date('2024-01-01T09:00:00Z'),
});
await conversation.save();
const userMessage = new Message({
messageId: 'user-msg-1',
conversationId,
user: userId,
isCreatedByUser: true,
tokenCount: 100,
createdAt: new Date('2024-01-01T10:00:00Z'),
});
const aiMessage = new Message({
messageId: 'ai-msg-1',
conversationId,
user: userId,
isCreatedByUser: false,
tokenCount: 150,
createdAt: new Date('2024-01-01T10:01:00Z'),
});
await Promise.all([userMessage.save(), aiMessage.save()]);
const promptTransaction = new Transaction({
conversationId,
user: userId,
tokenType: 'prompt',
tokenValue: 500000,
createdAt: new Date('2024-01-01T10:00:30Z'),
});
const completionTransaction = new Transaction({
conversationId,
user: userId,
tokenType: 'completion',
tokenValue: 750000,
createdAt: new Date('2024-01-01T10:01:30Z'),
});
await Promise.all([promptTransaction.save(), completionTransaction.save()]);
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
expect(response.status).toBe(200);
expect(response.body).toMatchObject({
conversationId,
totals: {
prompt: { usd: 0.5, tokenCount: 100 },
completion: { usd: 0.75, tokenCount: 150 },
total: { usd: 1.25, tokenCount: 250 },
},
perMessage: [
{ messageId: 'user-msg-1', tokenType: 'prompt', tokenCount: 100, usd: 0.5 },
{ messageId: 'ai-msg-1', tokenType: 'completion', tokenCount: 150, usd: 0.75 },
],
});
});
it('should return empty data for conversation with no messages', async () => {
const { getConvo } = require('~/models');
getConvo.mockResolvedValue({
conversationId,
user: userId,
title: 'Test Conversation',
});
const conversation = new Conversation({
conversationId,
user: userId,
title: 'Test Conversation',
createdAt: new Date('2024-01-01T09:00:00Z'),
});
await conversation.save();
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
expect(response.status).toBe(200);
expect(response.body).toMatchObject({
conversationId,
totals: {
prompt: { usd: 0, tokenCount: 0 },
completion: { usd: 0, tokenCount: 0 },
total: { usd: 0, tokenCount: 0 },
},
perMessage: [],
});
});
it('should handle messages without transactions', async () => {
const { getConvo } = require('~/models');
getConvo.mockResolvedValue({
conversationId,
user: userId,
title: 'Test Conversation',
});
const conversation = new Conversation({
conversationId,
user: userId,
title: 'Test Conversation',
createdAt: new Date('2024-01-01T09:00:00Z'),
});
await conversation.save();
const userMessage = new Message({
messageId: 'user-msg-1',
conversationId,
user: userId,
isCreatedByUser: true,
tokenCount: 100,
createdAt: new Date('2024-01-01T10:00:00Z'),
});
const aiMessage = new Message({
messageId: 'ai-msg-1',
conversationId,
user: userId,
isCreatedByUser: false,
tokenCount: 150,
createdAt: new Date('2024-01-01T10:01:00Z'),
});
await Promise.all([userMessage.save(), aiMessage.save()]);
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
expect(response.status).toBe(200);
expect(response.body.totals.prompt.usd).toBe(0);
expect(response.body.totals.completion.usd).toBe(0);
expect(response.body.totals.total.usd).toBe(0);
});
it('should aggregate multiple transactions correctly', async () => {
const { getConvo } = require('~/models');
getConvo.mockResolvedValue({
conversationId,
user: userId,
title: 'Test Conversation',
});
const conversation = new Conversation({
conversationId,
user: userId,
title: 'Test Conversation',
createdAt: new Date('2024-01-01T09:00:00Z'),
});
await conversation.save();
const userMessage = new Message({
messageId: 'user-msg-1',
conversationId,
user: userId,
isCreatedByUser: true,
tokenCount: 100,
createdAt: new Date('2024-01-01T10:00:00Z'),
});
await userMessage.save();
const promptTransaction1 = new Transaction({
conversationId,
user: userId,
tokenType: 'prompt',
tokenValue: 300000,
createdAt: new Date('2024-01-01T10:00:30Z'),
});
const promptTransaction2 = new Transaction({
conversationId,
user: userId,
tokenType: 'prompt',
tokenValue: 200000,
createdAt: new Date('2024-01-01T10:00:45Z'),
});
await Promise.all([promptTransaction1.save(), promptTransaction2.save()]);
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
expect(response.status).toBe(200);
expect(response.body.totals.prompt.usd).toBe(0.5);
expect(response.body.perMessage[0].usd).toBe(0.5);
});
it('should handle null tokenCount values', async () => {
const { getConvo } = require('~/models');
getConvo.mockResolvedValue({
conversationId,
user: userId,
title: 'Test Conversation',
});
const conversation = new Conversation({
conversationId,
user: userId,
title: 'Test Conversation',
createdAt: new Date('2024-01-01T09:00:00Z'),
});
await conversation.save();
const userMessage = new Message({
messageId: 'user-msg-1',
conversationId,
user: userId,
isCreatedByUser: true,
tokenCount: null,
createdAt: new Date('2024-01-01T10:00:00Z'),
});
await userMessage.save();
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
expect(response.status).toBe(200);
expect(response.body.totals.prompt.tokenCount).toBe(0);
});
it('should handle null tokenValue in transactions', async () => {
const { getConvo } = require('~/models');
getConvo.mockResolvedValue({
conversationId,
user: userId,
title: 'Test Conversation',
});
const conversation = new Conversation({
conversationId,
user: userId,
title: 'Test Conversation',
createdAt: new Date('2024-01-01T09:00:00Z'),
});
await conversation.save();
const userMessage = new Message({
messageId: 'user-msg-1',
conversationId,
user: userId,
isCreatedByUser: true,
tokenCount: 100,
createdAt: new Date('2024-01-01T10:00:00Z'),
});
await userMessage.save();
const promptTransaction = new Transaction({
conversationId,
user: userId,
tokenType: 'prompt',
tokenValue: null,
createdAt: new Date('2024-01-01T10:00:30Z'),
});
await promptTransaction.save();
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
expect(response.status).toBe(200);
expect(response.body.totals.prompt.usd).toBe(0);
});
it('should handle negative tokenValue using Math.abs', async () => {
const { getConvo } = require('~/models');
getConvo.mockResolvedValue({
conversationId,
user: userId,
title: 'Test Conversation',
});
const conversation = new Conversation({
conversationId,
user: userId,
title: 'Test Conversation',
createdAt: new Date('2024-01-01T09:00:00Z'),
});
await conversation.save();
const userMessage = new Message({
messageId: 'user-msg-1',
conversationId,
user: userId,
isCreatedByUser: true,
tokenCount: 100,
createdAt: new Date('2024-01-01T10:00:00Z'),
});
await userMessage.save();
const promptTransaction = new Transaction({
conversationId,
user: userId,
tokenType: 'prompt',
tokenValue: -500000,
createdAt: new Date('2024-01-01T10:00:30Z'),
});
await promptTransaction.save();
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
expect(response.status).toBe(200);
expect(response.body.totals.prompt.usd).toBe(0.5);
});
it('should filter by user correctly', async () => {
const { getConvo } = require('~/models');
getConvo.mockResolvedValue({
conversationId,
user: userId,
title: 'Test Conversation',
});
const conversation = new Conversation({
conversationId,
user: userId,
title: 'Test Conversation',
createdAt: new Date('2024-01-01T09:00:00Z'),
});
await conversation.save();
const otherUserId = 'other-user-id';
const userMessage = new Message({
messageId: 'user-msg-1',
conversationId,
user: userId,
isCreatedByUser: true,
tokenCount: 100,
createdAt: new Date('2024-01-01T10:00:00Z'),
});
const otherUserMessage = new Message({
messageId: 'other-user-msg-1',
conversationId,
user: otherUserId,
isCreatedByUser: true,
tokenCount: 200,
createdAt: new Date('2024-01-01T10:00:00Z'),
});
await Promise.all([userMessage.save(), otherUserMessage.save()]);
const userTransaction = new Transaction({
conversationId,
user: userId,
tokenType: 'prompt',
tokenValue: 500000,
createdAt: new Date('2024-01-01T10:00:30Z'),
});
const otherUserTransaction = new Transaction({
conversationId,
user: otherUserId,
tokenType: 'prompt',
tokenValue: 1000000,
createdAt: new Date('2024-01-01T10:00:30Z'),
});
await Promise.all([userTransaction.save(), otherUserTransaction.save()]);
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
expect(response.status).toBe(200);
expect(response.body.totals.prompt.usd).toBe(0.5);
expect(response.body.perMessage).toHaveLength(1);
expect(response.body.perMessage[0].messageId).toBe('user-msg-1');
});
it('should filter transactions by tokenType', async () => {
const { getConvo } = require('~/models');
getConvo.mockResolvedValue({
conversationId,
user: userId,
title: 'Test Conversation',
});
const conversation = new Conversation({
conversationId,
user: userId,
title: 'Test Conversation',
createdAt: new Date('2024-01-01T09:00:00Z'),
});
await conversation.save();
const userMessage = new Message({
messageId: 'user-msg-1',
conversationId,
user: userId,
isCreatedByUser: true,
tokenCount: 100,
createdAt: new Date('2024-01-01T10:00:00Z'),
});
await userMessage.save();
const promptTransaction = new Transaction({
conversationId,
user: userId,
tokenType: 'prompt',
tokenValue: 500000,
createdAt: new Date('2024-01-01T10:00:30Z'),
});
const otherTransaction = new Transaction({
conversationId,
user: userId,
tokenType: 'other',
tokenValue: 1000000,
createdAt: new Date('2024-01-01T10:00:30Z'),
});
await Promise.all([promptTransaction.save(), otherTransaction.save()]);
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
expect(response.status).toBe(200);
expect(response.body.totals.prompt.usd).toBe(0.5);
expect(response.body.totals.completion.usd).toBe(0);
expect(response.body.totals.total.usd).toBe(0.5);
});
it('should map transactions to messages chronologically', async () => {
const { getConvo } = require('~/models');
getConvo.mockResolvedValue({
conversationId,
user: userId,
title: 'Test Conversation',
});
const conversation = new Conversation({
conversationId,
user: userId,
title: 'Test Conversation',
createdAt: new Date('2024-01-01T09:00:00Z'),
});
await conversation.save();
const userMessage1 = new Message({
messageId: 'user-msg-1',
conversationId,
user: userId,
isCreatedByUser: true,
tokenCount: 100,
createdAt: new Date('2024-01-01T10:00:00Z'),
});
const userMessage2 = new Message({
messageId: 'user-msg-2',
conversationId,
user: userId,
isCreatedByUser: true,
tokenCount: 200,
createdAt: new Date('2024-01-01T10:01:00Z'),
});
await Promise.all([userMessage1.save(), userMessage2.save()]);
const promptTransaction1 = new Transaction({
conversationId,
user: userId,
tokenType: 'prompt',
tokenValue: 500000,
createdAt: new Date('2024-01-01T10:00:30Z'),
});
const promptTransaction2 = new Transaction({
conversationId,
user: userId,
tokenType: 'prompt',
tokenValue: 1000000,
createdAt: new Date('2024-01-01T10:01:30Z'),
});
await Promise.all([promptTransaction1.save(), promptTransaction2.save()]);
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
expect(response.status).toBe(200);
expect(response.body.perMessage).toHaveLength(2);
expect(response.body.perMessage[0].messageId).toBe('user-msg-1');
expect(response.body.perMessage[0].usd).toBe(0.5);
expect(response.body.perMessage[1].messageId).toBe('user-msg-2');
expect(response.body.perMessage[1].usd).toBe(1.0);
});
it('should handle database errors', async () => {
const { getConvo } = require('~/models');
getConvo.mockResolvedValue({
conversationId,
user: userId,
title: 'Test Conversation',
});
const conversation = new Conversation({
conversationId,
user: userId,
title: 'Test Conversation',
createdAt: new Date('2024-01-01T09:00:00Z'),
});
await conversation.save();
await mongoose.connection.close();
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
expect(response.status).toBe(500);
expect(response.body).toHaveProperty('error');
});
});
});

View File

@@ -0,0 +1,66 @@
const express = require('express');
const passport = require('passport');
const { randomState } = require('openid-client');
const { createSetBalanceConfig } = require('@librechat/api');
const { loginController } = require('~/server/controllers/auth/LoginController');
const { createOAuthHandler } = require('~/server/controllers/auth/oauth');
const { getAppConfig } = require('~/server/services/Config');
const { getOpenIdConfig } = require('~/strategies');
const middleware = require('~/server/middleware');
const { Balance } = require('~/db/models');
const setBalanceConfig = createSetBalanceConfig({
getAppConfig,
Balance,
});
const router = express.Router();
router.post(
'/login/local',
middleware.logHeaders,
middleware.loginLimiter,
middleware.checkBan,
middleware.requireLocalAuth,
middleware.requireAdmin,
setBalanceConfig,
loginController,
);
router.get('/verify', middleware.requireJwtAuth, middleware.requireAdmin, (req, res) => {
const { password: _p, totpSecret: _t, __v, ...user } = req.user;
user.id = user._id.toString();
res.status(200).json({ user });
});
router.get('/oauth/openid/check', (req, res) => {
const openidConfig = getOpenIdConfig();
if (!openidConfig) {
return res.status(404).json({ message: 'OpenID configuration not found' });
}
res.status(200).json({ message: 'OpenID check successful' });
});
router.get('/oauth/openid', (req, res, next) => {
return passport.authenticate('openidAdmin', {
session: false,
state: randomState(),
})(req, res, next);
});
router.get(
'/oauth/openid/callback',
passport.authenticate('openidAdmin', {
failureRedirect: `${process.env.DOMAIN_CLIENT}/oauth/error`,
failureMessage: true,
session: false,
}),
middleware.requireAdmin,
setBalanceConfig,
middleware.checkDomainAllowed,
createOAuthHandler(
(process.env.ADMIN_PANEL_URL || 'http://localhost:3000') + '/auth/openid/callback',
),
);
module.exports = router;

View File

@@ -1,6 +1,7 @@
const accessPermissions = require('./accessPermissions');
const assistants = require('./assistants');
const categories = require('./categories');
const adminAuth = require('./admin/auth');
const tokenizer = require('./tokenizer');
const endpoints = require('./endpoints');
const staticRoute = require('./static');
@@ -32,6 +33,7 @@ module.exports = {
mcp,
edit,
auth,
adminAuth,
keys,
user,
tags,

View File

@@ -1,12 +1,12 @@
const { Router } = require('express');
const { logger } = require('@librechat/data-schemas');
const { CacheKeys, Constants } = require('librechat-data-provider');
const { MCPOAuthHandler, getUserMCPAuthMap } = require('@librechat/api');
const { getMCPManager, getFlowStateManager, getOAuthReconnectionManager } = require('~/config');
const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP');
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
const { updateMCPUserTools } = require('~/server/services/Config/mcpToolsCache');
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
const { CacheKeys, Constants } = require('librechat-data-provider');
const { getMCPManager, getFlowStateManager } = require('~/config');
const { reinitMCPServer } = require('~/server/services/Tools/mcp');
const { requireJwtAuth } = require('~/server/middleware');
const { findPluginAuthsByKeys } = require('~/models');
@@ -144,6 +144,10 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
`[MCP OAuth] Successfully reconnected ${serverName} for user ${flowState.userId}`,
);
// clear any reconnection attempts
const oauthReconnectionManager = getOAuthReconnectionManager();
oauthReconnectionManager.clearReconnection(flowState.userId, serverName);
const tools = await userConnection.fetchTools();
await updateMCPUserTools({
userId: flowState.userId,

View File

@@ -11,7 +11,6 @@ const {
} = require('~/models');
const { findAllArtifacts, replaceArtifactContent } = require('~/server/services/Artifacts/update');
const { requireJwtAuth, validateMessageReq } = require('~/server/middleware');
const { tokenValues, getValueKey, defaultRate } = require('~/models/tx');
const { cleanUpPrimaryKeyValue } = require('~/lib/utils/misc');
const { getConvosQueried } = require('~/models/Conversation');
const { countTokens } = require('~/server/utils');
@@ -161,41 +160,6 @@ router.post('/artifact/:messageId', async (req, res) => {
}
});
/**
* POST /costs
* Get cost information for models in modelHistory array
*/
router.post('/costs', async (req, res) => {
try {
const { modelHistory } = req.body;
if (!Array.isArray(modelHistory)) {
return res.status(400).json({ error: 'modelHistory must be an array' });
}
const modelCostTable = {};
modelHistory.forEach((modelEntry) => {
if (modelEntry && typeof modelEntry === 'object' && modelEntry.model && modelEntry.endpoint) {
const { model, endpoint } = modelEntry;
const valueKey = getValueKey(model, endpoint);
const pricing = tokenValues[valueKey];
modelCostTable[model] = {
prompt: pricing?.prompt ?? defaultRate,
completion: pricing?.completion ?? defaultRate,
};
}
});
res.status(200).json({ modelCostTable });
} catch (error) {
logger.error('Error fetching model costs:', error);
res.status(500).json({ error: 'Internal server error' });
}
});
/* Note: It's necessary to add `validateMessageReq` within route definition for correct params */
router.get('/:conversationId', validateMessageReq, async (req, res) => {
try {

View File

@@ -4,10 +4,9 @@ const passport = require('passport');
const { randomState } = require('openid-client');
const { logger } = require('@librechat/data-schemas');
const { ErrorTypes } = require('librechat-data-provider');
const { isEnabled, createSetBalanceConfig } = require('@librechat/api');
const { checkDomainAllowed, loginLimiter, logHeaders, checkBan } = require('~/server/middleware');
const { syncUserEntraGroupMemberships } = require('~/server/services/PermissionService');
const { setAuthTokens, setOpenIDAuthTokens } = require('~/server/services/AuthService');
const { createSetBalanceConfig } = require('@librechat/api');
const { checkDomainAllowed, loginLimiter, logHeaders } = require('~/server/middleware');
const { createOAuthHandler } = require('~/server/controllers/auth/oauth');
const { getAppConfig } = require('~/server/services/Config');
const { Balance } = require('~/db/models');
@@ -26,32 +25,7 @@ const domains = {
router.use(logHeaders);
router.use(loginLimiter);
const oauthHandler = async (req, res, next) => {
try {
if (res.headersSent) {
return;
}
await checkBan(req, res);
if (req.banned) {
return;
}
if (
req.user &&
req.user.provider == 'openid' &&
isEnabled(process.env.OPENID_REUSE_TOKENS) === true
) {
await syncUserEntraGroupMemberships(req.user, req.user.tokenset.access_token);
setOpenIDAuthTokens(req.user.tokenset, res, req.user._id.toString());
} else {
await setAuthTokens(req.user._id, res);
}
res.redirect(domains.client);
} catch (err) {
logger.error('Error in setting authentication tokens:', err);
next(err);
}
};
const oauthHandler = createOAuthHandler();
router.get('/error', (req, res) => {
/** A single error message is pushed by passport when authentication fails. */

View File

@@ -357,23 +357,18 @@ const resetPassword = async (userId, token, password) => {
/**
* Set Auth Tokens
*
* @param {String | ObjectId} userId
* @param {Object} res
* @param {String} sessionId
* @param {ServerResponse} res
* @param {ISession | null} [session=null]
* @returns
*/
const setAuthTokens = async (userId, res, sessionId = null) => {
const setAuthTokens = async (userId, res, _session = null) => {
try {
const user = await getUserById(userId);
const token = await generateToken(user);
let session;
let session = _session;
let refreshToken;
let refreshTokenExpires;
if (sessionId) {
session = await findSession({ sessionId: sessionId }, { lean: false });
if (session && session._id && session.expiration != null) {
refreshTokenExpires = session.expiration.getTime();
refreshToken = await generateRefreshToken(session);
} else {
@@ -383,6 +378,9 @@ const setAuthTokens = async (userId, res, sessionId = null) => {
refreshTokenExpires = session.expiration.getTime();
}
const user = await getUserById(userId);
const token = await generateToken(user);
res.cookie('refreshToken', refreshToken, {
expires: new Date(refreshTokenExpires),
httpOnly: true,

View File

@@ -20,8 +20,8 @@ const {
ContentTypes,
isAssistantsEndpoint,
} = require('librechat-data-provider');
const { getMCPManager, getFlowStateManager, getOAuthReconnectionManager } = require('~/config');
const { findToken, createToken, updateToken } = require('~/models');
const { getMCPManager, getFlowStateManager } = require('~/config');
const { getCachedTools, getAppConfig } = require('./Config');
const { reinitMCPServer } = require('./Tools/mcp');
const { getLogStores } = require('~/cache');
@@ -538,13 +538,20 @@ async function getServerConnectionStatus(
const baseConnectionState = getConnectionState();
let finalConnectionState = baseConnectionState;
// connection state overrides specific to OAuth servers
if (baseConnectionState === 'disconnected' && oauthServers.has(serverName)) {
const { hasActiveFlow, hasFailedFlow } = await checkOAuthFlowStatus(userId, serverName);
if (hasFailedFlow) {
finalConnectionState = 'error';
} else if (hasActiveFlow) {
// check if server is actively being reconnected
const oauthReconnectionManager = getOAuthReconnectionManager();
if (oauthReconnectionManager.isReconnecting(userId, serverName)) {
finalConnectionState = 'connecting';
} else {
const { hasActiveFlow, hasFailedFlow } = await checkOAuthFlowStatus(userId, serverName);
if (hasFailedFlow) {
finalConnectionState = 'error';
} else if (hasActiveFlow) {
finalConnectionState = 'connecting';
}
}
}

View File

@@ -31,6 +31,7 @@ jest.mock('./Config', () => ({
jest.mock('~/config', () => ({
getMCPManager: jest.fn(),
getFlowStateManager: jest.fn(),
getOAuthReconnectionManager: jest.fn(),
}));
jest.mock('~/cache', () => ({
@@ -48,6 +49,7 @@ describe('tests for the new helper functions used by the MCP connection status e
let mockGetMCPManager;
let mockGetFlowStateManager;
let mockGetLogStores;
let mockGetOAuthReconnectionManager;
beforeEach(() => {
jest.clearAllMocks();
@@ -56,6 +58,7 @@ describe('tests for the new helper functions used by the MCP connection status e
mockGetMCPManager = require('~/config').getMCPManager;
mockGetFlowStateManager = require('~/config').getFlowStateManager;
mockGetLogStores = require('~/cache').getLogStores;
mockGetOAuthReconnectionManager = require('~/config').getOAuthReconnectionManager;
});
describe('getMCPSetupData', () => {
@@ -354,6 +357,12 @@ describe('tests for the new helper functions used by the MCP connection status e
const userConnections = new Map();
const oauthServers = new Set([mockServerName]);
// Mock OAuthReconnectionManager
const mockOAuthReconnectionManager = {
isReconnecting: jest.fn(() => false),
};
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
const result = await getServerConnectionStatus(
mockUserId,
mockServerName,
@@ -370,6 +379,12 @@ describe('tests for the new helper functions used by the MCP connection status e
const userConnections = new Map();
const oauthServers = new Set([mockServerName]);
// Mock OAuthReconnectionManager
const mockOAuthReconnectionManager = {
isReconnecting: jest.fn(() => false),
};
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
// Mock flow state to return failed flow
const mockFlowManager = {
getFlowState: jest.fn(() => ({
@@ -401,6 +416,12 @@ describe('tests for the new helper functions used by the MCP connection status e
const userConnections = new Map();
const oauthServers = new Set([mockServerName]);
// Mock OAuthReconnectionManager
const mockOAuthReconnectionManager = {
isReconnecting: jest.fn(() => false),
};
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
// Mock flow state to return active flow
const mockFlowManager = {
getFlowState: jest.fn(() => ({
@@ -432,6 +453,12 @@ describe('tests for the new helper functions used by the MCP connection status e
const userConnections = new Map();
const oauthServers = new Set([mockServerName]);
// Mock OAuthReconnectionManager
const mockOAuthReconnectionManager = {
isReconnecting: jest.fn(() => false),
};
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
// Mock flow state to return no flow
const mockFlowManager = {
getFlowState: jest.fn(() => null),
@@ -454,6 +481,35 @@ describe('tests for the new helper functions used by the MCP connection status e
});
});
it('should return connecting state when OAuth server is reconnecting', async () => {
const appConnections = new Map();
const userConnections = new Map();
const oauthServers = new Set([mockServerName]);
// Mock OAuthReconnectionManager to return true for isReconnecting
const mockOAuthReconnectionManager = {
isReconnecting: jest.fn(() => true),
};
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
const result = await getServerConnectionStatus(
mockUserId,
mockServerName,
appConnections,
userConnections,
oauthServers,
);
expect(result).toEqual({
requiresOAuth: true,
connectionState: 'connecting',
});
expect(mockOAuthReconnectionManager.isReconnecting).toHaveBeenCalledWith(
mockUserId,
mockServerName,
);
});
it('should not check OAuth flow status when server is connected', async () => {
const mockFlowManager = {
getFlowState: jest.fn(),

View File

@@ -313,7 +313,7 @@ const ensurePrincipalExists = async function (principal) {
idOnTheSource: principal.idOnTheSource,
};
const userId = await createUser(userData, true, false);
const userId = await createUser(userData, true, true);
return userId.toString();
}

View File

@@ -88,7 +88,6 @@ async function saveUserMessage(req, params) {
parentMessageId: params.parentMessageId ?? Constants.NO_PARENT,
/* For messages, use the assistant_id instead of model */
model: params.assistant_id,
targetModel: params.model,
thread_id: params.thread_id,
sender: 'User',
text: params.text,

View File

@@ -0,0 +1,26 @@
const { logger } = require('@librechat/data-schemas');
const { CacheKeys } = require('librechat-data-provider');
const { createOAuthReconnectionManager, getFlowStateManager } = require('~/config');
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
const { getLogStores } = require('~/cache');
/**
* Initialize OAuth reconnect manager
*/
async function initializeOAuthReconnectManager() {
try {
const flowManager = getFlowStateManager(getLogStores(CacheKeys.FLOWS));
const tokenMethods = {
findToken,
updateToken,
createToken,
deleteTokens,
};
await createOAuthReconnectionManager(flowManager, tokenMethods);
logger.info(`OAuth reconnect manager initialized successfully.`);
} catch (error) {
logger.error('Failed to initialize OAuth reconnect manager:', error);
}
}
module.exports = initializeOAuthReconnectManager;

View File

@@ -1,14 +1,14 @@
const appleLogin = require('./appleStrategy');
const { setupOpenId, getOpenIdConfig } = require('./openidStrategy');
const openIdJwtLogin = require('./openIdJwtStrategy');
const facebookLogin = require('./facebookStrategy');
const discordLogin = require('./discordStrategy');
const passportLogin = require('./localStrategy');
const googleLogin = require('./googleStrategy');
const githubLogin = require('./githubStrategy');
const discordLogin = require('./discordStrategy');
const facebookLogin = require('./facebookStrategy');
const { setupOpenId, getOpenIdConfig } = require('./openidStrategy');
const jwtLogin = require('./jwtStrategy');
const ldapLogin = require('./ldapStrategy');
const { setupSaml } = require('./samlStrategy');
const openIdJwtLogin = require('./openIdJwtStrategy');
const appleLogin = require('./appleStrategy');
const ldapLogin = require('./ldapStrategy');
const jwtLogin = require('./jwtStrategy');
module.exports = {
appleLogin,

View File

@@ -281,6 +281,221 @@ function convertToUsername(input, defaultValue = '') {
return defaultValue;
}
/**
* Process OpenID authentication tokenset and userinfo
* This is the core logic extracted from the passport strategy callback
* Can be reused by both the passport strategy and proxy authentication
*
* @param {Object} tokenset - The OpenID tokenset containing access_token, id_token, etc.
* @param {boolean} existingUsersOnly - If true, only existing users will be processed
* @returns {Promise<Object>} The authenticated user object with tokenset
*/
async function processOpenIDAuth(tokenset, existingUsersOnly = false) {
const claims = tokenset.claims ? tokenset.claims() : tokenset;
const userinfo = {
...claims,
};
// Get userinfo from provider if we have access_token and haven't already
if (tokenset.access_token) {
const providerUserinfo = await getUserInfo(openidConfig, tokenset.access_token, claims.sub);
Object.assign(userinfo, providerUserinfo);
}
const appConfig = await getAppConfig();
if (!isEmailDomainAllowed(userinfo.email, appConfig?.registration?.allowedDomains)) {
logger.error(
`[OpenID Auth] Authentication blocked - email domain not allowed [Email: ${userinfo.email}]`,
);
throw new Error('Email domain not allowed');
}
const result = await findOpenIDUser({
openidId: claims.sub || userinfo.sub,
email: claims.email || userinfo.email,
strategyName: 'openidStrategy',
findUser,
});
let user = result.user;
const error = result.error;
if (error) {
throw new Error(ErrorTypes.AUTH_FAILED);
}
const fullName = getFullName(userinfo);
/** Required role if configured */
const requiredRole = process.env.OPENID_REQUIRED_ROLE;
if (requiredRole) {
const requiredRoleParameterPath = process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH;
const requiredRoleTokenKind = process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND;
let decodedToken = '';
if (requiredRoleTokenKind === 'access' && tokenset.access_token) {
decodedToken = jwtDecode(tokenset.access_token);
} else if (requiredRoleTokenKind === 'id' && tokenset.id_token) {
decodedToken = jwtDecode(tokenset.id_token);
} else if (userinfo.roles) {
// If roles are already in userinfo, use them directly
const roles = Array.isArray(userinfo.roles) ? userinfo.roles : [userinfo.roles];
if (!roles.includes(requiredRole)) {
throw new Error(`You must have the "${requiredRole}" role to log in.`);
}
} else if (requiredRoleParameterPath) {
const pathParts = requiredRoleParameterPath.split('.');
let found = true;
let roles = pathParts.reduce((o, key) => {
if (o === null || o === undefined || !(key in o)) {
found = false;
return [];
}
return o[key];
}, decodedToken);
if (!found) {
logger.error(
`[OpenID Auth] Key '${requiredRoleParameterPath}' not found in ${requiredRoleTokenKind} token!`,
);
}
if (!roles.includes(requiredRole)) {
throw new Error(`You must have the "${requiredRole}" role to log in.`);
}
}
}
let username = '';
if (process.env.OPENID_USERNAME_CLAIM) {
username = userinfo[process.env.OPENID_USERNAME_CLAIM];
} else {
username = convertToUsername(
userinfo.preferred_username || userinfo.username || userinfo.email,
);
}
if (existingUsersOnly && !user) {
throw new Error('User does not exist');
}
if (!user) {
user = {
provider: 'openid',
openidId: userinfo.sub,
username,
email: userinfo.email || '',
emailVerified: userinfo.email_verified || false,
name: fullName,
idOnTheSource: userinfo.oid,
};
const balanceConfig = getBalanceConfig(appConfig);
user = await createUser(user, balanceConfig, true, true);
} else {
user.provider = 'openid';
user.openidId = userinfo.sub;
user.username = username;
user.name = fullName;
user.idOnTheSource = userinfo.oid;
}
// Handle avatar
if (!!userinfo && userinfo.picture && !user.avatar?.includes('manual=true')) {
const imageUrl = userinfo.picture;
let fileName;
if (crypto) {
fileName = (await hashToken(userinfo.sub)) + '.png';
} else {
fileName = userinfo.sub + '.png';
}
const imageBuffer = await downloadImage(
imageUrl,
openidConfig,
tokenset.access_token,
userinfo.sub,
);
if (imageBuffer) {
const { saveBuffer } = getStrategyFunctions(
appConfig?.fileStrategy ?? process.env.CDN_PROVIDER,
);
const imagePath = await saveBuffer({
fileName,
userId: user._id.toString(),
buffer: imageBuffer,
});
user.avatar = imagePath ?? '';
}
}
user = await updateUser(user._id, user);
logger.info(
`[OpenID Auth] login success openidId: ${user.openidId} | email: ${user.email} | username: ${user.username}`,
{
user: {
openidId: user.openidId,
username: user.username,
email: user.email,
name: user.name,
},
},
);
return { ...user, tokenset };
}
/**
* @param {boolean | undefined} [existingUsersOnly]
*/
function createOpenIDCallback(existingUsersOnly) {
return async (tokenset, done) => {
try {
const user = await processOpenIDAuth(tokenset, existingUsersOnly);
done(null, user);
} catch (err) {
if (err.message === 'Email domain not allowed') {
return done(null, false, { message: err.message });
}
if (err.message === ErrorTypes.AUTH_FAILED) {
return done(null, false, { message: err.message });
}
if (err.message && err.message.includes('role to log in')) {
return done(null, false, { message: err.message });
}
logger.error('[openidStrategy] login failed', err);
done(err);
}
};
}
/**
* Sets up the OpenID strategy specifically for admin authentication.
* @param {Configuration} openidConfig
*/
const setupOpenIdAdmin = (openidConfig) => {
try {
if (!openidConfig) {
throw new Error('OpenID configuration not initialized');
}
const openidAdminLogin = new CustomOpenIDStrategy(
{
config: openidConfig,
scope: process.env.OPENID_SCOPE,
usePKCE: isEnabled(process.env.OPENID_USE_PKCE),
clockTolerance: process.env.OPENID_CLOCK_TOLERANCE || 300,
callbackURL: process.env.DOMAIN_SERVER + '/api/admin/oauth/openid/callback',
},
createOpenIDCallback(true),
);
passport.use('openidAdmin', openidAdminLogin);
} catch (err) {
logger.error('[openidStrategy] setupOpenIdAdmin', err);
}
};
/**
* Sets up the OpenID strategy for authentication.
* This function configures the OpenID client, handles proxy settings,
@@ -318,10 +533,6 @@ async function setupOpenId() {
},
);
const requiredRole = process.env.OPENID_REQUIRED_ROLE;
const requiredRoleParameterPath = process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH;
const requiredRoleTokenKind = process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND;
const usePKCE = isEnabled(process.env.OPENID_USE_PKCE);
logger.info(`[openidStrategy] OpenID authentication configuration`, {
generateNonce: shouldGenerateNonce,
reason: shouldGenerateNonce
@@ -335,159 +546,19 @@ async function setupOpenId() {
scope: process.env.OPENID_SCOPE,
callbackURL: process.env.DOMAIN_SERVER + process.env.OPENID_CALLBACK_URL,
clockTolerance: process.env.OPENID_CLOCK_TOLERANCE || 300,
usePKCE,
},
async (tokenset, done) => {
try {
const claims = tokenset.claims();
const userinfo = {
...claims,
...(await getUserInfo(openidConfig, tokenset.access_token, claims.sub)),
};
const appConfig = await getAppConfig();
if (!isEmailDomainAllowed(userinfo.email, appConfig?.registration?.allowedDomains)) {
logger.error(
`[OpenID Strategy] Authentication blocked - email domain not allowed [Email: ${userinfo.email}]`,
);
return done(null, false, { message: 'Email domain not allowed' });
}
const result = await findOpenIDUser({
openidId: claims.sub,
email: claims.email,
strategyName: 'openidStrategy',
findUser,
});
let user = result.user;
const error = result.error;
if (error) {
return done(null, false, {
message: ErrorTypes.AUTH_FAILED,
});
}
const fullName = getFullName(userinfo);
if (requiredRole) {
let decodedToken = '';
if (requiredRoleTokenKind === 'access') {
decodedToken = jwtDecode(tokenset.access_token);
} else if (requiredRoleTokenKind === 'id') {
decodedToken = jwtDecode(tokenset.id_token);
}
const pathParts = requiredRoleParameterPath.split('.');
let found = true;
let roles = pathParts.reduce((o, key) => {
if (o === null || o === undefined || !(key in o)) {
found = false;
return [];
}
return o[key];
}, decodedToken);
if (!found) {
logger.error(
`[openidStrategy] Key '${requiredRoleParameterPath}' not found in ${requiredRoleTokenKind} token!`,
);
}
if (!roles.includes(requiredRole)) {
return done(null, false, {
message: `You must have the "${requiredRole}" role to log in.`,
});
}
}
let username = '';
if (process.env.OPENID_USERNAME_CLAIM) {
username = userinfo[process.env.OPENID_USERNAME_CLAIM];
} else {
username = convertToUsername(
userinfo.preferred_username || userinfo.username || userinfo.email,
);
}
if (!user) {
user = {
provider: 'openid',
openidId: userinfo.sub,
username,
email: userinfo.email || '',
emailVerified: userinfo.email_verified || false,
name: fullName,
idOnTheSource: userinfo.oid,
};
const balanceConfig = getBalanceConfig(appConfig);
user = await createUser(user, balanceConfig, true, true);
} else {
user.provider = 'openid';
user.openidId = userinfo.sub;
user.username = username;
user.name = fullName;
user.idOnTheSource = userinfo.oid;
}
if (!!userinfo && userinfo.picture && !user.avatar?.includes('manual=true')) {
/** @type {string | undefined} */
const imageUrl = userinfo.picture;
let fileName;
if (crypto) {
fileName = (await hashToken(userinfo.sub)) + '.png';
} else {
fileName = userinfo.sub + '.png';
}
const imageBuffer = await downloadImage(
imageUrl,
openidConfig,
tokenset.access_token,
userinfo.sub,
);
if (imageBuffer) {
const { saveBuffer } = getStrategyFunctions(
appConfig?.fileStrategy ?? process.env.CDN_PROVIDER,
);
const imagePath = await saveBuffer({
fileName,
userId: user._id.toString(),
buffer: imageBuffer,
});
user.avatar = imagePath ?? '';
}
}
user = await updateUser(user._id, user);
logger.info(
`[openidStrategy] login success openidId: ${user.openidId} | email: ${user.email} | username: ${user.username} `,
{
user: {
openidId: user.openidId,
username: user.username,
email: user.email,
name: user.name,
},
},
);
done(null, { ...user, tokenset });
} catch (err) {
logger.error('[openidStrategy] login failed', err);
done(err);
}
usePKCE: isEnabled(process.env.OPENID_USE_PKCE),
},
createOpenIDCallback(),
);
passport.use('openid', openidLogin);
setupOpenIdAdmin(openidConfig);
return openidConfig;
} catch (err) {
logger.error('[openidStrategy]', err);
return null;
}
}
/**
* @function getOpenIdConfig
* @description Returns the OpenID client instance.

View File

@@ -873,6 +873,13 @@
* @typedef {import('@librechat/data-schemas').IMongoFile} MongoFile
* @memberof typedefs
*/
/**
* @exports ISession
* @typedef {import('@librechat/data-schemas').ISession} ISession
* @memberof typedefs
*/
/**
* @exports IBalance
* @typedef {import('@librechat/data-schemas').IBalance} IBalance

View File

@@ -1,10 +1,15 @@
import { createContext, useContext } from 'react';
type MessageContext = {
messageId: string;
nextType?: string;
partIndex?: number;
isExpanded: boolean;
conversationId?: string | null;
/** Submission state for cursor display - only true for latest message when submitting */
isSubmitting?: boolean;
/** Whether this is the latest message in the conversation */
isLatestMessage?: boolean;
};
export const MessageContext = createContext<MessageContext>({} as MessageContext);

View File

@@ -0,0 +1,150 @@
import React, { createContext, useContext, useMemo } from 'react';
import { useAddedChatContext } from './AddedChatContext';
import { useChatContext } from './ChatContext';
interface MessagesViewContextValue {
/** Core conversation data */
conversation: ReturnType<typeof useChatContext>['conversation'];
conversationId: string | null | undefined;
/** Submission and control states */
isSubmitting: ReturnType<typeof useChatContext>['isSubmitting'];
isSubmittingFamily: boolean;
abortScroll: ReturnType<typeof useChatContext>['abortScroll'];
setAbortScroll: ReturnType<typeof useChatContext>['setAbortScroll'];
/** Message operations */
ask: ReturnType<typeof useChatContext>['ask'];
regenerate: ReturnType<typeof useChatContext>['regenerate'];
handleContinue: ReturnType<typeof useChatContext>['handleContinue'];
/** Message state management */
index: ReturnType<typeof useChatContext>['index'];
latestMessage: ReturnType<typeof useChatContext>['latestMessage'];
setLatestMessage: ReturnType<typeof useChatContext>['setLatestMessage'];
getMessages: ReturnType<typeof useChatContext>['getMessages'];
setMessages: ReturnType<typeof useChatContext>['setMessages'];
}
const MessagesViewContext = createContext<MessagesViewContextValue | undefined>(undefined);
export function MessagesViewProvider({ children }: { children: React.ReactNode }) {
const chatContext = useChatContext();
const addedChatContext = useAddedChatContext();
const {
ask,
index,
regenerate,
isSubmitting: isSubmittingRoot,
conversation,
latestMessage,
setAbortScroll,
handleContinue,
setLatestMessage,
abortScroll,
getMessages,
setMessages,
} = chatContext;
const { isSubmitting: isSubmittingAdditional } = addedChatContext;
/** Memoize conversation-related values */
const conversationValues = useMemo(
() => ({
conversation,
conversationId: conversation?.conversationId,
}),
[conversation],
);
/** Memoize submission states */
const submissionStates = useMemo(
() => ({
isSubmitting: isSubmittingRoot,
isSubmittingFamily: isSubmittingRoot || isSubmittingAdditional,
abortScroll,
setAbortScroll,
}),
[isSubmittingRoot, isSubmittingAdditional, abortScroll, setAbortScroll],
);
/** Memoize message operations (these are typically stable references) */
const messageOperations = useMemo(
() => ({
ask,
regenerate,
getMessages,
setMessages,
handleContinue,
}),
[ask, regenerate, handleContinue, getMessages, setMessages],
);
/** Memoize message state values */
const messageState = useMemo(
() => ({
index,
latestMessage,
setLatestMessage,
}),
[index, latestMessage, setLatestMessage],
);
/** Combine all values into final context value */
const contextValue = useMemo<MessagesViewContextValue>(
() => ({
...conversationValues,
...submissionStates,
...messageOperations,
...messageState,
}),
[conversationValues, submissionStates, messageOperations, messageState],
);
return (
<MessagesViewContext.Provider value={contextValue}>{children}</MessagesViewContext.Provider>
);
}
export function useMessagesViewContext() {
const context = useContext(MessagesViewContext);
if (!context) {
throw new Error('useMessagesViewContext must be used within MessagesViewProvider');
}
return context;
}
/** Hook for components that only need conversation data */
export function useMessagesConversation() {
const { conversation, conversationId } = useMessagesViewContext();
return useMemo(() => ({ conversation, conversationId }), [conversation, conversationId]);
}
/** Hook for components that only need submission states */
export function useMessagesSubmission() {
const { isSubmitting, isSubmittingFamily, abortScroll, setAbortScroll } =
useMessagesViewContext();
return useMemo(
() => ({ isSubmitting, isSubmittingFamily, abortScroll, setAbortScroll }),
[isSubmitting, isSubmittingFamily, abortScroll, setAbortScroll],
);
}
/** Hook for components that only need message operations */
export function useMessagesOperations() {
const { ask, regenerate, handleContinue, getMessages, setMessages } = useMessagesViewContext();
return useMemo(
() => ({ ask, regenerate, handleContinue, getMessages, setMessages }),
[ask, regenerate, handleContinue, getMessages, setMessages],
);
}
/** Hook for components that only need message state */
export function useMessagesState() {
const { index, latestMessage, setLatestMessage } = useMessagesViewContext();
return useMemo(
() => ({ index, latestMessage, setLatestMessage }),
[index, latestMessage, setLatestMessage],
);
}

View File

@@ -26,4 +26,5 @@ export * from './SidePanelContext';
export * from './MCPPanelContext';
export * from './ArtifactsContext';
export * from './PromptGroupsContext';
export * from './MessagesViewContext';
export { default as BadgeRowProvider } from './BadgeRowContext';

View File

@@ -1,4 +1,4 @@
import { memo, useCallback, useState, useEffect, useRef } from 'react';
import { memo, useCallback } from 'react';
import { useRecoilValue } from 'recoil';
import { useForm } from 'react-hook-form';
import { Spinner } from '@librechat/client';
@@ -13,7 +13,6 @@ import { useGetMessagesByConvoId } from '~/data-provider';
import MessagesView from './Messages/MessagesView';
import Presentation from './Presentation';
import ChatForm from './Input/ChatForm';
import CostBar from './CostBar';
import Landing from './Landing';
import Header from './Header';
import Footer from './Footer';
@@ -30,13 +29,7 @@ function LoadingSpinner() {
);
}
function ChatView({
index = 0,
modelCosts,
}: {
index?: number;
modelCosts?: { modelCostTable: Record<string, { prompt: number; completion: number }> };
}) {
function ChatView({ index = 0 }: { index?: number }) {
const { conversationId } = useParams();
const rootSubmission = useRecoilValue(store.submissionByIndex(index));
const addedSubmission = useRecoilValue(store.submissionByIndex(index + 1));
@@ -44,9 +37,6 @@ function ChatView({
const fileMap = useFileMapContext();
const [showCostBar, setShowCostBar] = useState(false);
const lastScrollY = useRef(0);
const { data: messagesTree = null, isLoading } = useGetMessagesByConvoId(conversationId ?? '', {
select: useCallback(
(data: TMessage[]) => {
@@ -64,58 +54,6 @@ function ChatView({
useSSE(rootSubmission, chatHelpers, false);
useSSE(addedSubmission, addedChatHelpers, true);
const checkIfAtBottom = useCallback(
(container: HTMLElement) => {
const currentScrollY = container.scrollTop;
const scrollHeight = container.scrollHeight;
const clientHeight = container.clientHeight;
const distanceFromBottom = scrollHeight - currentScrollY - clientHeight;
const isAtBottom = distanceFromBottom < 10;
const isStreaming = chatHelpers.isSubmitting || addedChatHelpers.isSubmitting;
setShowCostBar(isAtBottom && !isStreaming);
lastScrollY.current = currentScrollY;
},
[chatHelpers.isSubmitting, addedChatHelpers.isSubmitting],
);
useEffect(() => {
const handleScroll = (event: Event) => {
const target = event.target as HTMLElement;
checkIfAtBottom(target);
};
const findAndAttachScrollListener = () => {
const messagesContainer = document.querySelector('[class*="scrollbar-gutter-stable"]');
if (messagesContainer) {
checkIfAtBottom(messagesContainer as HTMLElement);
messagesContainer.addEventListener('scroll', handleScroll, { passive: true });
return () => {
messagesContainer.removeEventListener('scroll', handleScroll);
};
}
setTimeout(findAndAttachScrollListener, 100);
};
const cleanup = findAndAttachScrollListener();
return cleanup;
}, [messagesTree, checkIfAtBottom]);
useEffect(() => {
const isStreaming = chatHelpers.isSubmitting || addedChatHelpers.isSubmitting;
if (isStreaming) {
setShowCostBar(false);
} else {
const messagesContainer = document.querySelector('[class*="scrollbar-gutter-stable"]');
if (messagesContainer) {
checkIfAtBottom(messagesContainer as HTMLElement);
}
}
}, [chatHelpers.isSubmitting, addedChatHelpers.isSubmitting, checkIfAtBottom]);
const methods = useForm<ChatFormValues>({
defaultValues: { text: '' },
});
@@ -131,22 +69,7 @@ function ChatView({
} else if ((isLoading || isNavigating) && !isLandingPage) {
content = <LoadingSpinner />;
} else if (!isLandingPage) {
const isStreaming = chatHelpers.isSubmitting || addedChatHelpers.isSubmitting;
content = (
<MessagesView
messagesTree={messagesTree}
costBar={
!isLandingPage &&
modelCosts && (
<CostBar
messagesTree={messagesTree}
modelCosts={modelCosts}
showCostBar={showCostBar && !isStreaming}
/>
)
}
/>
);
content = <MessagesView messagesTree={messagesTree} />;
} else {
content = <Landing centerFormOnLanding={centerFormOnLanding} />;
}

View File

@@ -1,112 +0,0 @@
import { useMemo } from 'react';
import { useRecoilValue } from 'recoil';
import { ArrowIcon } from '@librechat/client';
import { TModelCosts, TMessage } from 'librechat-data-provider';
import { useLocalize } from '~/hooks';
import { cn } from '~/utils';
import store from '~/store';
interface CostBarProps {
messagesTree: TMessage[];
modelCosts: TModelCosts;
showCostBar: boolean;
}
export default function CostBar({ messagesTree, modelCosts, showCostBar }: CostBarProps) {
const localize = useLocalize();
const showCostTracking = useRecoilValue(store.showCostTracking);
const conversationCosts = useMemo(() => {
if (!modelCosts?.modelCostTable || !messagesTree) {
return null;
}
let totalPromptTokens = 0;
let totalCompletionTokens = 0;
let totalPromptUSD = 0;
let totalCompletionUSD = 0;
const flattenMessages = (messages: TMessage[]) => {
const flattened: TMessage[] = [];
messages.forEach((message: TMessage) => {
flattened.push(message);
if (message.children && message.children.length > 0) {
flattened.push(...flattenMessages(message.children));
}
});
return flattened;
};
const allMessages = flattenMessages(messagesTree);
allMessages.forEach((message) => {
if (!message.tokenCount) {
return null;
}
const modelToUse = message.isCreatedByUser ? message.targetModel : message.model;
const modelPricing = modelCosts.modelCostTable[modelToUse];
if (message.isCreatedByUser) {
totalPromptTokens += message.tokenCount;
totalPromptUSD += (message.tokenCount / 1000000) * modelPricing.prompt;
} else {
totalCompletionTokens += message.tokenCount;
totalCompletionUSD += (message.tokenCount / 1000000) * modelPricing.completion;
}
});
const totalTokens = totalPromptTokens + totalCompletionTokens;
const totalUSD = totalPromptUSD + totalCompletionUSD;
return {
totals: {
prompt: { tokenCount: totalPromptTokens, usd: totalPromptUSD },
completion: { tokenCount: totalCompletionTokens, usd: totalCompletionUSD },
total: { tokenCount: totalTokens, usd: totalUSD },
},
};
}, [modelCosts, messagesTree]);
if (!showCostTracking || !conversationCosts || !conversationCosts.totals) {
return null;
}
return (
<div
className={cn(
'mx-auto w-full max-w-md px-4 text-xs text-muted-foreground transition-all duration-300 ease-in-out',
showCostBar ? 'opacity-100' : 'opacity-0',
)}
>
<div className="grid grid-cols-3 gap-2 text-center">
<div>
<div>
<ArrowIcon direction="up" />
{localize('com_ui_token_abbreviation', {
0: conversationCosts.totals.prompt.tokenCount,
})}
</div>
<div>${Math.abs(conversationCosts.totals.prompt.usd).toFixed(6)}</div>
</div>
<div>
<div>
{localize('com_ui_token_abbreviation', {
0: conversationCosts.totals.total.tokenCount,
})}
</div>
<div>${Math.abs(conversationCosts.totals.total.usd).toFixed(6)}</div>
</div>
<div>
<div>
<ArrowIcon direction="down" />
{localize('com_ui_token_abbreviation', {
0: conversationCosts.totals.completion.tokenCount,
})}
</div>
<div>${Math.abs(conversationCosts.totals.completion.usd).toFixed(6)}</div>
</div>
</div>
</div>
);
}

View File

@@ -26,6 +26,7 @@ type ContentPartsProps = {
isCreatedByUser: boolean;
isLast: boolean;
isSubmitting: boolean;
isLatestMessage?: boolean;
edit?: boolean;
enterEdit?: (cancel?: boolean) => void | null | undefined;
siblingIdx?: number;
@@ -45,6 +46,7 @@ const ContentParts = memo(
isCreatedByUser,
isLast,
isSubmitting,
isLatestMessage,
edit,
enterEdit,
siblingIdx,
@@ -55,6 +57,8 @@ const ContentParts = memo(
const [isExpanded, setIsExpanded] = useState(showThinking);
const attachmentMap = useMemo(() => mapAttachments(attachments ?? []), [attachments]);
const effectiveIsSubmitting = isLatestMessage ? isSubmitting : false;
const hasReasoningParts = useMemo(() => {
const hasThinkPart = content?.some((part) => part?.type === ContentTypes.THINK) ?? false;
const allThinkPartsHaveContent =
@@ -134,7 +138,9 @@ const ContentParts = memo(
})
}
label={
isSubmitting && isLast ? localize('com_ui_thinking') : localize('com_ui_thoughts')
effectiveIsSubmitting && isLast
? localize('com_ui_thinking')
: localize('com_ui_thoughts')
}
/>
</div>
@@ -155,12 +161,14 @@ const ContentParts = memo(
conversationId,
partIndex: idx,
nextType: content[idx + 1]?.type,
isSubmitting: effectiveIsSubmitting,
isLatestMessage,
}}
>
<Part
part={part}
attachments={attachments}
isSubmitting={isSubmitting}
isSubmitting={effectiveIsSubmitting}
key={`part-${messageId}-${idx}`}
isCreatedByUser={isCreatedByUser}
isLast={idx === content.length - 1}

View File

@@ -4,7 +4,7 @@ import { useRecoilState, useRecoilValue } from 'recoil';
import { TextareaAutosize, TooltipAnchor } from '@librechat/client';
import { useUpdateMessageMutation } from 'librechat-data-provider/react-query';
import type { TEditProps } from '~/common';
import { useChatContext, useAddedChatContext } from '~/Providers';
import { useMessagesOperations, useMessagesConversation, useAddedChatContext } from '~/Providers';
import { cn, removeFocusRings } from '~/utils';
import { useLocalize } from '~/hooks';
import Container from './Container';
@@ -22,7 +22,8 @@ const EditMessage = ({
const { addedIndex } = useAddedChatContext();
const saveButtonRef = useRef<HTMLButtonElement | null>(null);
const submitButtonRef = useRef<HTMLButtonElement | null>(null);
const { getMessages, setMessages, conversation } = useChatContext();
const { conversation } = useMessagesConversation();
const { getMessages, setMessages } = useMessagesOperations();
const [latestMultiMessage, setLatestMultiMessage] = useRecoilState(
store.latestMessageFamily(addedIndex),
);

View File

@@ -5,7 +5,7 @@ import type { TMessage } from 'librechat-data-provider';
import type { TMessageContentProps, TDisplayProps } from '~/common';
import Error from '~/components/Messages/Content/Error';
import Thinking from '~/components/Artifacts/Thinking';
import { useChatContext } from '~/Providers';
import { useMessageContext } from '~/Providers';
import MarkdownLite from './MarkdownLite';
import EditMessage from './EditMessage';
import { useLocalize } from '~/hooks';
@@ -70,16 +70,12 @@ export const ErrorMessage = ({
};
const DisplayMessage = ({ text, isCreatedByUser, message, showCursor }: TDisplayProps) => {
const { isSubmitting, latestMessage } = useChatContext();
const { isSubmitting = false, isLatestMessage = false } = useMessageContext();
const enableUserMsgMarkdown = useRecoilValue(store.enableUserMsgMarkdown);
const showCursorState = useMemo(
() => showCursor === true && isSubmitting,
[showCursor, isSubmitting],
);
const isLatestMessage = useMemo(
() => message.messageId === latestMessage?.messageId,
[message.messageId, latestMessage?.messageId],
);
let content: React.ReactElement;
if (!isCreatedByUser) {

View File

@@ -85,13 +85,14 @@ const Part = memo(
const isToolCall =
'args' in toolCall && (!toolCall.type || toolCall.type === ToolCallTypes.TOOL_CALL);
if (isToolCall && toolCall.name === Tools.execute_code && toolCall.args) {
if (isToolCall && toolCall.name === Tools.execute_code) {
return (
<ExecuteCode
args={typeof toolCall.args === 'string' ? toolCall.args : ''}
attachments={attachments}
isSubmitting={isSubmitting}
output={toolCall.output ?? ''}
initialProgress={toolCall.progress ?? 0.1}
attachments={attachments}
args={typeof toolCall.args === 'string' ? toolCall.args : ''}
/>
);
} else if (

View File

@@ -6,8 +6,8 @@ import { useRecoilState, useRecoilValue } from 'recoil';
import { useUpdateMessageContentMutation } from 'librechat-data-provider/react-query';
import type { Agents } from 'librechat-data-provider';
import type { TEditProps } from '~/common';
import { useMessagesOperations, useMessagesConversation, useAddedChatContext } from '~/Providers';
import Container from '~/components/Chat/Messages/Content/Container';
import { useChatContext, useAddedChatContext } from '~/Providers';
import { cn, removeFocusRings } from '~/utils';
import { useLocalize } from '~/hooks';
import store from '~/store';
@@ -25,7 +25,8 @@ const EditTextPart = ({
}) => {
const localize = useLocalize();
const { addedIndex } = useAddedChatContext();
const { ask, getMessages, setMessages, conversation } = useChatContext();
const { conversation } = useMessagesConversation();
const { ask, getMessages, setMessages } = useMessagesOperations();
const [latestMultiMessage, setLatestMultiMessage] = useRecoilState(
store.latestMessageFamily(addedIndex),
);

View File

@@ -45,26 +45,28 @@ export function useParseArgs(args?: string): ParsedArgs | null {
}
export default function ExecuteCode({
isSubmitting,
initialProgress = 0.1,
args,
output = '',
attachments,
}: {
initialProgress: number;
isSubmitting: boolean;
args?: string;
output?: string;
attachments?: TAttachment[];
}) {
const localize = useLocalize();
const showAnalysisCode = useRecoilValue(store.showCode);
const [showCode, setShowCode] = useState(showAnalysisCode);
const codeContentRef = useRef<HTMLDivElement>(null);
const [contentHeight, setContentHeight] = useState<number | undefined>(0);
const [isAnimating, setIsAnimating] = useState(false);
const hasOutput = output.length > 0;
const outputRef = useRef<string>(output);
const prevShowCodeRef = useRef<boolean>(showCode);
const codeContentRef = useRef<HTMLDivElement>(null);
const [isAnimating, setIsAnimating] = useState(false);
const showAnalysisCode = useRecoilValue(store.showCode);
const [showCode, setShowCode] = useState(showAnalysisCode);
const [contentHeight, setContentHeight] = useState<number | undefined>(0);
const prevShowCodeRef = useRef<boolean>(showCode);
const { lang, code } = useParseArgs(args) ?? ({} as ParsedArgs);
const progress = useProgress(initialProgress);
@@ -136,6 +138,8 @@ export default function ExecuteCode({
};
}, [showCode, isAnimating]);
const cancelled = !isSubmitting && progress < 1;
return (
<>
<div className="relative my-2.5 flex size-5 shrink-0 items-center gap-2.5">
@@ -143,9 +147,12 @@ export default function ExecuteCode({
progress={progress}
onClick={() => setShowCode((prev) => !prev)}
inProgressText={localize('com_ui_analyzing')}
finishedText={localize('com_ui_analyzing_finished')}
finishedText={
cancelled ? localize('com_ui_cancelled') : localize('com_ui_analyzing_finished')
}
hasInput={!!code?.length}
isExpanded={showCode}
error={cancelled}
/>
</div>
<div

View File

@@ -2,7 +2,7 @@ import { memo, useMemo, ReactElement } from 'react';
import { useRecoilValue } from 'recoil';
import MarkdownLite from '~/components/Chat/Messages/Content/MarkdownLite';
import Markdown from '~/components/Chat/Messages/Content/Markdown';
import { useChatContext, useMessageContext } from '~/Providers';
import { useMessageContext } from '~/Providers';
import { cn } from '~/utils';
import store from '~/store';
@@ -18,14 +18,9 @@ type ContentType =
| ReactElement;
const TextPart = memo(({ text, isCreatedByUser, showCursor }: TextPartProps) => {
const { messageId } = useMessageContext();
const { isSubmitting, latestMessage } = useChatContext();
const { isSubmitting = false, isLatestMessage = false } = useMessageContext();
const enableUserMsgMarkdown = useRecoilValue(store.enableUserMsgMarkdown);
const showCursorState = useMemo(() => showCursor && isSubmitting, [showCursor, isSubmitting]);
const isLatestMessage = useMemo(
() => messageId === latestMessage?.messageId,
[messageId, latestMessage?.messageId],
);
const content: ContentType = useMemo(() => {
if (!isCreatedByUser) {

View File

@@ -21,7 +21,7 @@ type THoverButtons = {
latestMessage: TMessage | null;
isLast: boolean;
index: number;
handleFeedback: ({ feedback }: { feedback: TFeedback | undefined }) => void;
handleFeedback?: ({ feedback }: { feedback: TFeedback | undefined }) => void;
};
type HoverButtonProps = {
@@ -238,7 +238,7 @@ const HoverButtons = ({
/>
{/* Feedback Buttons */}
{!isCreatedByUser && (
{!isCreatedByUser && handleFeedback != null && (
<Feedback handleFeedback={handleFeedback} feedback={message.feedback} isLast={isLast} />
)}

View File

@@ -1,7 +1,6 @@
import React from 'react';
import { useRecoilValue } from 'recoil';
import { useMessageProcess } from '~/hooks';
import type { TConversationCosts } from 'librechat-data-provider';
import type { TMessageProps } from '~/common';
import MessageRender from './ui/MessageRender';
// eslint-disable-next-line import/no-cycle
@@ -29,7 +28,7 @@ const MessageContainer = React.memo(
},
);
export default function Message(props: TMessageProps & { costs?: TConversationCosts }) {
export default function Message(props: TMessageProps) {
const {
showSibling,
conversation,
@@ -38,7 +37,7 @@ export default function Message(props: TMessageProps & { costs?: TConversationCo
latestMultiMessage,
isSubmittingFamily,
} = useMessageProcess({ message: props.message });
const { message, currentEditId, setCurrentEditId, costs } = props;
const { message, currentEditId, setCurrentEditId } = props;
const maximizeChatSpace = useRecoilValue(store.maximizeChatSpace);
if (!message || typeof message !== 'object') {
@@ -63,7 +62,6 @@ export default function Message(props: TMessageProps & { costs?: TConversationCo
message={message}
isSubmittingFamily={isSubmittingFamily}
isCard
costs={costs}
/>
<MessageRender
{...props}
@@ -71,13 +69,12 @@ export default function Message(props: TMessageProps & { costs?: TConversationCo
isCard
message={siblingMessage ?? latestMultiMessage ?? undefined}
isSubmittingFamily={isSubmittingFamily}
costs={costs}
/>
</div>
</div>
) : (
<div className="m-auto justify-center p-4 py-2 md:gap-6">
<MessageRender {...props} costs={costs} />
<MessageRender {...props} />
</div>
)}
</MessageContainer>
@@ -88,7 +85,6 @@ export default function Message(props: TMessageProps & { costs?: TConversationCo
messagesTree={children ?? []}
currentEditId={currentEditId}
setCurrentEditId={setCurrentEditId}
costs={costs}
/>
</>
);

View File

@@ -1,6 +1,6 @@
import React, { useMemo } from 'react';
import { useRecoilValue } from 'recoil';
import type { TMessageContentParts, TConversationCosts } from 'librechat-data-provider';
import type { TMessageContentParts } from 'librechat-data-provider';
import type { TMessageProps, TMessageIcon } from '~/common';
import { useMessageHelpers, useLocalize, useAttachments } from '~/hooks';
import MessageIcon from '~/components/Chat/Messages/MessageIcon';
@@ -12,17 +12,10 @@ import SubRow from './SubRow';
import { cn } from '~/utils';
import store from '~/store';
export default function Message(props: TMessageProps & { costs?: TConversationCosts }) {
export default function Message(props: TMessageProps) {
const localize = useLocalize();
const {
message,
siblingIdx,
siblingCount,
setSiblingIdx,
currentEditId,
setCurrentEditId,
costs,
} = props;
const { message, siblingIdx, siblingCount, setSiblingIdx, currentEditId, setCurrentEditId } =
props;
const { attachments, searchResults } = useAttachments({
messageId: message?.messageId,
attachments: message?.attachments,
@@ -132,6 +125,7 @@ export default function Message(props: TMessageProps & { costs?: TConversationCo
setSiblingIdx={setSiblingIdx}
isCreatedByUser={message.isCreatedByUser}
conversationId={conversation?.conversationId}
isLatestMessage={messageId === latestMessage?.messageId}
content={message.content as Array<TMessageContentParts | undefined>}
/>
</div>
@@ -171,7 +165,6 @@ export default function Message(props: TMessageProps & { costs?: TConversationCo
messagesTree={children ?? []}
currentEditId={currentEditId}
setCurrentEditId={setCurrentEditId}
costs={costs}
/>
</>
);

View File

@@ -1,21 +1,18 @@
import { useState } from 'react';
import { useRecoilValue } from 'recoil';
import { CSSTransition } from 'react-transition-group';
import type { TMessage, TConversationCosts } from 'librechat-data-provider';
import type { TMessage } from 'librechat-data-provider';
import { useScreenshot, useMessageScrolling, useLocalize } from '~/hooks';
import ScrollToBottom from '~/components/Messages/ScrollToBottom';
import { MessagesViewProvider } from '~/Providers';
import MultiMessage from './MultiMessage';
import { cn } from '~/utils';
import store from '~/store';
export default function MessagesView({
function MessagesViewContent({
messagesTree: _messagesTree,
costBar,
costs,
}: {
messagesTree?: TMessage[] | null;
costBar?: React.ReactNode;
costs?: TConversationCosts;
}) {
const localize = useLocalize();
const fontSize = useRecoilValue(store.fontSize);
@@ -48,7 +45,7 @@ export default function MessagesView({
width: '100%',
}}
>
<div className="flex flex-col dark:bg-transparent">
<div className="flex flex-col pb-9 dark:bg-transparent">
{(_messagesTree && _messagesTree.length == 0) || _messagesTree === null ? (
<div
className={cn(
@@ -67,25 +64,18 @@ export default function MessagesView({
messageId={conversationId ?? null}
setCurrentEditId={setCurrentEditId}
currentEditId={currentEditId ?? null}
costs={costs}
/>
</div>
</>
)}
<div
id="messages-end"
className="group h-1 w-full flex-shrink-0 pb-7"
className="group h-0 w-full flex-shrink-0"
ref={messagesEndRef}
/>
</div>
</div>
{costBar && (
<div className="pointer-events-none absolute bottom-2 left-1/2 z-10 -translate-x-1/2">
{costBar}
</div>
)}
<CSSTransition
in={showScrollButton && scrollButtonPreference}
timeout={{
@@ -103,3 +93,11 @@ export default function MessagesView({
</>
);
}
export default function MessagesView({ messagesTree }: { messagesTree?: TMessage[] | null }) {
return (
<MessagesViewProvider>
<MessagesViewContent messagesTree={messagesTree} />
</MessagesViewProvider>
);
}

View File

@@ -1,7 +1,7 @@
import { useRecoilState } from 'recoil';
import { useEffect, useCallback } from 'react';
import { isAssistantsEndpoint } from 'librechat-data-provider';
import type { TMessage, TConversationCosts } from 'librechat-data-provider';
import type { TMessage } from 'librechat-data-provider';
import type { TMessageProps } from '~/common';
import MessageContent from '~/components/Messages/MessageContent';
import MessageParts from './MessageParts';
@@ -14,8 +14,7 @@ export default function MultiMessage({
messagesTree,
currentEditId,
setCurrentEditId,
costs,
}: TMessageProps & { costs?: TConversationCosts }) {
}: TMessageProps) {
const [siblingIdx, setSiblingIdx] = useRecoilState(store.messagesSiblingIdxFamily(messageId));
const setSiblingIdxRev = useCallback(
@@ -28,7 +27,7 @@ export default function MultiMessage({
useEffect(() => {
// reset siblingIdx when the tree changes, mostly when a new message is submitting.
setSiblingIdx(0);
}, [messagesTree?.length]);
}, [messagesTree?.length, setSiblingIdx]);
useEffect(() => {
if (messagesTree?.length && siblingIdx >= messagesTree.length) {
@@ -56,7 +55,6 @@ export default function MultiMessage({
siblingIdx={messagesTree.length - siblingIdx - 1}
siblingCount={messagesTree.length}
setSiblingIdx={setSiblingIdxRev}
costs={costs}
/>
);
} else if (message.content) {
@@ -69,7 +67,6 @@ export default function MultiMessage({
siblingIdx={messagesTree.length - siblingIdx - 1}
siblingCount={messagesTree.length}
setSiblingIdx={setSiblingIdxRev}
costs={costs}
/>
);
}
@@ -83,7 +80,6 @@ export default function MultiMessage({
siblingIdx={messagesTree.length - siblingIdx - 1}
siblingCount={messagesTree.length}
setSiblingIdx={setSiblingIdxRev}
costs={costs}
/>
);
}

View File

@@ -1,17 +1,16 @@
import React, { useCallback, useMemo, memo } from 'react';
import { useRecoilValue } from 'recoil';
import { ArrowIcon } from '@librechat/client';
import { type TMessage, TConversationCosts } from 'librechat-data-provider';
import { type TMessage } from 'librechat-data-provider';
import type { TMessageProps, TMessageIcon } from '~/common';
import MessageContent from '~/components/Chat/Messages/Content/MessageContent';
import PlaceholderRow from '~/components/Chat/Messages/ui/PlaceholderRow';
import SiblingSwitch from '~/components/Chat/Messages/SiblingSwitch';
import HoverButtons from '~/components/Chat/Messages/HoverButtons';
import MessageIcon from '~/components/Chat/Messages/MessageIcon';
import { useMessageActions, useLocalize } from '~/hooks';
import { Plugin } from '~/components/Messages/Content';
import SubRow from '~/components/Chat/Messages/SubRow';
import { MessageContext } from '~/Providers';
import { useMessageActions } from '~/hooks';
import { cn, logger } from '~/utils';
import store from '~/store';
@@ -20,7 +19,6 @@ type MessageRenderProps = {
isCard?: boolean;
isMultiMessage?: boolean;
isSubmittingFamily?: boolean;
costs?: TConversationCosts;
} & Pick<
TMessageProps,
'currentEditId' | 'setCurrentEditId' | 'siblingIdx' | 'setSiblingIdx' | 'siblingCount'
@@ -37,9 +35,7 @@ const MessageRender = memo(
isMultiMessage = false,
setCurrentEditId,
isSubmittingFamily = false,
costs,
}: MessageRenderProps) => {
const localize = useLocalize();
const {
ask,
edit,
@@ -64,18 +60,6 @@ const MessageRender = memo(
});
const maximizeChatSpace = useRecoilValue(store.maximizeChatSpace);
const fontSize = useRecoilValue(store.fontSize);
const showCostTracking = useRecoilValue(store.showCostTracking);
const perMessageCost = useMemo(() => {
if (!showCostTracking || !costs || !costs.perMessage || !msg?.messageId) {
return null;
}
const entry = costs.perMessage.find((p) => p.messageId === msg.messageId);
if (!entry) {
return null;
}
return entry;
}, [showCostTracking, costs, msg?.messageId]);
const handleRegenerateMessage = useCallback(() => regenerateMessage(), [regenerateMessage]);
const hasNoChildren = !(msg?.children?.length ?? 0);
@@ -87,6 +71,9 @@ const MessageRender = memo(
const showCardRender = isLast && !isSubmittingFamily && isCard;
const isLatestCard = isCard && !isSubmittingFamily && isLatestMessage;
/** Only pass isSubmitting to the latest message to prevent unnecessary re-renders */
const effectiveIsSubmitting = isLatestMessage ? isSubmitting : false;
const iconData: TMessageIcon = useMemo(
() => ({
endpoint: msg?.endpoint ?? conversation?.endpoint,
@@ -173,26 +160,7 @@ const MessageRender = memo(
msg.isCreatedByUser ? 'user-turn' : 'agent-turn',
)}
>
<h2 className={cn('select-none font-semibold', fontSize)}>
{messageLabel}
{perMessageCost && (
<span className="ml-2 inline-flex items-center gap-2 px-2 py-0.5 text-xs text-muted-foreground">
{perMessageCost.tokenCount > 0 && (
<span>
{perMessageCost.tokenType === 'prompt' ? (
<ArrowIcon direction="up" className="inline" />
) : (
<ArrowIcon direction="down" className="inline" />
)}
{localize('com_ui_token_abbreviation', {
0: perMessageCost.tokenCount,
})}
</span>
)}
<span className="whitespace-pre">${Math.abs(perMessageCost.usd).toFixed(6)}</span>
</span>
)}
</h2>
<h2 className={cn('select-none font-semibold', fontSize)}>{messageLabel}</h2>
<div className="flex flex-col gap-1">
<div className="flex max-w-full flex-grow flex-col gap-0">
@@ -201,6 +169,8 @@ const MessageRender = memo(
messageId: msg.messageId,
conversationId: conversation?.conversationId,
isExpanded: false,
isSubmitting: effectiveIsSubmitting,
isLatestMessage,
}}
>
{msg.plugin && <Plugin plugin={msg.plugin} />}
@@ -212,7 +182,7 @@ const MessageRender = memo(
message={msg}
enterEdit={enterEdit}
error={!!(msg.error ?? false)}
isSubmitting={isSubmitting}
isSubmitting={effectiveIsSubmitting}
unfinished={msg.unfinished ?? false}
isCreatedByUser={msg.isCreatedByUser ?? true}
siblingIdx={siblingIdx ?? 0}
@@ -221,7 +191,7 @@ const MessageRender = memo(
</MessageContext.Provider>
</div>
{hasNoChildren && (isSubmittingFamily === true || isSubmitting) ? (
{hasNoChildren && (isSubmittingFamily === true || effectiveIsSubmitting) ? (
<PlaceholderRow isCard={isCard} />
) : (
<SubRow classes="text-xs">

View File

@@ -1,6 +1,5 @@
import { useMemo, memo, type FC, useCallback } from 'react';
import throttle from 'lodash/throttle';
import { parseISO, isToday } from 'date-fns';
import { Spinner, useMediaQuery } from '@librechat/client';
import { List, AutoSizer, CellMeasurer, CellMeasurerCache } from 'react-virtualized';
import { TConversation } from 'librechat-data-provider';
@@ -50,27 +49,17 @@ const MemoizedConvo = memo(
conversation,
retainView,
toggleNav,
isLatestConvo,
}: {
conversation: TConversation;
retainView: () => void;
toggleNav: () => void;
isLatestConvo: boolean;
}) => {
return (
<Convo
conversation={conversation}
retainView={retainView}
toggleNav={toggleNav}
isLatestConvo={isLatestConvo}
/>
);
return <Convo conversation={conversation} retainView={retainView} toggleNav={toggleNav} />;
},
(prevProps, nextProps) => {
return (
prevProps.conversation.conversationId === nextProps.conversation.conversationId &&
prevProps.conversation.title === nextProps.conversation.title &&
prevProps.isLatestConvo === nextProps.isLatestConvo &&
prevProps.conversation.endpoint === nextProps.conversation.endpoint
);
},
@@ -98,13 +87,6 @@ const Conversations: FC<ConversationsProps> = ({
[filteredConversations],
);
const firstTodayConvoId = useMemo(
() =>
filteredConversations.find((convo) => convo.updatedAt && isToday(parseISO(convo.updatedAt)))
?.conversationId ?? undefined,
[filteredConversations],
);
const flattenedItems = useMemo(() => {
const items: FlattenedItem[] = [];
groupedConversations.forEach(([groupName, convos]) => {
@@ -154,26 +136,25 @@ const Conversations: FC<ConversationsProps> = ({
</CellMeasurer>
);
}
let rendering: JSX.Element;
if (item.type === 'header') {
rendering = <DateLabel groupName={item.groupName} />;
} else if (item.type === 'convo') {
rendering = (
<MemoizedConvo conversation={item.convo} retainView={moveToTop} toggleNav={toggleNav} />
);
}
return (
<CellMeasurer cache={cache} columnIndex={0} key={key} parent={parent} rowIndex={index}>
{({ registerChild }) => (
<div ref={registerChild} style={style}>
{item.type === 'header' ? (
<DateLabel groupName={item.groupName} />
) : item.type === 'convo' ? (
<MemoizedConvo
conversation={item.convo}
retainView={moveToTop}
toggleNav={toggleNav}
isLatestConvo={item.convo.conversationId === firstTodayConvoId}
/>
) : null}
{rendering}
</div>
)}
</CellMeasurer>
);
},
[cache, flattenedItems, firstTodayConvoId, moveToTop, toggleNav],
[cache, flattenedItems, moveToTop, toggleNav],
);
const getRowHeight = useCallback(

View File

@@ -11,23 +11,17 @@ import { useGetEndpointsQuery } from '~/data-provider';
import { NotificationSeverity } from '~/common';
import { ConvoOptions } from './ConvoOptions';
import RenameForm from './RenameForm';
import { cn, logger } from '~/utils';
import ConvoLink from './ConvoLink';
import { cn } from '~/utils';
import store from '~/store';
interface ConversationProps {
conversation: TConversation;
retainView: () => void;
toggleNav: () => void;
isLatestConvo: boolean;
}
export default function Conversation({
conversation,
retainView,
toggleNav,
isLatestConvo,
}: ConversationProps) {
export default function Conversation({ conversation, retainView, toggleNav }: ConversationProps) {
const params = useParams();
const localize = useLocalize();
const { showToast } = useToastContext();
@@ -84,6 +78,7 @@ export default function Conversation({
});
setRenaming(false);
} catch (error) {
logger.error('Error renaming conversation', error);
setTitleInput(title as string);
showToast({
message: localize('com_ui_rename_failed'),

View File

@@ -1,14 +1,13 @@
import { useRecoilValue } from 'recoil';
import { ArrowIcon } from '@librechat/client';
import { useCallback, useMemo, memo } from 'react';
import type { TMessage, TMessageContentParts, TConversationCosts } from 'librechat-data-provider';
import type { TMessage, TMessageContentParts } from 'librechat-data-provider';
import type { TMessageProps, TMessageIcon } from '~/common';
import ContentParts from '~/components/Chat/Messages/Content/ContentParts';
import PlaceholderRow from '~/components/Chat/Messages/ui/PlaceholderRow';
import { useAttachments, useMessageActions, useLocalize } from '~/hooks';
import SiblingSwitch from '~/components/Chat/Messages/SiblingSwitch';
import HoverButtons from '~/components/Chat/Messages/HoverButtons';
import MessageIcon from '~/components/Chat/Messages/MessageIcon';
import { useAttachments, useMessageActions } from '~/hooks';
import SubRow from '~/components/Chat/Messages/SubRow';
import { cn, logger } from '~/utils';
import store from '~/store';
@@ -18,7 +17,6 @@ type ContentRenderProps = {
isCard?: boolean;
isMultiMessage?: boolean;
isSubmittingFamily?: boolean;
costs?: TConversationCosts;
} & Pick<
TMessageProps,
'currentEditId' | 'setCurrentEditId' | 'siblingIdx' | 'setSiblingIdx' | 'siblingCount'
@@ -35,9 +33,7 @@ const ContentRender = memo(
isMultiMessage = false,
setCurrentEditId,
isSubmittingFamily = false,
costs,
}: ContentRenderProps) => {
const localize = useLocalize();
const { attachments, searchResults } = useAttachments({
messageId: msg?.messageId,
attachments: msg?.attachments,
@@ -66,14 +62,6 @@ const ContentRender = memo(
});
const maximizeChatSpace = useRecoilValue(store.maximizeChatSpace);
const fontSize = useRecoilValue(store.fontSize);
const showCostTracking = useRecoilValue(store.showCostTracking);
const perMessageCost = useMemo(() => {
if (!showCostTracking || !costs || !costs.perMessage || !msg?.messageId) {
return null;
}
return costs.perMessage.find((p) => p.messageId === msg.messageId) ?? null;
}, [showCostTracking, costs, msg?.messageId]);
const handleRegenerateMessage = useCallback(() => regenerateMessage(), [regenerateMessage]);
const isLast = useMemo(
@@ -171,26 +159,7 @@ const ContentRender = memo(
msg.isCreatedByUser ? 'user-turn' : 'agent-turn',
)}
>
<h2 className={cn('select-none font-semibold', fontSize)}>
{messageLabel}
{perMessageCost && (
<span className="ml-2 inline-flex items-center gap-2 px-2 py-0.5 text-xs text-muted-foreground">
{perMessageCost.tokenCount > 0 && (
<span className="mr-2">
{perMessageCost.tokenType === 'prompt' ? (
<ArrowIcon direction="up" className="inline" />
) : (
<ArrowIcon direction="down" className="inline" />
)}
{localize('com_ui_token_abbreviation', {
0: perMessageCost.tokenCount,
})}
</span>
)}
<span className="whitespace-pre">${Math.abs(perMessageCost.usd).toFixed(6)}</span>
</span>
)}
</h2>
<h2 className={cn('select-none font-semibold', fontSize)}>{messageLabel}</h2>
<div className="flex flex-col gap-1">
<div className="flex max-w-full flex-grow flex-col gap-0">
@@ -204,6 +173,7 @@ const ContentRender = memo(
isSubmitting={isSubmitting}
searchResults={searchResults}
setSiblingIdx={setSiblingIdx}
isLatestMessage={isLatestMessage}
isCreatedByUser={msg.isCreatedByUser}
conversationId={conversation?.conversationId}
content={msg.content as Array<TMessageContentParts | undefined>}

View File

@@ -1,6 +1,5 @@
import React from 'react';
import { useMessageProcess } from '~/hooks';
import type { TConversationCosts } from 'librechat-data-provider';
import type { TMessageProps } from '~/common';
// eslint-disable-next-line import/no-cycle
import MultiMessage from '~/components/Chat/Messages/MultiMessage';
@@ -26,7 +25,7 @@ const MessageContainer = React.memo(
},
);
export default function MessageContent(props: TMessageProps & { costs?: TConversationCosts }) {
export default function MessageContent(props: TMessageProps) {
const {
showSibling,
conversation,
@@ -35,7 +34,7 @@ export default function MessageContent(props: TMessageProps & { costs?: TConvers
latestMultiMessage,
isSubmittingFamily,
} = useMessageProcess({ message: props.message });
const { message, currentEditId, setCurrentEditId, costs } = props;
const { message, currentEditId, setCurrentEditId } = props;
if (!message || typeof message !== 'object') {
return null;
@@ -54,7 +53,6 @@ export default function MessageContent(props: TMessageProps & { costs?: TConvers
message={message}
isSubmittingFamily={isSubmittingFamily}
isCard
costs={costs}
/>
<ContentRender
{...props}
@@ -62,13 +60,12 @@ export default function MessageContent(props: TMessageProps & { costs?: TConvers
isCard
message={siblingMessage ?? latestMultiMessage ?? undefined}
isSubmittingFamily={isSubmittingFamily}
costs={costs}
/>
</div>
</div>
) : (
<div className="m-auto justify-center p-4 py-2 md:gap-6">
<ContentRender {...props} costs={costs} />
<div className="m-auto justify-center p-4 py-2 md:gap-6 ">
<ContentRender {...props} />
</div>
)}
</MessageContainer>
@@ -79,7 +76,6 @@ export default function MessageContent(props: TMessageProps & { costs?: TConvers
messagesTree={children ?? []}
currentEditId={currentEditId}
setCurrentEditId={setCurrentEditId}
costs={costs}
/>
</>
);

View File

@@ -76,13 +76,6 @@ const toggleSwitchConfigs = [
hoverCardText: undefined,
key: 'modularChat',
},
{
stateAtom: store.showCostTracking,
localizationKey: 'com_nav_show_cost_tracking',
switchId: 'showCostTracking',
hoverCardText: 'com_nav_info_show_cost_tracking',
key: 'showCostTracking',
},
];
function Chat() {

View File

@@ -76,6 +76,8 @@ export default function Message(props: TMessageProps) {
messageId,
isExpanded: false,
conversationId: conversation?.conversationId,
isSubmitting: false, // Share view is always read-only
isLatestMessage: false, // No concept of latest message in share view
}}
>
{/* Legacy Plugins */}

View File

@@ -1,4 +1,4 @@
import React, { useState, useMemo, useCallback } from 'react';
import React, { useState, useMemo, useCallback, useEffect } from 'react';
import { ChevronLeft, Trash2 } from 'lucide-react';
import { useQueryClient } from '@tanstack/react-query';
import { Button, useToastContext } from '@librechat/client';
@@ -12,6 +12,8 @@ import { useLocalize, useMCPConnectionStatus } from '~/hooks';
import { useGetStartupConfig } from '~/data-provider';
import MCPPanelSkeleton from './MCPPanelSkeleton';
const POLL_FOR_CONNECTION_STATUS_INTERVAL = 2_000; // ms
function MCPPanelContent() {
const localize = useLocalize();
const queryClient = useQueryClient();
@@ -26,6 +28,29 @@ function MCPPanelContent() {
null,
);
// Check if any connections are in 'connecting' state
const hasConnectingServers = useMemo(() => {
if (!connectionStatus) {
return false;
}
return Object.values(connectionStatus).some(
(status) => status?.connectionState === 'connecting',
);
}, [connectionStatus]);
// Set up polling when servers are connecting
useEffect(() => {
if (!hasConnectingServers) {
return;
}
const intervalId = setInterval(() => {
queryClient.invalidateQueries([QueryKeys.mcpConnectionStatus]);
}, POLL_FOR_CONNECTION_STATUS_INTERVAL);
return () => clearInterval(intervalId);
}, [hasConnectingServers, queryClient]);
const updateUserPluginsMutation = useUpdateUserPluginsMutation({
onSuccess: async () => {
showToast({ message: localize('com_nav_mcp_vars_updated'), status: 'success' });

View File

@@ -0,0 +1,440 @@
import { renderHook } from '@testing-library/react';
import { Tools, Constants } from 'librechat-data-provider';
import useAgentToolPermissions from '../useAgentToolPermissions';
// Mock dependencies
jest.mock('~/data-provider', () => ({
useGetAgentByIdQuery: jest.fn(),
}));
jest.mock('~/Providers', () => ({
useAgentsMapContext: jest.fn(),
}));
// Import mocked functions after mocking
import { useGetAgentByIdQuery } from '~/data-provider';
import { useAgentsMapContext } from '~/Providers';
describe('useAgentToolPermissions', () => {
beforeEach(() => {
jest.clearAllMocks();
});
describe('Ephemeral Agent Scenarios', () => {
it('should return true for all tools when agentId is null', () => {
(useAgentsMapContext as jest.Mock).mockReturnValue({});
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
const { result } = renderHook(() => useAgentToolPermissions(null));
expect(result.current.fileSearchAllowedByAgent).toBe(true);
expect(result.current.codeAllowedByAgent).toBe(true);
expect(result.current.tools).toBeUndefined();
});
it('should return true for all tools when agentId is undefined', () => {
(useAgentsMapContext as jest.Mock).mockReturnValue({});
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
const { result } = renderHook(() => useAgentToolPermissions(undefined));
expect(result.current.fileSearchAllowedByAgent).toBe(true);
expect(result.current.codeAllowedByAgent).toBe(true);
expect(result.current.tools).toBeUndefined();
});
it('should return true for all tools when agentId is empty string', () => {
(useAgentsMapContext as jest.Mock).mockReturnValue({});
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
const { result } = renderHook(() => useAgentToolPermissions(''));
expect(result.current.fileSearchAllowedByAgent).toBe(true);
expect(result.current.codeAllowedByAgent).toBe(true);
expect(result.current.tools).toBeUndefined();
});
it('should return true for all tools when agentId is EPHEMERAL_AGENT_ID', () => {
(useAgentsMapContext as jest.Mock).mockReturnValue({});
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
const { result } = renderHook(() =>
useAgentToolPermissions(Constants.EPHEMERAL_AGENT_ID)
);
expect(result.current.fileSearchAllowedByAgent).toBe(true);
expect(result.current.codeAllowedByAgent).toBe(true);
expect(result.current.tools).toBeUndefined();
});
});
describe('Regular Agent with Tools', () => {
it('should allow file_search when agent has the tool', () => {
const agentId = 'agent-123';
const mockAgent = {
id: agentId,
tools: [Tools.file_search, 'other_tool'],
};
(useAgentsMapContext as jest.Mock).mockReturnValue({
[agentId]: mockAgent,
});
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
const { result } = renderHook(() => useAgentToolPermissions(agentId));
expect(result.current.fileSearchAllowedByAgent).toBe(true);
expect(result.current.codeAllowedByAgent).toBe(false);
expect(result.current.tools).toEqual([Tools.file_search, 'other_tool']);
});
it('should allow execute_code when agent has the tool', () => {
const agentId = 'agent-456';
const mockAgent = {
id: agentId,
tools: [Tools.execute_code, 'another_tool'],
};
(useAgentsMapContext as jest.Mock).mockReturnValue({
[agentId]: mockAgent,
});
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
const { result } = renderHook(() => useAgentToolPermissions(agentId));
expect(result.current.fileSearchAllowedByAgent).toBe(false);
expect(result.current.codeAllowedByAgent).toBe(true);
expect(result.current.tools).toEqual([Tools.execute_code, 'another_tool']);
});
it('should allow both tools when agent has both', () => {
const agentId = 'agent-789';
const mockAgent = {
id: agentId,
tools: [Tools.file_search, Tools.execute_code, 'custom_tool'],
};
(useAgentsMapContext as jest.Mock).mockReturnValue({
[agentId]: mockAgent,
});
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
const { result } = renderHook(() => useAgentToolPermissions(agentId));
expect(result.current.fileSearchAllowedByAgent).toBe(true);
expect(result.current.codeAllowedByAgent).toBe(true);
expect(result.current.tools).toEqual([Tools.file_search, Tools.execute_code, 'custom_tool']);
});
it('should disallow both tools when agent has neither', () => {
const agentId = 'agent-no-tools';
const mockAgent = {
id: agentId,
tools: ['custom_tool1', 'custom_tool2'],
};
(useAgentsMapContext as jest.Mock).mockReturnValue({
[agentId]: mockAgent,
});
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
const { result } = renderHook(() => useAgentToolPermissions(agentId));
expect(result.current.fileSearchAllowedByAgent).toBe(false);
expect(result.current.codeAllowedByAgent).toBe(false);
expect(result.current.tools).toEqual(['custom_tool1', 'custom_tool2']);
});
it('should handle agent with empty tools array', () => {
const agentId = 'agent-empty-tools';
const mockAgent = {
id: agentId,
tools: [],
};
(useAgentsMapContext as jest.Mock).mockReturnValue({
[agentId]: mockAgent,
});
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
const { result } = renderHook(() => useAgentToolPermissions(agentId));
expect(result.current.fileSearchAllowedByAgent).toBe(false);
expect(result.current.codeAllowedByAgent).toBe(false);
expect(result.current.tools).toEqual([]);
});
it('should handle agent with undefined tools', () => {
const agentId = 'agent-undefined-tools';
const mockAgent = {
id: agentId,
tools: undefined,
};
(useAgentsMapContext as jest.Mock).mockReturnValue({
[agentId]: mockAgent,
});
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
const { result } = renderHook(() => useAgentToolPermissions(agentId));
expect(result.current.fileSearchAllowedByAgent).toBe(false);
expect(result.current.codeAllowedByAgent).toBe(false);
expect(result.current.tools).toBeUndefined();
});
});
describe('Agent Data from Query', () => {
it('should prioritize agentData tools over selectedAgent tools', () => {
const agentId = 'agent-with-query-data';
const mockAgent = {
id: agentId,
tools: ['old_tool'],
};
const mockAgentData = {
id: agentId,
tools: [Tools.file_search, Tools.execute_code],
};
(useAgentsMapContext as jest.Mock).mockReturnValue({
[agentId]: mockAgent,
});
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: mockAgentData });
const { result } = renderHook(() => useAgentToolPermissions(agentId));
expect(result.current.fileSearchAllowedByAgent).toBe(true);
expect(result.current.codeAllowedByAgent).toBe(true);
expect(result.current.tools).toEqual([Tools.file_search, Tools.execute_code]);
});
it('should fallback to selectedAgent tools when agentData has no tools', () => {
const agentId = 'agent-fallback';
const mockAgent = {
id: agentId,
tools: [Tools.file_search],
};
const mockAgentData = {
id: agentId,
tools: undefined,
};
(useAgentsMapContext as jest.Mock).mockReturnValue({
[agentId]: mockAgent,
});
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: mockAgentData });
const { result } = renderHook(() => useAgentToolPermissions(agentId));
expect(result.current.fileSearchAllowedByAgent).toBe(true);
expect(result.current.codeAllowedByAgent).toBe(false);
expect(result.current.tools).toEqual([Tools.file_search]);
});
});
describe('Agent Not Found Scenarios', () => {
it('should disallow all tools when agent is not found in map', () => {
const agentId = 'non-existent-agent';
(useAgentsMapContext as jest.Mock).mockReturnValue({});
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
const { result } = renderHook(() => useAgentToolPermissions(agentId));
expect(result.current.fileSearchAllowedByAgent).toBe(false);
expect(result.current.codeAllowedByAgent).toBe(false);
expect(result.current.tools).toBeUndefined();
});
it('should disallow all tools when agentsMap is null', () => {
const agentId = 'agent-with-null-map';
(useAgentsMapContext as jest.Mock).mockReturnValue(null);
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
const { result } = renderHook(() => useAgentToolPermissions(agentId));
expect(result.current.fileSearchAllowedByAgent).toBe(false);
expect(result.current.codeAllowedByAgent).toBe(false);
expect(result.current.tools).toBeUndefined();
});
it('should disallow all tools when agentsMap is undefined', () => {
const agentId = 'agent-with-undefined-map';
(useAgentsMapContext as jest.Mock).mockReturnValue(undefined);
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
const { result } = renderHook(() => useAgentToolPermissions(agentId));
expect(result.current.fileSearchAllowedByAgent).toBe(false);
expect(result.current.codeAllowedByAgent).toBe(false);
expect(result.current.tools).toBeUndefined();
});
});
describe('Memoization and Performance', () => {
it('should memoize results when inputs do not change', () => {
const agentId = 'memoized-agent';
const mockAgent = {
id: agentId,
tools: [Tools.file_search],
};
(useAgentsMapContext as jest.Mock).mockReturnValue({
[agentId]: mockAgent,
});
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
const { result, rerender } = renderHook(() => useAgentToolPermissions(agentId));
const firstResult = result.current;
// Rerender without changing inputs
rerender();
const secondResult = result.current;
// The hook returns a new object each time, but the values should be equal
expect(firstResult.fileSearchAllowedByAgent).toBe(secondResult.fileSearchAllowedByAgent);
expect(firstResult.codeAllowedByAgent).toBe(secondResult.codeAllowedByAgent);
// Tools array reference should be the same since it comes from useMemo
expect(firstResult.tools).toBe(secondResult.tools);
// Verify the actual values are correct
expect(secondResult.fileSearchAllowedByAgent).toBe(true);
expect(secondResult.codeAllowedByAgent).toBe(false);
expect(secondResult.tools).toEqual([Tools.file_search]);
});
it('should recompute when agentId changes', () => {
const agentId1 = 'agent-1';
const agentId2 = 'agent-2';
const mockAgents = {
[agentId1]: { id: agentId1, tools: [Tools.file_search] },
[agentId2]: { id: agentId2, tools: [Tools.execute_code] },
};
(useAgentsMapContext as jest.Mock).mockReturnValue(mockAgents);
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
const { result, rerender } = renderHook(
({ agentId }) => useAgentToolPermissions(agentId),
{ initialProps: { agentId: agentId1 } }
);
expect(result.current.fileSearchAllowedByAgent).toBe(true);
expect(result.current.codeAllowedByAgent).toBe(false);
// Change agentId
rerender({ agentId: agentId2 });
expect(result.current.fileSearchAllowedByAgent).toBe(false);
expect(result.current.codeAllowedByAgent).toBe(true);
});
it('should handle switching between ephemeral and regular agents', () => {
const regularAgentId = 'regular-agent';
const mockAgent = {
id: regularAgentId,
tools: [],
};
(useAgentsMapContext as jest.Mock).mockReturnValue({
[regularAgentId]: mockAgent,
});
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
const { result, rerender } = renderHook(
({ agentId }) => useAgentToolPermissions(agentId),
{ initialProps: { agentId: null } }
);
// Start with ephemeral agent (null)
expect(result.current.fileSearchAllowedByAgent).toBe(true);
expect(result.current.codeAllowedByAgent).toBe(true);
// Switch to regular agent
rerender({ agentId: regularAgentId });
expect(result.current.fileSearchAllowedByAgent).toBe(false);
expect(result.current.codeAllowedByAgent).toBe(false);
// Switch back to ephemeral
rerender({ agentId: '' });
expect(result.current.fileSearchAllowedByAgent).toBe(true);
expect(result.current.codeAllowedByAgent).toBe(true);
});
});
describe('Edge Cases', () => {
it('should handle agents with null tools gracefully', () => {
const agentId = 'agent-null-tools';
const mockAgent = {
id: agentId,
tools: null as any,
};
(useAgentsMapContext as jest.Mock).mockReturnValue({
[agentId]: mockAgent,
});
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
const { result } = renderHook(() => useAgentToolPermissions(agentId));
expect(result.current.fileSearchAllowedByAgent).toBe(false);
expect(result.current.codeAllowedByAgent).toBe(false);
expect(result.current.tools).toBeNull();
});
it('should handle whitespace-only agentId as ephemeral', () => {
// Note: Based on the current implementation, only empty string is treated as ephemeral
// Whitespace-only strings would be treated as regular agent IDs
const whitespaceId = ' ';
(useAgentsMapContext as jest.Mock).mockReturnValue({});
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
const { result } = renderHook(() => useAgentToolPermissions(whitespaceId));
// Whitespace ID is not considered ephemeral in current implementation
expect(result.current.fileSearchAllowedByAgent).toBe(false);
expect(result.current.codeAllowedByAgent).toBe(false);
});
it('should handle query loading state', () => {
const agentId = 'loading-agent';
(useAgentsMapContext as jest.Mock).mockReturnValue({});
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({
data: undefined,
isLoading: true,
error: null,
});
const { result } = renderHook(() => useAgentToolPermissions(agentId));
// During loading, should return false for non-ephemeral agents
expect(result.current.fileSearchAllowedByAgent).toBe(false);
expect(result.current.codeAllowedByAgent).toBe(false);
expect(result.current.tools).toBeUndefined();
});
it('should handle query error state', () => {
const agentId = 'error-agent';
(useAgentsMapContext as jest.Mock).mockReturnValue({});
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({
data: undefined,
isLoading: false,
error: new Error('Failed to fetch agent'),
});
const { result } = renderHook(() => useAgentToolPermissions(agentId));
// On error, should return false for non-ephemeral agents
expect(result.current.fileSearchAllowedByAgent).toBe(false);
expect(result.current.codeAllowedByAgent).toBe(false);
expect(result.current.tools).toBeUndefined();
});
});
});

View File

@@ -1,5 +1,5 @@
import { useMemo } from 'react';
import { Tools } from 'librechat-data-provider';
import { Tools, Constants } from 'librechat-data-provider';
import { useGetAgentByIdQuery } from '~/data-provider';
import { useAgentsMapContext } from '~/Providers';
@@ -9,6 +9,10 @@ interface AgentToolPermissionsResult {
tools: string[] | undefined;
}
function isEphemeralAgent(agentId: string | null | undefined): boolean {
return agentId == null || agentId === '' || agentId === Constants.EPHEMERAL_AGENT_ID;
}
/**
* Hook to determine whether specific tools are allowed for a given agent.
*
@@ -33,8 +37,8 @@ export default function useAgentToolPermissions(
);
const fileSearchAllowedByAgent = useMemo(() => {
// If no agentId, allow for ephemeral agents
if (!agentId) return true;
// Allow for ephemeral agents
if (isEphemeralAgent(agentId)) return true;
// If agentId exists but agent not found, disallow
if (!selectedAgent) return false;
// Check if the agent has the file_search tool
@@ -42,8 +46,8 @@ export default function useAgentToolPermissions(
}, [agentId, selectedAgent, tools]);
const codeAllowedByAgent = useMemo(() => {
// If no agentId, allow for ephemeral agents
if (!agentId) return true;
// Allow for ephemeral agents
if (isEphemeralAgent(agentId)) return true;
// If agentId exists but agent not found, disallow
if (!selectedAgent) return false;
// Check if the agent has the execute_code tool

View File

@@ -0,0 +1,495 @@
import React from 'react';
import { Provider, createStore } from 'jotai';
import { renderHook, act, waitFor } from '@testing-library/react';
import { RecoilRoot, useRecoilState, useRecoilValue, useSetRecoilState } from 'recoil';
import { Constants, LocalStorageKeys } from 'librechat-data-provider';
import { ephemeralAgentByConvoId } from '~/store';
import { setTimestamp } from '~/utils/timestamps';
import { useMCPSelect } from '../useMCPSelect';
// Mock dependencies
jest.mock('~/utils/timestamps', () => ({
setTimestamp: jest.fn(),
}));
jest.mock('lodash/isEqual', () => jest.fn((a, b) => JSON.stringify(a) === JSON.stringify(b)));
const createWrapper = () => {
// Create a new Jotai store for each test to ensure clean state
const store = createStore();
const Wrapper: React.FC<{ children: React.ReactNode }> = ({ children }) => (
<RecoilRoot>
<Provider store={store}>{children}</Provider>
</RecoilRoot>
);
return Wrapper;
};
describe('useMCPSelect', () => {
beforeEach(() => {
jest.clearAllMocks();
localStorage.clear();
});
describe('Basic Functionality', () => {
it('should initialize with default values', () => {
const { result } = renderHook(() => useMCPSelect({}), {
wrapper: createWrapper(),
});
expect(result.current.mcpValues).toEqual([]);
expect(result.current.isPinned).toBe(true); // Default value from mcpPinnedAtom is true
expect(typeof result.current.setMCPValues).toBe('function');
expect(typeof result.current.setIsPinned).toBe('function');
});
it('should use conversationId when provided', () => {
const conversationId = 'test-convo-123';
const { result } = renderHook(() => useMCPSelect({ conversationId }), {
wrapper: createWrapper(),
});
expect(result.current.mcpValues).toEqual([]);
});
it('should use NEW_CONVO constant when conversationId is null', () => {
const { result } = renderHook(() => useMCPSelect({ conversationId: null }), {
wrapper: createWrapper(),
});
expect(result.current.mcpValues).toEqual([]);
});
});
describe('State Updates', () => {
it('should update mcpValues when setMCPValues is called', async () => {
const { result } = renderHook(() => useMCPSelect({}), {
wrapper: createWrapper(),
});
const newValues = ['value1', 'value2'];
act(() => {
result.current.setMCPValues(newValues);
});
await waitFor(() => {
expect(result.current.mcpValues).toEqual(newValues);
});
});
it('should not update mcpValues if non-array is passed', () => {
const { result } = renderHook(() => useMCPSelect({}), {
wrapper: createWrapper(),
});
act(() => {
// @ts-ignore - Testing invalid input
result.current.setMCPValues('not-an-array');
});
expect(result.current.mcpValues).toEqual([]);
});
it('should update isPinned state', () => {
const { result } = renderHook(() => useMCPSelect({}), {
wrapper: createWrapper(),
});
// Default is true
expect(result.current.isPinned).toBe(true);
// Toggle to false
act(() => {
result.current.setIsPinned(false);
});
expect(result.current.isPinned).toBe(false);
// Toggle back to true
act(() => {
result.current.setIsPinned(true);
});
expect(result.current.isPinned).toBe(true);
});
});
describe('Timestamp Management', () => {
it('should set timestamp when mcpValues is updated with values', async () => {
const conversationId = 'test-convo';
const { result } = renderHook(() => useMCPSelect({ conversationId }), {
wrapper: createWrapper(),
});
const newValues = ['value1', 'value2'];
act(() => {
result.current.setMCPValues(newValues);
});
await waitFor(() => {
const expectedKey = `${LocalStorageKeys.LAST_MCP_}${conversationId}`;
expect(setTimestamp).toHaveBeenCalledWith(expectedKey);
});
});
it('should not set timestamp when mcpValues is empty', async () => {
const { result } = renderHook(() => useMCPSelect({}), {
wrapper: createWrapper(),
});
act(() => {
result.current.setMCPValues([]);
});
await waitFor(() => {
expect(setTimestamp).not.toHaveBeenCalled();
});
});
});
describe('Race Conditions and Infinite Loops Prevention', () => {
it('should not create infinite loop when syncing between Jotai and Recoil states', async () => {
const { result, rerender } = renderHook(() => useMCPSelect({}), {
wrapper: createWrapper(),
});
let renderCount = 0;
const maxRenders = 10;
// Track renders to detect infinite loops
const trackRender = () => {
renderCount++;
if (renderCount > maxRenders) {
throw new Error('Potential infinite loop detected');
}
};
// Set initial value
act(() => {
trackRender();
result.current.setMCPValues(['initial']);
});
// Trigger multiple rerenders
for (let i = 0; i < 3; i++) {
rerender();
trackRender();
}
// Should not exceed max renders
expect(renderCount).toBeLessThanOrEqual(maxRenders);
});
it('should handle rapid consecutive updates without race conditions', async () => {
const { result } = renderHook(() => useMCPSelect({}), {
wrapper: createWrapper(),
});
const updates = [
['value1'],
['value1', 'value2'],
['value1', 'value2', 'value3'],
['value4'],
[],
];
// Rapid fire updates
act(() => {
updates.forEach((update) => {
result.current.setMCPValues(update);
});
});
await waitFor(() => {
// Should settle on the last update
expect(result.current.mcpValues).toEqual([]);
});
});
it('should maintain stable setter function reference', () => {
const { result, rerender } = renderHook(() => useMCPSelect({}), {
wrapper: createWrapper(),
});
const firstSetMCPValues = result.current.setMCPValues;
// Trigger multiple rerenders
rerender();
rerender();
rerender();
// Setter should remain the same reference (memoized)
expect(result.current.setMCPValues).toBe(firstSetMCPValues);
});
it('should handle switching conversation IDs without issues', async () => {
const { result, rerender } = renderHook(
({ conversationId }) => useMCPSelect({ conversationId }),
{
wrapper: createWrapper(),
initialProps: { conversationId: 'convo1' },
},
);
// Set values for first conversation
act(() => {
result.current.setMCPValues(['convo1-value']);
});
await waitFor(() => {
expect(result.current.mcpValues).toEqual(['convo1-value']);
});
// Switch to different conversation
rerender({ conversationId: 'convo2' });
// Should have different state for new conversation
expect(result.current.mcpValues).toEqual([]);
// Set values for second conversation
act(() => {
result.current.setMCPValues(['convo2-value']);
});
await waitFor(() => {
expect(result.current.mcpValues).toEqual(['convo2-value']);
});
// Switch back to first conversation
rerender({ conversationId: 'convo1' });
// Should maintain separate state
await waitFor(() => {
expect(result.current.mcpValues).toEqual(['convo1-value']);
});
});
});
describe('Ephemeral Agent Synchronization', () => {
it('should sync mcpValues when ephemeralAgent is updated externally', async () => {
// Create a shared wrapper for both hooks to share the same Recoil/Jotai context
const wrapper = createWrapper();
// Create a component that uses both hooks to ensure they share state
const TestComponent = () => {
const mcpHook = useMCPSelect({});
const [ephemeralAgent, setEphemeralAgent] = useRecoilState(
ephemeralAgentByConvoId(Constants.NEW_CONVO),
);
return { mcpHook, ephemeralAgent, setEphemeralAgent };
};
const { result } = renderHook(() => TestComponent(), { wrapper });
// Simulate external update to ephemeralAgent (e.g., from another component)
const externalMcpValues = ['external-value1', 'external-value2'];
act(() => {
result.current.setEphemeralAgent({
mcp: externalMcpValues,
});
});
// The hook should sync with the external update
await waitFor(() => {
expect(result.current.mcpHook.mcpValues).toEqual(externalMcpValues);
});
});
it('should update ephemeralAgent when mcpValues changes through hook', async () => {
// Create a shared wrapper for both hooks
const wrapper = createWrapper();
// Create a component that uses both the hook and accesses Recoil state
const TestComponent = () => {
const mcpHook = useMCPSelect({});
const ephemeralAgent = useRecoilValue(ephemeralAgentByConvoId(Constants.NEW_CONVO));
return { mcpHook, ephemeralAgent };
};
const { result } = renderHook(() => TestComponent(), { wrapper });
const newValues = ['hook-value1', 'hook-value2'];
act(() => {
result.current.mcpHook.setMCPValues(newValues);
});
// Verify both mcpValues and ephemeralAgent are updated
await waitFor(() => {
expect(result.current.mcpHook.mcpValues).toEqual(newValues);
expect(result.current.ephemeralAgent?.mcp).toEqual(newValues);
});
});
it('should handle empty ephemeralAgent.mcp array correctly', async () => {
// Create a shared wrapper
const wrapper = createWrapper();
// Create a component that uses both hooks
const TestComponent = () => {
const mcpHook = useMCPSelect({});
const setEphemeralAgent = useSetRecoilState(ephemeralAgentByConvoId(Constants.NEW_CONVO));
return { mcpHook, setEphemeralAgent };
};
const { result } = renderHook(() => TestComponent(), { wrapper });
// Set initial values
act(() => {
result.current.mcpHook.setMCPValues(['initial-value']);
});
await waitFor(() => {
expect(result.current.mcpHook.mcpValues).toEqual(['initial-value']);
});
// Try to set empty array externally
act(() => {
result.current.setEphemeralAgent({
mcp: [],
});
});
// Values should remain unchanged since empty mcp array doesn't trigger update
// (due to the condition: ephemeralAgent?.mcp && ephemeralAgent.mcp.length > 0)
expect(result.current.mcpHook.mcpValues).toEqual(['initial-value']);
});
it('should properly sync non-empty arrays from ephemeralAgent', async () => {
// Additional test to ensure non-empty arrays DO sync
const wrapper = createWrapper();
const TestComponent = () => {
const mcpHook = useMCPSelect({});
const setEphemeralAgent = useSetRecoilState(ephemeralAgentByConvoId(Constants.NEW_CONVO));
return { mcpHook, setEphemeralAgent };
};
const { result } = renderHook(() => TestComponent(), { wrapper });
// Set initial values through ephemeralAgent with non-empty array
const initialValues = ['value1', 'value2'];
act(() => {
result.current.setEphemeralAgent({
mcp: initialValues,
});
});
// Should sync since it's non-empty
await waitFor(() => {
expect(result.current.mcpHook.mcpValues).toEqual(initialValues);
});
// Update with different non-empty values
const updatedValues = ['value3', 'value4', 'value5'];
act(() => {
result.current.setEphemeralAgent({
mcp: updatedValues,
});
});
// Should sync the new values
await waitFor(() => {
expect(result.current.mcpHook.mcpValues).toEqual(updatedValues);
});
});
});
describe('Edge Cases', () => {
it('should handle undefined conversationId', () => {
const { result } = renderHook(() => useMCPSelect({ conversationId: undefined }), {
wrapper: createWrapper(),
});
expect(result.current.mcpValues).toEqual([]);
act(() => {
result.current.setMCPValues(['test']);
});
expect(() => result.current).not.toThrow();
});
it('should handle empty string conversationId', () => {
const { result } = renderHook(() => useMCPSelect({ conversationId: '' }), {
wrapper: createWrapper(),
});
expect(result.current.mcpValues).toEqual([]);
});
it('should handle very large arrays without performance issues', async () => {
const { result } = renderHook(() => useMCPSelect({}), {
wrapper: createWrapper(),
});
const largeArray = Array.from({ length: 1000 }, (_, i) => `value-${i}`);
const startTime = performance.now();
act(() => {
result.current.setMCPValues(largeArray);
});
const endTime = performance.now();
const executionTime = endTime - startTime;
// Should complete within reasonable time (< 100ms)
expect(executionTime).toBeLessThan(100);
await waitFor(() => {
expect(result.current.mcpValues).toEqual(largeArray);
});
});
it('should cleanup properly on unmount', () => {
const { unmount } = renderHook(() => useMCPSelect({}), {
wrapper: createWrapper(),
});
// Should unmount without errors
expect(() => unmount()).not.toThrow();
});
});
describe('Memory Leak Prevention', () => {
it('should not leak memory on repeated updates', async () => {
const { result } = renderHook(() => useMCPSelect({}), {
wrapper: createWrapper(),
});
// Perform many updates to test for memory leaks
for (let i = 0; i < 100; i++) {
act(() => {
result.current.setMCPValues([`value-${i}`]);
});
}
// If we get here without crashing, memory management is likely OK
expect(result.current.mcpValues).toEqual(['value-99']);
});
it('should handle component remounting', () => {
const { result, unmount } = renderHook(() => useMCPSelect({}), {
wrapper: createWrapper(),
});
act(() => {
result.current.setMCPValues(['before-unmount']);
});
unmount();
// Remount
const { result: newResult } = renderHook(() => useMCPSelect({}), {
wrapper: createWrapper(),
});
// Should handle remounting gracefully
expect(newResult.current.mcpValues).toBeDefined();
});
});
});

View File

@@ -1,5 +1,6 @@
import { useCallback, useEffect } from 'react';
import { useAtom } from 'jotai';
import isEqual from 'lodash/isEqual';
import { useRecoilState } from 'recoil';
import { Constants, LocalStorageKeys } from 'librechat-data-provider';
import { ephemeralAgentByConvoId, mcpValuesAtomFamily, mcpPinnedAtom } from '~/store';
@@ -19,15 +20,14 @@ export function useMCPSelect({ conversationId }: { conversationId?: string | nul
}
}, [ephemeralAgent?.mcp, setMCPValuesRaw]);
// Update ephemeral agent when Jotai state changes
useEffect(() => {
if (mcpValues.length > 0 && JSON.stringify(mcpValues) !== JSON.stringify(ephemeralAgent?.mcp)) {
setEphemeralAgent((prev) => ({
...prev,
mcp: mcpValues,
}));
}
}, [mcpValues, ephemeralAgent?.mcp, setEphemeralAgent]);
setEphemeralAgent((prev) => {
if (!isEqual(prev?.mcp, mcpValues)) {
return { ...(prev ?? {}), mcp: mcpValues };
}
return prev;
});
}, [mcpValues, setEphemeralAgent]);
useEffect(() => {
const mcpStorageKey = `${LocalStorageKeys.LAST_MCP_}${key}`;

View File

@@ -2,7 +2,7 @@ import throttle from 'lodash/throttle';
import { useEffect, useRef, useCallback, useMemo } from 'react';
import { Constants, isAssistantsEndpoint, isAgentsEndpoint } from 'librechat-data-provider';
import type { TMessageProps } from '~/common';
import { useChatContext, useAssistantsMapContext, useAgentsMapContext } from '~/Providers';
import { useMessagesViewContext, useAssistantsMapContext, useAgentsMapContext } from '~/Providers';
import useCopyToClipboard from './useCopyToClipboard';
import { getTextKey, logger } from '~/utils';
@@ -20,9 +20,9 @@ export default function useMessageHelpers(props: TMessageProps) {
setAbortScroll,
handleContinue,
setLatestMessage,
} = useChatContext();
const assistantMap = useAssistantsMapContext();
} = useMessagesViewContext();
const agentsMap = useAgentsMapContext();
const assistantMap = useAssistantsMapContext();
const { text, content, children, messageId = null, isCreatedByUser } = message ?? {};
const edit = messageId === currentEditId;

View File

@@ -3,7 +3,7 @@ import { useRecoilValue } from 'recoil';
import { Constants } from 'librechat-data-provider';
import { useEffect, useRef, useCallback, useMemo, useState } from 'react';
import type { TMessage } from 'librechat-data-provider';
import { useChatContext, useAddedChatContext } from '~/Providers';
import { useMessagesViewContext } from '~/Providers';
import { getTextKey, logger } from '~/utils';
import store from '~/store';
@@ -18,14 +18,9 @@ export default function useMessageProcess({ message }: { message?: TMessage | nu
latestMessage,
setAbortScroll,
setLatestMessage,
isSubmitting: isSubmittingRoot,
} = useChatContext();
const { isSubmitting: isSubmittingAdditional } = useAddedChatContext();
isSubmittingFamily,
} = useMessagesViewContext();
const latestMultiMessage = useRecoilValue(store.latestMessageFamily(index + 1));
const isSubmittingFamily = useMemo(
() => isSubmittingRoot || isSubmittingAdditional,
[isSubmittingRoot, isSubmittingAdditional],
);
useEffect(() => {
const convoId = conversation?.conversationId;

View File

@@ -2,8 +2,8 @@ import { useRecoilValue } from 'recoil';
import { Constants } from 'librechat-data-provider';
import { useState, useRef, useCallback, useEffect } from 'react';
import type { TMessage } from 'librechat-data-provider';
import { useMessagesConversation, useMessagesSubmission } from '~/Providers';
import useScrollToRef from '~/hooks/useScrollToRef';
import { useChatContext } from '~/Providers';
import store from '~/store';
const threshold = 0.85;
@@ -15,11 +15,10 @@ export default function useMessageScrolling(messagesTree?: TMessage[] | null) {
const scrollableRef = useRef<HTMLDivElement | null>(null);
const messagesEndRef = useRef<HTMLDivElement | null>(null);
const [showScrollButton, setShowScrollButton] = useState(false);
const { conversation, setAbortScroll, isSubmitting, abortScroll } = useChatContext();
const { conversationId } = conversation ?? {};
const { conversation, conversationId } = useMessagesConversation();
const { setAbortScroll, isSubmitting, abortScroll } = useMessagesSubmission();
const timeoutIdRef = useRef<NodeJS.Timeout>();
const prevIsSubmittingRef = useRef<boolean>(false);
const debouncedSetShowScrollButton = useCallback((value: boolean) => {
clearTimeout(timeoutIdRef.current);
@@ -61,10 +60,7 @@ export default function useMessageScrolling(messagesTree?: TMessage[] | null) {
}
}, [debouncedSetShowScrollButton]);
const scrollCallback = useCallback(
() => debouncedSetShowScrollButton(false),
[debouncedSetShowScrollButton],
);
const scrollCallback = () => debouncedSetShowScrollButton(false);
const { scrollToRef: scrollToBottom, handleSmoothToRef } = useScrollToRef({
targetRef: messagesEndRef,
@@ -75,18 +71,6 @@ export default function useMessageScrolling(messagesTree?: TMessage[] | null) {
},
});
const smoothScrollToBottom = useCallback(() => {
if (messagesEndRef.current) {
messagesEndRef.current.scrollIntoView({
behavior: 'smooth',
block: 'end',
inline: 'nearest',
});
scrollCallback();
setAbortScroll(false);
}
}, [scrollCallback, setAbortScroll]);
useEffect(() => {
if (!messagesTree || messagesTree.length === 0) {
return;
@@ -107,20 +91,6 @@ export default function useMessageScrolling(messagesTree?: TMessage[] | null) {
};
}, [isSubmitting, messagesTree, scrollToBottom, abortScroll]);
useEffect(() => {
if (!messagesEndRef.current || !scrollableRef.current) {
return;
}
if (prevIsSubmittingRef.current && !isSubmitting && abortScroll !== true) {
setTimeout(() => {
smoothScrollToBottom();
}, 100);
}
prevIsSubmittingRef.current = isSubmitting;
}, [isSubmitting, smoothScrollToBottom, abortScroll]);
useEffect(() => {
if (!messagesEndRef.current || !scrollableRef.current) {
return;

View File

@@ -232,14 +232,8 @@ export default function useEventHandlers({
},
]);
}
if (userMessage?.conversationId) {
queryClient.invalidateQueries({
queryKey: [QueryKeys.conversation, userMessage.conversationId, 'costs'],
});
}
},
[setMessages, announcePolite, setIsSubmitting, queryClient],
[setMessages, announcePolite, setIsSubmitting],
);
const cancelHandler = useCallback(
@@ -281,12 +275,6 @@ export default function useEventHandlers({
});
}
if (convoUpdate?.conversationId) {
queryClient.invalidateQueries({
queryKey: [QueryKeys.conversation, convoUpdate.conversationId, 'costs'],
});
}
setIsSubmitting(false);
},
[setMessages, setConversation, genTitle, isAddedRequest, queryClient, setIsSubmitting],
@@ -353,12 +341,6 @@ export default function useEventHandlers({
if (resetLatestMessage) {
resetLatestMessage();
}
if (conversationId) {
queryClient.invalidateQueries({
queryKey: [QueryKeys.conversation, conversationId, 'costs'],
});
}
},
[
queryClient,
@@ -545,12 +527,6 @@ export default function useEventHandlers({
);
}
if (conversation.conversationId) {
queryClient.invalidateQueries({
queryKey: [QueryKeys.conversation, conversation.conversationId, 'costs'],
});
}
if (isNewConvo && submissionConvo.conversationId) {
removeConvoFromAllQueries(queryClient, submissionConvo.conversationId);
}

View File

@@ -31,7 +31,7 @@ export default function useScrollToRef({
// eslint-disable-next-line react-hooks/exhaustive-deps
const scrollToRef = useCallback(
throttle(() => logAndScroll('instant', callback), 100, { leading: true }),
throttle(() => logAndScroll('instant', callback), 145, { leading: true }),
[targetRef],
);

View File

@@ -833,7 +833,6 @@
"com_ui_delete_tool": "Werkzeug löschen",
"com_ui_delete_tool_confirm": "Bist du sicher, dass du dieses Werkzeug löschen möchtest?",
"com_ui_delete_tool_error": "Fehler beim Löschen des Tools: {{error}}",
"com_ui_delete_tool_success": "Tool erfolgreich gelöscht",
"com_ui_deleted": "Gelöscht",
"com_ui_deleting_file": "Lösche Datei...",
"com_ui_descending": "Absteigend",

View File

@@ -568,8 +568,6 @@
"com_nav_settings": "Settings",
"com_nav_shared_links": "Shared links",
"com_nav_show_code": "Always show code when using code interpreter",
"com_nav_show_cost_tracking": "Show cost tracking",
"com_nav_info_show_cost_tracking": "Display conversation costs and per-message cost breakdowns",
"com_nav_show_thinking": "Open Thinking Dropdowns by Default",
"com_nav_slash_command": "/-Command",
"com_nav_slash_command_description": "Toggle command \"/\" for selecting a prompt via keyboard",
@@ -1199,7 +1197,6 @@
"com_ui_thinking": "Thinking...",
"com_ui_thoughts": "Thoughts",
"com_ui_token": "token",
"com_ui_token_abbreviation": "{{0}}t",
"com_ui_token_exchange_method": "Token Exchange Method",
"com_ui_token_url": "Token URL",
"com_ui_tokens": "tokens",

View File

@@ -337,7 +337,7 @@
"com_endpoint_prompt_prefix_assistants": "Papildu instrukcijas",
"com_endpoint_prompt_prefix_assistants_placeholder": "Iestatiet papildu norādījumus vai kontekstu virs Asistenta galvenajiem norādījumiem. Ja lauks ir tukšs, tas tiek ignorēts.",
"com_endpoint_prompt_prefix_placeholder": "Iestatiet pielāgotas instrukcijas vai kontekstu. Ja lauks ir tukšs, tas tiek ignorēts.",
"com_endpoint_reasoning_effort": "Spriešanas piepūle",
"com_endpoint_reasoning_effort": "Spriešanas līmenis",
"com_endpoint_reasoning_summary": "Spriešanas kopsavilkums",
"com_endpoint_save_as_preset": "Saglabāt kā iestatījumu",
"com_endpoint_search": "Meklēt galapunktu pēc nosaukuma",
@@ -834,7 +834,7 @@
"com_ui_delete_tool": "Dzēst rīku",
"com_ui_delete_tool_confirm": "Vai tiešām vēlaties dzēst šo rīku?",
"com_ui_delete_tool_error": "Kļūda, dzēšot rīku: {{error}}",
"com_ui_delete_tool_success": "Rīks veiksmīgi izdzēsts",
"com_ui_delete_tool_save_reminder": "Rīks noņemts. Saglabājiet aģentu, lai piemērotu izmaiņas.",
"com_ui_deleted": "Dzēsts",
"com_ui_deleting_file": "Dzēšu failu...",
"com_ui_descending": "Dilstošs",
@@ -1012,7 +1012,7 @@
"com_ui_memory_would_exceed": "Nevar saglabāt - pārsniegtu tokenu limitu par {{tokens}}. Izdzēsiet esošās atmiņas, lai atbrīvotu vietu.",
"com_ui_mention": "Pieminiet galapunktu, assistentu vai iestatījumu, lai ātri uz to pārslēgtos",
"com_ui_min_tags": "Nevar noņemt vairāk vērtību, vismaz {{0}} ir nepieciešamas.",
"com_ui_minimal": "Minimāla",
"com_ui_minimal": "Minimāls",
"com_ui_misc": "Dažādi",
"com_ui_model": "Modelis",
"com_ui_model_parameters": "Modeļa Parametrus",
@@ -1035,7 +1035,7 @@
"com_ui_no_results_found": "Nav atrastu rezultātu",
"com_ui_no_terms_content": "Nav noteikumu un nosacījumu satura, ko parādīt",
"com_ui_no_valid_items": "Nav rezultātu",
"com_ui_none": "Neviens",
"com_ui_none": "Nekāds",
"com_ui_not_used": "Nav izmantots",
"com_ui_nothing_found": "Nekas nav atrasts",
"com_ui_oauth": "OAuth",

View File

@@ -834,7 +834,6 @@
"com_ui_delete_tool": "Slett verktøy",
"com_ui_delete_tool_confirm": "Er du sikker på at du vil slette dette verktøyet?",
"com_ui_delete_tool_error": "En feil oppstod ved sletting av verktøyet: {{error}}",
"com_ui_delete_tool_success": "Verktøyet ble slettet",
"com_ui_deleted": "Slettet",
"com_ui_deleting_file": "Sletter fil ...",
"com_ui_descending": "Synkende",

View File

@@ -27,6 +27,13 @@
"com_agents_file_search_disabled": "Maak eerst een Agent aan voordat je bestanden uploadt voor File Search.",
"com_agents_file_search_info": "Als deze functie is ingeschakeld, krijgt de agent informatie over de exacte bestandsnamen die hieronder staan vermeld, zodat deze relevante context uit deze bestanden kan ophalen.",
"com_agents_instructions_placeholder": "De systeeminstructies die de agent gebruikt",
"com_agents_link_copied": "Link gekopieerd",
"com_agents_link_copy_failed": "Link niet gekopieerd",
"com_agents_load_more_label": "Laad meer agenten van {{category}} categorie",
"com_agents_loading": "Aan het laden...",
"com_agents_mcp_icon_size": "Minimum formaat 128 x 128 px",
"com_agents_mcp_info": "MCP-servers toevoegen aan je agent zodat deze taken kan uitvoeren en kan communiceren met externe services",
"com_agents_mcp_name_placeholder": "Aangepast hulpmiddel",
"com_agents_missing_provider_model": "Selecteer een provider en model voordat je een agent aanmaakt.",
"com_agents_name_placeholder": "De naam van de agent",
"com_agents_no_access": "Je hebt geen toegang om deze agent te bewerken.",

View File

@@ -834,7 +834,6 @@
"com_ui_delete_tool": "Видалити інструмент",
"com_ui_delete_tool_confirm": "Ви дійсно хочете видалити цей інструмент?",
"com_ui_delete_tool_error": "Помилка під час видалення інструменту: {{error}}",
"com_ui_delete_tool_success": "Інструмент успішно видалено",
"com_ui_deleted": "Видалено",
"com_ui_deleting_file": "Видалення файлу...",
"com_ui_descending": "За спаданням",

View File

@@ -834,7 +834,6 @@
"com_ui_delete_tool": "删除工具",
"com_ui_delete_tool_confirm": "您确定要删除此工具吗?",
"com_ui_delete_tool_error": "删除工具时发生错误:{{error}}",
"com_ui_delete_tool_success": "工具删除成功",
"com_ui_deleted": "已删除",
"com_ui_deleting_file": "删除文件中...",
"com_ui_descending": "降序",

View File

@@ -6,7 +6,6 @@ import { useGetModelsQuery } from 'librechat-data-provider/react-query';
import type { TPreset } from 'librechat-data-provider';
import { useGetConvoIdQuery, useGetStartupConfig, useGetEndpointsQuery } from '~/data-provider';
import { useNewConvo, useAppStartup, useAssistantListMap, useIdChangeEffect } from '~/hooks';
import { useGetModelCostsQuery } from 'librechat-data-provider/react-query';
import { getDefaultModelSpec, getModelSpecPreset, logger } from '~/utils';
import { ToolCallsMapProvider } from '~/Providers';
import ChatView from '~/components/Chat/ChatView';
@@ -45,10 +44,6 @@ export default function ChatRoute() {
const endpointsQuery = useGetEndpointsQuery({ enabled: isAuthenticated });
const assistantListMap = useAssistantListMap();
const modelCostsQuery = useGetModelCostsQuery(initialConvoQuery.data?.modelHistory || [], {
enabled: !!initialConvoQuery.data?.modelHistory?.length,
});
const isTemporaryChat = conversation && conversation.expiredAt ? true : false;
useEffect(() => {
@@ -153,7 +148,7 @@ export default function ChatRoute() {
return (
<ToolCallsMapProvider conversationId={conversation.conversationId ?? ''}>
<ChatView index={index} modelCosts={modelCostsQuery.data} />
<ChatView index={index} />
</ToolCallsMapProvider>
);
}

View File

@@ -34,7 +34,6 @@ const localStorageAtoms = {
showCode: atomWithLocalStorage(LocalStorageKeys.SHOW_ANALYSIS_CODE, true),
saveDrafts: atomWithLocalStorage('saveDrafts', true),
showScrollButton: atomWithLocalStorage('showScrollButton', true),
showCostTracking: atomWithLocalStorage('showCostTracking', true),
forkSetting: atomWithLocalStorage('forkSetting', ''),
splitAtTarget: atomWithLocalStorage('splitAtTarget', false),
rememberDefaultFork: atomWithLocalStorage(LocalStorageKeys.REMEMBER_FORK_OPTION, false),

View File

@@ -49,7 +49,7 @@ export default defineConfig(({ command }) => ({
],
globIgnores: ['images/**/*', '**/*.map', 'index.html'],
maximumFileSizeToCacheInBytes: 4 * 1024 * 1024,
navigateFallbackDenylist: [/^\/oauth/, /^\/api/],
navigateFallbackDenylist: [/^\/oauth/, /^\/api/, /^\/admin\/openid/],
},
includeAssets: [],
manifest: {

View File

@@ -10,6 +10,9 @@ jest.mock('form-data', () => {
getLength: jest.fn().mockReturnValue(100),
}));
});
jest.mock('https-proxy-agent', () => ({
HttpsProxyAgent: jest.fn().mockImplementation((url) => ({ proxyUrl: url })),
}));
jest.mock('axios', () => {
const mockAxiosInstance = {
get: jest.fn().mockResolvedValue({ data: {} }),
@@ -44,6 +47,7 @@ jest.mock('~/utils/axios', () => ({
import * as fs from 'fs';
import axios from 'axios';
import { HttpsProxyAgent } from 'https-proxy-agent';
import type { Readable } from 'stream';
import type {
MistralFileUploadResponse,
@@ -1182,6 +1186,8 @@ describe('MistralOCR Service', () => {
describe('Mixed env var and hardcoded configuration', () => {
beforeEach(() => {
// Clean up any PROXY env var from previous tests
delete process.env.PROXY;
const mockReadStream: MockReadStream = {
on: jest.fn().mockImplementation(function (
this: MockReadStream,
@@ -1708,9 +1714,403 @@ describe('MistralOCR Service', () => {
});
});
describe('Proxy Configuration', () => {
const originalProxy = process.env.PROXY;
beforeEach(() => {
// Reset the HttpsProxyAgent mock to its default implementation
(HttpsProxyAgent as unknown as jest.Mock).mockImplementation((url) => ({ proxyUrl: url }));
// Clear any previous axios mock calls
mockAxios.post!.mockClear();
mockAxios.get!.mockClear();
mockAxios.delete!.mockClear();
});
afterEach(() => {
if (originalProxy) {
process.env.PROXY = originalProxy;
} else {
delete process.env.PROXY;
}
// Clear mocks after each test to prevent leaking
mockAxios.post!.mockClear();
mockAxios.get!.mockClear();
mockAxios.delete!.mockClear();
});
describe('uploadDocumentToMistral with proxy', () => {
beforeEach(() => {
const mockReadStream: MockReadStream = {
on: jest.fn().mockImplementation(function (
this: MockReadStream,
event: string,
handler: () => void,
) {
if (event === 'end') {
handler();
}
return this;
}),
pipe: jest.fn().mockImplementation(function (this: MockReadStream) {
return this;
}),
pause: jest.fn(),
resume: jest.fn(),
emit: jest.fn(),
once: jest.fn(),
destroy: jest.fn(),
path: '/path/to/test.pdf',
fd: 1,
flags: 'r',
mode: 0o666,
autoClose: true,
bytesRead: 0,
closed: false,
pending: false,
};
(jest.mocked(fs).createReadStream as jest.Mock).mockReturnValue(mockReadStream);
});
it('should use proxy configuration when PROXY env var is set', async () => {
process.env.PROXY = 'http://proxy.example.com:8080';
const mockResponse: { data: MistralFileUploadResponse } = {
data: {
id: 'file-proxy-123',
object: 'file',
bytes: 1024,
created_at: Date.now(),
filename: 'test.pdf',
purpose: 'ocr',
},
};
mockAxios.post!.mockResolvedValueOnce(mockResponse);
await uploadDocumentToMistral({
filePath: '/path/to/test.pdf',
fileName: 'test.pdf',
apiKey: 'test-api-key',
});
expect(mockAxios.post).toHaveBeenCalledWith(
'https://api.mistral.ai/v1/files',
expect.anything(),
expect.objectContaining({
httpsAgent: expect.objectContaining({
proxyUrl: 'http://proxy.example.com:8080',
}),
}),
);
});
it('should handle proxy URL with authentication', async () => {
process.env.PROXY = 'http://user:pass@proxy.example.com:8080';
const mockResponse: { data: MistralFileUploadResponse } = {
data: {
id: 'file-proxy-auth-123',
object: 'file',
bytes: 1024,
created_at: Date.now(),
filename: 'test.pdf',
purpose: 'ocr',
},
};
mockAxios.post!.mockResolvedValueOnce(mockResponse);
await uploadDocumentToMistral({
filePath: '/path/to/test.pdf',
fileName: 'test.pdf',
apiKey: 'test-api-key',
});
expect(mockAxios.post).toHaveBeenCalledWith(
'https://api.mistral.ai/v1/files',
expect.anything(),
expect.objectContaining({
httpsAgent: expect.objectContaining({
proxyUrl: 'http://user:pass@proxy.example.com:8080',
}),
}),
);
});
it('should handle IPv6 proxy addresses', async () => {
process.env.PROXY = 'http://[::1]:8080';
const mockResponse: { data: MistralFileUploadResponse } = {
data: {
id: 'file-proxy-ipv6-123',
object: 'file',
bytes: 1024,
created_at: Date.now(),
filename: 'test.pdf',
purpose: 'ocr',
},
};
mockAxios.post!.mockResolvedValueOnce(mockResponse);
await uploadDocumentToMistral({
filePath: '/path/to/test.pdf',
fileName: 'test.pdf',
apiKey: 'test-api-key',
});
expect(mockAxios.post).toHaveBeenCalledWith(
'https://api.mistral.ai/v1/files',
expect.anything(),
expect.objectContaining({
httpsAgent: expect.objectContaining({
proxyUrl: 'http://[::1]:8080',
}),
}),
);
});
it('should not use proxy when PROXY env var is not set', async () => {
delete process.env.PROXY;
const mockResponse: { data: MistralFileUploadResponse } = {
data: {
id: 'file-no-proxy-123',
object: 'file',
bytes: 1024,
created_at: Date.now(),
filename: 'test.pdf',
purpose: 'ocr',
},
};
mockAxios.post!.mockResolvedValueOnce(mockResponse);
await uploadDocumentToMistral({
filePath: '/path/to/test.pdf',
fileName: 'test.pdf',
apiKey: 'test-api-key',
});
expect(mockAxios.post).toHaveBeenCalledWith(
'https://api.mistral.ai/v1/files',
expect.anything(),
expect.not.objectContaining({
httpsAgent: expect.anything(),
}),
);
});
});
describe('performOCR with proxy', () => {
it('should use proxy configuration when PROXY env var is set', async () => {
process.env.PROXY = 'http://proxy.example.com:3128';
const mockResponse: { data: OCRResult } = {
data: {
model: 'mistral-ocr-latest',
pages: [
{
index: 0,
markdown: 'Proxy test content',
images: [],
dimensions: { dpi: 300, height: 1100, width: 850 },
},
],
document_annotation: '',
usage_info: {
pages_processed: 1,
doc_size_bytes: 1024,
},
},
};
mockAxios.post!.mockResolvedValueOnce(mockResponse);
await performOCR({
apiKey: 'test-api-key',
url: 'https://document-url.com',
model: 'mistral-ocr-latest',
documentType: 'document_url',
});
expect(mockAxios.post).toHaveBeenCalledWith(
'https://api.mistral.ai/v1/ocr',
expect.anything(),
expect.objectContaining({
httpsAgent: expect.objectContaining({
proxyUrl: 'http://proxy.example.com:3128',
}),
}),
);
});
it('should handle malformed proxy URLs gracefully', async () => {
(HttpsProxyAgent as unknown as jest.Mock).mockImplementationOnce(() => {
throw new Error('Invalid URL');
});
process.env.PROXY = 'not-a-valid-url';
const mockResponse: { data: OCRResult } = {
data: {
model: 'mistral-ocr-latest',
pages: [
{
index: 0,
markdown: 'Test content',
images: [],
dimensions: { dpi: 300, height: 1100, width: 850 },
},
],
document_annotation: '',
usage_info: {
pages_processed: 1,
doc_size_bytes: 1024,
},
},
};
mockAxios.post!.mockResolvedValueOnce(mockResponse);
await expect(
performOCR({
apiKey: 'test-api-key',
url: 'https://document-url.com',
}),
).rejects.toThrow('Invalid URL');
});
});
describe('Azure Mistral OCR with proxy', () => {
beforeEach(() => {
(jest.mocked(fs).readFileSync as jest.Mock).mockReturnValue(
Buffer.from('mock-file-content'),
);
});
it('should use proxy for Azure Mistral OCR requests', async () => {
process.env.PROXY = 'http://proxy.example.com:8080';
mockLoadAuthValues.mockResolvedValue({
OCR_API_KEY: 'azure-api-key',
OCR_BASEURL: 'https://azure.mistral.ai/v1',
});
mockAxios.post!.mockResolvedValueOnce({
data: {
model: 'mistral-ocr-latest',
pages: [
{
index: 0,
markdown: 'Azure OCR with proxy',
images: [],
dimensions: { dpi: 300, height: 1100, width: 850 },
},
],
document_annotation: '',
usage_info: {
pages_processed: 1,
doc_size_bytes: 1024,
},
},
});
const req = {
user: { id: 'user123' },
config: {
ocr: {
apiKey: '${OCR_API_KEY}',
baseURL: '${OCR_BASEURL}',
mistralModel: 'mistral-ocr-latest',
},
},
} as unknown as ServerRequest;
const file = {
path: '/tmp/upload/azure-file.pdf',
originalname: 'azure-document.pdf',
mimetype: 'application/pdf',
} as Express.Multer.File;
await uploadAzureMistralOCR({
req,
file,
loadAuthValues: mockLoadAuthValues,
});
expect(mockAxios.post).toHaveBeenCalledWith(
'https://azure.mistral.ai/v1/ocr',
expect.anything(),
expect.objectContaining({
httpsAgent: expect.objectContaining({
proxyUrl: 'http://proxy.example.com:8080',
}),
}),
);
});
});
describe('getSignedUrl with proxy', () => {
it('should use proxy configuration when PROXY env var is set', async () => {
process.env.PROXY = 'https://secure-proxy.example.com:443';
const mockResponse: { data: MistralSignedUrlResponse } = {
data: {
url: 'https://signed-url.com',
expires_at: Date.now() + 86400000,
},
};
mockAxios.get!.mockResolvedValueOnce(mockResponse);
await getSignedUrl({
fileId: 'file-123',
apiKey: 'test-api-key',
});
expect(mockAxios.get).toHaveBeenCalledWith(
'https://api.mistral.ai/v1/files/file-123/url?expiry=24',
expect.objectContaining({
httpsAgent: expect.objectContaining({
proxyUrl: 'https://secure-proxy.example.com:443',
}),
}),
);
});
});
describe('deleteMistralFile with proxy', () => {
it('should use proxy configuration when PROXY env var is set', async () => {
process.env.PROXY = 'socks5://proxy.example.com:1080';
mockAxios.delete!.mockResolvedValueOnce({ data: {} });
await deleteMistralFile({
fileId: 'file-123',
apiKey: 'test-api-key',
});
expect(mockAxios.delete).toHaveBeenCalledWith(
'https://api.mistral.ai/v1/files/file-123',
expect.objectContaining({
httpsAgent: expect.objectContaining({
proxyUrl: 'socks5://proxy.example.com:1080',
}),
}),
);
});
});
});
describe('uploadAzureMistralOCR', () => {
beforeEach(() => {
(jest.mocked(fs).readFileSync as jest.Mock).mockReturnValue(Buffer.from('mock-file-content'));
// Reset the HttpsProxyAgent mock to its default implementation for Azure tests
(HttpsProxyAgent as unknown as jest.Mock).mockImplementation((url) => ({ proxyUrl: url }));
// Clean up any PROXY env var from previous tests
delete process.env.PROXY;
// Reset axios mocks completely to clear any queued responses
mockAxios.post!.mockReset();
mockAxios.get!.mockReset();
mockAxios.delete!.mockReset();
// Re-establish default resolved values
mockAxios.post!.mockResolvedValue({ data: {} });
mockAxios.get!.mockResolvedValue({ data: {} });
mockAxios.delete!.mockResolvedValue({ data: {} });
});
it('should process OCR using Azure Mistral with base64 encoding', async () => {
@@ -1796,6 +2196,11 @@ describe('MistralOCR Service', () => {
});
describe('Mixed env var and hardcoded configuration', () => {
beforeEach(() => {
// Clean up any PROXY env var from previous tests
delete process.env.PROXY;
});
it('should preserve hardcoded baseURL when only apiKey is an env var', async () => {
// This test demonstrates the current bug
mockLoadAuthValues.mockResolvedValue({

View File

@@ -2,6 +2,7 @@ import * as fs from 'fs';
import * as path from 'path';
import FormData from 'form-data';
import { logger } from '@librechat/data-schemas';
import { HttpsProxyAgent } from 'https-proxy-agent';
import {
FileSources,
envVarRegex,
@@ -9,7 +10,7 @@ import {
extractVariableName,
} from 'librechat-data-provider';
import type { TCustomConfig } from 'librechat-data-provider';
import type { AxiosError } from 'axios';
import type { AxiosError, AxiosRequestConfig } from 'axios';
import type {
MistralFileUploadResponse,
MistralSignedUrlResponse,
@@ -77,15 +78,21 @@ export async function uploadDocumentToMistral({
const fileStream = fs.createReadStream(filePath);
form.append('file', fileStream, { filename: actualFileName });
const config: AxiosRequestConfig = {
headers: {
Authorization: `Bearer ${apiKey}`,
...form.getHeaders(),
},
maxBodyLength: Infinity,
maxContentLength: Infinity,
};
if (process.env.PROXY) {
config.httpsAgent = new HttpsProxyAgent(process.env.PROXY);
}
return axios
.post(`${baseURL}/files`, form, {
headers: {
Authorization: `Bearer ${apiKey}`,
...form.getHeaders(),
},
maxBodyLength: Infinity,
maxContentLength: Infinity,
})
.post(`${baseURL}/files`, form, config)
.then((res) => res.data)
.catch((error) => {
throw error;
@@ -103,12 +110,18 @@ export async function getSignedUrl({
expiry?: number;
baseURL?: string;
}): Promise<MistralSignedUrlResponse> {
const config: AxiosRequestConfig = {
headers: {
Authorization: `Bearer ${apiKey}`,
},
};
if (process.env.PROXY) {
config.httpsAgent = new HttpsProxyAgent(process.env.PROXY);
}
return axios
.get(`${baseURL}/files/${fileId}/url?expiry=${expiry}`, {
headers: {
Authorization: `Bearer ${apiKey}`,
},
})
.get(`${baseURL}/files/${fileId}/url?expiry=${expiry}`, config)
.then((res) => res.data)
.catch((error) => {
logger.error('Error fetching signed URL:', error.message);
@@ -139,6 +152,18 @@ export async function performOCR({
documentType?: 'document_url' | 'image_url';
}): Promise<OCRResult> {
const documentKey = documentType === 'image_url' ? 'image_url' : 'document_url';
const config: AxiosRequestConfig = {
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${apiKey}`,
},
};
if (process.env.PROXY) {
config.httpsAgent = new HttpsProxyAgent(process.env.PROXY);
}
return axios
.post(
`${baseURL}/ocr`,
@@ -151,12 +176,7 @@ export async function performOCR({
[documentKey]: url,
},
},
{
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${apiKey}`,
},
},
config,
)
.then((res) => res.data)
.catch((error) => {
@@ -182,12 +202,18 @@ export async function deleteMistralFile({
apiKey: string;
baseURL?: string;
}): Promise<void> {
const config: AxiosRequestConfig = {
headers: {
Authorization: `Bearer ${apiKey}`,
},
};
if (process.env.PROXY) {
config.httpsAgent = new HttpsProxyAgent(process.env.PROXY);
}
try {
const result = await axios.delete(`${baseURL}/files/${fileId}`, {
headers: {
Authorization: `Bearer ${apiKey}`,
},
});
const result = await axios.delete(`${baseURL}/files/${fileId}`, config);
logger.debug(`Mistral file ${fileId} deleted successfully:`, result.data);
} catch (error) {
logger.error(`Error deleting Mistral file ${fileId}:`, error);
@@ -543,17 +569,23 @@ async function createJWT(serviceKey: GoogleServiceAccount): Promise<string> {
* Exchanges JWT for access token
*/
async function exchangeJWTForAccessToken(jwt: string): Promise<string> {
const config: AxiosRequestConfig = {
headers: {
'Content-Type': 'application/x-www-form-urlencoded',
},
};
if (process.env.PROXY) {
config.httpsAgent = new HttpsProxyAgent(process.env.PROXY);
}
const response = await axios.post(
'https://oauth2.googleapis.com/token',
new URLSearchParams({
grant_type: 'urn:ietf:params:oauth:grant-type:jwt-bearer',
assertion: jwt,
}),
{
headers: {
'Content-Type': 'application/x-www-form-urlencoded',
},
},
config,
);
if (!response.data?.access_token) {
@@ -608,14 +640,20 @@ async function performGoogleVertexOCR({
},
});
const config: AxiosRequestConfig = {
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${accessToken}`,
Accept: 'application/json',
},
};
if (process.env.PROXY) {
config.httpsAgent = new HttpsProxyAgent(process.env.PROXY);
}
return axios
.post(baseURL, requestBody, {
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${accessToken}`,
Accept: 'application/json',
},
})
.post(baseURL, requestBody, config)
.then((res) => {
logger.debug('Google Vertex AI response received');
return res.data;

View File

@@ -14,6 +14,7 @@ export * from './utils';
export * from './db/utils';
/* OAuth */
export * from './oauth';
export * from './mcp/oauth/OAuthReconnectionManager';
/* Crypto */
export * from './crypto';
/* Flow */

View File

@@ -6,6 +6,7 @@ import type { FlowStateManager } from '~/flow/manager';
import type { FlowMetadata } from '~/flow/types';
import type * as t from './types';
import { MCPTokenStorage, MCPOAuthHandler } from '~/mcp/oauth';
import { sanitizeUrlForLogging } from './utils';
import { MCPConnection } from './connection';
import { processMCPEnv } from '~/utils';
@@ -308,7 +309,9 @@ export class MCPConnectionFactory {
metadata?: OAuthMetadata;
} | null> {
const serverUrl = (this.serverConfig as t.SSEOptions | t.StreamableHTTPOptions).url;
logger.debug(`${this.logPrefix} \`handleOAuthRequired\` called with serverUrl: ${serverUrl}`);
logger.debug(
`${this.logPrefix} \`handleOAuthRequired\` called with serverUrl: ${serverUrl ? sanitizeUrlForLogging(serverUrl) : 'undefined'}`,
);
if (!this.flowManager || !serverUrl) {
logger.error(

View File

@@ -7,6 +7,7 @@ import type { JsonSchemaType } from '~/types';
import type * as t from '~/mcp/types';
import { ConnectionsRepository } from '~/mcp/ConnectionsRepository';
import { detectOAuthRequirement } from '~/mcp/oauth';
import { sanitizeUrlForLogging } from '~/mcp/utils';
import { processMCPEnv } from '~/utils';
import { CONSTANTS } from '~/mcp/enum';
@@ -183,7 +184,7 @@ export class MCPServersRegistry {
const prefix = this.prefix(serverName);
const config = this.parsedConfigs[serverName];
logger.info(`${prefix} -------------------------------------------------┐`);
logger.info(`${prefix} URL: ${config.url}`);
logger.info(`${prefix} URL: ${config.url ? sanitizeUrlForLogging(config.url) : 'N/A'}`);
logger.info(`${prefix} OAuth Required: ${config.requiresOAuth}`);
logger.info(`${prefix} Capabilities: ${config.capabilities}`);
logger.info(`${prefix} Tools: ${config.tools}`);

View File

@@ -1,15 +1,15 @@
import { EventEmitter } from 'events';
import { logger } from '@librechat/data-schemas';
import { fetch as undiciFetch, Agent } from 'undici';
import {
StdioClientTransport,
getDefaultEnvironment,
} from '@modelcontextprotocol/sdk/client/stdio.js';
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
import { ResourceListChangedNotificationSchema } from '@modelcontextprotocol/sdk/types.js';
import { WebSocketClientTransport } from '@modelcontextprotocol/sdk/client/websocket.js';
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
import { logger } from '@librechat/data-schemas';
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
import { WebSocketClientTransport } from '@modelcontextprotocol/sdk/client/websocket.js';
import { ResourceListChangedNotificationSchema } from '@modelcontextprotocol/sdk/types.js';
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
import type { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js';
import type {
@@ -18,8 +18,9 @@ import type {
Response as UndiciResponse,
} from 'undici';
import type { MCPOAuthTokens } from './oauth/types';
import { mcpConfig } from './mcpConfig';
import type * as t from './types';
import { sanitizeUrlForLogging } from './utils';
import { mcpConfig } from './mcpConfig';
type FetchLike = (url: string | URL, init?: RequestInit) => Promise<Response>;
@@ -238,7 +239,9 @@ export class MCPConnection extends EventEmitter {
}
this.url = options.url;
const url = new URL(options.url);
logger.info(`${this.getLogPrefix()} Creating SSE transport: ${url.toString()}`);
logger.info(
`${this.getLogPrefix()} Creating SSE transport: ${sanitizeUrlForLogging(url)}`,
);
const abortController = new AbortController();
/** Add OAuth token to headers if available */
@@ -293,7 +296,7 @@ export class MCPConnection extends EventEmitter {
this.url = options.url;
const url = new URL(options.url);
logger.info(
`${this.getLogPrefix()} Creating streamable-http transport: ${url.toString()}`,
`${this.getLogPrefix()} Creating streamable-http transport: ${sanitizeUrlForLogging(url)}`,
);
const abortController = new AbortController();
@@ -473,7 +476,9 @@ export class MCPConnection extends EventEmitter {
logger.warn(`${this.getLogPrefix()} OAuth authentication required`);
this.oauthRequired = true;
const serverUrl = this.url;
logger.debug(`${this.getLogPrefix()} Server URL for OAuth: ${serverUrl}`);
logger.debug(
`${this.getLogPrefix()} Server URL for OAuth: ${serverUrl ? sanitizeUrlForLogging(serverUrl) : 'undefined'}`,
);
const oauthTimeout = this.options.initTimeout ?? 60000 * 2;
/** Promise that will resolve when OAuth is handled */

View File

@@ -0,0 +1,294 @@
import { TokenMethods } from '@librechat/data-schemas';
import { FlowStateManager, MCPConnection, MCPOAuthTokens, MCPOptions } from '../..';
import { MCPManager } from '../MCPManager';
import { OAuthReconnectionManager } from './OAuthReconnectionManager';
import { OAuthReconnectionTracker } from './OAuthReconnectionTracker';
jest.mock('@librechat/data-schemas', () => ({
logger: {
info: jest.fn(),
warn: jest.fn(),
error: jest.fn(),
debug: jest.fn(),
},
}));
jest.mock('../MCPManager');
describe('OAuthReconnectionManager', () => {
let flowManager: jest.Mocked<FlowStateManager<null>>;
let tokenMethods: jest.Mocked<TokenMethods>;
let mockMCPManager: jest.Mocked<MCPManager>;
let reconnectionManager: OAuthReconnectionManager;
beforeEach(() => {
jest.clearAllMocks();
// Reset singleton instance
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(OAuthReconnectionManager as any).instance = null;
// Setup mock flow manager
flowManager = {
createFlow: jest.fn(),
completeFlow: jest.fn(),
failFlow: jest.fn(),
deleteFlow: jest.fn(),
getFlow: jest.fn(),
} as unknown as jest.Mocked<FlowStateManager<null>>;
// Setup mock token methods
tokenMethods = {
findToken: jest.fn(),
createToken: jest.fn(),
updateToken: jest.fn(),
deleteToken: jest.fn(),
} as unknown as jest.Mocked<TokenMethods>;
// Setup mock MCP Manager
mockMCPManager = {
getOAuthServers: jest.fn(),
getUserConnection: jest.fn(),
getUserConnections: jest.fn(),
disconnectUserConnection: jest.fn(),
getRawConfig: jest.fn(),
} as unknown as jest.Mocked<MCPManager>;
(MCPManager.getInstance as jest.Mock).mockReturnValue(mockMCPManager);
});
afterEach(() => {
jest.clearAllMocks();
});
describe('Singleton Pattern', () => {
it('should create instance successfully', async () => {
const instance = await OAuthReconnectionManager.createInstance(flowManager, tokenMethods);
expect(instance).toBeInstanceOf(OAuthReconnectionManager);
});
it('should throw error when creating instance twice', async () => {
await OAuthReconnectionManager.createInstance(flowManager, tokenMethods);
await expect(
OAuthReconnectionManager.createInstance(flowManager, tokenMethods),
).rejects.toThrow('OAuthReconnectionManager already initialized');
});
it('should throw error when getting instance before creation', () => {
expect(() => OAuthReconnectionManager.getInstance()).toThrow(
'OAuthReconnectionManager not initialized',
);
});
});
describe('isReconnecting', () => {
let reconnectionTracker: OAuthReconnectionTracker;
beforeEach(async () => {
reconnectionTracker = new OAuthReconnectionTracker();
reconnectionManager = await OAuthReconnectionManager.createInstance(
flowManager,
tokenMethods,
reconnectionTracker,
);
});
it('should return true when server is actively reconnecting', () => {
const userId = 'user-123';
const serverName = 'test-server';
expect(reconnectionManager.isReconnecting(userId, serverName)).toBe(false);
reconnectionTracker.setActive(userId, serverName);
const result = reconnectionManager.isReconnecting(userId, serverName);
expect(result).toBe(true);
});
it('should return false when server is not reconnecting', () => {
const userId = 'user-123';
const serverName = 'test-server';
const result = reconnectionManager.isReconnecting(userId, serverName);
expect(result).toBe(false);
});
});
describe('clearReconnection', () => {
let reconnectionTracker: OAuthReconnectionTracker;
beforeEach(async () => {
reconnectionTracker = new OAuthReconnectionTracker();
reconnectionManager = await OAuthReconnectionManager.createInstance(
flowManager,
tokenMethods,
reconnectionTracker,
);
});
it('should clear both failed and active reconnection states', () => {
const userId = 'user-123';
const serverName = 'test-server';
reconnectionTracker.setFailed(userId, serverName);
reconnectionTracker.setActive(userId, serverName);
reconnectionManager.clearReconnection(userId, serverName);
expect(reconnectionManager.isReconnecting(userId, serverName)).toBe(false);
expect(reconnectionTracker.isFailed(userId, serverName)).toBe(false);
expect(reconnectionTracker.isActive(userId, serverName)).toBe(false);
});
});
describe('reconnectServers', () => {
let reconnectionTracker: OAuthReconnectionTracker;
beforeEach(async () => {
reconnectionTracker = new OAuthReconnectionTracker();
reconnectionManager = await OAuthReconnectionManager.createInstance(
flowManager,
tokenMethods,
reconnectionTracker,
);
});
it('should reconnect eligible servers', async () => {
const userId = 'user-123';
const oauthServers = new Set(['server1', 'server2', 'server3']);
mockMCPManager.getOAuthServers.mockReturnValue(oauthServers);
// server1: has failed reconnection
reconnectionTracker.setFailed(userId, 'server1');
// server2: already connected
const mockConnection = {
isConnected: jest.fn().mockResolvedValue(true),
};
const userConnections = new Map([['server2', mockConnection]]);
mockMCPManager.getUserConnections.mockReturnValue(
userConnections as unknown as Map<string, MCPConnection>,
);
// server3: has valid token and not connected
tokenMethods.findToken.mockImplementation(async ({ identifier }) => {
if (identifier === 'mcp:server3') {
return {
userId,
identifier,
expiresAt: new Date(Date.now() + 3600000), // 1 hour from now
} as unknown as MCPOAuthTokens;
}
return null;
});
// Mock successful reconnection
const mockNewConnection = {
isConnected: jest.fn().mockResolvedValue(true),
disconnect: jest.fn(),
};
mockMCPManager.getUserConnection.mockResolvedValue(
mockNewConnection as unknown as MCPConnection,
);
mockMCPManager.getRawConfig.mockReturnValue({ initTimeout: 5000 } as unknown as MCPOptions);
await reconnectionManager.reconnectServers(userId);
// Verify server3 was marked as active
expect(reconnectionTracker.isActive(userId, 'server3')).toBe(true);
// Wait for async tryReconnect to complete
await new Promise((resolve) => setTimeout(resolve, 100));
// Verify reconnection was attempted for server3
expect(mockMCPManager.getUserConnection).toHaveBeenCalledWith({
serverName: 'server3',
user: { id: userId },
flowManager,
tokenMethods,
forceNew: false,
connectionTimeout: 5000,
returnOnOAuth: true,
});
// Verify successful reconnection cleared the states
expect(reconnectionTracker.isFailed(userId, 'server3')).toBe(false);
expect(reconnectionTracker.isActive(userId, 'server3')).toBe(false);
});
it('should handle failed reconnection attempts', async () => {
const userId = 'user-123';
const oauthServers = new Set(['server1']);
mockMCPManager.getOAuthServers.mockReturnValue(oauthServers);
// server1: has valid token
tokenMethods.findToken.mockResolvedValue({
userId,
identifier: 'mcp:server1',
expiresAt: new Date(Date.now() + 3600000),
} as unknown as MCPOAuthTokens);
// Mock failed connection
mockMCPManager.getUserConnection.mockRejectedValue(new Error('Connection failed'));
mockMCPManager.getRawConfig.mockReturnValue({} as unknown as MCPOptions);
await reconnectionManager.reconnectServers(userId);
// Wait for async tryReconnect to complete
await new Promise((resolve) => setTimeout(resolve, 100));
// Verify failure handling
expect(reconnectionTracker.isFailed(userId, 'server1')).toBe(true);
expect(reconnectionTracker.isActive(userId, 'server1')).toBe(false);
expect(mockMCPManager.disconnectUserConnection).toHaveBeenCalledWith(userId, 'server1');
});
it('should not reconnect servers with expired tokens', async () => {
const userId = 'user-123';
const oauthServers = new Set(['server1']);
mockMCPManager.getOAuthServers.mockReturnValue(oauthServers);
// server1: has expired token
tokenMethods.findToken.mockResolvedValue({
userId,
identifier: 'mcp:server1',
expiresAt: new Date(Date.now() - 3600000), // 1 hour ago
} as unknown as MCPOAuthTokens);
await reconnectionManager.reconnectServers(userId);
// Verify no reconnection attempt was made
expect(reconnectionTracker.isActive(userId, 'server1')).toBe(false);
expect(mockMCPManager.getUserConnection).not.toHaveBeenCalled();
});
it('should handle connection that returns but is not connected', async () => {
const userId = 'user-123';
const oauthServers = new Set(['server1']);
mockMCPManager.getOAuthServers.mockReturnValue(oauthServers);
tokenMethods.findToken.mockResolvedValue({
userId,
identifier: 'mcp:server1',
expiresAt: new Date(Date.now() + 3600000),
} as unknown as MCPOAuthTokens);
// Mock connection that returns but is not connected
const mockConnection = {
isConnected: jest.fn().mockResolvedValue(false),
disconnect: jest.fn(),
};
mockMCPManager.getUserConnection.mockResolvedValue(
mockConnection as unknown as MCPConnection,
);
mockMCPManager.getRawConfig.mockReturnValue({} as unknown as MCPOptions);
await reconnectionManager.reconnectServers(userId);
// Wait for async tryReconnect to complete
await new Promise((resolve) => setTimeout(resolve, 100));
// Verify failure handling
expect(mockConnection.disconnect).toHaveBeenCalled();
expect(reconnectionTracker.isFailed(userId, 'server1')).toBe(true);
expect(reconnectionTracker.isActive(userId, 'server1')).toBe(false);
expect(mockMCPManager.disconnectUserConnection).toHaveBeenCalledWith(userId, 'server1');
});
});
});

View File

@@ -0,0 +1,163 @@
import { logger } from '@librechat/data-schemas';
import type { TokenMethods } from '@librechat/data-schemas';
import type { TUser } from 'librechat-data-provider';
import type { MCPOAuthTokens } from './types';
import { OAuthReconnectionTracker } from './OAuthReconnectionTracker';
import { FlowStateManager } from '~/flow/manager';
import { MCPManager } from '~/mcp/MCPManager';
const DEFAULT_CONNECTION_TIMEOUT_MS = 10_000; // ms
export class OAuthReconnectionManager {
private static instance: OAuthReconnectionManager | null = null;
protected readonly flowManager: FlowStateManager<MCPOAuthTokens | null>;
protected readonly tokenMethods: TokenMethods;
private readonly reconnectionsTracker: OAuthReconnectionTracker;
public static getInstance(): OAuthReconnectionManager {
if (!OAuthReconnectionManager.instance) {
throw new Error('OAuthReconnectionManager not initialized');
}
return OAuthReconnectionManager.instance;
}
public static async createInstance(
flowManager: FlowStateManager<MCPOAuthTokens | null>,
tokenMethods: TokenMethods,
reconnections?: OAuthReconnectionTracker,
): Promise<OAuthReconnectionManager> {
if (OAuthReconnectionManager.instance != null) {
throw new Error('OAuthReconnectionManager already initialized');
}
const manager = new OAuthReconnectionManager(flowManager, tokenMethods, reconnections);
OAuthReconnectionManager.instance = manager;
return manager;
}
public constructor(
flowManager: FlowStateManager<MCPOAuthTokens | null>,
tokenMethods: TokenMethods,
reconnections?: OAuthReconnectionTracker,
) {
this.flowManager = flowManager;
this.tokenMethods = tokenMethods;
this.reconnectionsTracker = reconnections ?? new OAuthReconnectionTracker();
}
public isReconnecting(userId: string, serverName: string): boolean {
return this.reconnectionsTracker.isActive(userId, serverName);
}
public async reconnectServers(userId: string) {
const mcpManager = MCPManager.getInstance();
// 1. derive the servers to reconnect
const serversToReconnect = [];
for (const serverName of mcpManager.getOAuthServers() ?? []) {
const canReconnect = await this.canReconnect(userId, serverName);
if (canReconnect) {
serversToReconnect.push(serverName);
}
}
// 2. mark the servers as reconnecting
for (const serverName of serversToReconnect) {
this.reconnectionsTracker.setActive(userId, serverName);
}
// 3. attempt to reconnect the servers
for (const serverName of serversToReconnect) {
void this.tryReconnect(userId, serverName);
}
}
public clearReconnection(userId: string, serverName: string) {
this.reconnectionsTracker.removeFailed(userId, serverName);
this.reconnectionsTracker.removeActive(userId, serverName);
}
private async tryReconnect(userId: string, serverName: string) {
const mcpManager = MCPManager.getInstance();
const logPrefix = `[tryReconnectOAuthMCPServer][User: ${userId}][${serverName}]`;
logger.info(`${logPrefix} Attempting reconnection`);
const config = mcpManager.getRawConfig(serverName);
const cleanupOnFailedReconnect = () => {
this.reconnectionsTracker.setFailed(userId, serverName);
this.reconnectionsTracker.removeActive(userId, serverName);
mcpManager.disconnectUserConnection(userId, serverName);
};
try {
// attempt to get connection (this will use existing tokens and refresh if needed)
const connection = await mcpManager.getUserConnection({
serverName,
user: { id: userId } as TUser,
flowManager: this.flowManager,
tokenMethods: this.tokenMethods,
// don't force new connection, let it reuse existing or create new as needed
forceNew: false,
// set a reasonable timeout for reconnection attempts
connectionTimeout: config?.initTimeout ?? DEFAULT_CONNECTION_TIMEOUT_MS,
// don't trigger OAuth flow during reconnection
returnOnOAuth: true,
});
if (connection && (await connection.isConnected())) {
logger.info(`${logPrefix} Successfully reconnected`);
this.clearReconnection(userId, serverName);
} else {
logger.warn(`${logPrefix} Failed to reconnect`);
await connection?.disconnect();
cleanupOnFailedReconnect();
}
} catch (error) {
logger.warn(`${logPrefix} Failed to reconnect: ${error}`);
cleanupOnFailedReconnect();
}
}
private async canReconnect(userId: string, serverName: string) {
const mcpManager = MCPManager.getInstance();
// if the server has failed reconnection, don't attempt to reconnect
if (this.reconnectionsTracker.isFailed(userId, serverName)) {
return false;
}
// if the server is already connected, don't attempt to reconnect
const existingConnections = mcpManager.getUserConnections(userId);
if (existingConnections?.has(serverName)) {
const isConnected = await existingConnections.get(serverName)?.isConnected();
if (isConnected) {
return false;
}
}
// if the server has no tokens for the user, don't attempt to reconnect
const accessToken = await this.tokenMethods.findToken({
userId,
type: 'mcp_oauth',
identifier: `mcp:${serverName}`,
});
if (accessToken == null) {
return false;
}
// if the token has expired, don't attempt to reconnect
const now = new Date();
if (accessToken.expiresAt && accessToken.expiresAt < now) {
return false;
}
// …otherwise, we're good to go with the reconnect attempt
return true;
}
}

View File

@@ -0,0 +1,181 @@
import { OAuthReconnectionTracker } from './OAuthReconnectionTracker';
describe('OAuthReconnectTracker', () => {
let tracker: OAuthReconnectionTracker;
const userId = 'user123';
const serverName = 'test-server';
const anotherServer = 'another-server';
beforeEach(() => {
tracker = new OAuthReconnectionTracker();
});
describe('setFailed', () => {
it('should record a failed reconnection attempt', () => {
tracker.setFailed(userId, serverName);
expect(tracker.isFailed(userId, serverName)).toBe(true);
});
it('should track multiple servers for the same user', () => {
tracker.setFailed(userId, serverName);
tracker.setFailed(userId, anotherServer);
expect(tracker.isFailed(userId, serverName)).toBe(true);
expect(tracker.isFailed(userId, anotherServer)).toBe(true);
});
it('should track different users independently', () => {
const anotherUserId = 'user456';
tracker.setFailed(userId, serverName);
tracker.setFailed(anotherUserId, serverName);
expect(tracker.isFailed(userId, serverName)).toBe(true);
expect(tracker.isFailed(anotherUserId, serverName)).toBe(true);
});
});
describe('isFailed', () => {
it('should return false when no failed attempt is recorded', () => {
expect(tracker.isFailed(userId, serverName)).toBe(false);
});
it('should return true after a failed attempt is recorded', () => {
tracker.setFailed(userId, serverName);
expect(tracker.isFailed(userId, serverName)).toBe(true);
});
it('should return false for a different server even after another server failed', () => {
tracker.setFailed(userId, serverName);
expect(tracker.isFailed(userId, anotherServer)).toBe(false);
});
});
describe('removeFailed', () => {
it('should clear a failed reconnect record', () => {
tracker.setFailed(userId, serverName);
expect(tracker.isFailed(userId, serverName)).toBe(true);
tracker.removeFailed(userId, serverName);
expect(tracker.isFailed(userId, serverName)).toBe(false);
});
it('should only clear the specific server for the user', () => {
tracker.setFailed(userId, serverName);
tracker.setFailed(userId, anotherServer);
tracker.removeFailed(userId, serverName);
expect(tracker.isFailed(userId, serverName)).toBe(false);
expect(tracker.isFailed(userId, anotherServer)).toBe(true);
});
it('should handle clearing non-existent records gracefully', () => {
expect(() => tracker.removeFailed(userId, serverName)).not.toThrow();
});
});
describe('setActive', () => {
it('should mark a server as reconnecting', () => {
tracker.setActive(userId, serverName);
expect(tracker.isActive(userId, serverName)).toBe(true);
});
it('should track multiple reconnecting servers', () => {
tracker.setActive(userId, serverName);
tracker.setActive(userId, anotherServer);
expect(tracker.isActive(userId, serverName)).toBe(true);
expect(tracker.isActive(userId, anotherServer)).toBe(true);
});
});
describe('isActive', () => {
it('should return false when server is not reconnecting', () => {
expect(tracker.isActive(userId, serverName)).toBe(false);
});
it('should return true when server is marked as reconnecting', () => {
tracker.setActive(userId, serverName);
expect(tracker.isActive(userId, serverName)).toBe(true);
});
it('should handle non-existent user gracefully', () => {
expect(tracker.isActive('non-existent-user', serverName)).toBe(false);
});
});
describe('removeActive', () => {
it('should clear reconnecting state for a server', () => {
tracker.setActive(userId, serverName);
expect(tracker.isActive(userId, serverName)).toBe(true);
tracker.removeActive(userId, serverName);
expect(tracker.isActive(userId, serverName)).toBe(false);
});
it('should only clear specific server state', () => {
tracker.setActive(userId, serverName);
tracker.setActive(userId, anotherServer);
tracker.removeActive(userId, serverName);
expect(tracker.isActive(userId, serverName)).toBe(false);
expect(tracker.isActive(userId, anotherServer)).toBe(true);
});
it('should handle clearing non-existent state gracefully', () => {
expect(() => tracker.removeActive(userId, serverName)).not.toThrow();
});
});
describe('cleanup behavior', () => {
it('should clean up empty user sets for failed reconnects', () => {
tracker.setFailed(userId, serverName);
tracker.removeFailed(userId, serverName);
// Record and clear another user to ensure internal cleanup
const anotherUserId = 'user456';
tracker.setFailed(anotherUserId, serverName);
// Original user should still be able to reconnect
expect(tracker.isFailed(userId, serverName)).toBe(false);
});
it('should clean up empty user sets for active reconnections', () => {
tracker.setActive(userId, serverName);
tracker.removeActive(userId, serverName);
// Mark another user to ensure internal cleanup
const anotherUserId = 'user456';
tracker.setActive(anotherUserId, serverName);
// Original user should not be reconnecting
expect(tracker.isActive(userId, serverName)).toBe(false);
});
});
describe('combined state management', () => {
it('should handle both failed and reconnecting states independently', () => {
// Mark as reconnecting
tracker.setActive(userId, serverName);
expect(tracker.isActive(userId, serverName)).toBe(true);
expect(tracker.isFailed(userId, serverName)).toBe(false);
// Record failed attempt
tracker.setFailed(userId, serverName);
expect(tracker.isActive(userId, serverName)).toBe(true);
expect(tracker.isFailed(userId, serverName)).toBe(true);
// Clear reconnecting state
tracker.removeActive(userId, serverName);
expect(tracker.isActive(userId, serverName)).toBe(false);
expect(tracker.isFailed(userId, serverName)).toBe(true);
// Clear failed state
tracker.removeFailed(userId, serverName);
expect(tracker.isActive(userId, serverName)).toBe(false);
expect(tracker.isFailed(userId, serverName)).toBe(false);
});
});
});

View File

@@ -0,0 +1,46 @@
export class OAuthReconnectionTracker {
// Map of userId -> Set of serverNames that have failed reconnection
private failed: Map<string, Set<string>> = new Map();
// Map of userId -> Set of serverNames that are actively reconnecting
private active: Map<string, Set<string>> = new Map();
public isFailed(userId: string, serverName: string): boolean {
return this.failed.get(userId)?.has(serverName) ?? false;
}
public isActive(userId: string, serverName: string): boolean {
return this.active.get(userId)?.has(serverName) ?? false;
}
public setFailed(userId: string, serverName: string): void {
if (!this.failed.has(userId)) {
this.failed.set(userId, new Set());
}
this.failed.get(userId)?.add(serverName);
}
public setActive(userId: string, serverName: string): void {
if (!this.active.has(userId)) {
this.active.set(userId, new Set());
}
this.active.get(userId)?.add(serverName);
}
public removeFailed(userId: string, serverName: string): void {
const userServers = this.failed.get(userId);
userServers?.delete(serverName);
if (userServers?.size === 0) {
this.failed.delete(userId);
}
}
public removeActive(userId: string, serverName: string): void {
const userServers = this.active.get(userId);
userServers?.delete(serverName);
if (userServers?.size === 0) {
this.active.delete(userId);
}
}
}

View File

@@ -17,6 +17,7 @@ import type {
MCPOAuthTokens,
OAuthMetadata,
} from './types';
import { sanitizeUrlForLogging } from '~/mcp/utils';
/** Type for the OAuth metadata from the SDK */
type SDKOAuthMetadata = Parameters<typeof registerClient>[1]['metadata'];
@@ -33,7 +34,9 @@ export class MCPOAuthHandler {
resourceMetadata?: OAuthProtectedResourceMetadata;
authServerUrl: URL;
}> {
logger.debug(`[MCPOAuth] discoverMetadata called with serverUrl: ${serverUrl}`);
logger.debug(
`[MCPOAuth] discoverMetadata called with serverUrl: ${sanitizeUrlForLogging(serverUrl)}`,
);
let authServerUrl = new URL(serverUrl);
let resourceMetadata: OAuthProtectedResourceMetadata | undefined;
@@ -60,11 +63,15 @@ export class MCPOAuthHandler {
}
// Discover OAuth metadata
logger.debug(`[MCPOAuth] Discovering OAuth metadata from ${authServerUrl}`);
logger.debug(
`[MCPOAuth] Discovering OAuth metadata from ${sanitizeUrlForLogging(authServerUrl)}`,
);
const rawMetadata = await discoverAuthorizationServerMetadata(authServerUrl);
if (!rawMetadata) {
logger.error(`[MCPOAuth] Failed to discover OAuth metadata from ${authServerUrl}`);
logger.error(
`[MCPOAuth] Failed to discover OAuth metadata from ${sanitizeUrlForLogging(authServerUrl)}`,
);
throw new Error('Failed to discover OAuth metadata');
}
@@ -88,12 +95,15 @@ export class MCPOAuthHandler {
resourceMetadata?: OAuthProtectedResourceMetadata,
redirectUri?: string,
): Promise<OAuthClientInformation> {
logger.debug(`[MCPOAuth] Starting client registration for ${serverUrl}, server metadata:`, {
grant_types_supported: metadata.grant_types_supported,
response_types_supported: metadata.response_types_supported,
token_endpoint_auth_methods_supported: metadata.token_endpoint_auth_methods_supported,
scopes_supported: metadata.scopes_supported,
});
logger.debug(
`[MCPOAuth] Starting client registration for ${sanitizeUrlForLogging(serverUrl)}, server metadata:`,
{
grant_types_supported: metadata.grant_types_supported,
response_types_supported: metadata.response_types_supported,
token_endpoint_auth_methods_supported: metadata.token_endpoint_auth_methods_supported,
scopes_supported: metadata.scopes_supported,
},
);
/** Client metadata based on what the server supports */
const clientMetadata = {
@@ -114,7 +124,9 @@ export class MCPOAuthHandler {
`[MCPOAuth] Server ${serverUrl} supports \`refresh_token\` grant type, adding to request`,
);
} else {
logger.debug(`[MCPOAuth] Server ${serverUrl} does not support \`refresh_token\` grant type`);
logger.debug(
`[MCPOAuth] Server ${sanitizeUrlForLogging(serverUrl)} does not support \`refresh_token\` grant type`,
);
}
clientMetadata.grant_types = requestedGrantTypes;
@@ -139,19 +151,25 @@ export class MCPOAuthHandler {
clientMetadata.scope = availableScopes.join(' ');
}
logger.debug(`[MCPOAuth] Registering client for ${serverUrl} with metadata:`, clientMetadata);
logger.debug(
`[MCPOAuth] Registering client for ${sanitizeUrlForLogging(serverUrl)} with metadata:`,
clientMetadata,
);
const clientInfo = await registerClient(serverUrl, {
metadata: metadata as unknown as SDKOAuthMetadata,
clientMetadata,
});
logger.debug(`[MCPOAuth] Client registered successfully for ${serverUrl}:`, {
client_id: clientInfo.client_id,
has_client_secret: !!clientInfo.client_secret,
grant_types: clientInfo.grant_types,
scope: clientInfo.scope,
});
logger.debug(
`[MCPOAuth] Client registered successfully for ${sanitizeUrlForLogging(serverUrl)}:`,
{
client_id: clientInfo.client_id,
has_client_secret: !!clientInfo.client_secret,
grant_types: clientInfo.grant_types,
scope: clientInfo.scope,
},
);
return clientInfo;
}
@@ -165,7 +183,9 @@ export class MCPOAuthHandler {
userId: string,
config: MCPOptions['oauth'] | undefined,
): Promise<{ authorizationUrl: string; flowId: string; flowMetadata: MCPOAuthFlowMetadata }> {
logger.debug(`[MCPOAuth] initiateOAuthFlow called for ${serverName} with URL: ${serverUrl}`);
logger.debug(
`[MCPOAuth] initiateOAuthFlow called for ${serverName} with URL: ${sanitizeUrlForLogging(serverUrl)}`,
);
const flowId = this.generateFlowId(userId, serverName);
const state = this.generateState();
@@ -226,7 +246,9 @@ export class MCPOAuthHandler {
metadata,
};
logger.debug(`[MCPOAuth] Authorization URL generated: ${authorizationUrl.toString()}`);
logger.debug(
`[MCPOAuth] Authorization URL generated: ${sanitizeUrlForLogging(authorizationUrl.toString())}`,
);
return {
authorizationUrl: authorizationUrl.toString(),
flowId,
@@ -234,10 +256,14 @@ export class MCPOAuthHandler {
};
}
logger.debug(`[MCPOAuth] Starting auto-discovery of OAuth metadata from ${serverUrl}`);
logger.debug(
`[MCPOAuth] Starting auto-discovery of OAuth metadata from ${sanitizeUrlForLogging(serverUrl)}`,
);
const { metadata, resourceMetadata, authServerUrl } = await this.discoverMetadata(serverUrl);
logger.debug(`[MCPOAuth] OAuth metadata discovered, auth server URL: ${authServerUrl}`);
logger.debug(
`[MCPOAuth] OAuth metadata discovered, auth server URL: ${sanitizeUrlForLogging(authServerUrl)}`,
);
/** Dynamic client registration based on the discovered metadata */
const redirectUri = config?.redirect_uri || this.getDefaultRedirectUri(serverName);
@@ -276,7 +302,9 @@ export class MCPOAuthHandler {
codeVerifier = authResult.codeVerifier;
logger.debug(`[MCPOAuth] startAuthorization completed successfully`);
logger.debug(`[MCPOAuth] Authorization URL: ${authorizationUrl.toString()}`);
logger.debug(
`[MCPOAuth] Authorization URL: ${sanitizeUrlForLogging(authorizationUrl.toString())}`,
);
/** Add state parameter with flowId to the authorization URL */
authorizationUrl.searchParams.set('state', flowId);
@@ -515,7 +543,7 @@ export class MCPOAuthHandler {
body.append('client_id', metadata.clientInfo.client_id);
}
logger.debug(`[MCPOAuth] Refresh request to: ${tokenUrl}`, {
logger.debug(`[MCPOAuth] Refresh request to: ${sanitizeUrlForLogging(tokenUrl)}`, {
body: body.toString(),
headers,
});
@@ -695,7 +723,9 @@ export class MCPOAuthHandler {
}
// perform the revoke request
logger.info(`[MCPOAuth] Revoking tokens for ${serverName} via ${revokeUrl.toString()}`);
logger.info(
`[MCPOAuth] Revoking tokens for ${serverName} via ${sanitizeUrlForLogging(revokeUrl.toString())}`,
);
const response = await fetch(revokeUrl, {
method: 'POST',
body: body.toString(),

View File

@@ -31,3 +31,17 @@ export function normalizeServerName(serverName: string): string {
return normalized;
}
/**
* Sanitizes a URL by removing query parameters to prevent credential leakage in logs.
* @param url - The URL to sanitize (string or URL object)
* @returns The sanitized URL string without query parameters
*/
export function sanitizeUrlForLogging(url: string | URL): string {
try {
const urlObj = typeof url === 'string' ? new URL(url) : url;
return `${urlObj.protocol}//${urlObj.host}${urlObj.pathname}`;
} catch {
return '[invalid URL]';
}
}

View File

@@ -1,42 +0,0 @@
interface ArrowIconProps {
direction: 'up' | 'down';
className?: string;
}
export default function ArrowIcon({ direction, className = 'inline' }: ArrowIconProps) {
if (direction === 'up') {
return (
<svg
xmlns="http://www.w3.org/2000/svg"
width="1em"
height="1em"
fill="currentColor"
viewBox="0 0 24 24"
className={className}
>
<path
fillRule="evenodd"
d="M11.293 5.293a1 1 0 0 1 1.414 0l5 5a1 1 0 0 1-1.414 1.414L13 8.414V18a1 1 0 1 1-2 0V8.414l-3.293 3.293a1 1 0 0 1-1.414-1.414l5-5Z"
clipRule="evenodd"
/>
</svg>
);
}
return (
<svg
xmlns="http://www.w3.org/2000/svg"
width="1em"
height="1em"
fill="currentColor"
viewBox="0 0 24 24"
className={className}
>
<path
fillRule="evenodd"
d="M12.707 18.707a1 1 0 0 1-1.414 0l-5-5a1 1 0 1 1 1.414-1.414L11 15.586V6a1 1 0 1 1 2 0v9.586l3.293-3.293a1 1 0 0 1 1.414 1.414l-5 5Z"
clipRule="evenodd"
/>
</svg>
);
}

View File

@@ -1,5 +1,4 @@
export { default as ArchiveIcon } from './ArchiveIcon';
export { default as ArrowIcon } from './ArrowIcon';
export { default as Blocks } from './Blocks';
export { default as Plugin } from './Plugin';
export { default as GPTIcon } from './GPTIcon';

View File

@@ -66,8 +66,6 @@ export const messages = (params: q.MessagesListParams) => {
export const messagesArtifacts = (messageId: string) => `${messagesRoot}/artifacts/${messageId}`;
export const costs = () => `/api/messages/costs`;
const shareRoot = `${BASE_URL}/api/share`;
export const shareMessages = (shareId: string) => `${shareRoot}/${shareId}`;
export const getSharedLink = (conversationId: string) => `${shareRoot}/link/${conversationId}`;

View File

@@ -51,7 +51,6 @@ export const excludedKeys = new Set([
'_id',
'tools',
'model',
'modelHistory',
'files',
'spec',
'disableParams',

View File

@@ -697,12 +697,6 @@ export function getMessagesByConvoId(conversationId: string): Promise<s.TMessage
return request.get(endpoints.messages({ conversationId }));
}
export function getModelCosts(
modelHistory: Array<{ model: string; endpoint: string }>,
): Promise<t.TModelCosts> {
return request.post(endpoints.costs(), { modelHistory });
}
export function getPrompt(id: string): Promise<{ prompt: t.TPrompt }> {
return request.get(endpoints.getPrompt(id));
}

View File

@@ -6,7 +6,6 @@ export enum QueryKeys {
archivedConversations = 'archivedConversations',
searchConversations = 'searchConversations',
conversation = 'conversation',
modelCosts = 'modelCosts',
searchEnabled = 'searchEnabled',
user = 'user',
name = 'name', // user key name

View File

@@ -77,23 +77,6 @@ export const useGetConversationByIdQuery = (
);
};
export const useGetModelCostsQuery = (
modelHistory: Array<{ model: string; endpoint: string }>,
config?: UseQueryOptions<t.TModelCosts>,
): QueryObserverResult<t.TModelCosts> => {
return useQuery<t.TModelCosts>(
[QueryKeys.modelCosts, modelHistory],
() => dataService.getModelCosts(modelHistory),
{
refetchOnWindowFocus: false,
refetchOnReconnect: false,
refetchOnMount: false,
enabled: !!modelHistory && modelHistory.length > 0,
...config,
},
);
};
//This isn't ideal because its just a query and we're using mutation, but it was the only way
//to make it work with how the Chat component is structured
export const useGetConversationByIdMutation = (id: string): UseMutationResult<s.TConversation> => {

View File

@@ -98,12 +98,6 @@ if (typeof window !== 'undefined') {
if (originalRequest.url?.includes('/api/auth/logout') === true) {
return Promise.reject(error);
}
if (originalRequest.url?.includes('/api/auth/refresh') === true) {
// Refresh token itself failed - redirect to login
console.log('Refresh token request failed, redirecting to login...');
window.location.href = '/login';
return Promise.reject(error);
}
if (error.response.status === 401 && !originalRequest._retry) {
console.warn('401 error, refreshing token');
@@ -124,7 +118,10 @@ if (typeof window !== 'undefined') {
isRefreshing = true;
try {
const response = await refreshToken();
const response = await refreshToken(
// Handle edge case where we get a blank screen if the initial 401 error is from a refresh token request
originalRequest.url?.includes('api/auth/refresh') === true ? true : false,
);
const token = response?.token ?? '';

View File

@@ -518,7 +518,6 @@ export const tMessageSchema = z.object({
overrideParentMessageId: z.string().nullable().optional(),
bg: z.string().nullable().optional(),
model: z.string().nullable().optional(),
targetModel: z.string().nullable().optional(),
title: z.string().nullable().or(z.literal('New Chat')).default('New Chat'),
sender: z.string().optional(),
text: z.string(),
@@ -632,7 +631,6 @@ export const tConversationSchema = z.object({
modelLabel: z.string().nullable().optional(),
userLabel: z.string().optional(),
model: z.string().nullable().optional(),
modelHistory: z.array(z.object({ model: z.string(), endpoint: z.string() })).optional(),
promptPrefix: z.string().nullable().optional(),
temperature: z.number().nullable().optional(),
topP: z.number().optional(),

View File

@@ -653,15 +653,3 @@ export type TBalanceResponse = {
lastRefill?: Date;
refillAmount?: number;
};
export type TConversationCosts = {
totals: {
prompt: { usd: number; tokenCount: number };
completion: { usd: number; tokenCount: number };
total: { usd: number; tokenCount: number };
};
};
export type TModelCosts = {
modelCostTable: Record<string, { prompt: number; completion: number }>;
};

View File

@@ -155,14 +155,4 @@ export const conversationPreset = {
verbosity: {
type: String,
},
/** Track all unique models used in this conversation with their endpoints */
modelHistory: {
type: [
{
model: { type: String, required: true },
endpoint: { type: String, required: true },
},
],
default: [],
},
};

View File

@@ -26,10 +26,6 @@ const messageSchema: Schema<IMessage> = new Schema(
type: String,
default: null,
},
targetModel: {
type: String,
default: null,
},
endpoint: {
type: String,
},

View File

@@ -11,7 +11,6 @@ export interface IConversation extends Document {
endpoint?: string;
endpointType?: string;
model?: string;
modelHistory?: Array<{ model: string; endpoint: string }>;
region?: string;
chatGptLabel?: string;
examples?: unknown[];

View File

@@ -7,7 +7,6 @@ export interface IMessage extends Document {
conversationId: string;
user: string;
model?: string;
targetModel?: string;
endpoint?: string;
conversationSignature?: string;
clientId?: string;