Compare commits
17 Commits
feat/price
...
feat/admin
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
278590d0bb | ||
|
|
41a4674469 | ||
|
|
e7a9cf88ac | ||
|
|
f6925f906b | ||
|
|
e90fd1df15 | ||
|
|
a1f9f3dd39 | ||
|
|
fbe0def2fa | ||
|
|
d04da60b3b | ||
|
|
0e94d97bfb | ||
|
|
45ab4d4503 | ||
|
|
0ceef12eea | ||
|
|
6738360051 | ||
|
|
52b65492d5 | ||
|
|
7a9a99d2a0 | ||
|
|
5bfb06b417 | ||
|
|
2ce8f1f686 | ||
|
|
1a47601533 |
@@ -233,7 +233,6 @@ class BaseClient {
|
||||
sender: 'User',
|
||||
text,
|
||||
isCreatedByUser: true,
|
||||
targetModel: this.modelOptions?.model ?? this.model,
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
const { MCPManager, FlowStateManager } = require('@librechat/api');
|
||||
const { EventSource } = require('eventsource');
|
||||
const { Time } = require('librechat-data-provider');
|
||||
const { MCPManager, FlowStateManager, OAuthReconnectionManager } = require('@librechat/api');
|
||||
const logger = require('./winston');
|
||||
|
||||
global.EventSource = EventSource;
|
||||
@@ -26,4 +26,6 @@ module.exports = {
|
||||
createMCPManager: MCPManager.createInstance,
|
||||
getMCPManager: MCPManager.getInstance,
|
||||
getFlowStateManager,
|
||||
createOAuthReconnectionManager: OAuthReconnectionManager.createInstance,
|
||||
getOAuthReconnectionManager: OAuthReconnectionManager.getInstance,
|
||||
};
|
||||
|
||||
@@ -112,17 +112,8 @@ module.exports = {
|
||||
update.expiredAt = null;
|
||||
}
|
||||
|
||||
/** @type {{ $set: Partial<TConversation>; $addToSet?: Record<string, any>; $unset?: Record<keyof TConversation, number> }} */
|
||||
/** @type {{ $set: Partial<TConversation>; $unset?: Record<keyof TConversation, number> }} */
|
||||
const updateOperation = { $set: update };
|
||||
|
||||
if (convo.model && convo.endpoint) {
|
||||
updateOperation.$addToSet = {
|
||||
modelHistory: {
|
||||
model: convo.model,
|
||||
endpoint: convo.endpoint,
|
||||
},
|
||||
};
|
||||
}
|
||||
if (metadata && metadata.unsetFields && Object.keys(metadata.unsetFields).length > 0) {
|
||||
updateOperation.$unset = metadata.unsetFields;
|
||||
}
|
||||
|
||||
@@ -11,8 +11,9 @@ const {
|
||||
registerUser,
|
||||
} = require('~/server/services/AuthService');
|
||||
const { findUser, getUserById, deleteAllUserSessions, findSession } = require('~/models');
|
||||
const { getOpenIdConfig } = require('~/strategies');
|
||||
const { getGraphApiToken } = require('~/server/services/GraphTokenService');
|
||||
const { getOAuthReconnectionManager } = require('~/config');
|
||||
const { getOpenIdConfig } = require('~/strategies');
|
||||
|
||||
const registrationController = async (req, res) => {
|
||||
try {
|
||||
@@ -96,14 +97,25 @@ const refreshController = async (req, res) => {
|
||||
return res.status(200).send({ token, user });
|
||||
}
|
||||
|
||||
// Find the session with the hashed refresh token
|
||||
const session = await findSession({
|
||||
userId: userId,
|
||||
refreshToken: refreshToken,
|
||||
});
|
||||
/** Session with the hashed refresh token */
|
||||
const session = await findSession(
|
||||
{
|
||||
userId: userId,
|
||||
refreshToken: refreshToken,
|
||||
},
|
||||
{ lean: false },
|
||||
);
|
||||
|
||||
if (session && session.expiration > new Date()) {
|
||||
const token = await setAuthTokens(userId, res, session._id);
|
||||
const token = await setAuthTokens(userId, res, session);
|
||||
|
||||
// trigger OAuth MCP server reconnection asynchronously (best effort)
|
||||
void getOAuthReconnectionManager()
|
||||
.reconnectServers(userId)
|
||||
.catch((err) => {
|
||||
logger.error('Error reconnecting OAuth MCP servers:', err);
|
||||
});
|
||||
|
||||
res.status(200).send({ token, user });
|
||||
} else if (req?.query?.retry) {
|
||||
// Retrying from a refresh token request that failed (401)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { generate2FATempToken } = require('~/server/services/twoFactorService');
|
||||
const { setAuthTokens } = require('~/server/services/AuthService');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const loginController = async (req, res) => {
|
||||
try {
|
||||
|
||||
50
api/server/controllers/auth/oauth.js
Normal file
50
api/server/controllers/auth/oauth.js
Normal file
@@ -0,0 +1,50 @@
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { syncUserEntraGroupMemberships } = require('~/server/services/PermissionService');
|
||||
const { setAuthTokens, setOpenIDAuthTokens } = require('~/server/services/AuthService');
|
||||
const { checkBan } = require('~/server/middleware');
|
||||
|
||||
const domains = {
|
||||
client: process.env.DOMAIN_CLIENT,
|
||||
server: process.env.DOMAIN_SERVER,
|
||||
};
|
||||
|
||||
function createOAuthHandler(redirectUri = domains.client) {
|
||||
/**
|
||||
* A handler to process OAuth authentication results.
|
||||
* @type {Function}
|
||||
* @param {ServerRequest} req - Express request object.
|
||||
* @param {ServerResponse} res - Express response object.
|
||||
* @param {NextFunction} next - Express next middleware function.
|
||||
*/
|
||||
return async (req, res, next) => {
|
||||
try {
|
||||
if (res.headersSent) {
|
||||
return;
|
||||
}
|
||||
|
||||
await checkBan(req, res);
|
||||
if (req.banned) {
|
||||
return;
|
||||
}
|
||||
if (
|
||||
req.user &&
|
||||
req.user.provider == 'openid' &&
|
||||
isEnabled(process.env.OPENID_REUSE_TOKENS) === true
|
||||
) {
|
||||
await syncUserEntraGroupMemberships(req.user, req.user.tokenset.access_token);
|
||||
setOpenIDAuthTokens(req.user.tokenset, res, req.user._id.toString());
|
||||
} else {
|
||||
await setAuthTokens(req.user._id, res);
|
||||
}
|
||||
res.redirect(redirectUri);
|
||||
} catch (err) {
|
||||
logger.error('Error in setting authentication tokens:', err);
|
||||
next(err);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
createOAuthHandler,
|
||||
};
|
||||
@@ -12,6 +12,7 @@ const { logger } = require('@librechat/data-schemas');
|
||||
const mongoSanitize = require('express-mongo-sanitize');
|
||||
const { isEnabled, ErrorController } = require('@librechat/api');
|
||||
const { connectDb, indexSync } = require('~/db');
|
||||
const initializeOAuthReconnectManager = require('./services/initializeOAuthReconnectManager');
|
||||
const createValidateImageRequest = require('./middleware/validateImageRequest');
|
||||
const { jwtLogin, ldapLogin, passportLogin } = require('~/strategies');
|
||||
const { updateInterfacePermissions } = require('~/models/interface');
|
||||
@@ -108,6 +109,7 @@ const startServer = async () => {
|
||||
app.use('/oauth', routes.oauth);
|
||||
/* API Endpoints */
|
||||
app.use('/api/auth', routes.auth);
|
||||
app.use('/api/admin', routes.adminAuth);
|
||||
app.use('/api/actions', routes.actions);
|
||||
app.use('/api/keys', routes.keys);
|
||||
app.use('/api/user', routes.user);
|
||||
@@ -154,7 +156,7 @@ const startServer = async () => {
|
||||
res.send(updatedIndexHtml);
|
||||
});
|
||||
|
||||
app.listen(port, host, () => {
|
||||
app.listen(port, host, async () => {
|
||||
if (host === '0.0.0.0') {
|
||||
logger.info(
|
||||
`Server listening on all interfaces at port ${port}. Use http://localhost:${port} to access it`,
|
||||
@@ -163,7 +165,9 @@ const startServer = async () => {
|
||||
logger.info(`Server listening at http://${host == '0.0.0.0' ? 'localhost' : host}:${port}`);
|
||||
}
|
||||
|
||||
initializeMCPs().then(() => checkMigrations());
|
||||
await initializeMCPs();
|
||||
await initializeOAuthReconnectManager();
|
||||
await checkMigrations();
|
||||
});
|
||||
};
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ const checkInviteUser = require('./checkInviteUser');
|
||||
const requireJwtAuth = require('./requireJwtAuth');
|
||||
const configMiddleware = require('./config/app');
|
||||
const validateModel = require('./validateModel');
|
||||
const requireAdmin = require('./requireAdmin');
|
||||
const moderateText = require('./moderateText');
|
||||
const logHeaders = require('./logHeaders');
|
||||
const setHeaders = require('./setHeaders');
|
||||
@@ -36,6 +37,7 @@ module.exports = {
|
||||
setHeaders,
|
||||
logHeaders,
|
||||
moderateText,
|
||||
requireAdmin,
|
||||
validateModel,
|
||||
requireJwtAuth,
|
||||
checkInviteUser,
|
||||
|
||||
22
api/server/middleware/requireAdmin.js
Normal file
22
api/server/middleware/requireAdmin.js
Normal file
@@ -0,0 +1,22 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { SystemRoles } = require('librechat-data-provider');
|
||||
|
||||
/**
|
||||
* Middleware to check if authenticated user has admin role
|
||||
* Should be used AFTER authentication middleware (requireJwtAuth, requireLocalAuth, etc.)
|
||||
*/
|
||||
const requireAdmin = (req, res, next) => {
|
||||
if (!req.user) {
|
||||
logger.warn('[requireAdmin] No user found in request');
|
||||
return res.status(401).json({ message: 'Authentication required' });
|
||||
}
|
||||
|
||||
if (!req.user.role || req.user.role !== SystemRoles.ADMIN) {
|
||||
logger.debug('[requireAdmin] Access denied for non-admin user:', req.user.email);
|
||||
return res.status(403).json({ message: 'Access denied: Admin privileges required' });
|
||||
}
|
||||
|
||||
next();
|
||||
};
|
||||
|
||||
module.exports = requireAdmin;
|
||||
@@ -1,6 +1,6 @@
|
||||
const passport = require('passport');
|
||||
const cookies = require('cookie');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const passport = require('passport');
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
|
||||
/**
|
||||
* Custom Middleware to handle JWT authentication, with support for OpenID token reuse
|
||||
|
||||
@@ -1,680 +0,0 @@
|
||||
const express = require('express');
|
||||
const request = require('supertest');
|
||||
const mongoose = require('mongoose');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
debug: jest.fn(),
|
||||
info: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
error: jest.fn(),
|
||||
},
|
||||
createMethods: jest.fn(() => ({})),
|
||||
createModels: jest.fn(() => ({})),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/middleware', () => ({
|
||||
requireJwtAuth: (req, res, next) => next(),
|
||||
validateMessageReq: (req, res, next) => next(),
|
||||
}));
|
||||
|
||||
jest.mock('~/models', () => ({
|
||||
getConvo: jest.fn(),
|
||||
saveConvo: jest.fn(),
|
||||
saveMessage: jest.fn(),
|
||||
getMessage: jest.fn(),
|
||||
getMessages: jest.fn(),
|
||||
updateMessage: jest.fn(),
|
||||
deleteMessages: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/db/models', () => {
|
||||
let User, Message, Transaction, Conversation;
|
||||
|
||||
return {
|
||||
get User() {
|
||||
return User;
|
||||
},
|
||||
get Message() {
|
||||
return Message;
|
||||
},
|
||||
get Transaction() {
|
||||
return Transaction;
|
||||
},
|
||||
get Conversation() {
|
||||
return Conversation;
|
||||
},
|
||||
setUser: (model) => {
|
||||
User = model;
|
||||
},
|
||||
setMessage: (model) => {
|
||||
Message = model;
|
||||
},
|
||||
setTransaction: (model) => {
|
||||
Transaction = model;
|
||||
},
|
||||
setConversation: (model) => {
|
||||
Conversation = model;
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
describe('Costs Endpoint', () => {
|
||||
let app;
|
||||
let mongoServer;
|
||||
let messagesRouter;
|
||||
let User, Message, Transaction, Conversation;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
await mongoose.connect(mongoServer.getUri());
|
||||
|
||||
const userSchema = new mongoose.Schema({
|
||||
_id: String,
|
||||
name: String,
|
||||
email: String,
|
||||
});
|
||||
|
||||
const conversationSchema = new mongoose.Schema({
|
||||
conversationId: String,
|
||||
user: String,
|
||||
title: String,
|
||||
createdAt: Date,
|
||||
});
|
||||
|
||||
const messageSchema = new mongoose.Schema({
|
||||
messageId: String,
|
||||
conversationId: String,
|
||||
user: String,
|
||||
isCreatedByUser: Boolean,
|
||||
tokenCount: Number,
|
||||
createdAt: Date,
|
||||
});
|
||||
|
||||
const transactionSchema = new mongoose.Schema({
|
||||
conversationId: String,
|
||||
user: String,
|
||||
tokenType: String,
|
||||
tokenValue: Number,
|
||||
createdAt: Date,
|
||||
});
|
||||
|
||||
User = mongoose.model('User', userSchema);
|
||||
Conversation = mongoose.model('Conversation', conversationSchema);
|
||||
Message = mongoose.model('Message', messageSchema);
|
||||
Transaction = mongoose.model('Transaction', transactionSchema);
|
||||
|
||||
const dbModels = require('~/db/models');
|
||||
dbModels.setUser(User);
|
||||
dbModels.setMessage(Message);
|
||||
dbModels.setTransaction(Transaction);
|
||||
dbModels.setConversation(Conversation);
|
||||
|
||||
require('~/db/models');
|
||||
|
||||
try {
|
||||
messagesRouter = require('../messages');
|
||||
} catch (error) {
|
||||
console.error('Error loading messages router:', error);
|
||||
throw error;
|
||||
}
|
||||
|
||||
app = express();
|
||||
app.use(express.json());
|
||||
app.use((req, res, next) => {
|
||||
req.user = { id: 'test-user-id' };
|
||||
next();
|
||||
});
|
||||
app.use('/api/messages', messagesRouter);
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
await User.deleteMany({});
|
||||
await Conversation.deleteMany({});
|
||||
await Message.deleteMany({});
|
||||
await Transaction.deleteMany({});
|
||||
});
|
||||
|
||||
describe('GET /:conversationId/costs', () => {
|
||||
const conversationId = 'test-conversation-123';
|
||||
const userId = 'test-user-id';
|
||||
|
||||
it('should return cost data for valid conversation', async () => {
|
||||
const { getConvo } = require('~/models');
|
||||
getConvo.mockResolvedValue({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
});
|
||||
|
||||
const conversation = new Conversation({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
createdAt: new Date('2024-01-01T09:00:00Z'),
|
||||
});
|
||||
|
||||
await conversation.save();
|
||||
|
||||
const userMessage = new Message({
|
||||
messageId: 'user-msg-1',
|
||||
conversationId,
|
||||
user: userId,
|
||||
isCreatedByUser: true,
|
||||
tokenCount: 100,
|
||||
createdAt: new Date('2024-01-01T10:00:00Z'),
|
||||
});
|
||||
|
||||
const aiMessage = new Message({
|
||||
messageId: 'ai-msg-1',
|
||||
conversationId,
|
||||
user: userId,
|
||||
isCreatedByUser: false,
|
||||
tokenCount: 150,
|
||||
createdAt: new Date('2024-01-01T10:01:00Z'),
|
||||
});
|
||||
|
||||
await Promise.all([userMessage.save(), aiMessage.save()]);
|
||||
|
||||
const promptTransaction = new Transaction({
|
||||
conversationId,
|
||||
user: userId,
|
||||
tokenType: 'prompt',
|
||||
tokenValue: 500000,
|
||||
createdAt: new Date('2024-01-01T10:00:30Z'),
|
||||
});
|
||||
|
||||
const completionTransaction = new Transaction({
|
||||
conversationId,
|
||||
user: userId,
|
||||
tokenType: 'completion',
|
||||
tokenValue: 750000,
|
||||
createdAt: new Date('2024-01-01T10:01:30Z'),
|
||||
});
|
||||
|
||||
await Promise.all([promptTransaction.save(), completionTransaction.save()]);
|
||||
|
||||
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body).toMatchObject({
|
||||
conversationId,
|
||||
totals: {
|
||||
prompt: { usd: 0.5, tokenCount: 100 },
|
||||
completion: { usd: 0.75, tokenCount: 150 },
|
||||
total: { usd: 1.25, tokenCount: 250 },
|
||||
},
|
||||
perMessage: [
|
||||
{ messageId: 'user-msg-1', tokenType: 'prompt', tokenCount: 100, usd: 0.5 },
|
||||
{ messageId: 'ai-msg-1', tokenType: 'completion', tokenCount: 150, usd: 0.75 },
|
||||
],
|
||||
});
|
||||
});
|
||||
|
||||
it('should return empty data for conversation with no messages', async () => {
|
||||
const { getConvo } = require('~/models');
|
||||
getConvo.mockResolvedValue({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
});
|
||||
|
||||
const conversation = new Conversation({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
createdAt: new Date('2024-01-01T09:00:00Z'),
|
||||
});
|
||||
|
||||
await conversation.save();
|
||||
|
||||
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body).toMatchObject({
|
||||
conversationId,
|
||||
totals: {
|
||||
prompt: { usd: 0, tokenCount: 0 },
|
||||
completion: { usd: 0, tokenCount: 0 },
|
||||
total: { usd: 0, tokenCount: 0 },
|
||||
},
|
||||
perMessage: [],
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle messages without transactions', async () => {
|
||||
const { getConvo } = require('~/models');
|
||||
getConvo.mockResolvedValue({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
});
|
||||
|
||||
const conversation = new Conversation({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
createdAt: new Date('2024-01-01T09:00:00Z'),
|
||||
});
|
||||
|
||||
await conversation.save();
|
||||
|
||||
const userMessage = new Message({
|
||||
messageId: 'user-msg-1',
|
||||
conversationId,
|
||||
user: userId,
|
||||
isCreatedByUser: true,
|
||||
tokenCount: 100,
|
||||
createdAt: new Date('2024-01-01T10:00:00Z'),
|
||||
});
|
||||
|
||||
const aiMessage = new Message({
|
||||
messageId: 'ai-msg-1',
|
||||
conversationId,
|
||||
user: userId,
|
||||
isCreatedByUser: false,
|
||||
tokenCount: 150,
|
||||
createdAt: new Date('2024-01-01T10:01:00Z'),
|
||||
});
|
||||
|
||||
await Promise.all([userMessage.save(), aiMessage.save()]);
|
||||
|
||||
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body.totals.prompt.usd).toBe(0);
|
||||
expect(response.body.totals.completion.usd).toBe(0);
|
||||
expect(response.body.totals.total.usd).toBe(0);
|
||||
});
|
||||
|
||||
it('should aggregate multiple transactions correctly', async () => {
|
||||
const { getConvo } = require('~/models');
|
||||
getConvo.mockResolvedValue({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
});
|
||||
|
||||
const conversation = new Conversation({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
createdAt: new Date('2024-01-01T09:00:00Z'),
|
||||
});
|
||||
|
||||
await conversation.save();
|
||||
|
||||
const userMessage = new Message({
|
||||
messageId: 'user-msg-1',
|
||||
conversationId,
|
||||
user: userId,
|
||||
isCreatedByUser: true,
|
||||
tokenCount: 100,
|
||||
createdAt: new Date('2024-01-01T10:00:00Z'),
|
||||
});
|
||||
|
||||
await userMessage.save();
|
||||
|
||||
const promptTransaction1 = new Transaction({
|
||||
conversationId,
|
||||
user: userId,
|
||||
tokenType: 'prompt',
|
||||
tokenValue: 300000,
|
||||
createdAt: new Date('2024-01-01T10:00:30Z'),
|
||||
});
|
||||
|
||||
const promptTransaction2 = new Transaction({
|
||||
conversationId,
|
||||
user: userId,
|
||||
tokenType: 'prompt',
|
||||
tokenValue: 200000,
|
||||
createdAt: new Date('2024-01-01T10:00:45Z'),
|
||||
});
|
||||
|
||||
await Promise.all([promptTransaction1.save(), promptTransaction2.save()]);
|
||||
|
||||
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body.totals.prompt.usd).toBe(0.5);
|
||||
expect(response.body.perMessage[0].usd).toBe(0.5);
|
||||
});
|
||||
|
||||
it('should handle null tokenCount values', async () => {
|
||||
const { getConvo } = require('~/models');
|
||||
getConvo.mockResolvedValue({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
});
|
||||
|
||||
const conversation = new Conversation({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
createdAt: new Date('2024-01-01T09:00:00Z'),
|
||||
});
|
||||
|
||||
await conversation.save();
|
||||
|
||||
const userMessage = new Message({
|
||||
messageId: 'user-msg-1',
|
||||
conversationId,
|
||||
user: userId,
|
||||
isCreatedByUser: true,
|
||||
tokenCount: null,
|
||||
createdAt: new Date('2024-01-01T10:00:00Z'),
|
||||
});
|
||||
|
||||
await userMessage.save();
|
||||
|
||||
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body.totals.prompt.tokenCount).toBe(0);
|
||||
});
|
||||
|
||||
it('should handle null tokenValue in transactions', async () => {
|
||||
const { getConvo } = require('~/models');
|
||||
getConvo.mockResolvedValue({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
});
|
||||
|
||||
const conversation = new Conversation({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
createdAt: new Date('2024-01-01T09:00:00Z'),
|
||||
});
|
||||
|
||||
await conversation.save();
|
||||
|
||||
const userMessage = new Message({
|
||||
messageId: 'user-msg-1',
|
||||
conversationId,
|
||||
user: userId,
|
||||
isCreatedByUser: true,
|
||||
tokenCount: 100,
|
||||
createdAt: new Date('2024-01-01T10:00:00Z'),
|
||||
});
|
||||
|
||||
await userMessage.save();
|
||||
|
||||
const promptTransaction = new Transaction({
|
||||
conversationId,
|
||||
user: userId,
|
||||
tokenType: 'prompt',
|
||||
tokenValue: null,
|
||||
createdAt: new Date('2024-01-01T10:00:30Z'),
|
||||
});
|
||||
|
||||
await promptTransaction.save();
|
||||
|
||||
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body.totals.prompt.usd).toBe(0);
|
||||
});
|
||||
|
||||
it('should handle negative tokenValue using Math.abs', async () => {
|
||||
const { getConvo } = require('~/models');
|
||||
getConvo.mockResolvedValue({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
});
|
||||
|
||||
const conversation = new Conversation({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
createdAt: new Date('2024-01-01T09:00:00Z'),
|
||||
});
|
||||
|
||||
await conversation.save();
|
||||
|
||||
const userMessage = new Message({
|
||||
messageId: 'user-msg-1',
|
||||
conversationId,
|
||||
user: userId,
|
||||
isCreatedByUser: true,
|
||||
tokenCount: 100,
|
||||
createdAt: new Date('2024-01-01T10:00:00Z'),
|
||||
});
|
||||
|
||||
await userMessage.save();
|
||||
|
||||
const promptTransaction = new Transaction({
|
||||
conversationId,
|
||||
user: userId,
|
||||
tokenType: 'prompt',
|
||||
tokenValue: -500000,
|
||||
createdAt: new Date('2024-01-01T10:00:30Z'),
|
||||
});
|
||||
|
||||
await promptTransaction.save();
|
||||
|
||||
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body.totals.prompt.usd).toBe(0.5);
|
||||
});
|
||||
|
||||
it('should filter by user correctly', async () => {
|
||||
const { getConvo } = require('~/models');
|
||||
getConvo.mockResolvedValue({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
});
|
||||
|
||||
const conversation = new Conversation({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
createdAt: new Date('2024-01-01T09:00:00Z'),
|
||||
});
|
||||
|
||||
await conversation.save();
|
||||
|
||||
const otherUserId = 'other-user-id';
|
||||
|
||||
const userMessage = new Message({
|
||||
messageId: 'user-msg-1',
|
||||
conversationId,
|
||||
user: userId,
|
||||
isCreatedByUser: true,
|
||||
tokenCount: 100,
|
||||
createdAt: new Date('2024-01-01T10:00:00Z'),
|
||||
});
|
||||
|
||||
const otherUserMessage = new Message({
|
||||
messageId: 'other-user-msg-1',
|
||||
conversationId,
|
||||
user: otherUserId,
|
||||
isCreatedByUser: true,
|
||||
tokenCount: 200,
|
||||
createdAt: new Date('2024-01-01T10:00:00Z'),
|
||||
});
|
||||
|
||||
await Promise.all([userMessage.save(), otherUserMessage.save()]);
|
||||
|
||||
const userTransaction = new Transaction({
|
||||
conversationId,
|
||||
user: userId,
|
||||
tokenType: 'prompt',
|
||||
tokenValue: 500000,
|
||||
createdAt: new Date('2024-01-01T10:00:30Z'),
|
||||
});
|
||||
|
||||
const otherUserTransaction = new Transaction({
|
||||
conversationId,
|
||||
user: otherUserId,
|
||||
tokenType: 'prompt',
|
||||
tokenValue: 1000000,
|
||||
createdAt: new Date('2024-01-01T10:00:30Z'),
|
||||
});
|
||||
|
||||
await Promise.all([userTransaction.save(), otherUserTransaction.save()]);
|
||||
|
||||
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body.totals.prompt.usd).toBe(0.5);
|
||||
expect(response.body.perMessage).toHaveLength(1);
|
||||
expect(response.body.perMessage[0].messageId).toBe('user-msg-1');
|
||||
});
|
||||
|
||||
it('should filter transactions by tokenType', async () => {
|
||||
const { getConvo } = require('~/models');
|
||||
getConvo.mockResolvedValue({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
});
|
||||
|
||||
const conversation = new Conversation({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
createdAt: new Date('2024-01-01T09:00:00Z'),
|
||||
});
|
||||
|
||||
await conversation.save();
|
||||
|
||||
const userMessage = new Message({
|
||||
messageId: 'user-msg-1',
|
||||
conversationId,
|
||||
user: userId,
|
||||
isCreatedByUser: true,
|
||||
tokenCount: 100,
|
||||
createdAt: new Date('2024-01-01T10:00:00Z'),
|
||||
});
|
||||
|
||||
await userMessage.save();
|
||||
|
||||
const promptTransaction = new Transaction({
|
||||
conversationId,
|
||||
user: userId,
|
||||
tokenType: 'prompt',
|
||||
tokenValue: 500000,
|
||||
createdAt: new Date('2024-01-01T10:00:30Z'),
|
||||
});
|
||||
|
||||
const otherTransaction = new Transaction({
|
||||
conversationId,
|
||||
user: userId,
|
||||
tokenType: 'other',
|
||||
tokenValue: 1000000,
|
||||
createdAt: new Date('2024-01-01T10:00:30Z'),
|
||||
});
|
||||
|
||||
await Promise.all([promptTransaction.save(), otherTransaction.save()]);
|
||||
|
||||
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body.totals.prompt.usd).toBe(0.5);
|
||||
expect(response.body.totals.completion.usd).toBe(0);
|
||||
expect(response.body.totals.total.usd).toBe(0.5);
|
||||
});
|
||||
|
||||
it('should map transactions to messages chronologically', async () => {
|
||||
const { getConvo } = require('~/models');
|
||||
getConvo.mockResolvedValue({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
});
|
||||
|
||||
const conversation = new Conversation({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
createdAt: new Date('2024-01-01T09:00:00Z'),
|
||||
});
|
||||
|
||||
await conversation.save();
|
||||
|
||||
const userMessage1 = new Message({
|
||||
messageId: 'user-msg-1',
|
||||
conversationId,
|
||||
user: userId,
|
||||
isCreatedByUser: true,
|
||||
tokenCount: 100,
|
||||
createdAt: new Date('2024-01-01T10:00:00Z'),
|
||||
});
|
||||
|
||||
const userMessage2 = new Message({
|
||||
messageId: 'user-msg-2',
|
||||
conversationId,
|
||||
user: userId,
|
||||
isCreatedByUser: true,
|
||||
tokenCount: 200,
|
||||
createdAt: new Date('2024-01-01T10:01:00Z'),
|
||||
});
|
||||
|
||||
await Promise.all([userMessage1.save(), userMessage2.save()]);
|
||||
|
||||
const promptTransaction1 = new Transaction({
|
||||
conversationId,
|
||||
user: userId,
|
||||
tokenType: 'prompt',
|
||||
tokenValue: 500000,
|
||||
createdAt: new Date('2024-01-01T10:00:30Z'),
|
||||
});
|
||||
|
||||
const promptTransaction2 = new Transaction({
|
||||
conversationId,
|
||||
user: userId,
|
||||
tokenType: 'prompt',
|
||||
tokenValue: 1000000,
|
||||
createdAt: new Date('2024-01-01T10:01:30Z'),
|
||||
});
|
||||
|
||||
await Promise.all([promptTransaction1.save(), promptTransaction2.save()]);
|
||||
|
||||
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body.perMessage).toHaveLength(2);
|
||||
expect(response.body.perMessage[0].messageId).toBe('user-msg-1');
|
||||
expect(response.body.perMessage[0].usd).toBe(0.5);
|
||||
expect(response.body.perMessage[1].messageId).toBe('user-msg-2');
|
||||
expect(response.body.perMessage[1].usd).toBe(1.0);
|
||||
});
|
||||
|
||||
it('should handle database errors', async () => {
|
||||
const { getConvo } = require('~/models');
|
||||
getConvo.mockResolvedValue({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
});
|
||||
|
||||
const conversation = new Conversation({
|
||||
conversationId,
|
||||
user: userId,
|
||||
title: 'Test Conversation',
|
||||
createdAt: new Date('2024-01-01T09:00:00Z'),
|
||||
});
|
||||
|
||||
await conversation.save();
|
||||
|
||||
await mongoose.connection.close();
|
||||
|
||||
const response = await request(app).get(`/api/messages/${conversationId}/costs`);
|
||||
|
||||
expect(response.status).toBe(500);
|
||||
expect(response.body).toHaveProperty('error');
|
||||
});
|
||||
});
|
||||
});
|
||||
66
api/server/routes/admin/auth.js
Normal file
66
api/server/routes/admin/auth.js
Normal file
@@ -0,0 +1,66 @@
|
||||
const express = require('express');
|
||||
const passport = require('passport');
|
||||
const { randomState } = require('openid-client');
|
||||
const { createSetBalanceConfig } = require('@librechat/api');
|
||||
const { loginController } = require('~/server/controllers/auth/LoginController');
|
||||
const { createOAuthHandler } = require('~/server/controllers/auth/oauth');
|
||||
const { getAppConfig } = require('~/server/services/Config');
|
||||
const { getOpenIdConfig } = require('~/strategies');
|
||||
const middleware = require('~/server/middleware');
|
||||
const { Balance } = require('~/db/models');
|
||||
|
||||
const setBalanceConfig = createSetBalanceConfig({
|
||||
getAppConfig,
|
||||
Balance,
|
||||
});
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
router.post(
|
||||
'/login/local',
|
||||
middleware.logHeaders,
|
||||
middleware.loginLimiter,
|
||||
middleware.checkBan,
|
||||
middleware.requireLocalAuth,
|
||||
middleware.requireAdmin,
|
||||
setBalanceConfig,
|
||||
loginController,
|
||||
);
|
||||
|
||||
router.get('/verify', middleware.requireJwtAuth, middleware.requireAdmin, (req, res) => {
|
||||
const { password: _p, totpSecret: _t, __v, ...user } = req.user;
|
||||
user.id = user._id.toString();
|
||||
res.status(200).json({ user });
|
||||
});
|
||||
|
||||
router.get('/oauth/openid/check', (req, res) => {
|
||||
const openidConfig = getOpenIdConfig();
|
||||
if (!openidConfig) {
|
||||
return res.status(404).json({ message: 'OpenID configuration not found' });
|
||||
}
|
||||
res.status(200).json({ message: 'OpenID check successful' });
|
||||
});
|
||||
|
||||
router.get('/oauth/openid', (req, res, next) => {
|
||||
return passport.authenticate('openidAdmin', {
|
||||
session: false,
|
||||
state: randomState(),
|
||||
})(req, res, next);
|
||||
});
|
||||
|
||||
router.get(
|
||||
'/oauth/openid/callback',
|
||||
passport.authenticate('openidAdmin', {
|
||||
failureRedirect: `${process.env.DOMAIN_CLIENT}/oauth/error`,
|
||||
failureMessage: true,
|
||||
session: false,
|
||||
}),
|
||||
middleware.requireAdmin,
|
||||
setBalanceConfig,
|
||||
middleware.checkDomainAllowed,
|
||||
createOAuthHandler(
|
||||
(process.env.ADMIN_PANEL_URL || 'http://localhost:3000') + '/auth/openid/callback',
|
||||
),
|
||||
);
|
||||
|
||||
module.exports = router;
|
||||
@@ -1,6 +1,7 @@
|
||||
const accessPermissions = require('./accessPermissions');
|
||||
const assistants = require('./assistants');
|
||||
const categories = require('./categories');
|
||||
const adminAuth = require('./admin/auth');
|
||||
const tokenizer = require('./tokenizer');
|
||||
const endpoints = require('./endpoints');
|
||||
const staticRoute = require('./static');
|
||||
@@ -32,6 +33,7 @@ module.exports = {
|
||||
mcp,
|
||||
edit,
|
||||
auth,
|
||||
adminAuth,
|
||||
keys,
|
||||
user,
|
||||
tags,
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
const { Router } = require('express');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys, Constants } = require('librechat-data-provider');
|
||||
const { MCPOAuthHandler, getUserMCPAuthMap } = require('@librechat/api');
|
||||
const { getMCPManager, getFlowStateManager, getOAuthReconnectionManager } = require('~/config');
|
||||
const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP');
|
||||
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
|
||||
const { updateMCPUserTools } = require('~/server/services/Config/mcpToolsCache');
|
||||
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
|
||||
const { CacheKeys, Constants } = require('librechat-data-provider');
|
||||
const { getMCPManager, getFlowStateManager } = require('~/config');
|
||||
const { reinitMCPServer } = require('~/server/services/Tools/mcp');
|
||||
const { requireJwtAuth } = require('~/server/middleware');
|
||||
const { findPluginAuthsByKeys } = require('~/models');
|
||||
@@ -144,6 +144,10 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
|
||||
`[MCP OAuth] Successfully reconnected ${serverName} for user ${flowState.userId}`,
|
||||
);
|
||||
|
||||
// clear any reconnection attempts
|
||||
const oauthReconnectionManager = getOAuthReconnectionManager();
|
||||
oauthReconnectionManager.clearReconnection(flowState.userId, serverName);
|
||||
|
||||
const tools = await userConnection.fetchTools();
|
||||
await updateMCPUserTools({
|
||||
userId: flowState.userId,
|
||||
|
||||
@@ -11,7 +11,6 @@ const {
|
||||
} = require('~/models');
|
||||
const { findAllArtifacts, replaceArtifactContent } = require('~/server/services/Artifacts/update');
|
||||
const { requireJwtAuth, validateMessageReq } = require('~/server/middleware');
|
||||
const { tokenValues, getValueKey, defaultRate } = require('~/models/tx');
|
||||
const { cleanUpPrimaryKeyValue } = require('~/lib/utils/misc');
|
||||
const { getConvosQueried } = require('~/models/Conversation');
|
||||
const { countTokens } = require('~/server/utils');
|
||||
@@ -161,41 +160,6 @@ router.post('/artifact/:messageId', async (req, res) => {
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* POST /costs
|
||||
* Get cost information for models in modelHistory array
|
||||
*/
|
||||
router.post('/costs', async (req, res) => {
|
||||
try {
|
||||
const { modelHistory } = req.body;
|
||||
|
||||
if (!Array.isArray(modelHistory)) {
|
||||
return res.status(400).json({ error: 'modelHistory must be an array' });
|
||||
}
|
||||
|
||||
const modelCostTable = {};
|
||||
|
||||
modelHistory.forEach((modelEntry) => {
|
||||
if (modelEntry && typeof modelEntry === 'object' && modelEntry.model && modelEntry.endpoint) {
|
||||
const { model, endpoint } = modelEntry;
|
||||
|
||||
const valueKey = getValueKey(model, endpoint);
|
||||
const pricing = tokenValues[valueKey];
|
||||
|
||||
modelCostTable[model] = {
|
||||
prompt: pricing?.prompt ?? defaultRate,
|
||||
completion: pricing?.completion ?? defaultRate,
|
||||
};
|
||||
}
|
||||
});
|
||||
|
||||
res.status(200).json({ modelCostTable });
|
||||
} catch (error) {
|
||||
logger.error('Error fetching model costs:', error);
|
||||
res.status(500).json({ error: 'Internal server error' });
|
||||
}
|
||||
});
|
||||
|
||||
/* Note: It's necessary to add `validateMessageReq` within route definition for correct params */
|
||||
router.get('/:conversationId', validateMessageReq, async (req, res) => {
|
||||
try {
|
||||
|
||||
@@ -4,10 +4,9 @@ const passport = require('passport');
|
||||
const { randomState } = require('openid-client');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { ErrorTypes } = require('librechat-data-provider');
|
||||
const { isEnabled, createSetBalanceConfig } = require('@librechat/api');
|
||||
const { checkDomainAllowed, loginLimiter, logHeaders, checkBan } = require('~/server/middleware');
|
||||
const { syncUserEntraGroupMemberships } = require('~/server/services/PermissionService');
|
||||
const { setAuthTokens, setOpenIDAuthTokens } = require('~/server/services/AuthService');
|
||||
const { createSetBalanceConfig } = require('@librechat/api');
|
||||
const { checkDomainAllowed, loginLimiter, logHeaders } = require('~/server/middleware');
|
||||
const { createOAuthHandler } = require('~/server/controllers/auth/oauth');
|
||||
const { getAppConfig } = require('~/server/services/Config');
|
||||
const { Balance } = require('~/db/models');
|
||||
|
||||
@@ -26,32 +25,7 @@ const domains = {
|
||||
router.use(logHeaders);
|
||||
router.use(loginLimiter);
|
||||
|
||||
const oauthHandler = async (req, res, next) => {
|
||||
try {
|
||||
if (res.headersSent) {
|
||||
return;
|
||||
}
|
||||
|
||||
await checkBan(req, res);
|
||||
if (req.banned) {
|
||||
return;
|
||||
}
|
||||
if (
|
||||
req.user &&
|
||||
req.user.provider == 'openid' &&
|
||||
isEnabled(process.env.OPENID_REUSE_TOKENS) === true
|
||||
) {
|
||||
await syncUserEntraGroupMemberships(req.user, req.user.tokenset.access_token);
|
||||
setOpenIDAuthTokens(req.user.tokenset, res, req.user._id.toString());
|
||||
} else {
|
||||
await setAuthTokens(req.user._id, res);
|
||||
}
|
||||
res.redirect(domains.client);
|
||||
} catch (err) {
|
||||
logger.error('Error in setting authentication tokens:', err);
|
||||
next(err);
|
||||
}
|
||||
};
|
||||
const oauthHandler = createOAuthHandler();
|
||||
|
||||
router.get('/error', (req, res) => {
|
||||
/** A single error message is pushed by passport when authentication fails. */
|
||||
|
||||
@@ -357,23 +357,18 @@ const resetPassword = async (userId, token, password) => {
|
||||
|
||||
/**
|
||||
* Set Auth Tokens
|
||||
*
|
||||
* @param {String | ObjectId} userId
|
||||
* @param {Object} res
|
||||
* @param {String} sessionId
|
||||
* @param {ServerResponse} res
|
||||
* @param {ISession | null} [session=null]
|
||||
* @returns
|
||||
*/
|
||||
const setAuthTokens = async (userId, res, sessionId = null) => {
|
||||
const setAuthTokens = async (userId, res, _session = null) => {
|
||||
try {
|
||||
const user = await getUserById(userId);
|
||||
const token = await generateToken(user);
|
||||
|
||||
let session;
|
||||
let session = _session;
|
||||
let refreshToken;
|
||||
let refreshTokenExpires;
|
||||
|
||||
if (sessionId) {
|
||||
session = await findSession({ sessionId: sessionId }, { lean: false });
|
||||
if (session && session._id && session.expiration != null) {
|
||||
refreshTokenExpires = session.expiration.getTime();
|
||||
refreshToken = await generateRefreshToken(session);
|
||||
} else {
|
||||
@@ -383,6 +378,9 @@ const setAuthTokens = async (userId, res, sessionId = null) => {
|
||||
refreshTokenExpires = session.expiration.getTime();
|
||||
}
|
||||
|
||||
const user = await getUserById(userId);
|
||||
const token = await generateToken(user);
|
||||
|
||||
res.cookie('refreshToken', refreshToken, {
|
||||
expires: new Date(refreshTokenExpires),
|
||||
httpOnly: true,
|
||||
|
||||
@@ -20,8 +20,8 @@ const {
|
||||
ContentTypes,
|
||||
isAssistantsEndpoint,
|
||||
} = require('librechat-data-provider');
|
||||
const { getMCPManager, getFlowStateManager, getOAuthReconnectionManager } = require('~/config');
|
||||
const { findToken, createToken, updateToken } = require('~/models');
|
||||
const { getMCPManager, getFlowStateManager } = require('~/config');
|
||||
const { getCachedTools, getAppConfig } = require('./Config');
|
||||
const { reinitMCPServer } = require('./Tools/mcp');
|
||||
const { getLogStores } = require('~/cache');
|
||||
@@ -538,13 +538,20 @@ async function getServerConnectionStatus(
|
||||
const baseConnectionState = getConnectionState();
|
||||
let finalConnectionState = baseConnectionState;
|
||||
|
||||
// connection state overrides specific to OAuth servers
|
||||
if (baseConnectionState === 'disconnected' && oauthServers.has(serverName)) {
|
||||
const { hasActiveFlow, hasFailedFlow } = await checkOAuthFlowStatus(userId, serverName);
|
||||
|
||||
if (hasFailedFlow) {
|
||||
finalConnectionState = 'error';
|
||||
} else if (hasActiveFlow) {
|
||||
// check if server is actively being reconnected
|
||||
const oauthReconnectionManager = getOAuthReconnectionManager();
|
||||
if (oauthReconnectionManager.isReconnecting(userId, serverName)) {
|
||||
finalConnectionState = 'connecting';
|
||||
} else {
|
||||
const { hasActiveFlow, hasFailedFlow } = await checkOAuthFlowStatus(userId, serverName);
|
||||
|
||||
if (hasFailedFlow) {
|
||||
finalConnectionState = 'error';
|
||||
} else if (hasActiveFlow) {
|
||||
finalConnectionState = 'connecting';
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ jest.mock('./Config', () => ({
|
||||
jest.mock('~/config', () => ({
|
||||
getMCPManager: jest.fn(),
|
||||
getFlowStateManager: jest.fn(),
|
||||
getOAuthReconnectionManager: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/cache', () => ({
|
||||
@@ -48,6 +49,7 @@ describe('tests for the new helper functions used by the MCP connection status e
|
||||
let mockGetMCPManager;
|
||||
let mockGetFlowStateManager;
|
||||
let mockGetLogStores;
|
||||
let mockGetOAuthReconnectionManager;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
@@ -56,6 +58,7 @@ describe('tests for the new helper functions used by the MCP connection status e
|
||||
mockGetMCPManager = require('~/config').getMCPManager;
|
||||
mockGetFlowStateManager = require('~/config').getFlowStateManager;
|
||||
mockGetLogStores = require('~/cache').getLogStores;
|
||||
mockGetOAuthReconnectionManager = require('~/config').getOAuthReconnectionManager;
|
||||
});
|
||||
|
||||
describe('getMCPSetupData', () => {
|
||||
@@ -354,6 +357,12 @@ describe('tests for the new helper functions used by the MCP connection status e
|
||||
const userConnections = new Map();
|
||||
const oauthServers = new Set([mockServerName]);
|
||||
|
||||
// Mock OAuthReconnectionManager
|
||||
const mockOAuthReconnectionManager = {
|
||||
isReconnecting: jest.fn(() => false),
|
||||
};
|
||||
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
|
||||
|
||||
const result = await getServerConnectionStatus(
|
||||
mockUserId,
|
||||
mockServerName,
|
||||
@@ -370,6 +379,12 @@ describe('tests for the new helper functions used by the MCP connection status e
|
||||
const userConnections = new Map();
|
||||
const oauthServers = new Set([mockServerName]);
|
||||
|
||||
// Mock OAuthReconnectionManager
|
||||
const mockOAuthReconnectionManager = {
|
||||
isReconnecting: jest.fn(() => false),
|
||||
};
|
||||
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
|
||||
|
||||
// Mock flow state to return failed flow
|
||||
const mockFlowManager = {
|
||||
getFlowState: jest.fn(() => ({
|
||||
@@ -401,6 +416,12 @@ describe('tests for the new helper functions used by the MCP connection status e
|
||||
const userConnections = new Map();
|
||||
const oauthServers = new Set([mockServerName]);
|
||||
|
||||
// Mock OAuthReconnectionManager
|
||||
const mockOAuthReconnectionManager = {
|
||||
isReconnecting: jest.fn(() => false),
|
||||
};
|
||||
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
|
||||
|
||||
// Mock flow state to return active flow
|
||||
const mockFlowManager = {
|
||||
getFlowState: jest.fn(() => ({
|
||||
@@ -432,6 +453,12 @@ describe('tests for the new helper functions used by the MCP connection status e
|
||||
const userConnections = new Map();
|
||||
const oauthServers = new Set([mockServerName]);
|
||||
|
||||
// Mock OAuthReconnectionManager
|
||||
const mockOAuthReconnectionManager = {
|
||||
isReconnecting: jest.fn(() => false),
|
||||
};
|
||||
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
|
||||
|
||||
// Mock flow state to return no flow
|
||||
const mockFlowManager = {
|
||||
getFlowState: jest.fn(() => null),
|
||||
@@ -454,6 +481,35 @@ describe('tests for the new helper functions used by the MCP connection status e
|
||||
});
|
||||
});
|
||||
|
||||
it('should return connecting state when OAuth server is reconnecting', async () => {
|
||||
const appConnections = new Map();
|
||||
const userConnections = new Map();
|
||||
const oauthServers = new Set([mockServerName]);
|
||||
|
||||
// Mock OAuthReconnectionManager to return true for isReconnecting
|
||||
const mockOAuthReconnectionManager = {
|
||||
isReconnecting: jest.fn(() => true),
|
||||
};
|
||||
mockGetOAuthReconnectionManager.mockReturnValue(mockOAuthReconnectionManager);
|
||||
|
||||
const result = await getServerConnectionStatus(
|
||||
mockUserId,
|
||||
mockServerName,
|
||||
appConnections,
|
||||
userConnections,
|
||||
oauthServers,
|
||||
);
|
||||
|
||||
expect(result).toEqual({
|
||||
requiresOAuth: true,
|
||||
connectionState: 'connecting',
|
||||
});
|
||||
expect(mockOAuthReconnectionManager.isReconnecting).toHaveBeenCalledWith(
|
||||
mockUserId,
|
||||
mockServerName,
|
||||
);
|
||||
});
|
||||
|
||||
it('should not check OAuth flow status when server is connected', async () => {
|
||||
const mockFlowManager = {
|
||||
getFlowState: jest.fn(),
|
||||
|
||||
@@ -313,7 +313,7 @@ const ensurePrincipalExists = async function (principal) {
|
||||
idOnTheSource: principal.idOnTheSource,
|
||||
};
|
||||
|
||||
const userId = await createUser(userData, true, false);
|
||||
const userId = await createUser(userData, true, true);
|
||||
return userId.toString();
|
||||
}
|
||||
|
||||
|
||||
@@ -88,7 +88,6 @@ async function saveUserMessage(req, params) {
|
||||
parentMessageId: params.parentMessageId ?? Constants.NO_PARENT,
|
||||
/* For messages, use the assistant_id instead of model */
|
||||
model: params.assistant_id,
|
||||
targetModel: params.model,
|
||||
thread_id: params.thread_id,
|
||||
sender: 'User',
|
||||
text: params.text,
|
||||
|
||||
26
api/server/services/initializeOAuthReconnectManager.js
Normal file
26
api/server/services/initializeOAuthReconnectManager.js
Normal file
@@ -0,0 +1,26 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const { createOAuthReconnectionManager, getFlowStateManager } = require('~/config');
|
||||
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
/**
|
||||
* Initialize OAuth reconnect manager
|
||||
*/
|
||||
async function initializeOAuthReconnectManager() {
|
||||
try {
|
||||
const flowManager = getFlowStateManager(getLogStores(CacheKeys.FLOWS));
|
||||
const tokenMethods = {
|
||||
findToken,
|
||||
updateToken,
|
||||
createToken,
|
||||
deleteTokens,
|
||||
};
|
||||
await createOAuthReconnectionManager(flowManager, tokenMethods);
|
||||
logger.info(`OAuth reconnect manager initialized successfully.`);
|
||||
} catch (error) {
|
||||
logger.error('Failed to initialize OAuth reconnect manager:', error);
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = initializeOAuthReconnectManager;
|
||||
@@ -1,14 +1,14 @@
|
||||
const appleLogin = require('./appleStrategy');
|
||||
const { setupOpenId, getOpenIdConfig } = require('./openidStrategy');
|
||||
const openIdJwtLogin = require('./openIdJwtStrategy');
|
||||
const facebookLogin = require('./facebookStrategy');
|
||||
const discordLogin = require('./discordStrategy');
|
||||
const passportLogin = require('./localStrategy');
|
||||
const googleLogin = require('./googleStrategy');
|
||||
const githubLogin = require('./githubStrategy');
|
||||
const discordLogin = require('./discordStrategy');
|
||||
const facebookLogin = require('./facebookStrategy');
|
||||
const { setupOpenId, getOpenIdConfig } = require('./openidStrategy');
|
||||
const jwtLogin = require('./jwtStrategy');
|
||||
const ldapLogin = require('./ldapStrategy');
|
||||
const { setupSaml } = require('./samlStrategy');
|
||||
const openIdJwtLogin = require('./openIdJwtStrategy');
|
||||
const appleLogin = require('./appleStrategy');
|
||||
const ldapLogin = require('./ldapStrategy');
|
||||
const jwtLogin = require('./jwtStrategy');
|
||||
|
||||
module.exports = {
|
||||
appleLogin,
|
||||
|
||||
@@ -281,6 +281,221 @@ function convertToUsername(input, defaultValue = '') {
|
||||
return defaultValue;
|
||||
}
|
||||
|
||||
/**
|
||||
* Process OpenID authentication tokenset and userinfo
|
||||
* This is the core logic extracted from the passport strategy callback
|
||||
* Can be reused by both the passport strategy and proxy authentication
|
||||
*
|
||||
* @param {Object} tokenset - The OpenID tokenset containing access_token, id_token, etc.
|
||||
* @param {boolean} existingUsersOnly - If true, only existing users will be processed
|
||||
* @returns {Promise<Object>} The authenticated user object with tokenset
|
||||
*/
|
||||
async function processOpenIDAuth(tokenset, existingUsersOnly = false) {
|
||||
const claims = tokenset.claims ? tokenset.claims() : tokenset;
|
||||
const userinfo = {
|
||||
...claims,
|
||||
};
|
||||
|
||||
// Get userinfo from provider if we have access_token and haven't already
|
||||
if (tokenset.access_token) {
|
||||
const providerUserinfo = await getUserInfo(openidConfig, tokenset.access_token, claims.sub);
|
||||
Object.assign(userinfo, providerUserinfo);
|
||||
}
|
||||
|
||||
const appConfig = await getAppConfig();
|
||||
if (!isEmailDomainAllowed(userinfo.email, appConfig?.registration?.allowedDomains)) {
|
||||
logger.error(
|
||||
`[OpenID Auth] Authentication blocked - email domain not allowed [Email: ${userinfo.email}]`,
|
||||
);
|
||||
throw new Error('Email domain not allowed');
|
||||
}
|
||||
|
||||
const result = await findOpenIDUser({
|
||||
openidId: claims.sub || userinfo.sub,
|
||||
email: claims.email || userinfo.email,
|
||||
strategyName: 'openidStrategy',
|
||||
findUser,
|
||||
});
|
||||
let user = result.user;
|
||||
const error = result.error;
|
||||
|
||||
if (error) {
|
||||
throw new Error(ErrorTypes.AUTH_FAILED);
|
||||
}
|
||||
|
||||
const fullName = getFullName(userinfo);
|
||||
|
||||
/** Required role if configured */
|
||||
const requiredRole = process.env.OPENID_REQUIRED_ROLE;
|
||||
if (requiredRole) {
|
||||
const requiredRoleParameterPath = process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH;
|
||||
const requiredRoleTokenKind = process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND;
|
||||
|
||||
let decodedToken = '';
|
||||
if (requiredRoleTokenKind === 'access' && tokenset.access_token) {
|
||||
decodedToken = jwtDecode(tokenset.access_token);
|
||||
} else if (requiredRoleTokenKind === 'id' && tokenset.id_token) {
|
||||
decodedToken = jwtDecode(tokenset.id_token);
|
||||
} else if (userinfo.roles) {
|
||||
// If roles are already in userinfo, use them directly
|
||||
const roles = Array.isArray(userinfo.roles) ? userinfo.roles : [userinfo.roles];
|
||||
if (!roles.includes(requiredRole)) {
|
||||
throw new Error(`You must have the "${requiredRole}" role to log in.`);
|
||||
}
|
||||
} else if (requiredRoleParameterPath) {
|
||||
const pathParts = requiredRoleParameterPath.split('.');
|
||||
let found = true;
|
||||
let roles = pathParts.reduce((o, key) => {
|
||||
if (o === null || o === undefined || !(key in o)) {
|
||||
found = false;
|
||||
return [];
|
||||
}
|
||||
return o[key];
|
||||
}, decodedToken);
|
||||
|
||||
if (!found) {
|
||||
logger.error(
|
||||
`[OpenID Auth] Key '${requiredRoleParameterPath}' not found in ${requiredRoleTokenKind} token!`,
|
||||
);
|
||||
}
|
||||
|
||||
if (!roles.includes(requiredRole)) {
|
||||
throw new Error(`You must have the "${requiredRole}" role to log in.`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let username = '';
|
||||
if (process.env.OPENID_USERNAME_CLAIM) {
|
||||
username = userinfo[process.env.OPENID_USERNAME_CLAIM];
|
||||
} else {
|
||||
username = convertToUsername(
|
||||
userinfo.preferred_username || userinfo.username || userinfo.email,
|
||||
);
|
||||
}
|
||||
|
||||
if (existingUsersOnly && !user) {
|
||||
throw new Error('User does not exist');
|
||||
}
|
||||
|
||||
if (!user) {
|
||||
user = {
|
||||
provider: 'openid',
|
||||
openidId: userinfo.sub,
|
||||
username,
|
||||
email: userinfo.email || '',
|
||||
emailVerified: userinfo.email_verified || false,
|
||||
name: fullName,
|
||||
idOnTheSource: userinfo.oid,
|
||||
};
|
||||
|
||||
const balanceConfig = getBalanceConfig(appConfig);
|
||||
user = await createUser(user, balanceConfig, true, true);
|
||||
} else {
|
||||
user.provider = 'openid';
|
||||
user.openidId = userinfo.sub;
|
||||
user.username = username;
|
||||
user.name = fullName;
|
||||
user.idOnTheSource = userinfo.oid;
|
||||
}
|
||||
|
||||
// Handle avatar
|
||||
if (!!userinfo && userinfo.picture && !user.avatar?.includes('manual=true')) {
|
||||
const imageUrl = userinfo.picture;
|
||||
let fileName;
|
||||
if (crypto) {
|
||||
fileName = (await hashToken(userinfo.sub)) + '.png';
|
||||
} else {
|
||||
fileName = userinfo.sub + '.png';
|
||||
}
|
||||
|
||||
const imageBuffer = await downloadImage(
|
||||
imageUrl,
|
||||
openidConfig,
|
||||
tokenset.access_token,
|
||||
userinfo.sub,
|
||||
);
|
||||
if (imageBuffer) {
|
||||
const { saveBuffer } = getStrategyFunctions(
|
||||
appConfig?.fileStrategy ?? process.env.CDN_PROVIDER,
|
||||
);
|
||||
const imagePath = await saveBuffer({
|
||||
fileName,
|
||||
userId: user._id.toString(),
|
||||
buffer: imageBuffer,
|
||||
});
|
||||
user.avatar = imagePath ?? '';
|
||||
}
|
||||
}
|
||||
|
||||
user = await updateUser(user._id, user);
|
||||
|
||||
logger.info(
|
||||
`[OpenID Auth] login success openidId: ${user.openidId} | email: ${user.email} | username: ${user.username}`,
|
||||
{
|
||||
user: {
|
||||
openidId: user.openidId,
|
||||
username: user.username,
|
||||
email: user.email,
|
||||
name: user.name,
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
return { ...user, tokenset };
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {boolean | undefined} [existingUsersOnly]
|
||||
*/
|
||||
function createOpenIDCallback(existingUsersOnly) {
|
||||
return async (tokenset, done) => {
|
||||
try {
|
||||
const user = await processOpenIDAuth(tokenset, existingUsersOnly);
|
||||
done(null, user);
|
||||
} catch (err) {
|
||||
if (err.message === 'Email domain not allowed') {
|
||||
return done(null, false, { message: err.message });
|
||||
}
|
||||
if (err.message === ErrorTypes.AUTH_FAILED) {
|
||||
return done(null, false, { message: err.message });
|
||||
}
|
||||
if (err.message && err.message.includes('role to log in')) {
|
||||
return done(null, false, { message: err.message });
|
||||
}
|
||||
logger.error('[openidStrategy] login failed', err);
|
||||
done(err);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets up the OpenID strategy specifically for admin authentication.
|
||||
* @param {Configuration} openidConfig
|
||||
*/
|
||||
const setupOpenIdAdmin = (openidConfig) => {
|
||||
try {
|
||||
if (!openidConfig) {
|
||||
throw new Error('OpenID configuration not initialized');
|
||||
}
|
||||
|
||||
const openidAdminLogin = new CustomOpenIDStrategy(
|
||||
{
|
||||
config: openidConfig,
|
||||
scope: process.env.OPENID_SCOPE,
|
||||
usePKCE: isEnabled(process.env.OPENID_USE_PKCE),
|
||||
clockTolerance: process.env.OPENID_CLOCK_TOLERANCE || 300,
|
||||
callbackURL: process.env.DOMAIN_SERVER + '/api/admin/oauth/openid/callback',
|
||||
},
|
||||
createOpenIDCallback(true),
|
||||
);
|
||||
|
||||
passport.use('openidAdmin', openidAdminLogin);
|
||||
} catch (err) {
|
||||
logger.error('[openidStrategy] setupOpenIdAdmin', err);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Sets up the OpenID strategy for authentication.
|
||||
* This function configures the OpenID client, handles proxy settings,
|
||||
@@ -318,10 +533,6 @@ async function setupOpenId() {
|
||||
},
|
||||
);
|
||||
|
||||
const requiredRole = process.env.OPENID_REQUIRED_ROLE;
|
||||
const requiredRoleParameterPath = process.env.OPENID_REQUIRED_ROLE_PARAMETER_PATH;
|
||||
const requiredRoleTokenKind = process.env.OPENID_REQUIRED_ROLE_TOKEN_KIND;
|
||||
const usePKCE = isEnabled(process.env.OPENID_USE_PKCE);
|
||||
logger.info(`[openidStrategy] OpenID authentication configuration`, {
|
||||
generateNonce: shouldGenerateNonce,
|
||||
reason: shouldGenerateNonce
|
||||
@@ -335,159 +546,19 @@ async function setupOpenId() {
|
||||
scope: process.env.OPENID_SCOPE,
|
||||
callbackURL: process.env.DOMAIN_SERVER + process.env.OPENID_CALLBACK_URL,
|
||||
clockTolerance: process.env.OPENID_CLOCK_TOLERANCE || 300,
|
||||
usePKCE,
|
||||
},
|
||||
async (tokenset, done) => {
|
||||
try {
|
||||
const claims = tokenset.claims();
|
||||
const userinfo = {
|
||||
...claims,
|
||||
...(await getUserInfo(openidConfig, tokenset.access_token, claims.sub)),
|
||||
};
|
||||
|
||||
const appConfig = await getAppConfig();
|
||||
if (!isEmailDomainAllowed(userinfo.email, appConfig?.registration?.allowedDomains)) {
|
||||
logger.error(
|
||||
`[OpenID Strategy] Authentication blocked - email domain not allowed [Email: ${userinfo.email}]`,
|
||||
);
|
||||
return done(null, false, { message: 'Email domain not allowed' });
|
||||
}
|
||||
|
||||
const result = await findOpenIDUser({
|
||||
openidId: claims.sub,
|
||||
email: claims.email,
|
||||
strategyName: 'openidStrategy',
|
||||
findUser,
|
||||
});
|
||||
let user = result.user;
|
||||
const error = result.error;
|
||||
|
||||
if (error) {
|
||||
return done(null, false, {
|
||||
message: ErrorTypes.AUTH_FAILED,
|
||||
});
|
||||
}
|
||||
|
||||
const fullName = getFullName(userinfo);
|
||||
|
||||
if (requiredRole) {
|
||||
let decodedToken = '';
|
||||
if (requiredRoleTokenKind === 'access') {
|
||||
decodedToken = jwtDecode(tokenset.access_token);
|
||||
} else if (requiredRoleTokenKind === 'id') {
|
||||
decodedToken = jwtDecode(tokenset.id_token);
|
||||
}
|
||||
const pathParts = requiredRoleParameterPath.split('.');
|
||||
let found = true;
|
||||
let roles = pathParts.reduce((o, key) => {
|
||||
if (o === null || o === undefined || !(key in o)) {
|
||||
found = false;
|
||||
return [];
|
||||
}
|
||||
return o[key];
|
||||
}, decodedToken);
|
||||
|
||||
if (!found) {
|
||||
logger.error(
|
||||
`[openidStrategy] Key '${requiredRoleParameterPath}' not found in ${requiredRoleTokenKind} token!`,
|
||||
);
|
||||
}
|
||||
|
||||
if (!roles.includes(requiredRole)) {
|
||||
return done(null, false, {
|
||||
message: `You must have the "${requiredRole}" role to log in.`,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let username = '';
|
||||
if (process.env.OPENID_USERNAME_CLAIM) {
|
||||
username = userinfo[process.env.OPENID_USERNAME_CLAIM];
|
||||
} else {
|
||||
username = convertToUsername(
|
||||
userinfo.preferred_username || userinfo.username || userinfo.email,
|
||||
);
|
||||
}
|
||||
|
||||
if (!user) {
|
||||
user = {
|
||||
provider: 'openid',
|
||||
openidId: userinfo.sub,
|
||||
username,
|
||||
email: userinfo.email || '',
|
||||
emailVerified: userinfo.email_verified || false,
|
||||
name: fullName,
|
||||
idOnTheSource: userinfo.oid,
|
||||
};
|
||||
|
||||
const balanceConfig = getBalanceConfig(appConfig);
|
||||
user = await createUser(user, balanceConfig, true, true);
|
||||
} else {
|
||||
user.provider = 'openid';
|
||||
user.openidId = userinfo.sub;
|
||||
user.username = username;
|
||||
user.name = fullName;
|
||||
user.idOnTheSource = userinfo.oid;
|
||||
}
|
||||
|
||||
if (!!userinfo && userinfo.picture && !user.avatar?.includes('manual=true')) {
|
||||
/** @type {string | undefined} */
|
||||
const imageUrl = userinfo.picture;
|
||||
|
||||
let fileName;
|
||||
if (crypto) {
|
||||
fileName = (await hashToken(userinfo.sub)) + '.png';
|
||||
} else {
|
||||
fileName = userinfo.sub + '.png';
|
||||
}
|
||||
|
||||
const imageBuffer = await downloadImage(
|
||||
imageUrl,
|
||||
openidConfig,
|
||||
tokenset.access_token,
|
||||
userinfo.sub,
|
||||
);
|
||||
if (imageBuffer) {
|
||||
const { saveBuffer } = getStrategyFunctions(
|
||||
appConfig?.fileStrategy ?? process.env.CDN_PROVIDER,
|
||||
);
|
||||
const imagePath = await saveBuffer({
|
||||
fileName,
|
||||
userId: user._id.toString(),
|
||||
buffer: imageBuffer,
|
||||
});
|
||||
user.avatar = imagePath ?? '';
|
||||
}
|
||||
}
|
||||
|
||||
user = await updateUser(user._id, user);
|
||||
|
||||
logger.info(
|
||||
`[openidStrategy] login success openidId: ${user.openidId} | email: ${user.email} | username: ${user.username} `,
|
||||
{
|
||||
user: {
|
||||
openidId: user.openidId,
|
||||
username: user.username,
|
||||
email: user.email,
|
||||
name: user.name,
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
done(null, { ...user, tokenset });
|
||||
} catch (err) {
|
||||
logger.error('[openidStrategy] login failed', err);
|
||||
done(err);
|
||||
}
|
||||
usePKCE: isEnabled(process.env.OPENID_USE_PKCE),
|
||||
},
|
||||
createOpenIDCallback(),
|
||||
);
|
||||
passport.use('openid', openidLogin);
|
||||
setupOpenIdAdmin(openidConfig);
|
||||
return openidConfig;
|
||||
} catch (err) {
|
||||
logger.error('[openidStrategy]', err);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @function getOpenIdConfig
|
||||
* @description Returns the OpenID client instance.
|
||||
|
||||
@@ -873,6 +873,13 @@
|
||||
* @typedef {import('@librechat/data-schemas').IMongoFile} MongoFile
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports ISession
|
||||
* @typedef {import('@librechat/data-schemas').ISession} ISession
|
||||
* @memberof typedefs
|
||||
*/
|
||||
|
||||
/**
|
||||
* @exports IBalance
|
||||
* @typedef {import('@librechat/data-schemas').IBalance} IBalance
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
import { createContext, useContext } from 'react';
|
||||
|
||||
type MessageContext = {
|
||||
messageId: string;
|
||||
nextType?: string;
|
||||
partIndex?: number;
|
||||
isExpanded: boolean;
|
||||
conversationId?: string | null;
|
||||
/** Submission state for cursor display - only true for latest message when submitting */
|
||||
isSubmitting?: boolean;
|
||||
/** Whether this is the latest message in the conversation */
|
||||
isLatestMessage?: boolean;
|
||||
};
|
||||
|
||||
export const MessageContext = createContext<MessageContext>({} as MessageContext);
|
||||
|
||||
150
client/src/Providers/MessagesViewContext.tsx
Normal file
150
client/src/Providers/MessagesViewContext.tsx
Normal file
@@ -0,0 +1,150 @@
|
||||
import React, { createContext, useContext, useMemo } from 'react';
|
||||
import { useAddedChatContext } from './AddedChatContext';
|
||||
import { useChatContext } from './ChatContext';
|
||||
|
||||
interface MessagesViewContextValue {
|
||||
/** Core conversation data */
|
||||
conversation: ReturnType<typeof useChatContext>['conversation'];
|
||||
conversationId: string | null | undefined;
|
||||
|
||||
/** Submission and control states */
|
||||
isSubmitting: ReturnType<typeof useChatContext>['isSubmitting'];
|
||||
isSubmittingFamily: boolean;
|
||||
abortScroll: ReturnType<typeof useChatContext>['abortScroll'];
|
||||
setAbortScroll: ReturnType<typeof useChatContext>['setAbortScroll'];
|
||||
|
||||
/** Message operations */
|
||||
ask: ReturnType<typeof useChatContext>['ask'];
|
||||
regenerate: ReturnType<typeof useChatContext>['regenerate'];
|
||||
handleContinue: ReturnType<typeof useChatContext>['handleContinue'];
|
||||
|
||||
/** Message state management */
|
||||
index: ReturnType<typeof useChatContext>['index'];
|
||||
latestMessage: ReturnType<typeof useChatContext>['latestMessage'];
|
||||
setLatestMessage: ReturnType<typeof useChatContext>['setLatestMessage'];
|
||||
getMessages: ReturnType<typeof useChatContext>['getMessages'];
|
||||
setMessages: ReturnType<typeof useChatContext>['setMessages'];
|
||||
}
|
||||
|
||||
const MessagesViewContext = createContext<MessagesViewContextValue | undefined>(undefined);
|
||||
|
||||
export function MessagesViewProvider({ children }: { children: React.ReactNode }) {
|
||||
const chatContext = useChatContext();
|
||||
const addedChatContext = useAddedChatContext();
|
||||
|
||||
const {
|
||||
ask,
|
||||
index,
|
||||
regenerate,
|
||||
isSubmitting: isSubmittingRoot,
|
||||
conversation,
|
||||
latestMessage,
|
||||
setAbortScroll,
|
||||
handleContinue,
|
||||
setLatestMessage,
|
||||
abortScroll,
|
||||
getMessages,
|
||||
setMessages,
|
||||
} = chatContext;
|
||||
|
||||
const { isSubmitting: isSubmittingAdditional } = addedChatContext;
|
||||
|
||||
/** Memoize conversation-related values */
|
||||
const conversationValues = useMemo(
|
||||
() => ({
|
||||
conversation,
|
||||
conversationId: conversation?.conversationId,
|
||||
}),
|
||||
[conversation],
|
||||
);
|
||||
|
||||
/** Memoize submission states */
|
||||
const submissionStates = useMemo(
|
||||
() => ({
|
||||
isSubmitting: isSubmittingRoot,
|
||||
isSubmittingFamily: isSubmittingRoot || isSubmittingAdditional,
|
||||
abortScroll,
|
||||
setAbortScroll,
|
||||
}),
|
||||
[isSubmittingRoot, isSubmittingAdditional, abortScroll, setAbortScroll],
|
||||
);
|
||||
|
||||
/** Memoize message operations (these are typically stable references) */
|
||||
const messageOperations = useMemo(
|
||||
() => ({
|
||||
ask,
|
||||
regenerate,
|
||||
getMessages,
|
||||
setMessages,
|
||||
handleContinue,
|
||||
}),
|
||||
[ask, regenerate, handleContinue, getMessages, setMessages],
|
||||
);
|
||||
|
||||
/** Memoize message state values */
|
||||
const messageState = useMemo(
|
||||
() => ({
|
||||
index,
|
||||
latestMessage,
|
||||
setLatestMessage,
|
||||
}),
|
||||
[index, latestMessage, setLatestMessage],
|
||||
);
|
||||
|
||||
/** Combine all values into final context value */
|
||||
const contextValue = useMemo<MessagesViewContextValue>(
|
||||
() => ({
|
||||
...conversationValues,
|
||||
...submissionStates,
|
||||
...messageOperations,
|
||||
...messageState,
|
||||
}),
|
||||
[conversationValues, submissionStates, messageOperations, messageState],
|
||||
);
|
||||
|
||||
return (
|
||||
<MessagesViewContext.Provider value={contextValue}>{children}</MessagesViewContext.Provider>
|
||||
);
|
||||
}
|
||||
|
||||
export function useMessagesViewContext() {
|
||||
const context = useContext(MessagesViewContext);
|
||||
if (!context) {
|
||||
throw new Error('useMessagesViewContext must be used within MessagesViewProvider');
|
||||
}
|
||||
return context;
|
||||
}
|
||||
|
||||
/** Hook for components that only need conversation data */
|
||||
export function useMessagesConversation() {
|
||||
const { conversation, conversationId } = useMessagesViewContext();
|
||||
return useMemo(() => ({ conversation, conversationId }), [conversation, conversationId]);
|
||||
}
|
||||
|
||||
/** Hook for components that only need submission states */
|
||||
export function useMessagesSubmission() {
|
||||
const { isSubmitting, isSubmittingFamily, abortScroll, setAbortScroll } =
|
||||
useMessagesViewContext();
|
||||
return useMemo(
|
||||
() => ({ isSubmitting, isSubmittingFamily, abortScroll, setAbortScroll }),
|
||||
[isSubmitting, isSubmittingFamily, abortScroll, setAbortScroll],
|
||||
);
|
||||
}
|
||||
|
||||
/** Hook for components that only need message operations */
|
||||
export function useMessagesOperations() {
|
||||
const { ask, regenerate, handleContinue, getMessages, setMessages } = useMessagesViewContext();
|
||||
return useMemo(
|
||||
() => ({ ask, regenerate, handleContinue, getMessages, setMessages }),
|
||||
[ask, regenerate, handleContinue, getMessages, setMessages],
|
||||
);
|
||||
}
|
||||
|
||||
/** Hook for components that only need message state */
|
||||
export function useMessagesState() {
|
||||
const { index, latestMessage, setLatestMessage } = useMessagesViewContext();
|
||||
return useMemo(
|
||||
() => ({ index, latestMessage, setLatestMessage }),
|
||||
[index, latestMessage, setLatestMessage],
|
||||
);
|
||||
}
|
||||
@@ -26,4 +26,5 @@ export * from './SidePanelContext';
|
||||
export * from './MCPPanelContext';
|
||||
export * from './ArtifactsContext';
|
||||
export * from './PromptGroupsContext';
|
||||
export * from './MessagesViewContext';
|
||||
export { default as BadgeRowProvider } from './BadgeRowContext';
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { memo, useCallback, useState, useEffect, useRef } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useRecoilValue } from 'recoil';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { Spinner } from '@librechat/client';
|
||||
@@ -13,7 +13,6 @@ import { useGetMessagesByConvoId } from '~/data-provider';
|
||||
import MessagesView from './Messages/MessagesView';
|
||||
import Presentation from './Presentation';
|
||||
import ChatForm from './Input/ChatForm';
|
||||
import CostBar from './CostBar';
|
||||
import Landing from './Landing';
|
||||
import Header from './Header';
|
||||
import Footer from './Footer';
|
||||
@@ -30,13 +29,7 @@ function LoadingSpinner() {
|
||||
);
|
||||
}
|
||||
|
||||
function ChatView({
|
||||
index = 0,
|
||||
modelCosts,
|
||||
}: {
|
||||
index?: number;
|
||||
modelCosts?: { modelCostTable: Record<string, { prompt: number; completion: number }> };
|
||||
}) {
|
||||
function ChatView({ index = 0 }: { index?: number }) {
|
||||
const { conversationId } = useParams();
|
||||
const rootSubmission = useRecoilValue(store.submissionByIndex(index));
|
||||
const addedSubmission = useRecoilValue(store.submissionByIndex(index + 1));
|
||||
@@ -44,9 +37,6 @@ function ChatView({
|
||||
|
||||
const fileMap = useFileMapContext();
|
||||
|
||||
const [showCostBar, setShowCostBar] = useState(false);
|
||||
const lastScrollY = useRef(0);
|
||||
|
||||
const { data: messagesTree = null, isLoading } = useGetMessagesByConvoId(conversationId ?? '', {
|
||||
select: useCallback(
|
||||
(data: TMessage[]) => {
|
||||
@@ -64,58 +54,6 @@ function ChatView({
|
||||
useSSE(rootSubmission, chatHelpers, false);
|
||||
useSSE(addedSubmission, addedChatHelpers, true);
|
||||
|
||||
const checkIfAtBottom = useCallback(
|
||||
(container: HTMLElement) => {
|
||||
const currentScrollY = container.scrollTop;
|
||||
const scrollHeight = container.scrollHeight;
|
||||
const clientHeight = container.clientHeight;
|
||||
|
||||
const distanceFromBottom = scrollHeight - currentScrollY - clientHeight;
|
||||
const isAtBottom = distanceFromBottom < 10;
|
||||
|
||||
const isStreaming = chatHelpers.isSubmitting || addedChatHelpers.isSubmitting;
|
||||
setShowCostBar(isAtBottom && !isStreaming);
|
||||
lastScrollY.current = currentScrollY;
|
||||
},
|
||||
[chatHelpers.isSubmitting, addedChatHelpers.isSubmitting],
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
const handleScroll = (event: Event) => {
|
||||
const target = event.target as HTMLElement;
|
||||
checkIfAtBottom(target);
|
||||
};
|
||||
|
||||
const findAndAttachScrollListener = () => {
|
||||
const messagesContainer = document.querySelector('[class*="scrollbar-gutter-stable"]');
|
||||
if (messagesContainer) {
|
||||
checkIfAtBottom(messagesContainer as HTMLElement);
|
||||
|
||||
messagesContainer.addEventListener('scroll', handleScroll, { passive: true });
|
||||
return () => {
|
||||
messagesContainer.removeEventListener('scroll', handleScroll);
|
||||
};
|
||||
}
|
||||
setTimeout(findAndAttachScrollListener, 100);
|
||||
};
|
||||
|
||||
const cleanup = findAndAttachScrollListener();
|
||||
|
||||
return cleanup;
|
||||
}, [messagesTree, checkIfAtBottom]);
|
||||
|
||||
useEffect(() => {
|
||||
const isStreaming = chatHelpers.isSubmitting || addedChatHelpers.isSubmitting;
|
||||
if (isStreaming) {
|
||||
setShowCostBar(false);
|
||||
} else {
|
||||
const messagesContainer = document.querySelector('[class*="scrollbar-gutter-stable"]');
|
||||
if (messagesContainer) {
|
||||
checkIfAtBottom(messagesContainer as HTMLElement);
|
||||
}
|
||||
}
|
||||
}, [chatHelpers.isSubmitting, addedChatHelpers.isSubmitting, checkIfAtBottom]);
|
||||
|
||||
const methods = useForm<ChatFormValues>({
|
||||
defaultValues: { text: '' },
|
||||
});
|
||||
@@ -131,22 +69,7 @@ function ChatView({
|
||||
} else if ((isLoading || isNavigating) && !isLandingPage) {
|
||||
content = <LoadingSpinner />;
|
||||
} else if (!isLandingPage) {
|
||||
const isStreaming = chatHelpers.isSubmitting || addedChatHelpers.isSubmitting;
|
||||
content = (
|
||||
<MessagesView
|
||||
messagesTree={messagesTree}
|
||||
costBar={
|
||||
!isLandingPage &&
|
||||
modelCosts && (
|
||||
<CostBar
|
||||
messagesTree={messagesTree}
|
||||
modelCosts={modelCosts}
|
||||
showCostBar={showCostBar && !isStreaming}
|
||||
/>
|
||||
)
|
||||
}
|
||||
/>
|
||||
);
|
||||
content = <MessagesView messagesTree={messagesTree} />;
|
||||
} else {
|
||||
content = <Landing centerFormOnLanding={centerFormOnLanding} />;
|
||||
}
|
||||
|
||||
@@ -1,112 +0,0 @@
|
||||
import { useMemo } from 'react';
|
||||
import { useRecoilValue } from 'recoil';
|
||||
import { ArrowIcon } from '@librechat/client';
|
||||
import { TModelCosts, TMessage } from 'librechat-data-provider';
|
||||
import { useLocalize } from '~/hooks';
|
||||
import { cn } from '~/utils';
|
||||
import store from '~/store';
|
||||
|
||||
interface CostBarProps {
|
||||
messagesTree: TMessage[];
|
||||
modelCosts: TModelCosts;
|
||||
showCostBar: boolean;
|
||||
}
|
||||
|
||||
export default function CostBar({ messagesTree, modelCosts, showCostBar }: CostBarProps) {
|
||||
const localize = useLocalize();
|
||||
const showCostTracking = useRecoilValue(store.showCostTracking);
|
||||
|
||||
const conversationCosts = useMemo(() => {
|
||||
if (!modelCosts?.modelCostTable || !messagesTree) {
|
||||
return null;
|
||||
}
|
||||
|
||||
let totalPromptTokens = 0;
|
||||
let totalCompletionTokens = 0;
|
||||
let totalPromptUSD = 0;
|
||||
let totalCompletionUSD = 0;
|
||||
|
||||
const flattenMessages = (messages: TMessage[]) => {
|
||||
const flattened: TMessage[] = [];
|
||||
messages.forEach((message: TMessage) => {
|
||||
flattened.push(message);
|
||||
if (message.children && message.children.length > 0) {
|
||||
flattened.push(...flattenMessages(message.children));
|
||||
}
|
||||
});
|
||||
return flattened;
|
||||
};
|
||||
|
||||
const allMessages = flattenMessages(messagesTree);
|
||||
|
||||
allMessages.forEach((message) => {
|
||||
if (!message.tokenCount) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const modelToUse = message.isCreatedByUser ? message.targetModel : message.model;
|
||||
|
||||
const modelPricing = modelCosts.modelCostTable[modelToUse];
|
||||
if (message.isCreatedByUser) {
|
||||
totalPromptTokens += message.tokenCount;
|
||||
totalPromptUSD += (message.tokenCount / 1000000) * modelPricing.prompt;
|
||||
} else {
|
||||
totalCompletionTokens += message.tokenCount;
|
||||
totalCompletionUSD += (message.tokenCount / 1000000) * modelPricing.completion;
|
||||
}
|
||||
});
|
||||
|
||||
const totalTokens = totalPromptTokens + totalCompletionTokens;
|
||||
const totalUSD = totalPromptUSD + totalCompletionUSD;
|
||||
|
||||
return {
|
||||
totals: {
|
||||
prompt: { tokenCount: totalPromptTokens, usd: totalPromptUSD },
|
||||
completion: { tokenCount: totalCompletionTokens, usd: totalCompletionUSD },
|
||||
total: { tokenCount: totalTokens, usd: totalUSD },
|
||||
},
|
||||
};
|
||||
}, [modelCosts, messagesTree]);
|
||||
|
||||
if (!showCostTracking || !conversationCosts || !conversationCosts.totals) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
'mx-auto w-full max-w-md px-4 text-xs text-muted-foreground transition-all duration-300 ease-in-out',
|
||||
showCostBar ? 'opacity-100' : 'opacity-0',
|
||||
)}
|
||||
>
|
||||
<div className="grid grid-cols-3 gap-2 text-center">
|
||||
<div>
|
||||
<div>
|
||||
<ArrowIcon direction="up" />
|
||||
{localize('com_ui_token_abbreviation', {
|
||||
0: conversationCosts.totals.prompt.tokenCount,
|
||||
})}
|
||||
</div>
|
||||
<div>${Math.abs(conversationCosts.totals.prompt.usd).toFixed(6)}</div>
|
||||
</div>
|
||||
<div>
|
||||
<div>
|
||||
{localize('com_ui_token_abbreviation', {
|
||||
0: conversationCosts.totals.total.tokenCount,
|
||||
})}
|
||||
</div>
|
||||
<div>${Math.abs(conversationCosts.totals.total.usd).toFixed(6)}</div>
|
||||
</div>
|
||||
<div>
|
||||
<div>
|
||||
<ArrowIcon direction="down" />
|
||||
{localize('com_ui_token_abbreviation', {
|
||||
0: conversationCosts.totals.completion.tokenCount,
|
||||
})}
|
||||
</div>
|
||||
<div>${Math.abs(conversationCosts.totals.completion.usd).toFixed(6)}</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -26,6 +26,7 @@ type ContentPartsProps = {
|
||||
isCreatedByUser: boolean;
|
||||
isLast: boolean;
|
||||
isSubmitting: boolean;
|
||||
isLatestMessage?: boolean;
|
||||
edit?: boolean;
|
||||
enterEdit?: (cancel?: boolean) => void | null | undefined;
|
||||
siblingIdx?: number;
|
||||
@@ -45,6 +46,7 @@ const ContentParts = memo(
|
||||
isCreatedByUser,
|
||||
isLast,
|
||||
isSubmitting,
|
||||
isLatestMessage,
|
||||
edit,
|
||||
enterEdit,
|
||||
siblingIdx,
|
||||
@@ -55,6 +57,8 @@ const ContentParts = memo(
|
||||
const [isExpanded, setIsExpanded] = useState(showThinking);
|
||||
const attachmentMap = useMemo(() => mapAttachments(attachments ?? []), [attachments]);
|
||||
|
||||
const effectiveIsSubmitting = isLatestMessage ? isSubmitting : false;
|
||||
|
||||
const hasReasoningParts = useMemo(() => {
|
||||
const hasThinkPart = content?.some((part) => part?.type === ContentTypes.THINK) ?? false;
|
||||
const allThinkPartsHaveContent =
|
||||
@@ -134,7 +138,9 @@ const ContentParts = memo(
|
||||
})
|
||||
}
|
||||
label={
|
||||
isSubmitting && isLast ? localize('com_ui_thinking') : localize('com_ui_thoughts')
|
||||
effectiveIsSubmitting && isLast
|
||||
? localize('com_ui_thinking')
|
||||
: localize('com_ui_thoughts')
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
@@ -155,12 +161,14 @@ const ContentParts = memo(
|
||||
conversationId,
|
||||
partIndex: idx,
|
||||
nextType: content[idx + 1]?.type,
|
||||
isSubmitting: effectiveIsSubmitting,
|
||||
isLatestMessage,
|
||||
}}
|
||||
>
|
||||
<Part
|
||||
part={part}
|
||||
attachments={attachments}
|
||||
isSubmitting={isSubmitting}
|
||||
isSubmitting={effectiveIsSubmitting}
|
||||
key={`part-${messageId}-${idx}`}
|
||||
isCreatedByUser={isCreatedByUser}
|
||||
isLast={idx === content.length - 1}
|
||||
|
||||
@@ -4,7 +4,7 @@ import { useRecoilState, useRecoilValue } from 'recoil';
|
||||
import { TextareaAutosize, TooltipAnchor } from '@librechat/client';
|
||||
import { useUpdateMessageMutation } from 'librechat-data-provider/react-query';
|
||||
import type { TEditProps } from '~/common';
|
||||
import { useChatContext, useAddedChatContext } from '~/Providers';
|
||||
import { useMessagesOperations, useMessagesConversation, useAddedChatContext } from '~/Providers';
|
||||
import { cn, removeFocusRings } from '~/utils';
|
||||
import { useLocalize } from '~/hooks';
|
||||
import Container from './Container';
|
||||
@@ -22,7 +22,8 @@ const EditMessage = ({
|
||||
const { addedIndex } = useAddedChatContext();
|
||||
const saveButtonRef = useRef<HTMLButtonElement | null>(null);
|
||||
const submitButtonRef = useRef<HTMLButtonElement | null>(null);
|
||||
const { getMessages, setMessages, conversation } = useChatContext();
|
||||
const { conversation } = useMessagesConversation();
|
||||
const { getMessages, setMessages } = useMessagesOperations();
|
||||
const [latestMultiMessage, setLatestMultiMessage] = useRecoilState(
|
||||
store.latestMessageFamily(addedIndex),
|
||||
);
|
||||
|
||||
@@ -5,7 +5,7 @@ import type { TMessage } from 'librechat-data-provider';
|
||||
import type { TMessageContentProps, TDisplayProps } from '~/common';
|
||||
import Error from '~/components/Messages/Content/Error';
|
||||
import Thinking from '~/components/Artifacts/Thinking';
|
||||
import { useChatContext } from '~/Providers';
|
||||
import { useMessageContext } from '~/Providers';
|
||||
import MarkdownLite from './MarkdownLite';
|
||||
import EditMessage from './EditMessage';
|
||||
import { useLocalize } from '~/hooks';
|
||||
@@ -70,16 +70,12 @@ export const ErrorMessage = ({
|
||||
};
|
||||
|
||||
const DisplayMessage = ({ text, isCreatedByUser, message, showCursor }: TDisplayProps) => {
|
||||
const { isSubmitting, latestMessage } = useChatContext();
|
||||
const { isSubmitting = false, isLatestMessage = false } = useMessageContext();
|
||||
const enableUserMsgMarkdown = useRecoilValue(store.enableUserMsgMarkdown);
|
||||
const showCursorState = useMemo(
|
||||
() => showCursor === true && isSubmitting,
|
||||
[showCursor, isSubmitting],
|
||||
);
|
||||
const isLatestMessage = useMemo(
|
||||
() => message.messageId === latestMessage?.messageId,
|
||||
[message.messageId, latestMessage?.messageId],
|
||||
);
|
||||
|
||||
let content: React.ReactElement;
|
||||
if (!isCreatedByUser) {
|
||||
|
||||
@@ -85,13 +85,14 @@ const Part = memo(
|
||||
|
||||
const isToolCall =
|
||||
'args' in toolCall && (!toolCall.type || toolCall.type === ToolCallTypes.TOOL_CALL);
|
||||
if (isToolCall && toolCall.name === Tools.execute_code && toolCall.args) {
|
||||
if (isToolCall && toolCall.name === Tools.execute_code) {
|
||||
return (
|
||||
<ExecuteCode
|
||||
args={typeof toolCall.args === 'string' ? toolCall.args : ''}
|
||||
attachments={attachments}
|
||||
isSubmitting={isSubmitting}
|
||||
output={toolCall.output ?? ''}
|
||||
initialProgress={toolCall.progress ?? 0.1}
|
||||
attachments={attachments}
|
||||
args={typeof toolCall.args === 'string' ? toolCall.args : ''}
|
||||
/>
|
||||
);
|
||||
} else if (
|
||||
|
||||
@@ -6,8 +6,8 @@ import { useRecoilState, useRecoilValue } from 'recoil';
|
||||
import { useUpdateMessageContentMutation } from 'librechat-data-provider/react-query';
|
||||
import type { Agents } from 'librechat-data-provider';
|
||||
import type { TEditProps } from '~/common';
|
||||
import { useMessagesOperations, useMessagesConversation, useAddedChatContext } from '~/Providers';
|
||||
import Container from '~/components/Chat/Messages/Content/Container';
|
||||
import { useChatContext, useAddedChatContext } from '~/Providers';
|
||||
import { cn, removeFocusRings } from '~/utils';
|
||||
import { useLocalize } from '~/hooks';
|
||||
import store from '~/store';
|
||||
@@ -25,7 +25,8 @@ const EditTextPart = ({
|
||||
}) => {
|
||||
const localize = useLocalize();
|
||||
const { addedIndex } = useAddedChatContext();
|
||||
const { ask, getMessages, setMessages, conversation } = useChatContext();
|
||||
const { conversation } = useMessagesConversation();
|
||||
const { ask, getMessages, setMessages } = useMessagesOperations();
|
||||
const [latestMultiMessage, setLatestMultiMessage] = useRecoilState(
|
||||
store.latestMessageFamily(addedIndex),
|
||||
);
|
||||
|
||||
@@ -45,26 +45,28 @@ export function useParseArgs(args?: string): ParsedArgs | null {
|
||||
}
|
||||
|
||||
export default function ExecuteCode({
|
||||
isSubmitting,
|
||||
initialProgress = 0.1,
|
||||
args,
|
||||
output = '',
|
||||
attachments,
|
||||
}: {
|
||||
initialProgress: number;
|
||||
isSubmitting: boolean;
|
||||
args?: string;
|
||||
output?: string;
|
||||
attachments?: TAttachment[];
|
||||
}) {
|
||||
const localize = useLocalize();
|
||||
const showAnalysisCode = useRecoilValue(store.showCode);
|
||||
const [showCode, setShowCode] = useState(showAnalysisCode);
|
||||
const codeContentRef = useRef<HTMLDivElement>(null);
|
||||
const [contentHeight, setContentHeight] = useState<number | undefined>(0);
|
||||
const [isAnimating, setIsAnimating] = useState(false);
|
||||
const hasOutput = output.length > 0;
|
||||
const outputRef = useRef<string>(output);
|
||||
const prevShowCodeRef = useRef<boolean>(showCode);
|
||||
const codeContentRef = useRef<HTMLDivElement>(null);
|
||||
const [isAnimating, setIsAnimating] = useState(false);
|
||||
const showAnalysisCode = useRecoilValue(store.showCode);
|
||||
const [showCode, setShowCode] = useState(showAnalysisCode);
|
||||
const [contentHeight, setContentHeight] = useState<number | undefined>(0);
|
||||
|
||||
const prevShowCodeRef = useRef<boolean>(showCode);
|
||||
const { lang, code } = useParseArgs(args) ?? ({} as ParsedArgs);
|
||||
const progress = useProgress(initialProgress);
|
||||
|
||||
@@ -136,6 +138,8 @@ export default function ExecuteCode({
|
||||
};
|
||||
}, [showCode, isAnimating]);
|
||||
|
||||
const cancelled = !isSubmitting && progress < 1;
|
||||
|
||||
return (
|
||||
<>
|
||||
<div className="relative my-2.5 flex size-5 shrink-0 items-center gap-2.5">
|
||||
@@ -143,9 +147,12 @@ export default function ExecuteCode({
|
||||
progress={progress}
|
||||
onClick={() => setShowCode((prev) => !prev)}
|
||||
inProgressText={localize('com_ui_analyzing')}
|
||||
finishedText={localize('com_ui_analyzing_finished')}
|
||||
finishedText={
|
||||
cancelled ? localize('com_ui_cancelled') : localize('com_ui_analyzing_finished')
|
||||
}
|
||||
hasInput={!!code?.length}
|
||||
isExpanded={showCode}
|
||||
error={cancelled}
|
||||
/>
|
||||
</div>
|
||||
<div
|
||||
|
||||
@@ -2,7 +2,7 @@ import { memo, useMemo, ReactElement } from 'react';
|
||||
import { useRecoilValue } from 'recoil';
|
||||
import MarkdownLite from '~/components/Chat/Messages/Content/MarkdownLite';
|
||||
import Markdown from '~/components/Chat/Messages/Content/Markdown';
|
||||
import { useChatContext, useMessageContext } from '~/Providers';
|
||||
import { useMessageContext } from '~/Providers';
|
||||
import { cn } from '~/utils';
|
||||
import store from '~/store';
|
||||
|
||||
@@ -18,14 +18,9 @@ type ContentType =
|
||||
| ReactElement;
|
||||
|
||||
const TextPart = memo(({ text, isCreatedByUser, showCursor }: TextPartProps) => {
|
||||
const { messageId } = useMessageContext();
|
||||
const { isSubmitting, latestMessage } = useChatContext();
|
||||
const { isSubmitting = false, isLatestMessage = false } = useMessageContext();
|
||||
const enableUserMsgMarkdown = useRecoilValue(store.enableUserMsgMarkdown);
|
||||
const showCursorState = useMemo(() => showCursor && isSubmitting, [showCursor, isSubmitting]);
|
||||
const isLatestMessage = useMemo(
|
||||
() => messageId === latestMessage?.messageId,
|
||||
[messageId, latestMessage?.messageId],
|
||||
);
|
||||
|
||||
const content: ContentType = useMemo(() => {
|
||||
if (!isCreatedByUser) {
|
||||
|
||||
@@ -21,7 +21,7 @@ type THoverButtons = {
|
||||
latestMessage: TMessage | null;
|
||||
isLast: boolean;
|
||||
index: number;
|
||||
handleFeedback: ({ feedback }: { feedback: TFeedback | undefined }) => void;
|
||||
handleFeedback?: ({ feedback }: { feedback: TFeedback | undefined }) => void;
|
||||
};
|
||||
|
||||
type HoverButtonProps = {
|
||||
@@ -238,7 +238,7 @@ const HoverButtons = ({
|
||||
/>
|
||||
|
||||
{/* Feedback Buttons */}
|
||||
{!isCreatedByUser && (
|
||||
{!isCreatedByUser && handleFeedback != null && (
|
||||
<Feedback handleFeedback={handleFeedback} feedback={message.feedback} isLast={isLast} />
|
||||
)}
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import React from 'react';
|
||||
import { useRecoilValue } from 'recoil';
|
||||
import { useMessageProcess } from '~/hooks';
|
||||
import type { TConversationCosts } from 'librechat-data-provider';
|
||||
import type { TMessageProps } from '~/common';
|
||||
import MessageRender from './ui/MessageRender';
|
||||
// eslint-disable-next-line import/no-cycle
|
||||
@@ -29,7 +28,7 @@ const MessageContainer = React.memo(
|
||||
},
|
||||
);
|
||||
|
||||
export default function Message(props: TMessageProps & { costs?: TConversationCosts }) {
|
||||
export default function Message(props: TMessageProps) {
|
||||
const {
|
||||
showSibling,
|
||||
conversation,
|
||||
@@ -38,7 +37,7 @@ export default function Message(props: TMessageProps & { costs?: TConversationCo
|
||||
latestMultiMessage,
|
||||
isSubmittingFamily,
|
||||
} = useMessageProcess({ message: props.message });
|
||||
const { message, currentEditId, setCurrentEditId, costs } = props;
|
||||
const { message, currentEditId, setCurrentEditId } = props;
|
||||
const maximizeChatSpace = useRecoilValue(store.maximizeChatSpace);
|
||||
|
||||
if (!message || typeof message !== 'object') {
|
||||
@@ -63,7 +62,6 @@ export default function Message(props: TMessageProps & { costs?: TConversationCo
|
||||
message={message}
|
||||
isSubmittingFamily={isSubmittingFamily}
|
||||
isCard
|
||||
costs={costs}
|
||||
/>
|
||||
<MessageRender
|
||||
{...props}
|
||||
@@ -71,13 +69,12 @@ export default function Message(props: TMessageProps & { costs?: TConversationCo
|
||||
isCard
|
||||
message={siblingMessage ?? latestMultiMessage ?? undefined}
|
||||
isSubmittingFamily={isSubmittingFamily}
|
||||
costs={costs}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<div className="m-auto justify-center p-4 py-2 md:gap-6">
|
||||
<MessageRender {...props} costs={costs} />
|
||||
<MessageRender {...props} />
|
||||
</div>
|
||||
)}
|
||||
</MessageContainer>
|
||||
@@ -88,7 +85,6 @@ export default function Message(props: TMessageProps & { costs?: TConversationCo
|
||||
messagesTree={children ?? []}
|
||||
currentEditId={currentEditId}
|
||||
setCurrentEditId={setCurrentEditId}
|
||||
costs={costs}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import React, { useMemo } from 'react';
|
||||
import { useRecoilValue } from 'recoil';
|
||||
import type { TMessageContentParts, TConversationCosts } from 'librechat-data-provider';
|
||||
import type { TMessageContentParts } from 'librechat-data-provider';
|
||||
import type { TMessageProps, TMessageIcon } from '~/common';
|
||||
import { useMessageHelpers, useLocalize, useAttachments } from '~/hooks';
|
||||
import MessageIcon from '~/components/Chat/Messages/MessageIcon';
|
||||
@@ -12,17 +12,10 @@ import SubRow from './SubRow';
|
||||
import { cn } from '~/utils';
|
||||
import store from '~/store';
|
||||
|
||||
export default function Message(props: TMessageProps & { costs?: TConversationCosts }) {
|
||||
export default function Message(props: TMessageProps) {
|
||||
const localize = useLocalize();
|
||||
const {
|
||||
message,
|
||||
siblingIdx,
|
||||
siblingCount,
|
||||
setSiblingIdx,
|
||||
currentEditId,
|
||||
setCurrentEditId,
|
||||
costs,
|
||||
} = props;
|
||||
const { message, siblingIdx, siblingCount, setSiblingIdx, currentEditId, setCurrentEditId } =
|
||||
props;
|
||||
const { attachments, searchResults } = useAttachments({
|
||||
messageId: message?.messageId,
|
||||
attachments: message?.attachments,
|
||||
@@ -132,6 +125,7 @@ export default function Message(props: TMessageProps & { costs?: TConversationCo
|
||||
setSiblingIdx={setSiblingIdx}
|
||||
isCreatedByUser={message.isCreatedByUser}
|
||||
conversationId={conversation?.conversationId}
|
||||
isLatestMessage={messageId === latestMessage?.messageId}
|
||||
content={message.content as Array<TMessageContentParts | undefined>}
|
||||
/>
|
||||
</div>
|
||||
@@ -171,7 +165,6 @@ export default function Message(props: TMessageProps & { costs?: TConversationCo
|
||||
messagesTree={children ?? []}
|
||||
currentEditId={currentEditId}
|
||||
setCurrentEditId={setCurrentEditId}
|
||||
costs={costs}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
|
||||
@@ -1,21 +1,18 @@
|
||||
import { useState } from 'react';
|
||||
import { useRecoilValue } from 'recoil';
|
||||
import { CSSTransition } from 'react-transition-group';
|
||||
import type { TMessage, TConversationCosts } from 'librechat-data-provider';
|
||||
import type { TMessage } from 'librechat-data-provider';
|
||||
import { useScreenshot, useMessageScrolling, useLocalize } from '~/hooks';
|
||||
import ScrollToBottom from '~/components/Messages/ScrollToBottom';
|
||||
import { MessagesViewProvider } from '~/Providers';
|
||||
import MultiMessage from './MultiMessage';
|
||||
import { cn } from '~/utils';
|
||||
import store from '~/store';
|
||||
|
||||
export default function MessagesView({
|
||||
function MessagesViewContent({
|
||||
messagesTree: _messagesTree,
|
||||
costBar,
|
||||
costs,
|
||||
}: {
|
||||
messagesTree?: TMessage[] | null;
|
||||
costBar?: React.ReactNode;
|
||||
costs?: TConversationCosts;
|
||||
}) {
|
||||
const localize = useLocalize();
|
||||
const fontSize = useRecoilValue(store.fontSize);
|
||||
@@ -48,7 +45,7 @@ export default function MessagesView({
|
||||
width: '100%',
|
||||
}}
|
||||
>
|
||||
<div className="flex flex-col dark:bg-transparent">
|
||||
<div className="flex flex-col pb-9 dark:bg-transparent">
|
||||
{(_messagesTree && _messagesTree.length == 0) || _messagesTree === null ? (
|
||||
<div
|
||||
className={cn(
|
||||
@@ -67,25 +64,18 @@ export default function MessagesView({
|
||||
messageId={conversationId ?? null}
|
||||
setCurrentEditId={setCurrentEditId}
|
||||
currentEditId={currentEditId ?? null}
|
||||
costs={costs}
|
||||
/>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
<div
|
||||
id="messages-end"
|
||||
className="group h-1 w-full flex-shrink-0 pb-7"
|
||||
className="group h-0 w-full flex-shrink-0"
|
||||
ref={messagesEndRef}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{costBar && (
|
||||
<div className="pointer-events-none absolute bottom-2 left-1/2 z-10 -translate-x-1/2">
|
||||
{costBar}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<CSSTransition
|
||||
in={showScrollButton && scrollButtonPreference}
|
||||
timeout={{
|
||||
@@ -103,3 +93,11 @@ export default function MessagesView({
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
export default function MessagesView({ messagesTree }: { messagesTree?: TMessage[] | null }) {
|
||||
return (
|
||||
<MessagesViewProvider>
|
||||
<MessagesViewContent messagesTree={messagesTree} />
|
||||
</MessagesViewProvider>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { useRecoilState } from 'recoil';
|
||||
import { useEffect, useCallback } from 'react';
|
||||
import { isAssistantsEndpoint } from 'librechat-data-provider';
|
||||
import type { TMessage, TConversationCosts } from 'librechat-data-provider';
|
||||
import type { TMessage } from 'librechat-data-provider';
|
||||
import type { TMessageProps } from '~/common';
|
||||
import MessageContent from '~/components/Messages/MessageContent';
|
||||
import MessageParts from './MessageParts';
|
||||
@@ -14,8 +14,7 @@ export default function MultiMessage({
|
||||
messagesTree,
|
||||
currentEditId,
|
||||
setCurrentEditId,
|
||||
costs,
|
||||
}: TMessageProps & { costs?: TConversationCosts }) {
|
||||
}: TMessageProps) {
|
||||
const [siblingIdx, setSiblingIdx] = useRecoilState(store.messagesSiblingIdxFamily(messageId));
|
||||
|
||||
const setSiblingIdxRev = useCallback(
|
||||
@@ -28,7 +27,7 @@ export default function MultiMessage({
|
||||
useEffect(() => {
|
||||
// reset siblingIdx when the tree changes, mostly when a new message is submitting.
|
||||
setSiblingIdx(0);
|
||||
}, [messagesTree?.length]);
|
||||
}, [messagesTree?.length, setSiblingIdx]);
|
||||
|
||||
useEffect(() => {
|
||||
if (messagesTree?.length && siblingIdx >= messagesTree.length) {
|
||||
@@ -56,7 +55,6 @@ export default function MultiMessage({
|
||||
siblingIdx={messagesTree.length - siblingIdx - 1}
|
||||
siblingCount={messagesTree.length}
|
||||
setSiblingIdx={setSiblingIdxRev}
|
||||
costs={costs}
|
||||
/>
|
||||
);
|
||||
} else if (message.content) {
|
||||
@@ -69,7 +67,6 @@ export default function MultiMessage({
|
||||
siblingIdx={messagesTree.length - siblingIdx - 1}
|
||||
siblingCount={messagesTree.length}
|
||||
setSiblingIdx={setSiblingIdxRev}
|
||||
costs={costs}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@@ -83,7 +80,6 @@ export default function MultiMessage({
|
||||
siblingIdx={messagesTree.length - siblingIdx - 1}
|
||||
siblingCount={messagesTree.length}
|
||||
setSiblingIdx={setSiblingIdxRev}
|
||||
costs={costs}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,17 +1,16 @@
|
||||
import React, { useCallback, useMemo, memo } from 'react';
|
||||
import { useRecoilValue } from 'recoil';
|
||||
import { ArrowIcon } from '@librechat/client';
|
||||
import { type TMessage, TConversationCosts } from 'librechat-data-provider';
|
||||
import { type TMessage } from 'librechat-data-provider';
|
||||
import type { TMessageProps, TMessageIcon } from '~/common';
|
||||
import MessageContent from '~/components/Chat/Messages/Content/MessageContent';
|
||||
import PlaceholderRow from '~/components/Chat/Messages/ui/PlaceholderRow';
|
||||
import SiblingSwitch from '~/components/Chat/Messages/SiblingSwitch';
|
||||
import HoverButtons from '~/components/Chat/Messages/HoverButtons';
|
||||
import MessageIcon from '~/components/Chat/Messages/MessageIcon';
|
||||
import { useMessageActions, useLocalize } from '~/hooks';
|
||||
import { Plugin } from '~/components/Messages/Content';
|
||||
import SubRow from '~/components/Chat/Messages/SubRow';
|
||||
import { MessageContext } from '~/Providers';
|
||||
import { useMessageActions } from '~/hooks';
|
||||
import { cn, logger } from '~/utils';
|
||||
import store from '~/store';
|
||||
|
||||
@@ -20,7 +19,6 @@ type MessageRenderProps = {
|
||||
isCard?: boolean;
|
||||
isMultiMessage?: boolean;
|
||||
isSubmittingFamily?: boolean;
|
||||
costs?: TConversationCosts;
|
||||
} & Pick<
|
||||
TMessageProps,
|
||||
'currentEditId' | 'setCurrentEditId' | 'siblingIdx' | 'setSiblingIdx' | 'siblingCount'
|
||||
@@ -37,9 +35,7 @@ const MessageRender = memo(
|
||||
isMultiMessage = false,
|
||||
setCurrentEditId,
|
||||
isSubmittingFamily = false,
|
||||
costs,
|
||||
}: MessageRenderProps) => {
|
||||
const localize = useLocalize();
|
||||
const {
|
||||
ask,
|
||||
edit,
|
||||
@@ -64,18 +60,6 @@ const MessageRender = memo(
|
||||
});
|
||||
const maximizeChatSpace = useRecoilValue(store.maximizeChatSpace);
|
||||
const fontSize = useRecoilValue(store.fontSize);
|
||||
const showCostTracking = useRecoilValue(store.showCostTracking);
|
||||
|
||||
const perMessageCost = useMemo(() => {
|
||||
if (!showCostTracking || !costs || !costs.perMessage || !msg?.messageId) {
|
||||
return null;
|
||||
}
|
||||
const entry = costs.perMessage.find((p) => p.messageId === msg.messageId);
|
||||
if (!entry) {
|
||||
return null;
|
||||
}
|
||||
return entry;
|
||||
}, [showCostTracking, costs, msg?.messageId]);
|
||||
|
||||
const handleRegenerateMessage = useCallback(() => regenerateMessage(), [regenerateMessage]);
|
||||
const hasNoChildren = !(msg?.children?.length ?? 0);
|
||||
@@ -87,6 +71,9 @@ const MessageRender = memo(
|
||||
const showCardRender = isLast && !isSubmittingFamily && isCard;
|
||||
const isLatestCard = isCard && !isSubmittingFamily && isLatestMessage;
|
||||
|
||||
/** Only pass isSubmitting to the latest message to prevent unnecessary re-renders */
|
||||
const effectiveIsSubmitting = isLatestMessage ? isSubmitting : false;
|
||||
|
||||
const iconData: TMessageIcon = useMemo(
|
||||
() => ({
|
||||
endpoint: msg?.endpoint ?? conversation?.endpoint,
|
||||
@@ -173,26 +160,7 @@ const MessageRender = memo(
|
||||
msg.isCreatedByUser ? 'user-turn' : 'agent-turn',
|
||||
)}
|
||||
>
|
||||
<h2 className={cn('select-none font-semibold', fontSize)}>
|
||||
{messageLabel}
|
||||
{perMessageCost && (
|
||||
<span className="ml-2 inline-flex items-center gap-2 px-2 py-0.5 text-xs text-muted-foreground">
|
||||
{perMessageCost.tokenCount > 0 && (
|
||||
<span>
|
||||
{perMessageCost.tokenType === 'prompt' ? (
|
||||
<ArrowIcon direction="up" className="inline" />
|
||||
) : (
|
||||
<ArrowIcon direction="down" className="inline" />
|
||||
)}
|
||||
{localize('com_ui_token_abbreviation', {
|
||||
0: perMessageCost.tokenCount,
|
||||
})}
|
||||
</span>
|
||||
)}
|
||||
<span className="whitespace-pre">${Math.abs(perMessageCost.usd).toFixed(6)}</span>
|
||||
</span>
|
||||
)}
|
||||
</h2>
|
||||
<h2 className={cn('select-none font-semibold', fontSize)}>{messageLabel}</h2>
|
||||
|
||||
<div className="flex flex-col gap-1">
|
||||
<div className="flex max-w-full flex-grow flex-col gap-0">
|
||||
@@ -201,6 +169,8 @@ const MessageRender = memo(
|
||||
messageId: msg.messageId,
|
||||
conversationId: conversation?.conversationId,
|
||||
isExpanded: false,
|
||||
isSubmitting: effectiveIsSubmitting,
|
||||
isLatestMessage,
|
||||
}}
|
||||
>
|
||||
{msg.plugin && <Plugin plugin={msg.plugin} />}
|
||||
@@ -212,7 +182,7 @@ const MessageRender = memo(
|
||||
message={msg}
|
||||
enterEdit={enterEdit}
|
||||
error={!!(msg.error ?? false)}
|
||||
isSubmitting={isSubmitting}
|
||||
isSubmitting={effectiveIsSubmitting}
|
||||
unfinished={msg.unfinished ?? false}
|
||||
isCreatedByUser={msg.isCreatedByUser ?? true}
|
||||
siblingIdx={siblingIdx ?? 0}
|
||||
@@ -221,7 +191,7 @@ const MessageRender = memo(
|
||||
</MessageContext.Provider>
|
||||
</div>
|
||||
|
||||
{hasNoChildren && (isSubmittingFamily === true || isSubmitting) ? (
|
||||
{hasNoChildren && (isSubmittingFamily === true || effectiveIsSubmitting) ? (
|
||||
<PlaceholderRow isCard={isCard} />
|
||||
) : (
|
||||
<SubRow classes="text-xs">
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { useMemo, memo, type FC, useCallback } from 'react';
|
||||
import throttle from 'lodash/throttle';
|
||||
import { parseISO, isToday } from 'date-fns';
|
||||
import { Spinner, useMediaQuery } from '@librechat/client';
|
||||
import { List, AutoSizer, CellMeasurer, CellMeasurerCache } from 'react-virtualized';
|
||||
import { TConversation } from 'librechat-data-provider';
|
||||
@@ -50,27 +49,17 @@ const MemoizedConvo = memo(
|
||||
conversation,
|
||||
retainView,
|
||||
toggleNav,
|
||||
isLatestConvo,
|
||||
}: {
|
||||
conversation: TConversation;
|
||||
retainView: () => void;
|
||||
toggleNav: () => void;
|
||||
isLatestConvo: boolean;
|
||||
}) => {
|
||||
return (
|
||||
<Convo
|
||||
conversation={conversation}
|
||||
retainView={retainView}
|
||||
toggleNav={toggleNav}
|
||||
isLatestConvo={isLatestConvo}
|
||||
/>
|
||||
);
|
||||
return <Convo conversation={conversation} retainView={retainView} toggleNav={toggleNav} />;
|
||||
},
|
||||
(prevProps, nextProps) => {
|
||||
return (
|
||||
prevProps.conversation.conversationId === nextProps.conversation.conversationId &&
|
||||
prevProps.conversation.title === nextProps.conversation.title &&
|
||||
prevProps.isLatestConvo === nextProps.isLatestConvo &&
|
||||
prevProps.conversation.endpoint === nextProps.conversation.endpoint
|
||||
);
|
||||
},
|
||||
@@ -98,13 +87,6 @@ const Conversations: FC<ConversationsProps> = ({
|
||||
[filteredConversations],
|
||||
);
|
||||
|
||||
const firstTodayConvoId = useMemo(
|
||||
() =>
|
||||
filteredConversations.find((convo) => convo.updatedAt && isToday(parseISO(convo.updatedAt)))
|
||||
?.conversationId ?? undefined,
|
||||
[filteredConversations],
|
||||
);
|
||||
|
||||
const flattenedItems = useMemo(() => {
|
||||
const items: FlattenedItem[] = [];
|
||||
groupedConversations.forEach(([groupName, convos]) => {
|
||||
@@ -154,26 +136,25 @@ const Conversations: FC<ConversationsProps> = ({
|
||||
</CellMeasurer>
|
||||
);
|
||||
}
|
||||
let rendering: JSX.Element;
|
||||
if (item.type === 'header') {
|
||||
rendering = <DateLabel groupName={item.groupName} />;
|
||||
} else if (item.type === 'convo') {
|
||||
rendering = (
|
||||
<MemoizedConvo conversation={item.convo} retainView={moveToTop} toggleNav={toggleNav} />
|
||||
);
|
||||
}
|
||||
return (
|
||||
<CellMeasurer cache={cache} columnIndex={0} key={key} parent={parent} rowIndex={index}>
|
||||
{({ registerChild }) => (
|
||||
<div ref={registerChild} style={style}>
|
||||
{item.type === 'header' ? (
|
||||
<DateLabel groupName={item.groupName} />
|
||||
) : item.type === 'convo' ? (
|
||||
<MemoizedConvo
|
||||
conversation={item.convo}
|
||||
retainView={moveToTop}
|
||||
toggleNav={toggleNav}
|
||||
isLatestConvo={item.convo.conversationId === firstTodayConvoId}
|
||||
/>
|
||||
) : null}
|
||||
{rendering}
|
||||
</div>
|
||||
)}
|
||||
</CellMeasurer>
|
||||
);
|
||||
},
|
||||
[cache, flattenedItems, firstTodayConvoId, moveToTop, toggleNav],
|
||||
[cache, flattenedItems, moveToTop, toggleNav],
|
||||
);
|
||||
|
||||
const getRowHeight = useCallback(
|
||||
|
||||
@@ -11,23 +11,17 @@ import { useGetEndpointsQuery } from '~/data-provider';
|
||||
import { NotificationSeverity } from '~/common';
|
||||
import { ConvoOptions } from './ConvoOptions';
|
||||
import RenameForm from './RenameForm';
|
||||
import { cn, logger } from '~/utils';
|
||||
import ConvoLink from './ConvoLink';
|
||||
import { cn } from '~/utils';
|
||||
import store from '~/store';
|
||||
|
||||
interface ConversationProps {
|
||||
conversation: TConversation;
|
||||
retainView: () => void;
|
||||
toggleNav: () => void;
|
||||
isLatestConvo: boolean;
|
||||
}
|
||||
|
||||
export default function Conversation({
|
||||
conversation,
|
||||
retainView,
|
||||
toggleNav,
|
||||
isLatestConvo,
|
||||
}: ConversationProps) {
|
||||
export default function Conversation({ conversation, retainView, toggleNav }: ConversationProps) {
|
||||
const params = useParams();
|
||||
const localize = useLocalize();
|
||||
const { showToast } = useToastContext();
|
||||
@@ -84,6 +78,7 @@ export default function Conversation({
|
||||
});
|
||||
setRenaming(false);
|
||||
} catch (error) {
|
||||
logger.error('Error renaming conversation', error);
|
||||
setTitleInput(title as string);
|
||||
showToast({
|
||||
message: localize('com_ui_rename_failed'),
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
import { useRecoilValue } from 'recoil';
|
||||
import { ArrowIcon } from '@librechat/client';
|
||||
import { useCallback, useMemo, memo } from 'react';
|
||||
import type { TMessage, TMessageContentParts, TConversationCosts } from 'librechat-data-provider';
|
||||
import type { TMessage, TMessageContentParts } from 'librechat-data-provider';
|
||||
import type { TMessageProps, TMessageIcon } from '~/common';
|
||||
import ContentParts from '~/components/Chat/Messages/Content/ContentParts';
|
||||
import PlaceholderRow from '~/components/Chat/Messages/ui/PlaceholderRow';
|
||||
import { useAttachments, useMessageActions, useLocalize } from '~/hooks';
|
||||
import SiblingSwitch from '~/components/Chat/Messages/SiblingSwitch';
|
||||
import HoverButtons from '~/components/Chat/Messages/HoverButtons';
|
||||
import MessageIcon from '~/components/Chat/Messages/MessageIcon';
|
||||
import { useAttachments, useMessageActions } from '~/hooks';
|
||||
import SubRow from '~/components/Chat/Messages/SubRow';
|
||||
import { cn, logger } from '~/utils';
|
||||
import store from '~/store';
|
||||
@@ -18,7 +17,6 @@ type ContentRenderProps = {
|
||||
isCard?: boolean;
|
||||
isMultiMessage?: boolean;
|
||||
isSubmittingFamily?: boolean;
|
||||
costs?: TConversationCosts;
|
||||
} & Pick<
|
||||
TMessageProps,
|
||||
'currentEditId' | 'setCurrentEditId' | 'siblingIdx' | 'setSiblingIdx' | 'siblingCount'
|
||||
@@ -35,9 +33,7 @@ const ContentRender = memo(
|
||||
isMultiMessage = false,
|
||||
setCurrentEditId,
|
||||
isSubmittingFamily = false,
|
||||
costs,
|
||||
}: ContentRenderProps) => {
|
||||
const localize = useLocalize();
|
||||
const { attachments, searchResults } = useAttachments({
|
||||
messageId: msg?.messageId,
|
||||
attachments: msg?.attachments,
|
||||
@@ -66,14 +62,6 @@ const ContentRender = memo(
|
||||
});
|
||||
const maximizeChatSpace = useRecoilValue(store.maximizeChatSpace);
|
||||
const fontSize = useRecoilValue(store.fontSize);
|
||||
const showCostTracking = useRecoilValue(store.showCostTracking);
|
||||
|
||||
const perMessageCost = useMemo(() => {
|
||||
if (!showCostTracking || !costs || !costs.perMessage || !msg?.messageId) {
|
||||
return null;
|
||||
}
|
||||
return costs.perMessage.find((p) => p.messageId === msg.messageId) ?? null;
|
||||
}, [showCostTracking, costs, msg?.messageId]);
|
||||
|
||||
const handleRegenerateMessage = useCallback(() => regenerateMessage(), [regenerateMessage]);
|
||||
const isLast = useMemo(
|
||||
@@ -171,26 +159,7 @@ const ContentRender = memo(
|
||||
msg.isCreatedByUser ? 'user-turn' : 'agent-turn',
|
||||
)}
|
||||
>
|
||||
<h2 className={cn('select-none font-semibold', fontSize)}>
|
||||
{messageLabel}
|
||||
{perMessageCost && (
|
||||
<span className="ml-2 inline-flex items-center gap-2 px-2 py-0.5 text-xs text-muted-foreground">
|
||||
{perMessageCost.tokenCount > 0 && (
|
||||
<span className="mr-2">
|
||||
{perMessageCost.tokenType === 'prompt' ? (
|
||||
<ArrowIcon direction="up" className="inline" />
|
||||
) : (
|
||||
<ArrowIcon direction="down" className="inline" />
|
||||
)}
|
||||
{localize('com_ui_token_abbreviation', {
|
||||
0: perMessageCost.tokenCount,
|
||||
})}
|
||||
</span>
|
||||
)}
|
||||
<span className="whitespace-pre">${Math.abs(perMessageCost.usd).toFixed(6)}</span>
|
||||
</span>
|
||||
)}
|
||||
</h2>
|
||||
<h2 className={cn('select-none font-semibold', fontSize)}>{messageLabel}</h2>
|
||||
|
||||
<div className="flex flex-col gap-1">
|
||||
<div className="flex max-w-full flex-grow flex-col gap-0">
|
||||
@@ -204,6 +173,7 @@ const ContentRender = memo(
|
||||
isSubmitting={isSubmitting}
|
||||
searchResults={searchResults}
|
||||
setSiblingIdx={setSiblingIdx}
|
||||
isLatestMessage={isLatestMessage}
|
||||
isCreatedByUser={msg.isCreatedByUser}
|
||||
conversationId={conversation?.conversationId}
|
||||
content={msg.content as Array<TMessageContentParts | undefined>}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import React from 'react';
|
||||
import { useMessageProcess } from '~/hooks';
|
||||
import type { TConversationCosts } from 'librechat-data-provider';
|
||||
import type { TMessageProps } from '~/common';
|
||||
// eslint-disable-next-line import/no-cycle
|
||||
import MultiMessage from '~/components/Chat/Messages/MultiMessage';
|
||||
@@ -26,7 +25,7 @@ const MessageContainer = React.memo(
|
||||
},
|
||||
);
|
||||
|
||||
export default function MessageContent(props: TMessageProps & { costs?: TConversationCosts }) {
|
||||
export default function MessageContent(props: TMessageProps) {
|
||||
const {
|
||||
showSibling,
|
||||
conversation,
|
||||
@@ -35,7 +34,7 @@ export default function MessageContent(props: TMessageProps & { costs?: TConvers
|
||||
latestMultiMessage,
|
||||
isSubmittingFamily,
|
||||
} = useMessageProcess({ message: props.message });
|
||||
const { message, currentEditId, setCurrentEditId, costs } = props;
|
||||
const { message, currentEditId, setCurrentEditId } = props;
|
||||
|
||||
if (!message || typeof message !== 'object') {
|
||||
return null;
|
||||
@@ -54,7 +53,6 @@ export default function MessageContent(props: TMessageProps & { costs?: TConvers
|
||||
message={message}
|
||||
isSubmittingFamily={isSubmittingFamily}
|
||||
isCard
|
||||
costs={costs}
|
||||
/>
|
||||
<ContentRender
|
||||
{...props}
|
||||
@@ -62,13 +60,12 @@ export default function MessageContent(props: TMessageProps & { costs?: TConvers
|
||||
isCard
|
||||
message={siblingMessage ?? latestMultiMessage ?? undefined}
|
||||
isSubmittingFamily={isSubmittingFamily}
|
||||
costs={costs}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<div className="m-auto justify-center p-4 py-2 md:gap-6">
|
||||
<ContentRender {...props} costs={costs} />
|
||||
<div className="m-auto justify-center p-4 py-2 md:gap-6 ">
|
||||
<ContentRender {...props} />
|
||||
</div>
|
||||
)}
|
||||
</MessageContainer>
|
||||
@@ -79,7 +76,6 @@ export default function MessageContent(props: TMessageProps & { costs?: TConvers
|
||||
messagesTree={children ?? []}
|
||||
currentEditId={currentEditId}
|
||||
setCurrentEditId={setCurrentEditId}
|
||||
costs={costs}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
|
||||
@@ -76,13 +76,6 @@ const toggleSwitchConfigs = [
|
||||
hoverCardText: undefined,
|
||||
key: 'modularChat',
|
||||
},
|
||||
{
|
||||
stateAtom: store.showCostTracking,
|
||||
localizationKey: 'com_nav_show_cost_tracking',
|
||||
switchId: 'showCostTracking',
|
||||
hoverCardText: 'com_nav_info_show_cost_tracking',
|
||||
key: 'showCostTracking',
|
||||
},
|
||||
];
|
||||
|
||||
function Chat() {
|
||||
|
||||
@@ -76,6 +76,8 @@ export default function Message(props: TMessageProps) {
|
||||
messageId,
|
||||
isExpanded: false,
|
||||
conversationId: conversation?.conversationId,
|
||||
isSubmitting: false, // Share view is always read-only
|
||||
isLatestMessage: false, // No concept of latest message in share view
|
||||
}}
|
||||
>
|
||||
{/* Legacy Plugins */}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import React, { useState, useMemo, useCallback } from 'react';
|
||||
import React, { useState, useMemo, useCallback, useEffect } from 'react';
|
||||
import { ChevronLeft, Trash2 } from 'lucide-react';
|
||||
import { useQueryClient } from '@tanstack/react-query';
|
||||
import { Button, useToastContext } from '@librechat/client';
|
||||
@@ -12,6 +12,8 @@ import { useLocalize, useMCPConnectionStatus } from '~/hooks';
|
||||
import { useGetStartupConfig } from '~/data-provider';
|
||||
import MCPPanelSkeleton from './MCPPanelSkeleton';
|
||||
|
||||
const POLL_FOR_CONNECTION_STATUS_INTERVAL = 2_000; // ms
|
||||
|
||||
function MCPPanelContent() {
|
||||
const localize = useLocalize();
|
||||
const queryClient = useQueryClient();
|
||||
@@ -26,6 +28,29 @@ function MCPPanelContent() {
|
||||
null,
|
||||
);
|
||||
|
||||
// Check if any connections are in 'connecting' state
|
||||
const hasConnectingServers = useMemo(() => {
|
||||
if (!connectionStatus) {
|
||||
return false;
|
||||
}
|
||||
return Object.values(connectionStatus).some(
|
||||
(status) => status?.connectionState === 'connecting',
|
||||
);
|
||||
}, [connectionStatus]);
|
||||
|
||||
// Set up polling when servers are connecting
|
||||
useEffect(() => {
|
||||
if (!hasConnectingServers) {
|
||||
return;
|
||||
}
|
||||
|
||||
const intervalId = setInterval(() => {
|
||||
queryClient.invalidateQueries([QueryKeys.mcpConnectionStatus]);
|
||||
}, POLL_FOR_CONNECTION_STATUS_INTERVAL);
|
||||
|
||||
return () => clearInterval(intervalId);
|
||||
}, [hasConnectingServers, queryClient]);
|
||||
|
||||
const updateUserPluginsMutation = useUpdateUserPluginsMutation({
|
||||
onSuccess: async () => {
|
||||
showToast({ message: localize('com_nav_mcp_vars_updated'), status: 'success' });
|
||||
|
||||
@@ -0,0 +1,440 @@
|
||||
import { renderHook } from '@testing-library/react';
|
||||
import { Tools, Constants } from 'librechat-data-provider';
|
||||
import useAgentToolPermissions from '../useAgentToolPermissions';
|
||||
|
||||
// Mock dependencies
|
||||
jest.mock('~/data-provider', () => ({
|
||||
useGetAgentByIdQuery: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/Providers', () => ({
|
||||
useAgentsMapContext: jest.fn(),
|
||||
}));
|
||||
|
||||
// Import mocked functions after mocking
|
||||
import { useGetAgentByIdQuery } from '~/data-provider';
|
||||
import { useAgentsMapContext } from '~/Providers';
|
||||
|
||||
describe('useAgentToolPermissions', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('Ephemeral Agent Scenarios', () => {
|
||||
it('should return true for all tools when agentId is null', () => {
|
||||
(useAgentsMapContext as jest.Mock).mockReturnValue({});
|
||||
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
|
||||
|
||||
const { result } = renderHook(() => useAgentToolPermissions(null));
|
||||
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(true);
|
||||
expect(result.current.codeAllowedByAgent).toBe(true);
|
||||
expect(result.current.tools).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should return true for all tools when agentId is undefined', () => {
|
||||
(useAgentsMapContext as jest.Mock).mockReturnValue({});
|
||||
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
|
||||
|
||||
const { result } = renderHook(() => useAgentToolPermissions(undefined));
|
||||
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(true);
|
||||
expect(result.current.codeAllowedByAgent).toBe(true);
|
||||
expect(result.current.tools).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should return true for all tools when agentId is empty string', () => {
|
||||
(useAgentsMapContext as jest.Mock).mockReturnValue({});
|
||||
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
|
||||
|
||||
const { result } = renderHook(() => useAgentToolPermissions(''));
|
||||
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(true);
|
||||
expect(result.current.codeAllowedByAgent).toBe(true);
|
||||
expect(result.current.tools).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should return true for all tools when agentId is EPHEMERAL_AGENT_ID', () => {
|
||||
(useAgentsMapContext as jest.Mock).mockReturnValue({});
|
||||
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useAgentToolPermissions(Constants.EPHEMERAL_AGENT_ID)
|
||||
);
|
||||
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(true);
|
||||
expect(result.current.codeAllowedByAgent).toBe(true);
|
||||
expect(result.current.tools).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Regular Agent with Tools', () => {
|
||||
it('should allow file_search when agent has the tool', () => {
|
||||
const agentId = 'agent-123';
|
||||
const mockAgent = {
|
||||
id: agentId,
|
||||
tools: [Tools.file_search, 'other_tool'],
|
||||
};
|
||||
|
||||
(useAgentsMapContext as jest.Mock).mockReturnValue({
|
||||
[agentId]: mockAgent,
|
||||
});
|
||||
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
|
||||
|
||||
const { result } = renderHook(() => useAgentToolPermissions(agentId));
|
||||
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(true);
|
||||
expect(result.current.codeAllowedByAgent).toBe(false);
|
||||
expect(result.current.tools).toEqual([Tools.file_search, 'other_tool']);
|
||||
});
|
||||
|
||||
it('should allow execute_code when agent has the tool', () => {
|
||||
const agentId = 'agent-456';
|
||||
const mockAgent = {
|
||||
id: agentId,
|
||||
tools: [Tools.execute_code, 'another_tool'],
|
||||
};
|
||||
|
||||
(useAgentsMapContext as jest.Mock).mockReturnValue({
|
||||
[agentId]: mockAgent,
|
||||
});
|
||||
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
|
||||
|
||||
const { result } = renderHook(() => useAgentToolPermissions(agentId));
|
||||
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(false);
|
||||
expect(result.current.codeAllowedByAgent).toBe(true);
|
||||
expect(result.current.tools).toEqual([Tools.execute_code, 'another_tool']);
|
||||
});
|
||||
|
||||
it('should allow both tools when agent has both', () => {
|
||||
const agentId = 'agent-789';
|
||||
const mockAgent = {
|
||||
id: agentId,
|
||||
tools: [Tools.file_search, Tools.execute_code, 'custom_tool'],
|
||||
};
|
||||
|
||||
(useAgentsMapContext as jest.Mock).mockReturnValue({
|
||||
[agentId]: mockAgent,
|
||||
});
|
||||
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
|
||||
|
||||
const { result } = renderHook(() => useAgentToolPermissions(agentId));
|
||||
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(true);
|
||||
expect(result.current.codeAllowedByAgent).toBe(true);
|
||||
expect(result.current.tools).toEqual([Tools.file_search, Tools.execute_code, 'custom_tool']);
|
||||
});
|
||||
|
||||
it('should disallow both tools when agent has neither', () => {
|
||||
const agentId = 'agent-no-tools';
|
||||
const mockAgent = {
|
||||
id: agentId,
|
||||
tools: ['custom_tool1', 'custom_tool2'],
|
||||
};
|
||||
|
||||
(useAgentsMapContext as jest.Mock).mockReturnValue({
|
||||
[agentId]: mockAgent,
|
||||
});
|
||||
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
|
||||
|
||||
const { result } = renderHook(() => useAgentToolPermissions(agentId));
|
||||
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(false);
|
||||
expect(result.current.codeAllowedByAgent).toBe(false);
|
||||
expect(result.current.tools).toEqual(['custom_tool1', 'custom_tool2']);
|
||||
});
|
||||
|
||||
it('should handle agent with empty tools array', () => {
|
||||
const agentId = 'agent-empty-tools';
|
||||
const mockAgent = {
|
||||
id: agentId,
|
||||
tools: [],
|
||||
};
|
||||
|
||||
(useAgentsMapContext as jest.Mock).mockReturnValue({
|
||||
[agentId]: mockAgent,
|
||||
});
|
||||
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
|
||||
|
||||
const { result } = renderHook(() => useAgentToolPermissions(agentId));
|
||||
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(false);
|
||||
expect(result.current.codeAllowedByAgent).toBe(false);
|
||||
expect(result.current.tools).toEqual([]);
|
||||
});
|
||||
|
||||
it('should handle agent with undefined tools', () => {
|
||||
const agentId = 'agent-undefined-tools';
|
||||
const mockAgent = {
|
||||
id: agentId,
|
||||
tools: undefined,
|
||||
};
|
||||
|
||||
(useAgentsMapContext as jest.Mock).mockReturnValue({
|
||||
[agentId]: mockAgent,
|
||||
});
|
||||
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
|
||||
|
||||
const { result } = renderHook(() => useAgentToolPermissions(agentId));
|
||||
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(false);
|
||||
expect(result.current.codeAllowedByAgent).toBe(false);
|
||||
expect(result.current.tools).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Agent Data from Query', () => {
|
||||
it('should prioritize agentData tools over selectedAgent tools', () => {
|
||||
const agentId = 'agent-with-query-data';
|
||||
const mockAgent = {
|
||||
id: agentId,
|
||||
tools: ['old_tool'],
|
||||
};
|
||||
const mockAgentData = {
|
||||
id: agentId,
|
||||
tools: [Tools.file_search, Tools.execute_code],
|
||||
};
|
||||
|
||||
(useAgentsMapContext as jest.Mock).mockReturnValue({
|
||||
[agentId]: mockAgent,
|
||||
});
|
||||
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: mockAgentData });
|
||||
|
||||
const { result } = renderHook(() => useAgentToolPermissions(agentId));
|
||||
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(true);
|
||||
expect(result.current.codeAllowedByAgent).toBe(true);
|
||||
expect(result.current.tools).toEqual([Tools.file_search, Tools.execute_code]);
|
||||
});
|
||||
|
||||
it('should fallback to selectedAgent tools when agentData has no tools', () => {
|
||||
const agentId = 'agent-fallback';
|
||||
const mockAgent = {
|
||||
id: agentId,
|
||||
tools: [Tools.file_search],
|
||||
};
|
||||
const mockAgentData = {
|
||||
id: agentId,
|
||||
tools: undefined,
|
||||
};
|
||||
|
||||
(useAgentsMapContext as jest.Mock).mockReturnValue({
|
||||
[agentId]: mockAgent,
|
||||
});
|
||||
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: mockAgentData });
|
||||
|
||||
const { result } = renderHook(() => useAgentToolPermissions(agentId));
|
||||
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(true);
|
||||
expect(result.current.codeAllowedByAgent).toBe(false);
|
||||
expect(result.current.tools).toEqual([Tools.file_search]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Agent Not Found Scenarios', () => {
|
||||
it('should disallow all tools when agent is not found in map', () => {
|
||||
const agentId = 'non-existent-agent';
|
||||
|
||||
(useAgentsMapContext as jest.Mock).mockReturnValue({});
|
||||
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
|
||||
|
||||
const { result } = renderHook(() => useAgentToolPermissions(agentId));
|
||||
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(false);
|
||||
expect(result.current.codeAllowedByAgent).toBe(false);
|
||||
expect(result.current.tools).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should disallow all tools when agentsMap is null', () => {
|
||||
const agentId = 'agent-with-null-map';
|
||||
|
||||
(useAgentsMapContext as jest.Mock).mockReturnValue(null);
|
||||
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
|
||||
|
||||
const { result } = renderHook(() => useAgentToolPermissions(agentId));
|
||||
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(false);
|
||||
expect(result.current.codeAllowedByAgent).toBe(false);
|
||||
expect(result.current.tools).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should disallow all tools when agentsMap is undefined', () => {
|
||||
const agentId = 'agent-with-undefined-map';
|
||||
|
||||
(useAgentsMapContext as jest.Mock).mockReturnValue(undefined);
|
||||
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
|
||||
|
||||
const { result } = renderHook(() => useAgentToolPermissions(agentId));
|
||||
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(false);
|
||||
expect(result.current.codeAllowedByAgent).toBe(false);
|
||||
expect(result.current.tools).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Memoization and Performance', () => {
|
||||
it('should memoize results when inputs do not change', () => {
|
||||
const agentId = 'memoized-agent';
|
||||
const mockAgent = {
|
||||
id: agentId,
|
||||
tools: [Tools.file_search],
|
||||
};
|
||||
|
||||
(useAgentsMapContext as jest.Mock).mockReturnValue({
|
||||
[agentId]: mockAgent,
|
||||
});
|
||||
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
|
||||
|
||||
const { result, rerender } = renderHook(() => useAgentToolPermissions(agentId));
|
||||
|
||||
const firstResult = result.current;
|
||||
|
||||
// Rerender without changing inputs
|
||||
rerender();
|
||||
|
||||
const secondResult = result.current;
|
||||
|
||||
// The hook returns a new object each time, but the values should be equal
|
||||
expect(firstResult.fileSearchAllowedByAgent).toBe(secondResult.fileSearchAllowedByAgent);
|
||||
expect(firstResult.codeAllowedByAgent).toBe(secondResult.codeAllowedByAgent);
|
||||
// Tools array reference should be the same since it comes from useMemo
|
||||
expect(firstResult.tools).toBe(secondResult.tools);
|
||||
|
||||
// Verify the actual values are correct
|
||||
expect(secondResult.fileSearchAllowedByAgent).toBe(true);
|
||||
expect(secondResult.codeAllowedByAgent).toBe(false);
|
||||
expect(secondResult.tools).toEqual([Tools.file_search]);
|
||||
});
|
||||
|
||||
it('should recompute when agentId changes', () => {
|
||||
const agentId1 = 'agent-1';
|
||||
const agentId2 = 'agent-2';
|
||||
const mockAgents = {
|
||||
[agentId1]: { id: agentId1, tools: [Tools.file_search] },
|
||||
[agentId2]: { id: agentId2, tools: [Tools.execute_code] },
|
||||
};
|
||||
|
||||
(useAgentsMapContext as jest.Mock).mockReturnValue(mockAgents);
|
||||
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
|
||||
|
||||
const { result, rerender } = renderHook(
|
||||
({ agentId }) => useAgentToolPermissions(agentId),
|
||||
{ initialProps: { agentId: agentId1 } }
|
||||
);
|
||||
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(true);
|
||||
expect(result.current.codeAllowedByAgent).toBe(false);
|
||||
|
||||
// Change agentId
|
||||
rerender({ agentId: agentId2 });
|
||||
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(false);
|
||||
expect(result.current.codeAllowedByAgent).toBe(true);
|
||||
});
|
||||
|
||||
it('should handle switching between ephemeral and regular agents', () => {
|
||||
const regularAgentId = 'regular-agent';
|
||||
const mockAgent = {
|
||||
id: regularAgentId,
|
||||
tools: [],
|
||||
};
|
||||
|
||||
(useAgentsMapContext as jest.Mock).mockReturnValue({
|
||||
[regularAgentId]: mockAgent,
|
||||
});
|
||||
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
|
||||
|
||||
const { result, rerender } = renderHook(
|
||||
({ agentId }) => useAgentToolPermissions(agentId),
|
||||
{ initialProps: { agentId: null } }
|
||||
);
|
||||
|
||||
// Start with ephemeral agent (null)
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(true);
|
||||
expect(result.current.codeAllowedByAgent).toBe(true);
|
||||
|
||||
// Switch to regular agent
|
||||
rerender({ agentId: regularAgentId });
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(false);
|
||||
expect(result.current.codeAllowedByAgent).toBe(false);
|
||||
|
||||
// Switch back to ephemeral
|
||||
rerender({ agentId: '' });
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(true);
|
||||
expect(result.current.codeAllowedByAgent).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle agents with null tools gracefully', () => {
|
||||
const agentId = 'agent-null-tools';
|
||||
const mockAgent = {
|
||||
id: agentId,
|
||||
tools: null as any,
|
||||
};
|
||||
|
||||
(useAgentsMapContext as jest.Mock).mockReturnValue({
|
||||
[agentId]: mockAgent,
|
||||
});
|
||||
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
|
||||
|
||||
const { result } = renderHook(() => useAgentToolPermissions(agentId));
|
||||
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(false);
|
||||
expect(result.current.codeAllowedByAgent).toBe(false);
|
||||
expect(result.current.tools).toBeNull();
|
||||
});
|
||||
|
||||
it('should handle whitespace-only agentId as ephemeral', () => {
|
||||
// Note: Based on the current implementation, only empty string is treated as ephemeral
|
||||
// Whitespace-only strings would be treated as regular agent IDs
|
||||
const whitespaceId = ' ';
|
||||
|
||||
(useAgentsMapContext as jest.Mock).mockReturnValue({});
|
||||
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({ data: undefined });
|
||||
|
||||
const { result } = renderHook(() => useAgentToolPermissions(whitespaceId));
|
||||
|
||||
// Whitespace ID is not considered ephemeral in current implementation
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(false);
|
||||
expect(result.current.codeAllowedByAgent).toBe(false);
|
||||
});
|
||||
|
||||
it('should handle query loading state', () => {
|
||||
const agentId = 'loading-agent';
|
||||
|
||||
(useAgentsMapContext as jest.Mock).mockReturnValue({});
|
||||
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({
|
||||
data: undefined,
|
||||
isLoading: true,
|
||||
error: null,
|
||||
});
|
||||
|
||||
const { result } = renderHook(() => useAgentToolPermissions(agentId));
|
||||
|
||||
// During loading, should return false for non-ephemeral agents
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(false);
|
||||
expect(result.current.codeAllowedByAgent).toBe(false);
|
||||
expect(result.current.tools).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should handle query error state', () => {
|
||||
const agentId = 'error-agent';
|
||||
|
||||
(useAgentsMapContext as jest.Mock).mockReturnValue({});
|
||||
(useGetAgentByIdQuery as jest.Mock).mockReturnValue({
|
||||
data: undefined,
|
||||
isLoading: false,
|
||||
error: new Error('Failed to fetch agent'),
|
||||
});
|
||||
|
||||
const { result } = renderHook(() => useAgentToolPermissions(agentId));
|
||||
|
||||
// On error, should return false for non-ephemeral agents
|
||||
expect(result.current.fileSearchAllowedByAgent).toBe(false);
|
||||
expect(result.current.codeAllowedByAgent).toBe(false);
|
||||
expect(result.current.tools).toBeUndefined();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,5 +1,5 @@
|
||||
import { useMemo } from 'react';
|
||||
import { Tools } from 'librechat-data-provider';
|
||||
import { Tools, Constants } from 'librechat-data-provider';
|
||||
import { useGetAgentByIdQuery } from '~/data-provider';
|
||||
import { useAgentsMapContext } from '~/Providers';
|
||||
|
||||
@@ -9,6 +9,10 @@ interface AgentToolPermissionsResult {
|
||||
tools: string[] | undefined;
|
||||
}
|
||||
|
||||
function isEphemeralAgent(agentId: string | null | undefined): boolean {
|
||||
return agentId == null || agentId === '' || agentId === Constants.EPHEMERAL_AGENT_ID;
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook to determine whether specific tools are allowed for a given agent.
|
||||
*
|
||||
@@ -33,8 +37,8 @@ export default function useAgentToolPermissions(
|
||||
);
|
||||
|
||||
const fileSearchAllowedByAgent = useMemo(() => {
|
||||
// If no agentId, allow for ephemeral agents
|
||||
if (!agentId) return true;
|
||||
// Allow for ephemeral agents
|
||||
if (isEphemeralAgent(agentId)) return true;
|
||||
// If agentId exists but agent not found, disallow
|
||||
if (!selectedAgent) return false;
|
||||
// Check if the agent has the file_search tool
|
||||
@@ -42,8 +46,8 @@ export default function useAgentToolPermissions(
|
||||
}, [agentId, selectedAgent, tools]);
|
||||
|
||||
const codeAllowedByAgent = useMemo(() => {
|
||||
// If no agentId, allow for ephemeral agents
|
||||
if (!agentId) return true;
|
||||
// Allow for ephemeral agents
|
||||
if (isEphemeralAgent(agentId)) return true;
|
||||
// If agentId exists but agent not found, disallow
|
||||
if (!selectedAgent) return false;
|
||||
// Check if the agent has the execute_code tool
|
||||
|
||||
495
client/src/hooks/MCP/__tests__/useMCPSelect.test.tsx
Normal file
495
client/src/hooks/MCP/__tests__/useMCPSelect.test.tsx
Normal file
@@ -0,0 +1,495 @@
|
||||
import React from 'react';
|
||||
import { Provider, createStore } from 'jotai';
|
||||
import { renderHook, act, waitFor } from '@testing-library/react';
|
||||
import { RecoilRoot, useRecoilState, useRecoilValue, useSetRecoilState } from 'recoil';
|
||||
import { Constants, LocalStorageKeys } from 'librechat-data-provider';
|
||||
import { ephemeralAgentByConvoId } from '~/store';
|
||||
import { setTimestamp } from '~/utils/timestamps';
|
||||
import { useMCPSelect } from '../useMCPSelect';
|
||||
|
||||
// Mock dependencies
|
||||
jest.mock('~/utils/timestamps', () => ({
|
||||
setTimestamp: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('lodash/isEqual', () => jest.fn((a, b) => JSON.stringify(a) === JSON.stringify(b)));
|
||||
|
||||
const createWrapper = () => {
|
||||
// Create a new Jotai store for each test to ensure clean state
|
||||
const store = createStore();
|
||||
|
||||
const Wrapper: React.FC<{ children: React.ReactNode }> = ({ children }) => (
|
||||
<RecoilRoot>
|
||||
<Provider store={store}>{children}</Provider>
|
||||
</RecoilRoot>
|
||||
);
|
||||
return Wrapper;
|
||||
};
|
||||
|
||||
describe('useMCPSelect', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
localStorage.clear();
|
||||
});
|
||||
|
||||
describe('Basic Functionality', () => {
|
||||
it('should initialize with default values', () => {
|
||||
const { result } = renderHook(() => useMCPSelect({}), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
expect(result.current.mcpValues).toEqual([]);
|
||||
expect(result.current.isPinned).toBe(true); // Default value from mcpPinnedAtom is true
|
||||
expect(typeof result.current.setMCPValues).toBe('function');
|
||||
expect(typeof result.current.setIsPinned).toBe('function');
|
||||
});
|
||||
|
||||
it('should use conversationId when provided', () => {
|
||||
const conversationId = 'test-convo-123';
|
||||
const { result } = renderHook(() => useMCPSelect({ conversationId }), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
expect(result.current.mcpValues).toEqual([]);
|
||||
});
|
||||
|
||||
it('should use NEW_CONVO constant when conversationId is null', () => {
|
||||
const { result } = renderHook(() => useMCPSelect({ conversationId: null }), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
expect(result.current.mcpValues).toEqual([]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('State Updates', () => {
|
||||
it('should update mcpValues when setMCPValues is called', async () => {
|
||||
const { result } = renderHook(() => useMCPSelect({}), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
const newValues = ['value1', 'value2'];
|
||||
|
||||
act(() => {
|
||||
result.current.setMCPValues(newValues);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.mcpValues).toEqual(newValues);
|
||||
});
|
||||
});
|
||||
|
||||
it('should not update mcpValues if non-array is passed', () => {
|
||||
const { result } = renderHook(() => useMCPSelect({}), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
act(() => {
|
||||
// @ts-ignore - Testing invalid input
|
||||
result.current.setMCPValues('not-an-array');
|
||||
});
|
||||
|
||||
expect(result.current.mcpValues).toEqual([]);
|
||||
});
|
||||
|
||||
it('should update isPinned state', () => {
|
||||
const { result } = renderHook(() => useMCPSelect({}), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
// Default is true
|
||||
expect(result.current.isPinned).toBe(true);
|
||||
|
||||
// Toggle to false
|
||||
act(() => {
|
||||
result.current.setIsPinned(false);
|
||||
});
|
||||
|
||||
expect(result.current.isPinned).toBe(false);
|
||||
|
||||
// Toggle back to true
|
||||
act(() => {
|
||||
result.current.setIsPinned(true);
|
||||
});
|
||||
|
||||
expect(result.current.isPinned).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Timestamp Management', () => {
|
||||
it('should set timestamp when mcpValues is updated with values', async () => {
|
||||
const conversationId = 'test-convo';
|
||||
const { result } = renderHook(() => useMCPSelect({ conversationId }), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
const newValues = ['value1', 'value2'];
|
||||
|
||||
act(() => {
|
||||
result.current.setMCPValues(newValues);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
const expectedKey = `${LocalStorageKeys.LAST_MCP_}${conversationId}`;
|
||||
expect(setTimestamp).toHaveBeenCalledWith(expectedKey);
|
||||
});
|
||||
});
|
||||
|
||||
it('should not set timestamp when mcpValues is empty', async () => {
|
||||
const { result } = renderHook(() => useMCPSelect({}), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
act(() => {
|
||||
result.current.setMCPValues([]);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(setTimestamp).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('Race Conditions and Infinite Loops Prevention', () => {
|
||||
it('should not create infinite loop when syncing between Jotai and Recoil states', async () => {
|
||||
const { result, rerender } = renderHook(() => useMCPSelect({}), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
let renderCount = 0;
|
||||
const maxRenders = 10;
|
||||
|
||||
// Track renders to detect infinite loops
|
||||
const trackRender = () => {
|
||||
renderCount++;
|
||||
if (renderCount > maxRenders) {
|
||||
throw new Error('Potential infinite loop detected');
|
||||
}
|
||||
};
|
||||
|
||||
// Set initial value
|
||||
act(() => {
|
||||
trackRender();
|
||||
result.current.setMCPValues(['initial']);
|
||||
});
|
||||
|
||||
// Trigger multiple rerenders
|
||||
for (let i = 0; i < 3; i++) {
|
||||
rerender();
|
||||
trackRender();
|
||||
}
|
||||
|
||||
// Should not exceed max renders
|
||||
expect(renderCount).toBeLessThanOrEqual(maxRenders);
|
||||
});
|
||||
|
||||
it('should handle rapid consecutive updates without race conditions', async () => {
|
||||
const { result } = renderHook(() => useMCPSelect({}), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
const updates = [
|
||||
['value1'],
|
||||
['value1', 'value2'],
|
||||
['value1', 'value2', 'value3'],
|
||||
['value4'],
|
||||
[],
|
||||
];
|
||||
|
||||
// Rapid fire updates
|
||||
act(() => {
|
||||
updates.forEach((update) => {
|
||||
result.current.setMCPValues(update);
|
||||
});
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
// Should settle on the last update
|
||||
expect(result.current.mcpValues).toEqual([]);
|
||||
});
|
||||
});
|
||||
|
||||
it('should maintain stable setter function reference', () => {
|
||||
const { result, rerender } = renderHook(() => useMCPSelect({}), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
const firstSetMCPValues = result.current.setMCPValues;
|
||||
|
||||
// Trigger multiple rerenders
|
||||
rerender();
|
||||
rerender();
|
||||
rerender();
|
||||
|
||||
// Setter should remain the same reference (memoized)
|
||||
expect(result.current.setMCPValues).toBe(firstSetMCPValues);
|
||||
});
|
||||
|
||||
it('should handle switching conversation IDs without issues', async () => {
|
||||
const { result, rerender } = renderHook(
|
||||
({ conversationId }) => useMCPSelect({ conversationId }),
|
||||
{
|
||||
wrapper: createWrapper(),
|
||||
initialProps: { conversationId: 'convo1' },
|
||||
},
|
||||
);
|
||||
|
||||
// Set values for first conversation
|
||||
act(() => {
|
||||
result.current.setMCPValues(['convo1-value']);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.mcpValues).toEqual(['convo1-value']);
|
||||
});
|
||||
|
||||
// Switch to different conversation
|
||||
rerender({ conversationId: 'convo2' });
|
||||
|
||||
// Should have different state for new conversation
|
||||
expect(result.current.mcpValues).toEqual([]);
|
||||
|
||||
// Set values for second conversation
|
||||
act(() => {
|
||||
result.current.setMCPValues(['convo2-value']);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.mcpValues).toEqual(['convo2-value']);
|
||||
});
|
||||
|
||||
// Switch back to first conversation
|
||||
rerender({ conversationId: 'convo1' });
|
||||
|
||||
// Should maintain separate state
|
||||
await waitFor(() => {
|
||||
expect(result.current.mcpValues).toEqual(['convo1-value']);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('Ephemeral Agent Synchronization', () => {
|
||||
it('should sync mcpValues when ephemeralAgent is updated externally', async () => {
|
||||
// Create a shared wrapper for both hooks to share the same Recoil/Jotai context
|
||||
const wrapper = createWrapper();
|
||||
|
||||
// Create a component that uses both hooks to ensure they share state
|
||||
const TestComponent = () => {
|
||||
const mcpHook = useMCPSelect({});
|
||||
const [ephemeralAgent, setEphemeralAgent] = useRecoilState(
|
||||
ephemeralAgentByConvoId(Constants.NEW_CONVO),
|
||||
);
|
||||
return { mcpHook, ephemeralAgent, setEphemeralAgent };
|
||||
};
|
||||
|
||||
const { result } = renderHook(() => TestComponent(), { wrapper });
|
||||
|
||||
// Simulate external update to ephemeralAgent (e.g., from another component)
|
||||
const externalMcpValues = ['external-value1', 'external-value2'];
|
||||
act(() => {
|
||||
result.current.setEphemeralAgent({
|
||||
mcp: externalMcpValues,
|
||||
});
|
||||
});
|
||||
|
||||
// The hook should sync with the external update
|
||||
await waitFor(() => {
|
||||
expect(result.current.mcpHook.mcpValues).toEqual(externalMcpValues);
|
||||
});
|
||||
});
|
||||
|
||||
it('should update ephemeralAgent when mcpValues changes through hook', async () => {
|
||||
// Create a shared wrapper for both hooks
|
||||
const wrapper = createWrapper();
|
||||
|
||||
// Create a component that uses both the hook and accesses Recoil state
|
||||
const TestComponent = () => {
|
||||
const mcpHook = useMCPSelect({});
|
||||
const ephemeralAgent = useRecoilValue(ephemeralAgentByConvoId(Constants.NEW_CONVO));
|
||||
return { mcpHook, ephemeralAgent };
|
||||
};
|
||||
|
||||
const { result } = renderHook(() => TestComponent(), { wrapper });
|
||||
|
||||
const newValues = ['hook-value1', 'hook-value2'];
|
||||
|
||||
act(() => {
|
||||
result.current.mcpHook.setMCPValues(newValues);
|
||||
});
|
||||
|
||||
// Verify both mcpValues and ephemeralAgent are updated
|
||||
await waitFor(() => {
|
||||
expect(result.current.mcpHook.mcpValues).toEqual(newValues);
|
||||
expect(result.current.ephemeralAgent?.mcp).toEqual(newValues);
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle empty ephemeralAgent.mcp array correctly', async () => {
|
||||
// Create a shared wrapper
|
||||
const wrapper = createWrapper();
|
||||
|
||||
// Create a component that uses both hooks
|
||||
const TestComponent = () => {
|
||||
const mcpHook = useMCPSelect({});
|
||||
const setEphemeralAgent = useSetRecoilState(ephemeralAgentByConvoId(Constants.NEW_CONVO));
|
||||
return { mcpHook, setEphemeralAgent };
|
||||
};
|
||||
|
||||
const { result } = renderHook(() => TestComponent(), { wrapper });
|
||||
|
||||
// Set initial values
|
||||
act(() => {
|
||||
result.current.mcpHook.setMCPValues(['initial-value']);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.mcpHook.mcpValues).toEqual(['initial-value']);
|
||||
});
|
||||
|
||||
// Try to set empty array externally
|
||||
act(() => {
|
||||
result.current.setEphemeralAgent({
|
||||
mcp: [],
|
||||
});
|
||||
});
|
||||
|
||||
// Values should remain unchanged since empty mcp array doesn't trigger update
|
||||
// (due to the condition: ephemeralAgent?.mcp && ephemeralAgent.mcp.length > 0)
|
||||
expect(result.current.mcpHook.mcpValues).toEqual(['initial-value']);
|
||||
});
|
||||
|
||||
it('should properly sync non-empty arrays from ephemeralAgent', async () => {
|
||||
// Additional test to ensure non-empty arrays DO sync
|
||||
const wrapper = createWrapper();
|
||||
|
||||
const TestComponent = () => {
|
||||
const mcpHook = useMCPSelect({});
|
||||
const setEphemeralAgent = useSetRecoilState(ephemeralAgentByConvoId(Constants.NEW_CONVO));
|
||||
return { mcpHook, setEphemeralAgent };
|
||||
};
|
||||
|
||||
const { result } = renderHook(() => TestComponent(), { wrapper });
|
||||
|
||||
// Set initial values through ephemeralAgent with non-empty array
|
||||
const initialValues = ['value1', 'value2'];
|
||||
act(() => {
|
||||
result.current.setEphemeralAgent({
|
||||
mcp: initialValues,
|
||||
});
|
||||
});
|
||||
|
||||
// Should sync since it's non-empty
|
||||
await waitFor(() => {
|
||||
expect(result.current.mcpHook.mcpValues).toEqual(initialValues);
|
||||
});
|
||||
|
||||
// Update with different non-empty values
|
||||
const updatedValues = ['value3', 'value4', 'value5'];
|
||||
act(() => {
|
||||
result.current.setEphemeralAgent({
|
||||
mcp: updatedValues,
|
||||
});
|
||||
});
|
||||
|
||||
// Should sync the new values
|
||||
await waitFor(() => {
|
||||
expect(result.current.mcpHook.mcpValues).toEqual(updatedValues);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle undefined conversationId', () => {
|
||||
const { result } = renderHook(() => useMCPSelect({ conversationId: undefined }), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
expect(result.current.mcpValues).toEqual([]);
|
||||
|
||||
act(() => {
|
||||
result.current.setMCPValues(['test']);
|
||||
});
|
||||
|
||||
expect(() => result.current).not.toThrow();
|
||||
});
|
||||
|
||||
it('should handle empty string conversationId', () => {
|
||||
const { result } = renderHook(() => useMCPSelect({ conversationId: '' }), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
expect(result.current.mcpValues).toEqual([]);
|
||||
});
|
||||
|
||||
it('should handle very large arrays without performance issues', async () => {
|
||||
const { result } = renderHook(() => useMCPSelect({}), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
const largeArray = Array.from({ length: 1000 }, (_, i) => `value-${i}`);
|
||||
|
||||
const startTime = performance.now();
|
||||
|
||||
act(() => {
|
||||
result.current.setMCPValues(largeArray);
|
||||
});
|
||||
|
||||
const endTime = performance.now();
|
||||
const executionTime = endTime - startTime;
|
||||
|
||||
// Should complete within reasonable time (< 100ms)
|
||||
expect(executionTime).toBeLessThan(100);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.mcpValues).toEqual(largeArray);
|
||||
});
|
||||
});
|
||||
|
||||
it('should cleanup properly on unmount', () => {
|
||||
const { unmount } = renderHook(() => useMCPSelect({}), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
// Should unmount without errors
|
||||
expect(() => unmount()).not.toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Memory Leak Prevention', () => {
|
||||
it('should not leak memory on repeated updates', async () => {
|
||||
const { result } = renderHook(() => useMCPSelect({}), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
// Perform many updates to test for memory leaks
|
||||
for (let i = 0; i < 100; i++) {
|
||||
act(() => {
|
||||
result.current.setMCPValues([`value-${i}`]);
|
||||
});
|
||||
}
|
||||
|
||||
// If we get here without crashing, memory management is likely OK
|
||||
expect(result.current.mcpValues).toEqual(['value-99']);
|
||||
});
|
||||
|
||||
it('should handle component remounting', () => {
|
||||
const { result, unmount } = renderHook(() => useMCPSelect({}), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
act(() => {
|
||||
result.current.setMCPValues(['before-unmount']);
|
||||
});
|
||||
|
||||
unmount();
|
||||
|
||||
// Remount
|
||||
const { result: newResult } = renderHook(() => useMCPSelect({}), {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
// Should handle remounting gracefully
|
||||
expect(newResult.current.mcpValues).toBeDefined();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,5 +1,6 @@
|
||||
import { useCallback, useEffect } from 'react';
|
||||
import { useAtom } from 'jotai';
|
||||
import isEqual from 'lodash/isEqual';
|
||||
import { useRecoilState } from 'recoil';
|
||||
import { Constants, LocalStorageKeys } from 'librechat-data-provider';
|
||||
import { ephemeralAgentByConvoId, mcpValuesAtomFamily, mcpPinnedAtom } from '~/store';
|
||||
@@ -19,15 +20,14 @@ export function useMCPSelect({ conversationId }: { conversationId?: string | nul
|
||||
}
|
||||
}, [ephemeralAgent?.mcp, setMCPValuesRaw]);
|
||||
|
||||
// Update ephemeral agent when Jotai state changes
|
||||
useEffect(() => {
|
||||
if (mcpValues.length > 0 && JSON.stringify(mcpValues) !== JSON.stringify(ephemeralAgent?.mcp)) {
|
||||
setEphemeralAgent((prev) => ({
|
||||
...prev,
|
||||
mcp: mcpValues,
|
||||
}));
|
||||
}
|
||||
}, [mcpValues, ephemeralAgent?.mcp, setEphemeralAgent]);
|
||||
setEphemeralAgent((prev) => {
|
||||
if (!isEqual(prev?.mcp, mcpValues)) {
|
||||
return { ...(prev ?? {}), mcp: mcpValues };
|
||||
}
|
||||
return prev;
|
||||
});
|
||||
}, [mcpValues, setEphemeralAgent]);
|
||||
|
||||
useEffect(() => {
|
||||
const mcpStorageKey = `${LocalStorageKeys.LAST_MCP_}${key}`;
|
||||
|
||||
@@ -2,7 +2,7 @@ import throttle from 'lodash/throttle';
|
||||
import { useEffect, useRef, useCallback, useMemo } from 'react';
|
||||
import { Constants, isAssistantsEndpoint, isAgentsEndpoint } from 'librechat-data-provider';
|
||||
import type { TMessageProps } from '~/common';
|
||||
import { useChatContext, useAssistantsMapContext, useAgentsMapContext } from '~/Providers';
|
||||
import { useMessagesViewContext, useAssistantsMapContext, useAgentsMapContext } from '~/Providers';
|
||||
import useCopyToClipboard from './useCopyToClipboard';
|
||||
import { getTextKey, logger } from '~/utils';
|
||||
|
||||
@@ -20,9 +20,9 @@ export default function useMessageHelpers(props: TMessageProps) {
|
||||
setAbortScroll,
|
||||
handleContinue,
|
||||
setLatestMessage,
|
||||
} = useChatContext();
|
||||
const assistantMap = useAssistantsMapContext();
|
||||
} = useMessagesViewContext();
|
||||
const agentsMap = useAgentsMapContext();
|
||||
const assistantMap = useAssistantsMapContext();
|
||||
|
||||
const { text, content, children, messageId = null, isCreatedByUser } = message ?? {};
|
||||
const edit = messageId === currentEditId;
|
||||
|
||||
@@ -3,7 +3,7 @@ import { useRecoilValue } from 'recoil';
|
||||
import { Constants } from 'librechat-data-provider';
|
||||
import { useEffect, useRef, useCallback, useMemo, useState } from 'react';
|
||||
import type { TMessage } from 'librechat-data-provider';
|
||||
import { useChatContext, useAddedChatContext } from '~/Providers';
|
||||
import { useMessagesViewContext } from '~/Providers';
|
||||
import { getTextKey, logger } from '~/utils';
|
||||
import store from '~/store';
|
||||
|
||||
@@ -18,14 +18,9 @@ export default function useMessageProcess({ message }: { message?: TMessage | nu
|
||||
latestMessage,
|
||||
setAbortScroll,
|
||||
setLatestMessage,
|
||||
isSubmitting: isSubmittingRoot,
|
||||
} = useChatContext();
|
||||
const { isSubmitting: isSubmittingAdditional } = useAddedChatContext();
|
||||
isSubmittingFamily,
|
||||
} = useMessagesViewContext();
|
||||
const latestMultiMessage = useRecoilValue(store.latestMessageFamily(index + 1));
|
||||
const isSubmittingFamily = useMemo(
|
||||
() => isSubmittingRoot || isSubmittingAdditional,
|
||||
[isSubmittingRoot, isSubmittingAdditional],
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
const convoId = conversation?.conversationId;
|
||||
|
||||
@@ -2,8 +2,8 @@ import { useRecoilValue } from 'recoil';
|
||||
import { Constants } from 'librechat-data-provider';
|
||||
import { useState, useRef, useCallback, useEffect } from 'react';
|
||||
import type { TMessage } from 'librechat-data-provider';
|
||||
import { useMessagesConversation, useMessagesSubmission } from '~/Providers';
|
||||
import useScrollToRef from '~/hooks/useScrollToRef';
|
||||
import { useChatContext } from '~/Providers';
|
||||
import store from '~/store';
|
||||
|
||||
const threshold = 0.85;
|
||||
@@ -15,11 +15,10 @@ export default function useMessageScrolling(messagesTree?: TMessage[] | null) {
|
||||
const scrollableRef = useRef<HTMLDivElement | null>(null);
|
||||
const messagesEndRef = useRef<HTMLDivElement | null>(null);
|
||||
const [showScrollButton, setShowScrollButton] = useState(false);
|
||||
const { conversation, setAbortScroll, isSubmitting, abortScroll } = useChatContext();
|
||||
const { conversationId } = conversation ?? {};
|
||||
const { conversation, conversationId } = useMessagesConversation();
|
||||
const { setAbortScroll, isSubmitting, abortScroll } = useMessagesSubmission();
|
||||
|
||||
const timeoutIdRef = useRef<NodeJS.Timeout>();
|
||||
const prevIsSubmittingRef = useRef<boolean>(false);
|
||||
|
||||
const debouncedSetShowScrollButton = useCallback((value: boolean) => {
|
||||
clearTimeout(timeoutIdRef.current);
|
||||
@@ -61,10 +60,7 @@ export default function useMessageScrolling(messagesTree?: TMessage[] | null) {
|
||||
}
|
||||
}, [debouncedSetShowScrollButton]);
|
||||
|
||||
const scrollCallback = useCallback(
|
||||
() => debouncedSetShowScrollButton(false),
|
||||
[debouncedSetShowScrollButton],
|
||||
);
|
||||
const scrollCallback = () => debouncedSetShowScrollButton(false);
|
||||
|
||||
const { scrollToRef: scrollToBottom, handleSmoothToRef } = useScrollToRef({
|
||||
targetRef: messagesEndRef,
|
||||
@@ -75,18 +71,6 @@ export default function useMessageScrolling(messagesTree?: TMessage[] | null) {
|
||||
},
|
||||
});
|
||||
|
||||
const smoothScrollToBottom = useCallback(() => {
|
||||
if (messagesEndRef.current) {
|
||||
messagesEndRef.current.scrollIntoView({
|
||||
behavior: 'smooth',
|
||||
block: 'end',
|
||||
inline: 'nearest',
|
||||
});
|
||||
scrollCallback();
|
||||
setAbortScroll(false);
|
||||
}
|
||||
}, [scrollCallback, setAbortScroll]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!messagesTree || messagesTree.length === 0) {
|
||||
return;
|
||||
@@ -107,20 +91,6 @@ export default function useMessageScrolling(messagesTree?: TMessage[] | null) {
|
||||
};
|
||||
}, [isSubmitting, messagesTree, scrollToBottom, abortScroll]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!messagesEndRef.current || !scrollableRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (prevIsSubmittingRef.current && !isSubmitting && abortScroll !== true) {
|
||||
setTimeout(() => {
|
||||
smoothScrollToBottom();
|
||||
}, 100);
|
||||
}
|
||||
|
||||
prevIsSubmittingRef.current = isSubmitting;
|
||||
}, [isSubmitting, smoothScrollToBottom, abortScroll]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!messagesEndRef.current || !scrollableRef.current) {
|
||||
return;
|
||||
|
||||
@@ -232,14 +232,8 @@ export default function useEventHandlers({
|
||||
},
|
||||
]);
|
||||
}
|
||||
|
||||
if (userMessage?.conversationId) {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: [QueryKeys.conversation, userMessage.conversationId, 'costs'],
|
||||
});
|
||||
}
|
||||
},
|
||||
[setMessages, announcePolite, setIsSubmitting, queryClient],
|
||||
[setMessages, announcePolite, setIsSubmitting],
|
||||
);
|
||||
|
||||
const cancelHandler = useCallback(
|
||||
@@ -281,12 +275,6 @@ export default function useEventHandlers({
|
||||
});
|
||||
}
|
||||
|
||||
if (convoUpdate?.conversationId) {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: [QueryKeys.conversation, convoUpdate.conversationId, 'costs'],
|
||||
});
|
||||
}
|
||||
|
||||
setIsSubmitting(false);
|
||||
},
|
||||
[setMessages, setConversation, genTitle, isAddedRequest, queryClient, setIsSubmitting],
|
||||
@@ -353,12 +341,6 @@ export default function useEventHandlers({
|
||||
if (resetLatestMessage) {
|
||||
resetLatestMessage();
|
||||
}
|
||||
|
||||
if (conversationId) {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: [QueryKeys.conversation, conversationId, 'costs'],
|
||||
});
|
||||
}
|
||||
},
|
||||
[
|
||||
queryClient,
|
||||
@@ -545,12 +527,6 @@ export default function useEventHandlers({
|
||||
);
|
||||
}
|
||||
|
||||
if (conversation.conversationId) {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: [QueryKeys.conversation, conversation.conversationId, 'costs'],
|
||||
});
|
||||
}
|
||||
|
||||
if (isNewConvo && submissionConvo.conversationId) {
|
||||
removeConvoFromAllQueries(queryClient, submissionConvo.conversationId);
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ export default function useScrollToRef({
|
||||
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
const scrollToRef = useCallback(
|
||||
throttle(() => logAndScroll('instant', callback), 100, { leading: true }),
|
||||
throttle(() => logAndScroll('instant', callback), 145, { leading: true }),
|
||||
[targetRef],
|
||||
);
|
||||
|
||||
|
||||
@@ -833,7 +833,6 @@
|
||||
"com_ui_delete_tool": "Werkzeug löschen",
|
||||
"com_ui_delete_tool_confirm": "Bist du sicher, dass du dieses Werkzeug löschen möchtest?",
|
||||
"com_ui_delete_tool_error": "Fehler beim Löschen des Tools: {{error}}",
|
||||
"com_ui_delete_tool_success": "Tool erfolgreich gelöscht",
|
||||
"com_ui_deleted": "Gelöscht",
|
||||
"com_ui_deleting_file": "Lösche Datei...",
|
||||
"com_ui_descending": "Absteigend",
|
||||
|
||||
@@ -568,8 +568,6 @@
|
||||
"com_nav_settings": "Settings",
|
||||
"com_nav_shared_links": "Shared links",
|
||||
"com_nav_show_code": "Always show code when using code interpreter",
|
||||
"com_nav_show_cost_tracking": "Show cost tracking",
|
||||
"com_nav_info_show_cost_tracking": "Display conversation costs and per-message cost breakdowns",
|
||||
"com_nav_show_thinking": "Open Thinking Dropdowns by Default",
|
||||
"com_nav_slash_command": "/-Command",
|
||||
"com_nav_slash_command_description": "Toggle command \"/\" for selecting a prompt via keyboard",
|
||||
@@ -1199,7 +1197,6 @@
|
||||
"com_ui_thinking": "Thinking...",
|
||||
"com_ui_thoughts": "Thoughts",
|
||||
"com_ui_token": "token",
|
||||
"com_ui_token_abbreviation": "{{0}}t",
|
||||
"com_ui_token_exchange_method": "Token Exchange Method",
|
||||
"com_ui_token_url": "Token URL",
|
||||
"com_ui_tokens": "tokens",
|
||||
|
||||
@@ -337,7 +337,7 @@
|
||||
"com_endpoint_prompt_prefix_assistants": "Papildu instrukcijas",
|
||||
"com_endpoint_prompt_prefix_assistants_placeholder": "Iestatiet papildu norādījumus vai kontekstu virs Asistenta galvenajiem norādījumiem. Ja lauks ir tukšs, tas tiek ignorēts.",
|
||||
"com_endpoint_prompt_prefix_placeholder": "Iestatiet pielāgotas instrukcijas vai kontekstu. Ja lauks ir tukšs, tas tiek ignorēts.",
|
||||
"com_endpoint_reasoning_effort": "Spriešanas piepūle",
|
||||
"com_endpoint_reasoning_effort": "Spriešanas līmenis",
|
||||
"com_endpoint_reasoning_summary": "Spriešanas kopsavilkums",
|
||||
"com_endpoint_save_as_preset": "Saglabāt kā iestatījumu",
|
||||
"com_endpoint_search": "Meklēt galapunktu pēc nosaukuma",
|
||||
@@ -834,7 +834,7 @@
|
||||
"com_ui_delete_tool": "Dzēst rīku",
|
||||
"com_ui_delete_tool_confirm": "Vai tiešām vēlaties dzēst šo rīku?",
|
||||
"com_ui_delete_tool_error": "Kļūda, dzēšot rīku: {{error}}",
|
||||
"com_ui_delete_tool_success": "Rīks veiksmīgi izdzēsts",
|
||||
"com_ui_delete_tool_save_reminder": "Rīks noņemts. Saglabājiet aģentu, lai piemērotu izmaiņas.",
|
||||
"com_ui_deleted": "Dzēsts",
|
||||
"com_ui_deleting_file": "Dzēšu failu...",
|
||||
"com_ui_descending": "Dilstošs",
|
||||
@@ -1012,7 +1012,7 @@
|
||||
"com_ui_memory_would_exceed": "Nevar saglabāt - pārsniegtu tokenu limitu par {{tokens}}. Izdzēsiet esošās atmiņas, lai atbrīvotu vietu.",
|
||||
"com_ui_mention": "Pieminiet galapunktu, assistentu vai iestatījumu, lai ātri uz to pārslēgtos",
|
||||
"com_ui_min_tags": "Nevar noņemt vairāk vērtību, vismaz {{0}} ir nepieciešamas.",
|
||||
"com_ui_minimal": "Minimāla",
|
||||
"com_ui_minimal": "Minimāls",
|
||||
"com_ui_misc": "Dažādi",
|
||||
"com_ui_model": "Modelis",
|
||||
"com_ui_model_parameters": "Modeļa Parametrus",
|
||||
@@ -1035,7 +1035,7 @@
|
||||
"com_ui_no_results_found": "Nav atrastu rezultātu",
|
||||
"com_ui_no_terms_content": "Nav noteikumu un nosacījumu satura, ko parādīt",
|
||||
"com_ui_no_valid_items": "Nav rezultātu",
|
||||
"com_ui_none": "Neviens",
|
||||
"com_ui_none": "Nekāds",
|
||||
"com_ui_not_used": "Nav izmantots",
|
||||
"com_ui_nothing_found": "Nekas nav atrasts",
|
||||
"com_ui_oauth": "OAuth",
|
||||
|
||||
@@ -834,7 +834,6 @@
|
||||
"com_ui_delete_tool": "Slett verktøy",
|
||||
"com_ui_delete_tool_confirm": "Er du sikker på at du vil slette dette verktøyet?",
|
||||
"com_ui_delete_tool_error": "En feil oppstod ved sletting av verktøyet: {{error}}",
|
||||
"com_ui_delete_tool_success": "Verktøyet ble slettet",
|
||||
"com_ui_deleted": "Slettet",
|
||||
"com_ui_deleting_file": "Sletter fil ...",
|
||||
"com_ui_descending": "Synkende",
|
||||
|
||||
@@ -27,6 +27,13 @@
|
||||
"com_agents_file_search_disabled": "Maak eerst een Agent aan voordat je bestanden uploadt voor File Search.",
|
||||
"com_agents_file_search_info": "Als deze functie is ingeschakeld, krijgt de agent informatie over de exacte bestandsnamen die hieronder staan vermeld, zodat deze relevante context uit deze bestanden kan ophalen.",
|
||||
"com_agents_instructions_placeholder": "De systeeminstructies die de agent gebruikt",
|
||||
"com_agents_link_copied": "Link gekopieerd",
|
||||
"com_agents_link_copy_failed": "Link niet gekopieerd",
|
||||
"com_agents_load_more_label": "Laad meer agenten van {{category}} categorie",
|
||||
"com_agents_loading": "Aan het laden...",
|
||||
"com_agents_mcp_icon_size": "Minimum formaat 128 x 128 px",
|
||||
"com_agents_mcp_info": "MCP-servers toevoegen aan je agent zodat deze taken kan uitvoeren en kan communiceren met externe services",
|
||||
"com_agents_mcp_name_placeholder": "Aangepast hulpmiddel",
|
||||
"com_agents_missing_provider_model": "Selecteer een provider en model voordat je een agent aanmaakt.",
|
||||
"com_agents_name_placeholder": "De naam van de agent",
|
||||
"com_agents_no_access": "Je hebt geen toegang om deze agent te bewerken.",
|
||||
|
||||
@@ -834,7 +834,6 @@
|
||||
"com_ui_delete_tool": "Видалити інструмент",
|
||||
"com_ui_delete_tool_confirm": "Ви дійсно хочете видалити цей інструмент?",
|
||||
"com_ui_delete_tool_error": "Помилка під час видалення інструменту: {{error}}",
|
||||
"com_ui_delete_tool_success": "Інструмент успішно видалено",
|
||||
"com_ui_deleted": "Видалено",
|
||||
"com_ui_deleting_file": "Видалення файлу...",
|
||||
"com_ui_descending": "За спаданням",
|
||||
|
||||
@@ -834,7 +834,6 @@
|
||||
"com_ui_delete_tool": "删除工具",
|
||||
"com_ui_delete_tool_confirm": "您确定要删除此工具吗?",
|
||||
"com_ui_delete_tool_error": "删除工具时发生错误:{{error}}",
|
||||
"com_ui_delete_tool_success": "工具删除成功",
|
||||
"com_ui_deleted": "已删除",
|
||||
"com_ui_deleting_file": "删除文件中...",
|
||||
"com_ui_descending": "降序",
|
||||
|
||||
@@ -6,7 +6,6 @@ import { useGetModelsQuery } from 'librechat-data-provider/react-query';
|
||||
import type { TPreset } from 'librechat-data-provider';
|
||||
import { useGetConvoIdQuery, useGetStartupConfig, useGetEndpointsQuery } from '~/data-provider';
|
||||
import { useNewConvo, useAppStartup, useAssistantListMap, useIdChangeEffect } from '~/hooks';
|
||||
import { useGetModelCostsQuery } from 'librechat-data-provider/react-query';
|
||||
import { getDefaultModelSpec, getModelSpecPreset, logger } from '~/utils';
|
||||
import { ToolCallsMapProvider } from '~/Providers';
|
||||
import ChatView from '~/components/Chat/ChatView';
|
||||
@@ -45,10 +44,6 @@ export default function ChatRoute() {
|
||||
const endpointsQuery = useGetEndpointsQuery({ enabled: isAuthenticated });
|
||||
const assistantListMap = useAssistantListMap();
|
||||
|
||||
const modelCostsQuery = useGetModelCostsQuery(initialConvoQuery.data?.modelHistory || [], {
|
||||
enabled: !!initialConvoQuery.data?.modelHistory?.length,
|
||||
});
|
||||
|
||||
const isTemporaryChat = conversation && conversation.expiredAt ? true : false;
|
||||
|
||||
useEffect(() => {
|
||||
@@ -153,7 +148,7 @@ export default function ChatRoute() {
|
||||
|
||||
return (
|
||||
<ToolCallsMapProvider conversationId={conversation.conversationId ?? ''}>
|
||||
<ChatView index={index} modelCosts={modelCostsQuery.data} />
|
||||
<ChatView index={index} />
|
||||
</ToolCallsMapProvider>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -34,7 +34,6 @@ const localStorageAtoms = {
|
||||
showCode: atomWithLocalStorage(LocalStorageKeys.SHOW_ANALYSIS_CODE, true),
|
||||
saveDrafts: atomWithLocalStorage('saveDrafts', true),
|
||||
showScrollButton: atomWithLocalStorage('showScrollButton', true),
|
||||
showCostTracking: atomWithLocalStorage('showCostTracking', true),
|
||||
forkSetting: atomWithLocalStorage('forkSetting', ''),
|
||||
splitAtTarget: atomWithLocalStorage('splitAtTarget', false),
|
||||
rememberDefaultFork: atomWithLocalStorage(LocalStorageKeys.REMEMBER_FORK_OPTION, false),
|
||||
|
||||
@@ -49,7 +49,7 @@ export default defineConfig(({ command }) => ({
|
||||
],
|
||||
globIgnores: ['images/**/*', '**/*.map', 'index.html'],
|
||||
maximumFileSizeToCacheInBytes: 4 * 1024 * 1024,
|
||||
navigateFallbackDenylist: [/^\/oauth/, /^\/api/],
|
||||
navigateFallbackDenylist: [/^\/oauth/, /^\/api/, /^\/admin\/openid/],
|
||||
},
|
||||
includeAssets: [],
|
||||
manifest: {
|
||||
|
||||
@@ -10,6 +10,9 @@ jest.mock('form-data', () => {
|
||||
getLength: jest.fn().mockReturnValue(100),
|
||||
}));
|
||||
});
|
||||
jest.mock('https-proxy-agent', () => ({
|
||||
HttpsProxyAgent: jest.fn().mockImplementation((url) => ({ proxyUrl: url })),
|
||||
}));
|
||||
jest.mock('axios', () => {
|
||||
const mockAxiosInstance = {
|
||||
get: jest.fn().mockResolvedValue({ data: {} }),
|
||||
@@ -44,6 +47,7 @@ jest.mock('~/utils/axios', () => ({
|
||||
|
||||
import * as fs from 'fs';
|
||||
import axios from 'axios';
|
||||
import { HttpsProxyAgent } from 'https-proxy-agent';
|
||||
import type { Readable } from 'stream';
|
||||
import type {
|
||||
MistralFileUploadResponse,
|
||||
@@ -1182,6 +1186,8 @@ describe('MistralOCR Service', () => {
|
||||
|
||||
describe('Mixed env var and hardcoded configuration', () => {
|
||||
beforeEach(() => {
|
||||
// Clean up any PROXY env var from previous tests
|
||||
delete process.env.PROXY;
|
||||
const mockReadStream: MockReadStream = {
|
||||
on: jest.fn().mockImplementation(function (
|
||||
this: MockReadStream,
|
||||
@@ -1708,9 +1714,403 @@ describe('MistralOCR Service', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('Proxy Configuration', () => {
|
||||
const originalProxy = process.env.PROXY;
|
||||
|
||||
beforeEach(() => {
|
||||
// Reset the HttpsProxyAgent mock to its default implementation
|
||||
(HttpsProxyAgent as unknown as jest.Mock).mockImplementation((url) => ({ proxyUrl: url }));
|
||||
// Clear any previous axios mock calls
|
||||
mockAxios.post!.mockClear();
|
||||
mockAxios.get!.mockClear();
|
||||
mockAxios.delete!.mockClear();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
if (originalProxy) {
|
||||
process.env.PROXY = originalProxy;
|
||||
} else {
|
||||
delete process.env.PROXY;
|
||||
}
|
||||
// Clear mocks after each test to prevent leaking
|
||||
mockAxios.post!.mockClear();
|
||||
mockAxios.get!.mockClear();
|
||||
mockAxios.delete!.mockClear();
|
||||
});
|
||||
|
||||
describe('uploadDocumentToMistral with proxy', () => {
|
||||
beforeEach(() => {
|
||||
const mockReadStream: MockReadStream = {
|
||||
on: jest.fn().mockImplementation(function (
|
||||
this: MockReadStream,
|
||||
event: string,
|
||||
handler: () => void,
|
||||
) {
|
||||
if (event === 'end') {
|
||||
handler();
|
||||
}
|
||||
return this;
|
||||
}),
|
||||
pipe: jest.fn().mockImplementation(function (this: MockReadStream) {
|
||||
return this;
|
||||
}),
|
||||
pause: jest.fn(),
|
||||
resume: jest.fn(),
|
||||
emit: jest.fn(),
|
||||
once: jest.fn(),
|
||||
destroy: jest.fn(),
|
||||
path: '/path/to/test.pdf',
|
||||
fd: 1,
|
||||
flags: 'r',
|
||||
mode: 0o666,
|
||||
autoClose: true,
|
||||
bytesRead: 0,
|
||||
closed: false,
|
||||
pending: false,
|
||||
};
|
||||
|
||||
(jest.mocked(fs).createReadStream as jest.Mock).mockReturnValue(mockReadStream);
|
||||
});
|
||||
|
||||
it('should use proxy configuration when PROXY env var is set', async () => {
|
||||
process.env.PROXY = 'http://proxy.example.com:8080';
|
||||
|
||||
const mockResponse: { data: MistralFileUploadResponse } = {
|
||||
data: {
|
||||
id: 'file-proxy-123',
|
||||
object: 'file',
|
||||
bytes: 1024,
|
||||
created_at: Date.now(),
|
||||
filename: 'test.pdf',
|
||||
purpose: 'ocr',
|
||||
},
|
||||
};
|
||||
mockAxios.post!.mockResolvedValueOnce(mockResponse);
|
||||
|
||||
await uploadDocumentToMistral({
|
||||
filePath: '/path/to/test.pdf',
|
||||
fileName: 'test.pdf',
|
||||
apiKey: 'test-api-key',
|
||||
});
|
||||
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
'https://api.mistral.ai/v1/files',
|
||||
expect.anything(),
|
||||
expect.objectContaining({
|
||||
httpsAgent: expect.objectContaining({
|
||||
proxyUrl: 'http://proxy.example.com:8080',
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle proxy URL with authentication', async () => {
|
||||
process.env.PROXY = 'http://user:pass@proxy.example.com:8080';
|
||||
|
||||
const mockResponse: { data: MistralFileUploadResponse } = {
|
||||
data: {
|
||||
id: 'file-proxy-auth-123',
|
||||
object: 'file',
|
||||
bytes: 1024,
|
||||
created_at: Date.now(),
|
||||
filename: 'test.pdf',
|
||||
purpose: 'ocr',
|
||||
},
|
||||
};
|
||||
mockAxios.post!.mockResolvedValueOnce(mockResponse);
|
||||
|
||||
await uploadDocumentToMistral({
|
||||
filePath: '/path/to/test.pdf',
|
||||
fileName: 'test.pdf',
|
||||
apiKey: 'test-api-key',
|
||||
});
|
||||
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
'https://api.mistral.ai/v1/files',
|
||||
expect.anything(),
|
||||
expect.objectContaining({
|
||||
httpsAgent: expect.objectContaining({
|
||||
proxyUrl: 'http://user:pass@proxy.example.com:8080',
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle IPv6 proxy addresses', async () => {
|
||||
process.env.PROXY = 'http://[::1]:8080';
|
||||
|
||||
const mockResponse: { data: MistralFileUploadResponse } = {
|
||||
data: {
|
||||
id: 'file-proxy-ipv6-123',
|
||||
object: 'file',
|
||||
bytes: 1024,
|
||||
created_at: Date.now(),
|
||||
filename: 'test.pdf',
|
||||
purpose: 'ocr',
|
||||
},
|
||||
};
|
||||
mockAxios.post!.mockResolvedValueOnce(mockResponse);
|
||||
|
||||
await uploadDocumentToMistral({
|
||||
filePath: '/path/to/test.pdf',
|
||||
fileName: 'test.pdf',
|
||||
apiKey: 'test-api-key',
|
||||
});
|
||||
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
'https://api.mistral.ai/v1/files',
|
||||
expect.anything(),
|
||||
expect.objectContaining({
|
||||
httpsAgent: expect.objectContaining({
|
||||
proxyUrl: 'http://[::1]:8080',
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should not use proxy when PROXY env var is not set', async () => {
|
||||
delete process.env.PROXY;
|
||||
|
||||
const mockResponse: { data: MistralFileUploadResponse } = {
|
||||
data: {
|
||||
id: 'file-no-proxy-123',
|
||||
object: 'file',
|
||||
bytes: 1024,
|
||||
created_at: Date.now(),
|
||||
filename: 'test.pdf',
|
||||
purpose: 'ocr',
|
||||
},
|
||||
};
|
||||
mockAxios.post!.mockResolvedValueOnce(mockResponse);
|
||||
|
||||
await uploadDocumentToMistral({
|
||||
filePath: '/path/to/test.pdf',
|
||||
fileName: 'test.pdf',
|
||||
apiKey: 'test-api-key',
|
||||
});
|
||||
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
'https://api.mistral.ai/v1/files',
|
||||
expect.anything(),
|
||||
expect.not.objectContaining({
|
||||
httpsAgent: expect.anything(),
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('performOCR with proxy', () => {
|
||||
it('should use proxy configuration when PROXY env var is set', async () => {
|
||||
process.env.PROXY = 'http://proxy.example.com:3128';
|
||||
|
||||
const mockResponse: { data: OCRResult } = {
|
||||
data: {
|
||||
model: 'mistral-ocr-latest',
|
||||
pages: [
|
||||
{
|
||||
index: 0,
|
||||
markdown: 'Proxy test content',
|
||||
images: [],
|
||||
dimensions: { dpi: 300, height: 1100, width: 850 },
|
||||
},
|
||||
],
|
||||
document_annotation: '',
|
||||
usage_info: {
|
||||
pages_processed: 1,
|
||||
doc_size_bytes: 1024,
|
||||
},
|
||||
},
|
||||
};
|
||||
mockAxios.post!.mockResolvedValueOnce(mockResponse);
|
||||
|
||||
await performOCR({
|
||||
apiKey: 'test-api-key',
|
||||
url: 'https://document-url.com',
|
||||
model: 'mistral-ocr-latest',
|
||||
documentType: 'document_url',
|
||||
});
|
||||
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
'https://api.mistral.ai/v1/ocr',
|
||||
expect.anything(),
|
||||
expect.objectContaining({
|
||||
httpsAgent: expect.objectContaining({
|
||||
proxyUrl: 'http://proxy.example.com:3128',
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle malformed proxy URLs gracefully', async () => {
|
||||
(HttpsProxyAgent as unknown as jest.Mock).mockImplementationOnce(() => {
|
||||
throw new Error('Invalid URL');
|
||||
});
|
||||
process.env.PROXY = 'not-a-valid-url';
|
||||
|
||||
const mockResponse: { data: OCRResult } = {
|
||||
data: {
|
||||
model: 'mistral-ocr-latest',
|
||||
pages: [
|
||||
{
|
||||
index: 0,
|
||||
markdown: 'Test content',
|
||||
images: [],
|
||||
dimensions: { dpi: 300, height: 1100, width: 850 },
|
||||
},
|
||||
],
|
||||
document_annotation: '',
|
||||
usage_info: {
|
||||
pages_processed: 1,
|
||||
doc_size_bytes: 1024,
|
||||
},
|
||||
},
|
||||
};
|
||||
mockAxios.post!.mockResolvedValueOnce(mockResponse);
|
||||
|
||||
await expect(
|
||||
performOCR({
|
||||
apiKey: 'test-api-key',
|
||||
url: 'https://document-url.com',
|
||||
}),
|
||||
).rejects.toThrow('Invalid URL');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Azure Mistral OCR with proxy', () => {
|
||||
beforeEach(() => {
|
||||
(jest.mocked(fs).readFileSync as jest.Mock).mockReturnValue(
|
||||
Buffer.from('mock-file-content'),
|
||||
);
|
||||
});
|
||||
|
||||
it('should use proxy for Azure Mistral OCR requests', async () => {
|
||||
process.env.PROXY = 'http://proxy.example.com:8080';
|
||||
|
||||
mockLoadAuthValues.mockResolvedValue({
|
||||
OCR_API_KEY: 'azure-api-key',
|
||||
OCR_BASEURL: 'https://azure.mistral.ai/v1',
|
||||
});
|
||||
|
||||
mockAxios.post!.mockResolvedValueOnce({
|
||||
data: {
|
||||
model: 'mistral-ocr-latest',
|
||||
pages: [
|
||||
{
|
||||
index: 0,
|
||||
markdown: 'Azure OCR with proxy',
|
||||
images: [],
|
||||
dimensions: { dpi: 300, height: 1100, width: 850 },
|
||||
},
|
||||
],
|
||||
document_annotation: '',
|
||||
usage_info: {
|
||||
pages_processed: 1,
|
||||
doc_size_bytes: 1024,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const req = {
|
||||
user: { id: 'user123' },
|
||||
config: {
|
||||
ocr: {
|
||||
apiKey: '${OCR_API_KEY}',
|
||||
baseURL: '${OCR_BASEURL}',
|
||||
mistralModel: 'mistral-ocr-latest',
|
||||
},
|
||||
},
|
||||
} as unknown as ServerRequest;
|
||||
|
||||
const file = {
|
||||
path: '/tmp/upload/azure-file.pdf',
|
||||
originalname: 'azure-document.pdf',
|
||||
mimetype: 'application/pdf',
|
||||
} as Express.Multer.File;
|
||||
|
||||
await uploadAzureMistralOCR({
|
||||
req,
|
||||
file,
|
||||
loadAuthValues: mockLoadAuthValues,
|
||||
});
|
||||
|
||||
expect(mockAxios.post).toHaveBeenCalledWith(
|
||||
'https://azure.mistral.ai/v1/ocr',
|
||||
expect.anything(),
|
||||
expect.objectContaining({
|
||||
httpsAgent: expect.objectContaining({
|
||||
proxyUrl: 'http://proxy.example.com:8080',
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getSignedUrl with proxy', () => {
|
||||
it('should use proxy configuration when PROXY env var is set', async () => {
|
||||
process.env.PROXY = 'https://secure-proxy.example.com:443';
|
||||
|
||||
const mockResponse: { data: MistralSignedUrlResponse } = {
|
||||
data: {
|
||||
url: 'https://signed-url.com',
|
||||
expires_at: Date.now() + 86400000,
|
||||
},
|
||||
};
|
||||
mockAxios.get!.mockResolvedValueOnce(mockResponse);
|
||||
|
||||
await getSignedUrl({
|
||||
fileId: 'file-123',
|
||||
apiKey: 'test-api-key',
|
||||
});
|
||||
|
||||
expect(mockAxios.get).toHaveBeenCalledWith(
|
||||
'https://api.mistral.ai/v1/files/file-123/url?expiry=24',
|
||||
expect.objectContaining({
|
||||
httpsAgent: expect.objectContaining({
|
||||
proxyUrl: 'https://secure-proxy.example.com:443',
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('deleteMistralFile with proxy', () => {
|
||||
it('should use proxy configuration when PROXY env var is set', async () => {
|
||||
process.env.PROXY = 'socks5://proxy.example.com:1080';
|
||||
|
||||
mockAxios.delete!.mockResolvedValueOnce({ data: {} });
|
||||
|
||||
await deleteMistralFile({
|
||||
fileId: 'file-123',
|
||||
apiKey: 'test-api-key',
|
||||
});
|
||||
|
||||
expect(mockAxios.delete).toHaveBeenCalledWith(
|
||||
'https://api.mistral.ai/v1/files/file-123',
|
||||
expect.objectContaining({
|
||||
httpsAgent: expect.objectContaining({
|
||||
proxyUrl: 'socks5://proxy.example.com:1080',
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('uploadAzureMistralOCR', () => {
|
||||
beforeEach(() => {
|
||||
(jest.mocked(fs).readFileSync as jest.Mock).mockReturnValue(Buffer.from('mock-file-content'));
|
||||
// Reset the HttpsProxyAgent mock to its default implementation for Azure tests
|
||||
(HttpsProxyAgent as unknown as jest.Mock).mockImplementation((url) => ({ proxyUrl: url }));
|
||||
// Clean up any PROXY env var from previous tests
|
||||
delete process.env.PROXY;
|
||||
// Reset axios mocks completely to clear any queued responses
|
||||
mockAxios.post!.mockReset();
|
||||
mockAxios.get!.mockReset();
|
||||
mockAxios.delete!.mockReset();
|
||||
// Re-establish default resolved values
|
||||
mockAxios.post!.mockResolvedValue({ data: {} });
|
||||
mockAxios.get!.mockResolvedValue({ data: {} });
|
||||
mockAxios.delete!.mockResolvedValue({ data: {} });
|
||||
});
|
||||
|
||||
it('should process OCR using Azure Mistral with base64 encoding', async () => {
|
||||
@@ -1796,6 +2196,11 @@ describe('MistralOCR Service', () => {
|
||||
});
|
||||
|
||||
describe('Mixed env var and hardcoded configuration', () => {
|
||||
beforeEach(() => {
|
||||
// Clean up any PROXY env var from previous tests
|
||||
delete process.env.PROXY;
|
||||
});
|
||||
|
||||
it('should preserve hardcoded baseURL when only apiKey is an env var', async () => {
|
||||
// This test demonstrates the current bug
|
||||
mockLoadAuthValues.mockResolvedValue({
|
||||
|
||||
@@ -2,6 +2,7 @@ import * as fs from 'fs';
|
||||
import * as path from 'path';
|
||||
import FormData from 'form-data';
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import { HttpsProxyAgent } from 'https-proxy-agent';
|
||||
import {
|
||||
FileSources,
|
||||
envVarRegex,
|
||||
@@ -9,7 +10,7 @@ import {
|
||||
extractVariableName,
|
||||
} from 'librechat-data-provider';
|
||||
import type { TCustomConfig } from 'librechat-data-provider';
|
||||
import type { AxiosError } from 'axios';
|
||||
import type { AxiosError, AxiosRequestConfig } from 'axios';
|
||||
import type {
|
||||
MistralFileUploadResponse,
|
||||
MistralSignedUrlResponse,
|
||||
@@ -77,15 +78,21 @@ export async function uploadDocumentToMistral({
|
||||
const fileStream = fs.createReadStream(filePath);
|
||||
form.append('file', fileStream, { filename: actualFileName });
|
||||
|
||||
const config: AxiosRequestConfig = {
|
||||
headers: {
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
...form.getHeaders(),
|
||||
},
|
||||
maxBodyLength: Infinity,
|
||||
maxContentLength: Infinity,
|
||||
};
|
||||
|
||||
if (process.env.PROXY) {
|
||||
config.httpsAgent = new HttpsProxyAgent(process.env.PROXY);
|
||||
}
|
||||
|
||||
return axios
|
||||
.post(`${baseURL}/files`, form, {
|
||||
headers: {
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
...form.getHeaders(),
|
||||
},
|
||||
maxBodyLength: Infinity,
|
||||
maxContentLength: Infinity,
|
||||
})
|
||||
.post(`${baseURL}/files`, form, config)
|
||||
.then((res) => res.data)
|
||||
.catch((error) => {
|
||||
throw error;
|
||||
@@ -103,12 +110,18 @@ export async function getSignedUrl({
|
||||
expiry?: number;
|
||||
baseURL?: string;
|
||||
}): Promise<MistralSignedUrlResponse> {
|
||||
const config: AxiosRequestConfig = {
|
||||
headers: {
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
},
|
||||
};
|
||||
|
||||
if (process.env.PROXY) {
|
||||
config.httpsAgent = new HttpsProxyAgent(process.env.PROXY);
|
||||
}
|
||||
|
||||
return axios
|
||||
.get(`${baseURL}/files/${fileId}/url?expiry=${expiry}`, {
|
||||
headers: {
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
},
|
||||
})
|
||||
.get(`${baseURL}/files/${fileId}/url?expiry=${expiry}`, config)
|
||||
.then((res) => res.data)
|
||||
.catch((error) => {
|
||||
logger.error('Error fetching signed URL:', error.message);
|
||||
@@ -139,6 +152,18 @@ export async function performOCR({
|
||||
documentType?: 'document_url' | 'image_url';
|
||||
}): Promise<OCRResult> {
|
||||
const documentKey = documentType === 'image_url' ? 'image_url' : 'document_url';
|
||||
|
||||
const config: AxiosRequestConfig = {
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
},
|
||||
};
|
||||
|
||||
if (process.env.PROXY) {
|
||||
config.httpsAgent = new HttpsProxyAgent(process.env.PROXY);
|
||||
}
|
||||
|
||||
return axios
|
||||
.post(
|
||||
`${baseURL}/ocr`,
|
||||
@@ -151,12 +176,7 @@ export async function performOCR({
|
||||
[documentKey]: url,
|
||||
},
|
||||
},
|
||||
{
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
},
|
||||
},
|
||||
config,
|
||||
)
|
||||
.then((res) => res.data)
|
||||
.catch((error) => {
|
||||
@@ -182,12 +202,18 @@ export async function deleteMistralFile({
|
||||
apiKey: string;
|
||||
baseURL?: string;
|
||||
}): Promise<void> {
|
||||
const config: AxiosRequestConfig = {
|
||||
headers: {
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
},
|
||||
};
|
||||
|
||||
if (process.env.PROXY) {
|
||||
config.httpsAgent = new HttpsProxyAgent(process.env.PROXY);
|
||||
}
|
||||
|
||||
try {
|
||||
const result = await axios.delete(`${baseURL}/files/${fileId}`, {
|
||||
headers: {
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
},
|
||||
});
|
||||
const result = await axios.delete(`${baseURL}/files/${fileId}`, config);
|
||||
logger.debug(`Mistral file ${fileId} deleted successfully:`, result.data);
|
||||
} catch (error) {
|
||||
logger.error(`Error deleting Mistral file ${fileId}:`, error);
|
||||
@@ -543,17 +569,23 @@ async function createJWT(serviceKey: GoogleServiceAccount): Promise<string> {
|
||||
* Exchanges JWT for access token
|
||||
*/
|
||||
async function exchangeJWTForAccessToken(jwt: string): Promise<string> {
|
||||
const config: AxiosRequestConfig = {
|
||||
headers: {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
},
|
||||
};
|
||||
|
||||
if (process.env.PROXY) {
|
||||
config.httpsAgent = new HttpsProxyAgent(process.env.PROXY);
|
||||
}
|
||||
|
||||
const response = await axios.post(
|
||||
'https://oauth2.googleapis.com/token',
|
||||
new URLSearchParams({
|
||||
grant_type: 'urn:ietf:params:oauth:grant-type:jwt-bearer',
|
||||
assertion: jwt,
|
||||
}),
|
||||
{
|
||||
headers: {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
},
|
||||
},
|
||||
config,
|
||||
);
|
||||
|
||||
if (!response.data?.access_token) {
|
||||
@@ -608,14 +640,20 @@ async function performGoogleVertexOCR({
|
||||
},
|
||||
});
|
||||
|
||||
const config: AxiosRequestConfig = {
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
Accept: 'application/json',
|
||||
},
|
||||
};
|
||||
|
||||
if (process.env.PROXY) {
|
||||
config.httpsAgent = new HttpsProxyAgent(process.env.PROXY);
|
||||
}
|
||||
|
||||
return axios
|
||||
.post(baseURL, requestBody, {
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
Accept: 'application/json',
|
||||
},
|
||||
})
|
||||
.post(baseURL, requestBody, config)
|
||||
.then((res) => {
|
||||
logger.debug('Google Vertex AI response received');
|
||||
return res.data;
|
||||
|
||||
@@ -14,6 +14,7 @@ export * from './utils';
|
||||
export * from './db/utils';
|
||||
/* OAuth */
|
||||
export * from './oauth';
|
||||
export * from './mcp/oauth/OAuthReconnectionManager';
|
||||
/* Crypto */
|
||||
export * from './crypto';
|
||||
/* Flow */
|
||||
|
||||
@@ -6,6 +6,7 @@ import type { FlowStateManager } from '~/flow/manager';
|
||||
import type { FlowMetadata } from '~/flow/types';
|
||||
import type * as t from './types';
|
||||
import { MCPTokenStorage, MCPOAuthHandler } from '~/mcp/oauth';
|
||||
import { sanitizeUrlForLogging } from './utils';
|
||||
import { MCPConnection } from './connection';
|
||||
import { processMCPEnv } from '~/utils';
|
||||
|
||||
@@ -308,7 +309,9 @@ export class MCPConnectionFactory {
|
||||
metadata?: OAuthMetadata;
|
||||
} | null> {
|
||||
const serverUrl = (this.serverConfig as t.SSEOptions | t.StreamableHTTPOptions).url;
|
||||
logger.debug(`${this.logPrefix} \`handleOAuthRequired\` called with serverUrl: ${serverUrl}`);
|
||||
logger.debug(
|
||||
`${this.logPrefix} \`handleOAuthRequired\` called with serverUrl: ${serverUrl ? sanitizeUrlForLogging(serverUrl) : 'undefined'}`,
|
||||
);
|
||||
|
||||
if (!this.flowManager || !serverUrl) {
|
||||
logger.error(
|
||||
|
||||
@@ -7,6 +7,7 @@ import type { JsonSchemaType } from '~/types';
|
||||
import type * as t from '~/mcp/types';
|
||||
import { ConnectionsRepository } from '~/mcp/ConnectionsRepository';
|
||||
import { detectOAuthRequirement } from '~/mcp/oauth';
|
||||
import { sanitizeUrlForLogging } from '~/mcp/utils';
|
||||
import { processMCPEnv } from '~/utils';
|
||||
import { CONSTANTS } from '~/mcp/enum';
|
||||
|
||||
@@ -183,7 +184,7 @@ export class MCPServersRegistry {
|
||||
const prefix = this.prefix(serverName);
|
||||
const config = this.parsedConfigs[serverName];
|
||||
logger.info(`${prefix} -------------------------------------------------┐`);
|
||||
logger.info(`${prefix} URL: ${config.url}`);
|
||||
logger.info(`${prefix} URL: ${config.url ? sanitizeUrlForLogging(config.url) : 'N/A'}`);
|
||||
logger.info(`${prefix} OAuth Required: ${config.requiresOAuth}`);
|
||||
logger.info(`${prefix} Capabilities: ${config.capabilities}`);
|
||||
logger.info(`${prefix} Tools: ${config.tools}`);
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
import { EventEmitter } from 'events';
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import { fetch as undiciFetch, Agent } from 'undici';
|
||||
import {
|
||||
StdioClientTransport,
|
||||
getDefaultEnvironment,
|
||||
} from '@modelcontextprotocol/sdk/client/stdio.js';
|
||||
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
|
||||
import { ResourceListChangedNotificationSchema } from '@modelcontextprotocol/sdk/types.js';
|
||||
import { WebSocketClientTransport } from '@modelcontextprotocol/sdk/client/websocket.js';
|
||||
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
|
||||
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
|
||||
import { WebSocketClientTransport } from '@modelcontextprotocol/sdk/client/websocket.js';
|
||||
import { ResourceListChangedNotificationSchema } from '@modelcontextprotocol/sdk/types.js';
|
||||
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js';
|
||||
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
|
||||
import type { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js';
|
||||
import type {
|
||||
@@ -18,8 +18,9 @@ import type {
|
||||
Response as UndiciResponse,
|
||||
} from 'undici';
|
||||
import type { MCPOAuthTokens } from './oauth/types';
|
||||
import { mcpConfig } from './mcpConfig';
|
||||
import type * as t from './types';
|
||||
import { sanitizeUrlForLogging } from './utils';
|
||||
import { mcpConfig } from './mcpConfig';
|
||||
|
||||
type FetchLike = (url: string | URL, init?: RequestInit) => Promise<Response>;
|
||||
|
||||
@@ -238,7 +239,9 @@ export class MCPConnection extends EventEmitter {
|
||||
}
|
||||
this.url = options.url;
|
||||
const url = new URL(options.url);
|
||||
logger.info(`${this.getLogPrefix()} Creating SSE transport: ${url.toString()}`);
|
||||
logger.info(
|
||||
`${this.getLogPrefix()} Creating SSE transport: ${sanitizeUrlForLogging(url)}`,
|
||||
);
|
||||
const abortController = new AbortController();
|
||||
|
||||
/** Add OAuth token to headers if available */
|
||||
@@ -293,7 +296,7 @@ export class MCPConnection extends EventEmitter {
|
||||
this.url = options.url;
|
||||
const url = new URL(options.url);
|
||||
logger.info(
|
||||
`${this.getLogPrefix()} Creating streamable-http transport: ${url.toString()}`,
|
||||
`${this.getLogPrefix()} Creating streamable-http transport: ${sanitizeUrlForLogging(url)}`,
|
||||
);
|
||||
const abortController = new AbortController();
|
||||
|
||||
@@ -473,7 +476,9 @@ export class MCPConnection extends EventEmitter {
|
||||
logger.warn(`${this.getLogPrefix()} OAuth authentication required`);
|
||||
this.oauthRequired = true;
|
||||
const serverUrl = this.url;
|
||||
logger.debug(`${this.getLogPrefix()} Server URL for OAuth: ${serverUrl}`);
|
||||
logger.debug(
|
||||
`${this.getLogPrefix()} Server URL for OAuth: ${serverUrl ? sanitizeUrlForLogging(serverUrl) : 'undefined'}`,
|
||||
);
|
||||
|
||||
const oauthTimeout = this.options.initTimeout ?? 60000 * 2;
|
||||
/** Promise that will resolve when OAuth is handled */
|
||||
|
||||
294
packages/api/src/mcp/oauth/OAuthReconnectionManager.test.ts
Normal file
294
packages/api/src/mcp/oauth/OAuthReconnectionManager.test.ts
Normal file
@@ -0,0 +1,294 @@
|
||||
import { TokenMethods } from '@librechat/data-schemas';
|
||||
import { FlowStateManager, MCPConnection, MCPOAuthTokens, MCPOptions } from '../..';
|
||||
import { MCPManager } from '../MCPManager';
|
||||
import { OAuthReconnectionManager } from './OAuthReconnectionManager';
|
||||
import { OAuthReconnectionTracker } from './OAuthReconnectionTracker';
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
logger: {
|
||||
info: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
error: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
jest.mock('../MCPManager');
|
||||
|
||||
describe('OAuthReconnectionManager', () => {
|
||||
let flowManager: jest.Mocked<FlowStateManager<null>>;
|
||||
let tokenMethods: jest.Mocked<TokenMethods>;
|
||||
let mockMCPManager: jest.Mocked<MCPManager>;
|
||||
let reconnectionManager: OAuthReconnectionManager;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
|
||||
// Reset singleton instance
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
(OAuthReconnectionManager as any).instance = null;
|
||||
|
||||
// Setup mock flow manager
|
||||
flowManager = {
|
||||
createFlow: jest.fn(),
|
||||
completeFlow: jest.fn(),
|
||||
failFlow: jest.fn(),
|
||||
deleteFlow: jest.fn(),
|
||||
getFlow: jest.fn(),
|
||||
} as unknown as jest.Mocked<FlowStateManager<null>>;
|
||||
|
||||
// Setup mock token methods
|
||||
tokenMethods = {
|
||||
findToken: jest.fn(),
|
||||
createToken: jest.fn(),
|
||||
updateToken: jest.fn(),
|
||||
deleteToken: jest.fn(),
|
||||
} as unknown as jest.Mocked<TokenMethods>;
|
||||
|
||||
// Setup mock MCP Manager
|
||||
mockMCPManager = {
|
||||
getOAuthServers: jest.fn(),
|
||||
getUserConnection: jest.fn(),
|
||||
getUserConnections: jest.fn(),
|
||||
disconnectUserConnection: jest.fn(),
|
||||
getRawConfig: jest.fn(),
|
||||
} as unknown as jest.Mocked<MCPManager>;
|
||||
|
||||
(MCPManager.getInstance as jest.Mock).mockReturnValue(mockMCPManager);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('Singleton Pattern', () => {
|
||||
it('should create instance successfully', async () => {
|
||||
const instance = await OAuthReconnectionManager.createInstance(flowManager, tokenMethods);
|
||||
expect(instance).toBeInstanceOf(OAuthReconnectionManager);
|
||||
});
|
||||
|
||||
it('should throw error when creating instance twice', async () => {
|
||||
await OAuthReconnectionManager.createInstance(flowManager, tokenMethods);
|
||||
await expect(
|
||||
OAuthReconnectionManager.createInstance(flowManager, tokenMethods),
|
||||
).rejects.toThrow('OAuthReconnectionManager already initialized');
|
||||
});
|
||||
|
||||
it('should throw error when getting instance before creation', () => {
|
||||
expect(() => OAuthReconnectionManager.getInstance()).toThrow(
|
||||
'OAuthReconnectionManager not initialized',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isReconnecting', () => {
|
||||
let reconnectionTracker: OAuthReconnectionTracker;
|
||||
beforeEach(async () => {
|
||||
reconnectionTracker = new OAuthReconnectionTracker();
|
||||
reconnectionManager = await OAuthReconnectionManager.createInstance(
|
||||
flowManager,
|
||||
tokenMethods,
|
||||
reconnectionTracker,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return true when server is actively reconnecting', () => {
|
||||
const userId = 'user-123';
|
||||
const serverName = 'test-server';
|
||||
|
||||
expect(reconnectionManager.isReconnecting(userId, serverName)).toBe(false);
|
||||
|
||||
reconnectionTracker.setActive(userId, serverName);
|
||||
const result = reconnectionManager.isReconnecting(userId, serverName);
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
it('should return false when server is not reconnecting', () => {
|
||||
const userId = 'user-123';
|
||||
const serverName = 'test-server';
|
||||
|
||||
const result = reconnectionManager.isReconnecting(userId, serverName);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('clearReconnection', () => {
|
||||
let reconnectionTracker: OAuthReconnectionTracker;
|
||||
beforeEach(async () => {
|
||||
reconnectionTracker = new OAuthReconnectionTracker();
|
||||
reconnectionManager = await OAuthReconnectionManager.createInstance(
|
||||
flowManager,
|
||||
tokenMethods,
|
||||
reconnectionTracker,
|
||||
);
|
||||
});
|
||||
|
||||
it('should clear both failed and active reconnection states', () => {
|
||||
const userId = 'user-123';
|
||||
const serverName = 'test-server';
|
||||
|
||||
reconnectionTracker.setFailed(userId, serverName);
|
||||
reconnectionTracker.setActive(userId, serverName);
|
||||
|
||||
reconnectionManager.clearReconnection(userId, serverName);
|
||||
|
||||
expect(reconnectionManager.isReconnecting(userId, serverName)).toBe(false);
|
||||
expect(reconnectionTracker.isFailed(userId, serverName)).toBe(false);
|
||||
expect(reconnectionTracker.isActive(userId, serverName)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('reconnectServers', () => {
|
||||
let reconnectionTracker: OAuthReconnectionTracker;
|
||||
beforeEach(async () => {
|
||||
reconnectionTracker = new OAuthReconnectionTracker();
|
||||
reconnectionManager = await OAuthReconnectionManager.createInstance(
|
||||
flowManager,
|
||||
tokenMethods,
|
||||
reconnectionTracker,
|
||||
);
|
||||
});
|
||||
|
||||
it('should reconnect eligible servers', async () => {
|
||||
const userId = 'user-123';
|
||||
const oauthServers = new Set(['server1', 'server2', 'server3']);
|
||||
mockMCPManager.getOAuthServers.mockReturnValue(oauthServers);
|
||||
|
||||
// server1: has failed reconnection
|
||||
reconnectionTracker.setFailed(userId, 'server1');
|
||||
|
||||
// server2: already connected
|
||||
const mockConnection = {
|
||||
isConnected: jest.fn().mockResolvedValue(true),
|
||||
};
|
||||
const userConnections = new Map([['server2', mockConnection]]);
|
||||
mockMCPManager.getUserConnections.mockReturnValue(
|
||||
userConnections as unknown as Map<string, MCPConnection>,
|
||||
);
|
||||
|
||||
// server3: has valid token and not connected
|
||||
tokenMethods.findToken.mockImplementation(async ({ identifier }) => {
|
||||
if (identifier === 'mcp:server3') {
|
||||
return {
|
||||
userId,
|
||||
identifier,
|
||||
expiresAt: new Date(Date.now() + 3600000), // 1 hour from now
|
||||
} as unknown as MCPOAuthTokens;
|
||||
}
|
||||
return null;
|
||||
});
|
||||
|
||||
// Mock successful reconnection
|
||||
const mockNewConnection = {
|
||||
isConnected: jest.fn().mockResolvedValue(true),
|
||||
disconnect: jest.fn(),
|
||||
};
|
||||
mockMCPManager.getUserConnection.mockResolvedValue(
|
||||
mockNewConnection as unknown as MCPConnection,
|
||||
);
|
||||
mockMCPManager.getRawConfig.mockReturnValue({ initTimeout: 5000 } as unknown as MCPOptions);
|
||||
|
||||
await reconnectionManager.reconnectServers(userId);
|
||||
|
||||
// Verify server3 was marked as active
|
||||
expect(reconnectionTracker.isActive(userId, 'server3')).toBe(true);
|
||||
|
||||
// Wait for async tryReconnect to complete
|
||||
await new Promise((resolve) => setTimeout(resolve, 100));
|
||||
|
||||
// Verify reconnection was attempted for server3
|
||||
expect(mockMCPManager.getUserConnection).toHaveBeenCalledWith({
|
||||
serverName: 'server3',
|
||||
user: { id: userId },
|
||||
flowManager,
|
||||
tokenMethods,
|
||||
forceNew: false,
|
||||
connectionTimeout: 5000,
|
||||
returnOnOAuth: true,
|
||||
});
|
||||
|
||||
// Verify successful reconnection cleared the states
|
||||
expect(reconnectionTracker.isFailed(userId, 'server3')).toBe(false);
|
||||
expect(reconnectionTracker.isActive(userId, 'server3')).toBe(false);
|
||||
});
|
||||
|
||||
it('should handle failed reconnection attempts', async () => {
|
||||
const userId = 'user-123';
|
||||
const oauthServers = new Set(['server1']);
|
||||
mockMCPManager.getOAuthServers.mockReturnValue(oauthServers);
|
||||
|
||||
// server1: has valid token
|
||||
tokenMethods.findToken.mockResolvedValue({
|
||||
userId,
|
||||
identifier: 'mcp:server1',
|
||||
expiresAt: new Date(Date.now() + 3600000),
|
||||
} as unknown as MCPOAuthTokens);
|
||||
|
||||
// Mock failed connection
|
||||
mockMCPManager.getUserConnection.mockRejectedValue(new Error('Connection failed'));
|
||||
mockMCPManager.getRawConfig.mockReturnValue({} as unknown as MCPOptions);
|
||||
|
||||
await reconnectionManager.reconnectServers(userId);
|
||||
|
||||
// Wait for async tryReconnect to complete
|
||||
await new Promise((resolve) => setTimeout(resolve, 100));
|
||||
|
||||
// Verify failure handling
|
||||
expect(reconnectionTracker.isFailed(userId, 'server1')).toBe(true);
|
||||
expect(reconnectionTracker.isActive(userId, 'server1')).toBe(false);
|
||||
expect(mockMCPManager.disconnectUserConnection).toHaveBeenCalledWith(userId, 'server1');
|
||||
});
|
||||
|
||||
it('should not reconnect servers with expired tokens', async () => {
|
||||
const userId = 'user-123';
|
||||
const oauthServers = new Set(['server1']);
|
||||
mockMCPManager.getOAuthServers.mockReturnValue(oauthServers);
|
||||
|
||||
// server1: has expired token
|
||||
tokenMethods.findToken.mockResolvedValue({
|
||||
userId,
|
||||
identifier: 'mcp:server1',
|
||||
expiresAt: new Date(Date.now() - 3600000), // 1 hour ago
|
||||
} as unknown as MCPOAuthTokens);
|
||||
|
||||
await reconnectionManager.reconnectServers(userId);
|
||||
|
||||
// Verify no reconnection attempt was made
|
||||
expect(reconnectionTracker.isActive(userId, 'server1')).toBe(false);
|
||||
expect(mockMCPManager.getUserConnection).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle connection that returns but is not connected', async () => {
|
||||
const userId = 'user-123';
|
||||
const oauthServers = new Set(['server1']);
|
||||
mockMCPManager.getOAuthServers.mockReturnValue(oauthServers);
|
||||
|
||||
tokenMethods.findToken.mockResolvedValue({
|
||||
userId,
|
||||
identifier: 'mcp:server1',
|
||||
expiresAt: new Date(Date.now() + 3600000),
|
||||
} as unknown as MCPOAuthTokens);
|
||||
|
||||
// Mock connection that returns but is not connected
|
||||
const mockConnection = {
|
||||
isConnected: jest.fn().mockResolvedValue(false),
|
||||
disconnect: jest.fn(),
|
||||
};
|
||||
mockMCPManager.getUserConnection.mockResolvedValue(
|
||||
mockConnection as unknown as MCPConnection,
|
||||
);
|
||||
mockMCPManager.getRawConfig.mockReturnValue({} as unknown as MCPOptions);
|
||||
|
||||
await reconnectionManager.reconnectServers(userId);
|
||||
|
||||
// Wait for async tryReconnect to complete
|
||||
await new Promise((resolve) => setTimeout(resolve, 100));
|
||||
|
||||
// Verify failure handling
|
||||
expect(mockConnection.disconnect).toHaveBeenCalled();
|
||||
expect(reconnectionTracker.isFailed(userId, 'server1')).toBe(true);
|
||||
expect(reconnectionTracker.isActive(userId, 'server1')).toBe(false);
|
||||
expect(mockMCPManager.disconnectUserConnection).toHaveBeenCalledWith(userId, 'server1');
|
||||
});
|
||||
});
|
||||
});
|
||||
163
packages/api/src/mcp/oauth/OAuthReconnectionManager.ts
Normal file
163
packages/api/src/mcp/oauth/OAuthReconnectionManager.ts
Normal file
@@ -0,0 +1,163 @@
|
||||
import { logger } from '@librechat/data-schemas';
|
||||
import type { TokenMethods } from '@librechat/data-schemas';
|
||||
import type { TUser } from 'librechat-data-provider';
|
||||
import type { MCPOAuthTokens } from './types';
|
||||
import { OAuthReconnectionTracker } from './OAuthReconnectionTracker';
|
||||
import { FlowStateManager } from '~/flow/manager';
|
||||
import { MCPManager } from '~/mcp/MCPManager';
|
||||
|
||||
const DEFAULT_CONNECTION_TIMEOUT_MS = 10_000; // ms
|
||||
|
||||
export class OAuthReconnectionManager {
|
||||
private static instance: OAuthReconnectionManager | null = null;
|
||||
|
||||
protected readonly flowManager: FlowStateManager<MCPOAuthTokens | null>;
|
||||
protected readonly tokenMethods: TokenMethods;
|
||||
|
||||
private readonly reconnectionsTracker: OAuthReconnectionTracker;
|
||||
|
||||
public static getInstance(): OAuthReconnectionManager {
|
||||
if (!OAuthReconnectionManager.instance) {
|
||||
throw new Error('OAuthReconnectionManager not initialized');
|
||||
}
|
||||
return OAuthReconnectionManager.instance;
|
||||
}
|
||||
|
||||
public static async createInstance(
|
||||
flowManager: FlowStateManager<MCPOAuthTokens | null>,
|
||||
tokenMethods: TokenMethods,
|
||||
reconnections?: OAuthReconnectionTracker,
|
||||
): Promise<OAuthReconnectionManager> {
|
||||
if (OAuthReconnectionManager.instance != null) {
|
||||
throw new Error('OAuthReconnectionManager already initialized');
|
||||
}
|
||||
|
||||
const manager = new OAuthReconnectionManager(flowManager, tokenMethods, reconnections);
|
||||
OAuthReconnectionManager.instance = manager;
|
||||
|
||||
return manager;
|
||||
}
|
||||
|
||||
public constructor(
|
||||
flowManager: FlowStateManager<MCPOAuthTokens | null>,
|
||||
tokenMethods: TokenMethods,
|
||||
reconnections?: OAuthReconnectionTracker,
|
||||
) {
|
||||
this.flowManager = flowManager;
|
||||
this.tokenMethods = tokenMethods;
|
||||
this.reconnectionsTracker = reconnections ?? new OAuthReconnectionTracker();
|
||||
}
|
||||
|
||||
public isReconnecting(userId: string, serverName: string): boolean {
|
||||
return this.reconnectionsTracker.isActive(userId, serverName);
|
||||
}
|
||||
|
||||
public async reconnectServers(userId: string) {
|
||||
const mcpManager = MCPManager.getInstance();
|
||||
|
||||
// 1. derive the servers to reconnect
|
||||
const serversToReconnect = [];
|
||||
for (const serverName of mcpManager.getOAuthServers() ?? []) {
|
||||
const canReconnect = await this.canReconnect(userId, serverName);
|
||||
if (canReconnect) {
|
||||
serversToReconnect.push(serverName);
|
||||
}
|
||||
}
|
||||
|
||||
// 2. mark the servers as reconnecting
|
||||
for (const serverName of serversToReconnect) {
|
||||
this.reconnectionsTracker.setActive(userId, serverName);
|
||||
}
|
||||
|
||||
// 3. attempt to reconnect the servers
|
||||
for (const serverName of serversToReconnect) {
|
||||
void this.tryReconnect(userId, serverName);
|
||||
}
|
||||
}
|
||||
|
||||
public clearReconnection(userId: string, serverName: string) {
|
||||
this.reconnectionsTracker.removeFailed(userId, serverName);
|
||||
this.reconnectionsTracker.removeActive(userId, serverName);
|
||||
}
|
||||
|
||||
private async tryReconnect(userId: string, serverName: string) {
|
||||
const mcpManager = MCPManager.getInstance();
|
||||
|
||||
const logPrefix = `[tryReconnectOAuthMCPServer][User: ${userId}][${serverName}]`;
|
||||
|
||||
logger.info(`${logPrefix} Attempting reconnection`);
|
||||
|
||||
const config = mcpManager.getRawConfig(serverName);
|
||||
|
||||
const cleanupOnFailedReconnect = () => {
|
||||
this.reconnectionsTracker.setFailed(userId, serverName);
|
||||
this.reconnectionsTracker.removeActive(userId, serverName);
|
||||
mcpManager.disconnectUserConnection(userId, serverName);
|
||||
};
|
||||
|
||||
try {
|
||||
// attempt to get connection (this will use existing tokens and refresh if needed)
|
||||
const connection = await mcpManager.getUserConnection({
|
||||
serverName,
|
||||
user: { id: userId } as TUser,
|
||||
flowManager: this.flowManager,
|
||||
tokenMethods: this.tokenMethods,
|
||||
// don't force new connection, let it reuse existing or create new as needed
|
||||
forceNew: false,
|
||||
// set a reasonable timeout for reconnection attempts
|
||||
connectionTimeout: config?.initTimeout ?? DEFAULT_CONNECTION_TIMEOUT_MS,
|
||||
// don't trigger OAuth flow during reconnection
|
||||
returnOnOAuth: true,
|
||||
});
|
||||
|
||||
if (connection && (await connection.isConnected())) {
|
||||
logger.info(`${logPrefix} Successfully reconnected`);
|
||||
this.clearReconnection(userId, serverName);
|
||||
} else {
|
||||
logger.warn(`${logPrefix} Failed to reconnect`);
|
||||
await connection?.disconnect();
|
||||
cleanupOnFailedReconnect();
|
||||
}
|
||||
} catch (error) {
|
||||
logger.warn(`${logPrefix} Failed to reconnect: ${error}`);
|
||||
cleanupOnFailedReconnect();
|
||||
}
|
||||
}
|
||||
|
||||
private async canReconnect(userId: string, serverName: string) {
|
||||
const mcpManager = MCPManager.getInstance();
|
||||
|
||||
// if the server has failed reconnection, don't attempt to reconnect
|
||||
if (this.reconnectionsTracker.isFailed(userId, serverName)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// if the server is already connected, don't attempt to reconnect
|
||||
const existingConnections = mcpManager.getUserConnections(userId);
|
||||
if (existingConnections?.has(serverName)) {
|
||||
const isConnected = await existingConnections.get(serverName)?.isConnected();
|
||||
if (isConnected) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// if the server has no tokens for the user, don't attempt to reconnect
|
||||
const accessToken = await this.tokenMethods.findToken({
|
||||
userId,
|
||||
type: 'mcp_oauth',
|
||||
identifier: `mcp:${serverName}`,
|
||||
});
|
||||
if (accessToken == null) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// if the token has expired, don't attempt to reconnect
|
||||
const now = new Date();
|
||||
if (accessToken.expiresAt && accessToken.expiresAt < now) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// …otherwise, we're good to go with the reconnect attempt
|
||||
return true;
|
||||
}
|
||||
}
|
||||
181
packages/api/src/mcp/oauth/OAuthReconnectionTracker.test.ts
Normal file
181
packages/api/src/mcp/oauth/OAuthReconnectionTracker.test.ts
Normal file
@@ -0,0 +1,181 @@
|
||||
import { OAuthReconnectionTracker } from './OAuthReconnectionTracker';
|
||||
|
||||
describe('OAuthReconnectTracker', () => {
|
||||
let tracker: OAuthReconnectionTracker;
|
||||
const userId = 'user123';
|
||||
const serverName = 'test-server';
|
||||
const anotherServer = 'another-server';
|
||||
|
||||
beforeEach(() => {
|
||||
tracker = new OAuthReconnectionTracker();
|
||||
});
|
||||
|
||||
describe('setFailed', () => {
|
||||
it('should record a failed reconnection attempt', () => {
|
||||
tracker.setFailed(userId, serverName);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(true);
|
||||
});
|
||||
|
||||
it('should track multiple servers for the same user', () => {
|
||||
tracker.setFailed(userId, serverName);
|
||||
tracker.setFailed(userId, anotherServer);
|
||||
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(true);
|
||||
expect(tracker.isFailed(userId, anotherServer)).toBe(true);
|
||||
});
|
||||
|
||||
it('should track different users independently', () => {
|
||||
const anotherUserId = 'user456';
|
||||
|
||||
tracker.setFailed(userId, serverName);
|
||||
tracker.setFailed(anotherUserId, serverName);
|
||||
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(true);
|
||||
expect(tracker.isFailed(anotherUserId, serverName)).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isFailed', () => {
|
||||
it('should return false when no failed attempt is recorded', () => {
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(false);
|
||||
});
|
||||
|
||||
it('should return true after a failed attempt is recorded', () => {
|
||||
tracker.setFailed(userId, serverName);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(true);
|
||||
});
|
||||
|
||||
it('should return false for a different server even after another server failed', () => {
|
||||
tracker.setFailed(userId, serverName);
|
||||
expect(tracker.isFailed(userId, anotherServer)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('removeFailed', () => {
|
||||
it('should clear a failed reconnect record', () => {
|
||||
tracker.setFailed(userId, serverName);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(true);
|
||||
|
||||
tracker.removeFailed(userId, serverName);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(false);
|
||||
});
|
||||
|
||||
it('should only clear the specific server for the user', () => {
|
||||
tracker.setFailed(userId, serverName);
|
||||
tracker.setFailed(userId, anotherServer);
|
||||
|
||||
tracker.removeFailed(userId, serverName);
|
||||
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(false);
|
||||
expect(tracker.isFailed(userId, anotherServer)).toBe(true);
|
||||
});
|
||||
|
||||
it('should handle clearing non-existent records gracefully', () => {
|
||||
expect(() => tracker.removeFailed(userId, serverName)).not.toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
describe('setActive', () => {
|
||||
it('should mark a server as reconnecting', () => {
|
||||
tracker.setActive(userId, serverName);
|
||||
expect(tracker.isActive(userId, serverName)).toBe(true);
|
||||
});
|
||||
|
||||
it('should track multiple reconnecting servers', () => {
|
||||
tracker.setActive(userId, serverName);
|
||||
tracker.setActive(userId, anotherServer);
|
||||
|
||||
expect(tracker.isActive(userId, serverName)).toBe(true);
|
||||
expect(tracker.isActive(userId, anotherServer)).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isActive', () => {
|
||||
it('should return false when server is not reconnecting', () => {
|
||||
expect(tracker.isActive(userId, serverName)).toBe(false);
|
||||
});
|
||||
|
||||
it('should return true when server is marked as reconnecting', () => {
|
||||
tracker.setActive(userId, serverName);
|
||||
expect(tracker.isActive(userId, serverName)).toBe(true);
|
||||
});
|
||||
|
||||
it('should handle non-existent user gracefully', () => {
|
||||
expect(tracker.isActive('non-existent-user', serverName)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('removeActive', () => {
|
||||
it('should clear reconnecting state for a server', () => {
|
||||
tracker.setActive(userId, serverName);
|
||||
expect(tracker.isActive(userId, serverName)).toBe(true);
|
||||
|
||||
tracker.removeActive(userId, serverName);
|
||||
expect(tracker.isActive(userId, serverName)).toBe(false);
|
||||
});
|
||||
|
||||
it('should only clear specific server state', () => {
|
||||
tracker.setActive(userId, serverName);
|
||||
tracker.setActive(userId, anotherServer);
|
||||
|
||||
tracker.removeActive(userId, serverName);
|
||||
|
||||
expect(tracker.isActive(userId, serverName)).toBe(false);
|
||||
expect(tracker.isActive(userId, anotherServer)).toBe(true);
|
||||
});
|
||||
|
||||
it('should handle clearing non-existent state gracefully', () => {
|
||||
expect(() => tracker.removeActive(userId, serverName)).not.toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
describe('cleanup behavior', () => {
|
||||
it('should clean up empty user sets for failed reconnects', () => {
|
||||
tracker.setFailed(userId, serverName);
|
||||
tracker.removeFailed(userId, serverName);
|
||||
|
||||
// Record and clear another user to ensure internal cleanup
|
||||
const anotherUserId = 'user456';
|
||||
tracker.setFailed(anotherUserId, serverName);
|
||||
|
||||
// Original user should still be able to reconnect
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(false);
|
||||
});
|
||||
|
||||
it('should clean up empty user sets for active reconnections', () => {
|
||||
tracker.setActive(userId, serverName);
|
||||
tracker.removeActive(userId, serverName);
|
||||
|
||||
// Mark another user to ensure internal cleanup
|
||||
const anotherUserId = 'user456';
|
||||
tracker.setActive(anotherUserId, serverName);
|
||||
|
||||
// Original user should not be reconnecting
|
||||
expect(tracker.isActive(userId, serverName)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('combined state management', () => {
|
||||
it('should handle both failed and reconnecting states independently', () => {
|
||||
// Mark as reconnecting
|
||||
tracker.setActive(userId, serverName);
|
||||
expect(tracker.isActive(userId, serverName)).toBe(true);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(false);
|
||||
|
||||
// Record failed attempt
|
||||
tracker.setFailed(userId, serverName);
|
||||
expect(tracker.isActive(userId, serverName)).toBe(true);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(true);
|
||||
|
||||
// Clear reconnecting state
|
||||
tracker.removeActive(userId, serverName);
|
||||
expect(tracker.isActive(userId, serverName)).toBe(false);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(true);
|
||||
|
||||
// Clear failed state
|
||||
tracker.removeFailed(userId, serverName);
|
||||
expect(tracker.isActive(userId, serverName)).toBe(false);
|
||||
expect(tracker.isFailed(userId, serverName)).toBe(false);
|
||||
});
|
||||
});
|
||||
});
|
||||
46
packages/api/src/mcp/oauth/OAuthReconnectionTracker.ts
Normal file
46
packages/api/src/mcp/oauth/OAuthReconnectionTracker.ts
Normal file
@@ -0,0 +1,46 @@
|
||||
export class OAuthReconnectionTracker {
|
||||
// Map of userId -> Set of serverNames that have failed reconnection
|
||||
private failed: Map<string, Set<string>> = new Map();
|
||||
// Map of userId -> Set of serverNames that are actively reconnecting
|
||||
private active: Map<string, Set<string>> = new Map();
|
||||
|
||||
public isFailed(userId: string, serverName: string): boolean {
|
||||
return this.failed.get(userId)?.has(serverName) ?? false;
|
||||
}
|
||||
|
||||
public isActive(userId: string, serverName: string): boolean {
|
||||
return this.active.get(userId)?.has(serverName) ?? false;
|
||||
}
|
||||
|
||||
public setFailed(userId: string, serverName: string): void {
|
||||
if (!this.failed.has(userId)) {
|
||||
this.failed.set(userId, new Set());
|
||||
}
|
||||
|
||||
this.failed.get(userId)?.add(serverName);
|
||||
}
|
||||
|
||||
public setActive(userId: string, serverName: string): void {
|
||||
if (!this.active.has(userId)) {
|
||||
this.active.set(userId, new Set());
|
||||
}
|
||||
|
||||
this.active.get(userId)?.add(serverName);
|
||||
}
|
||||
|
||||
public removeFailed(userId: string, serverName: string): void {
|
||||
const userServers = this.failed.get(userId);
|
||||
userServers?.delete(serverName);
|
||||
if (userServers?.size === 0) {
|
||||
this.failed.delete(userId);
|
||||
}
|
||||
}
|
||||
|
||||
public removeActive(userId: string, serverName: string): void {
|
||||
const userServers = this.active.get(userId);
|
||||
userServers?.delete(serverName);
|
||||
if (userServers?.size === 0) {
|
||||
this.active.delete(userId);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -17,6 +17,7 @@ import type {
|
||||
MCPOAuthTokens,
|
||||
OAuthMetadata,
|
||||
} from './types';
|
||||
import { sanitizeUrlForLogging } from '~/mcp/utils';
|
||||
|
||||
/** Type for the OAuth metadata from the SDK */
|
||||
type SDKOAuthMetadata = Parameters<typeof registerClient>[1]['metadata'];
|
||||
@@ -33,7 +34,9 @@ export class MCPOAuthHandler {
|
||||
resourceMetadata?: OAuthProtectedResourceMetadata;
|
||||
authServerUrl: URL;
|
||||
}> {
|
||||
logger.debug(`[MCPOAuth] discoverMetadata called with serverUrl: ${serverUrl}`);
|
||||
logger.debug(
|
||||
`[MCPOAuth] discoverMetadata called with serverUrl: ${sanitizeUrlForLogging(serverUrl)}`,
|
||||
);
|
||||
|
||||
let authServerUrl = new URL(serverUrl);
|
||||
let resourceMetadata: OAuthProtectedResourceMetadata | undefined;
|
||||
@@ -60,11 +63,15 @@ export class MCPOAuthHandler {
|
||||
}
|
||||
|
||||
// Discover OAuth metadata
|
||||
logger.debug(`[MCPOAuth] Discovering OAuth metadata from ${authServerUrl}`);
|
||||
logger.debug(
|
||||
`[MCPOAuth] Discovering OAuth metadata from ${sanitizeUrlForLogging(authServerUrl)}`,
|
||||
);
|
||||
const rawMetadata = await discoverAuthorizationServerMetadata(authServerUrl);
|
||||
|
||||
if (!rawMetadata) {
|
||||
logger.error(`[MCPOAuth] Failed to discover OAuth metadata from ${authServerUrl}`);
|
||||
logger.error(
|
||||
`[MCPOAuth] Failed to discover OAuth metadata from ${sanitizeUrlForLogging(authServerUrl)}`,
|
||||
);
|
||||
throw new Error('Failed to discover OAuth metadata');
|
||||
}
|
||||
|
||||
@@ -88,12 +95,15 @@ export class MCPOAuthHandler {
|
||||
resourceMetadata?: OAuthProtectedResourceMetadata,
|
||||
redirectUri?: string,
|
||||
): Promise<OAuthClientInformation> {
|
||||
logger.debug(`[MCPOAuth] Starting client registration for ${serverUrl}, server metadata:`, {
|
||||
grant_types_supported: metadata.grant_types_supported,
|
||||
response_types_supported: metadata.response_types_supported,
|
||||
token_endpoint_auth_methods_supported: metadata.token_endpoint_auth_methods_supported,
|
||||
scopes_supported: metadata.scopes_supported,
|
||||
});
|
||||
logger.debug(
|
||||
`[MCPOAuth] Starting client registration for ${sanitizeUrlForLogging(serverUrl)}, server metadata:`,
|
||||
{
|
||||
grant_types_supported: metadata.grant_types_supported,
|
||||
response_types_supported: metadata.response_types_supported,
|
||||
token_endpoint_auth_methods_supported: metadata.token_endpoint_auth_methods_supported,
|
||||
scopes_supported: metadata.scopes_supported,
|
||||
},
|
||||
);
|
||||
|
||||
/** Client metadata based on what the server supports */
|
||||
const clientMetadata = {
|
||||
@@ -114,7 +124,9 @@ export class MCPOAuthHandler {
|
||||
`[MCPOAuth] Server ${serverUrl} supports \`refresh_token\` grant type, adding to request`,
|
||||
);
|
||||
} else {
|
||||
logger.debug(`[MCPOAuth] Server ${serverUrl} does not support \`refresh_token\` grant type`);
|
||||
logger.debug(
|
||||
`[MCPOAuth] Server ${sanitizeUrlForLogging(serverUrl)} does not support \`refresh_token\` grant type`,
|
||||
);
|
||||
}
|
||||
clientMetadata.grant_types = requestedGrantTypes;
|
||||
|
||||
@@ -139,19 +151,25 @@ export class MCPOAuthHandler {
|
||||
clientMetadata.scope = availableScopes.join(' ');
|
||||
}
|
||||
|
||||
logger.debug(`[MCPOAuth] Registering client for ${serverUrl} with metadata:`, clientMetadata);
|
||||
logger.debug(
|
||||
`[MCPOAuth] Registering client for ${sanitizeUrlForLogging(serverUrl)} with metadata:`,
|
||||
clientMetadata,
|
||||
);
|
||||
|
||||
const clientInfo = await registerClient(serverUrl, {
|
||||
metadata: metadata as unknown as SDKOAuthMetadata,
|
||||
clientMetadata,
|
||||
});
|
||||
|
||||
logger.debug(`[MCPOAuth] Client registered successfully for ${serverUrl}:`, {
|
||||
client_id: clientInfo.client_id,
|
||||
has_client_secret: !!clientInfo.client_secret,
|
||||
grant_types: clientInfo.grant_types,
|
||||
scope: clientInfo.scope,
|
||||
});
|
||||
logger.debug(
|
||||
`[MCPOAuth] Client registered successfully for ${sanitizeUrlForLogging(serverUrl)}:`,
|
||||
{
|
||||
client_id: clientInfo.client_id,
|
||||
has_client_secret: !!clientInfo.client_secret,
|
||||
grant_types: clientInfo.grant_types,
|
||||
scope: clientInfo.scope,
|
||||
},
|
||||
);
|
||||
|
||||
return clientInfo;
|
||||
}
|
||||
@@ -165,7 +183,9 @@ export class MCPOAuthHandler {
|
||||
userId: string,
|
||||
config: MCPOptions['oauth'] | undefined,
|
||||
): Promise<{ authorizationUrl: string; flowId: string; flowMetadata: MCPOAuthFlowMetadata }> {
|
||||
logger.debug(`[MCPOAuth] initiateOAuthFlow called for ${serverName} with URL: ${serverUrl}`);
|
||||
logger.debug(
|
||||
`[MCPOAuth] initiateOAuthFlow called for ${serverName} with URL: ${sanitizeUrlForLogging(serverUrl)}`,
|
||||
);
|
||||
|
||||
const flowId = this.generateFlowId(userId, serverName);
|
||||
const state = this.generateState();
|
||||
@@ -226,7 +246,9 @@ export class MCPOAuthHandler {
|
||||
metadata,
|
||||
};
|
||||
|
||||
logger.debug(`[MCPOAuth] Authorization URL generated: ${authorizationUrl.toString()}`);
|
||||
logger.debug(
|
||||
`[MCPOAuth] Authorization URL generated: ${sanitizeUrlForLogging(authorizationUrl.toString())}`,
|
||||
);
|
||||
return {
|
||||
authorizationUrl: authorizationUrl.toString(),
|
||||
flowId,
|
||||
@@ -234,10 +256,14 @@ export class MCPOAuthHandler {
|
||||
};
|
||||
}
|
||||
|
||||
logger.debug(`[MCPOAuth] Starting auto-discovery of OAuth metadata from ${serverUrl}`);
|
||||
logger.debug(
|
||||
`[MCPOAuth] Starting auto-discovery of OAuth metadata from ${sanitizeUrlForLogging(serverUrl)}`,
|
||||
);
|
||||
const { metadata, resourceMetadata, authServerUrl } = await this.discoverMetadata(serverUrl);
|
||||
|
||||
logger.debug(`[MCPOAuth] OAuth metadata discovered, auth server URL: ${authServerUrl}`);
|
||||
logger.debug(
|
||||
`[MCPOAuth] OAuth metadata discovered, auth server URL: ${sanitizeUrlForLogging(authServerUrl)}`,
|
||||
);
|
||||
|
||||
/** Dynamic client registration based on the discovered metadata */
|
||||
const redirectUri = config?.redirect_uri || this.getDefaultRedirectUri(serverName);
|
||||
@@ -276,7 +302,9 @@ export class MCPOAuthHandler {
|
||||
codeVerifier = authResult.codeVerifier;
|
||||
|
||||
logger.debug(`[MCPOAuth] startAuthorization completed successfully`);
|
||||
logger.debug(`[MCPOAuth] Authorization URL: ${authorizationUrl.toString()}`);
|
||||
logger.debug(
|
||||
`[MCPOAuth] Authorization URL: ${sanitizeUrlForLogging(authorizationUrl.toString())}`,
|
||||
);
|
||||
|
||||
/** Add state parameter with flowId to the authorization URL */
|
||||
authorizationUrl.searchParams.set('state', flowId);
|
||||
@@ -515,7 +543,7 @@ export class MCPOAuthHandler {
|
||||
body.append('client_id', metadata.clientInfo.client_id);
|
||||
}
|
||||
|
||||
logger.debug(`[MCPOAuth] Refresh request to: ${tokenUrl}`, {
|
||||
logger.debug(`[MCPOAuth] Refresh request to: ${sanitizeUrlForLogging(tokenUrl)}`, {
|
||||
body: body.toString(),
|
||||
headers,
|
||||
});
|
||||
@@ -695,7 +723,9 @@ export class MCPOAuthHandler {
|
||||
}
|
||||
|
||||
// perform the revoke request
|
||||
logger.info(`[MCPOAuth] Revoking tokens for ${serverName} via ${revokeUrl.toString()}`);
|
||||
logger.info(
|
||||
`[MCPOAuth] Revoking tokens for ${serverName} via ${sanitizeUrlForLogging(revokeUrl.toString())}`,
|
||||
);
|
||||
const response = await fetch(revokeUrl, {
|
||||
method: 'POST',
|
||||
body: body.toString(),
|
||||
|
||||
@@ -31,3 +31,17 @@ export function normalizeServerName(serverName: string): string {
|
||||
|
||||
return normalized;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sanitizes a URL by removing query parameters to prevent credential leakage in logs.
|
||||
* @param url - The URL to sanitize (string or URL object)
|
||||
* @returns The sanitized URL string without query parameters
|
||||
*/
|
||||
export function sanitizeUrlForLogging(url: string | URL): string {
|
||||
try {
|
||||
const urlObj = typeof url === 'string' ? new URL(url) : url;
|
||||
return `${urlObj.protocol}//${urlObj.host}${urlObj.pathname}`;
|
||||
} catch {
|
||||
return '[invalid URL]';
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
interface ArrowIconProps {
|
||||
direction: 'up' | 'down';
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export default function ArrowIcon({ direction, className = 'inline' }: ArrowIconProps) {
|
||||
if (direction === 'up') {
|
||||
return (
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
width="1em"
|
||||
height="1em"
|
||||
fill="currentColor"
|
||||
viewBox="0 0 24 24"
|
||||
className={className}
|
||||
>
|
||||
<path
|
||||
fillRule="evenodd"
|
||||
d="M11.293 5.293a1 1 0 0 1 1.414 0l5 5a1 1 0 0 1-1.414 1.414L13 8.414V18a1 1 0 1 1-2 0V8.414l-3.293 3.293a1 1 0 0 1-1.414-1.414l5-5Z"
|
||||
clipRule="evenodd"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
width="1em"
|
||||
height="1em"
|
||||
fill="currentColor"
|
||||
viewBox="0 0 24 24"
|
||||
className={className}
|
||||
>
|
||||
<path
|
||||
fillRule="evenodd"
|
||||
d="M12.707 18.707a1 1 0 0 1-1.414 0l-5-5a1 1 0 1 1 1.414-1.414L11 15.586V6a1 1 0 1 1 2 0v9.586l3.293-3.293a1 1 0 0 1 1.414 1.414l-5 5Z"
|
||||
clipRule="evenodd"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
}
|
||||
@@ -1,5 +1,4 @@
|
||||
export { default as ArchiveIcon } from './ArchiveIcon';
|
||||
export { default as ArrowIcon } from './ArrowIcon';
|
||||
export { default as Blocks } from './Blocks';
|
||||
export { default as Plugin } from './Plugin';
|
||||
export { default as GPTIcon } from './GPTIcon';
|
||||
|
||||
@@ -66,8 +66,6 @@ export const messages = (params: q.MessagesListParams) => {
|
||||
|
||||
export const messagesArtifacts = (messageId: string) => `${messagesRoot}/artifacts/${messageId}`;
|
||||
|
||||
export const costs = () => `/api/messages/costs`;
|
||||
|
||||
const shareRoot = `${BASE_URL}/api/share`;
|
||||
export const shareMessages = (shareId: string) => `${shareRoot}/${shareId}`;
|
||||
export const getSharedLink = (conversationId: string) => `${shareRoot}/link/${conversationId}`;
|
||||
|
||||
@@ -51,7 +51,6 @@ export const excludedKeys = new Set([
|
||||
'_id',
|
||||
'tools',
|
||||
'model',
|
||||
'modelHistory',
|
||||
'files',
|
||||
'spec',
|
||||
'disableParams',
|
||||
|
||||
@@ -697,12 +697,6 @@ export function getMessagesByConvoId(conversationId: string): Promise<s.TMessage
|
||||
return request.get(endpoints.messages({ conversationId }));
|
||||
}
|
||||
|
||||
export function getModelCosts(
|
||||
modelHistory: Array<{ model: string; endpoint: string }>,
|
||||
): Promise<t.TModelCosts> {
|
||||
return request.post(endpoints.costs(), { modelHistory });
|
||||
}
|
||||
|
||||
export function getPrompt(id: string): Promise<{ prompt: t.TPrompt }> {
|
||||
return request.get(endpoints.getPrompt(id));
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ export enum QueryKeys {
|
||||
archivedConversations = 'archivedConversations',
|
||||
searchConversations = 'searchConversations',
|
||||
conversation = 'conversation',
|
||||
modelCosts = 'modelCosts',
|
||||
searchEnabled = 'searchEnabled',
|
||||
user = 'user',
|
||||
name = 'name', // user key name
|
||||
|
||||
@@ -77,23 +77,6 @@ export const useGetConversationByIdQuery = (
|
||||
);
|
||||
};
|
||||
|
||||
export const useGetModelCostsQuery = (
|
||||
modelHistory: Array<{ model: string; endpoint: string }>,
|
||||
config?: UseQueryOptions<t.TModelCosts>,
|
||||
): QueryObserverResult<t.TModelCosts> => {
|
||||
return useQuery<t.TModelCosts>(
|
||||
[QueryKeys.modelCosts, modelHistory],
|
||||
() => dataService.getModelCosts(modelHistory),
|
||||
{
|
||||
refetchOnWindowFocus: false,
|
||||
refetchOnReconnect: false,
|
||||
refetchOnMount: false,
|
||||
enabled: !!modelHistory && modelHistory.length > 0,
|
||||
...config,
|
||||
},
|
||||
);
|
||||
};
|
||||
|
||||
//This isn't ideal because its just a query and we're using mutation, but it was the only way
|
||||
//to make it work with how the Chat component is structured
|
||||
export const useGetConversationByIdMutation = (id: string): UseMutationResult<s.TConversation> => {
|
||||
|
||||
@@ -98,12 +98,6 @@ if (typeof window !== 'undefined') {
|
||||
if (originalRequest.url?.includes('/api/auth/logout') === true) {
|
||||
return Promise.reject(error);
|
||||
}
|
||||
if (originalRequest.url?.includes('/api/auth/refresh') === true) {
|
||||
// Refresh token itself failed - redirect to login
|
||||
console.log('Refresh token request failed, redirecting to login...');
|
||||
window.location.href = '/login';
|
||||
return Promise.reject(error);
|
||||
}
|
||||
|
||||
if (error.response.status === 401 && !originalRequest._retry) {
|
||||
console.warn('401 error, refreshing token');
|
||||
@@ -124,7 +118,10 @@ if (typeof window !== 'undefined') {
|
||||
isRefreshing = true;
|
||||
|
||||
try {
|
||||
const response = await refreshToken();
|
||||
const response = await refreshToken(
|
||||
// Handle edge case where we get a blank screen if the initial 401 error is from a refresh token request
|
||||
originalRequest.url?.includes('api/auth/refresh') === true ? true : false,
|
||||
);
|
||||
|
||||
const token = response?.token ?? '';
|
||||
|
||||
|
||||
@@ -518,7 +518,6 @@ export const tMessageSchema = z.object({
|
||||
overrideParentMessageId: z.string().nullable().optional(),
|
||||
bg: z.string().nullable().optional(),
|
||||
model: z.string().nullable().optional(),
|
||||
targetModel: z.string().nullable().optional(),
|
||||
title: z.string().nullable().or(z.literal('New Chat')).default('New Chat'),
|
||||
sender: z.string().optional(),
|
||||
text: z.string(),
|
||||
@@ -632,7 +631,6 @@ export const tConversationSchema = z.object({
|
||||
modelLabel: z.string().nullable().optional(),
|
||||
userLabel: z.string().optional(),
|
||||
model: z.string().nullable().optional(),
|
||||
modelHistory: z.array(z.object({ model: z.string(), endpoint: z.string() })).optional(),
|
||||
promptPrefix: z.string().nullable().optional(),
|
||||
temperature: z.number().nullable().optional(),
|
||||
topP: z.number().optional(),
|
||||
|
||||
@@ -653,15 +653,3 @@ export type TBalanceResponse = {
|
||||
lastRefill?: Date;
|
||||
refillAmount?: number;
|
||||
};
|
||||
|
||||
export type TConversationCosts = {
|
||||
totals: {
|
||||
prompt: { usd: number; tokenCount: number };
|
||||
completion: { usd: number; tokenCount: number };
|
||||
total: { usd: number; tokenCount: number };
|
||||
};
|
||||
};
|
||||
|
||||
export type TModelCosts = {
|
||||
modelCostTable: Record<string, { prompt: number; completion: number }>;
|
||||
};
|
||||
|
||||
@@ -155,14 +155,4 @@ export const conversationPreset = {
|
||||
verbosity: {
|
||||
type: String,
|
||||
},
|
||||
/** Track all unique models used in this conversation with their endpoints */
|
||||
modelHistory: {
|
||||
type: [
|
||||
{
|
||||
model: { type: String, required: true },
|
||||
endpoint: { type: String, required: true },
|
||||
},
|
||||
],
|
||||
default: [],
|
||||
},
|
||||
};
|
||||
|
||||
@@ -26,10 +26,6 @@ const messageSchema: Schema<IMessage> = new Schema(
|
||||
type: String,
|
||||
default: null,
|
||||
},
|
||||
targetModel: {
|
||||
type: String,
|
||||
default: null,
|
||||
},
|
||||
endpoint: {
|
||||
type: String,
|
||||
},
|
||||
|
||||
@@ -11,7 +11,6 @@ export interface IConversation extends Document {
|
||||
endpoint?: string;
|
||||
endpointType?: string;
|
||||
model?: string;
|
||||
modelHistory?: Array<{ model: string; endpoint: string }>;
|
||||
region?: string;
|
||||
chatGptLabel?: string;
|
||||
examples?: unknown[];
|
||||
|
||||
@@ -7,7 +7,6 @@ export interface IMessage extends Document {
|
||||
conversationId: string;
|
||||
user: string;
|
||||
model?: string;
|
||||
targetModel?: string;
|
||||
endpoint?: string;
|
||||
conversationSignature?: string;
|
||||
clientId?: string;
|
||||
|
||||
Reference in New Issue
Block a user