Compare commits
5 Commits
feat/granu
...
feat/group
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
65c81955f0 | ||
|
|
01e9b196bc | ||
|
|
d835f48307 | ||
|
|
0587a1cc7c | ||
|
|
72cd159a37 |
15
.env.example
15
.env.example
@@ -485,6 +485,21 @@ SAML_IMAGE_URL=
|
||||
# SAML_USE_AUTHN_RESPONSE_SIGNED=
|
||||
|
||||
|
||||
#===============================================#
|
||||
# Microsoft Graph API / Entra ID Integration #
|
||||
#===============================================#
|
||||
|
||||
# Enable Entra ID people search integration in permissions/sharing system
|
||||
# When enabled, the people picker will search both local database and Entra ID
|
||||
USE_ENTRA_ID_FOR_PEOPLE_SEARCH=false
|
||||
|
||||
# When enabled, entra id groups owners will be considered as members of the group
|
||||
ENTRA_ID_INCLUDE_OWNERS_AS_MEMBERS=false
|
||||
|
||||
# Microsoft Graph API scopes needed for people/group search
|
||||
# Default scopes provide access to user profiles and group memberships
|
||||
OPENID_GRAPH_SCOPES=User.Read,People.Read,GroupMember.Read.All
|
||||
|
||||
# LDAP
|
||||
LDAP_URL=
|
||||
LDAP_BIND_DN=
|
||||
|
||||
3
.vscode/launch.json
vendored
3
.vscode/launch.json
vendored
@@ -8,7 +8,8 @@
|
||||
"skipFiles": ["<node_internals>/**"],
|
||||
"program": "${workspaceFolder}/api/server/index.js",
|
||||
"env": {
|
||||
"NODE_ENV": "production"
|
||||
"NODE_ENV": "production",
|
||||
"NODE_TLS_REJECT_UNAUTHORIZED": "0"
|
||||
},
|
||||
"console": "integratedTerminal",
|
||||
"envFile": "${workspaceFolder}/.env"
|
||||
|
||||
@@ -792,7 +792,8 @@ class BaseClient {
|
||||
|
||||
userMessage.tokenCount = userMessageTokenCount;
|
||||
/*
|
||||
Note: `AskController` saves the user message, so we update the count of its `userMessage` reference
|
||||
Note: `AgentController` saves the user message if not saved here
|
||||
(noted by `savedMessageIds`), so we update the count of its `userMessage` reference
|
||||
*/
|
||||
if (typeof opts?.getReqData === 'function') {
|
||||
opts.getReqData({
|
||||
@@ -801,7 +802,8 @@ class BaseClient {
|
||||
}
|
||||
/*
|
||||
Note: we update the user message to be sure it gets the calculated token count;
|
||||
though `AskController` saves the user message, EditController does not
|
||||
though `AgentController` saves the user message if not saved here
|
||||
(noted by `savedMessageIds`), EditController does not
|
||||
*/
|
||||
await userMessagePromise;
|
||||
await this.updateMessageInDatabase({
|
||||
|
||||
@@ -96,35 +96,35 @@ function createContextHandlers(req, userMessageContent) {
|
||||
resolvedQueries.length === 0
|
||||
? '\n\tThe semantic search did not return any results.'
|
||||
: resolvedQueries
|
||||
.map((queryResult, index) => {
|
||||
const file = processedFiles[index];
|
||||
let contextItems = queryResult.data;
|
||||
.map((queryResult, index) => {
|
||||
const file = processedFiles[index];
|
||||
let contextItems = queryResult.data;
|
||||
|
||||
const generateContext = (currentContext) =>
|
||||
`
|
||||
const generateContext = (currentContext) =>
|
||||
`
|
||||
<file>
|
||||
<filename>${file.filename}</filename>
|
||||
<context>${currentContext}
|
||||
</context>
|
||||
</file>`;
|
||||
|
||||
if (useFullContext) {
|
||||
return generateContext(`\n${contextItems}`);
|
||||
}
|
||||
if (useFullContext) {
|
||||
return generateContext(`\n${contextItems}`);
|
||||
}
|
||||
|
||||
contextItems = queryResult.data
|
||||
.map((item) => {
|
||||
const pageContent = item[0].page_content;
|
||||
return `
|
||||
contextItems = queryResult.data
|
||||
.map((item) => {
|
||||
const pageContent = item[0].page_content;
|
||||
return `
|
||||
<contextItem>
|
||||
<![CDATA[${pageContent?.trim()}]]>
|
||||
</contextItem>`;
|
||||
})
|
||||
.join('');
|
||||
})
|
||||
.join('');
|
||||
|
||||
return generateContext(contextItems);
|
||||
})
|
||||
.join('');
|
||||
return generateContext(contextItems);
|
||||
})
|
||||
.join('');
|
||||
|
||||
if (useFullContext) {
|
||||
const prompt = `${header}
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
const mongoose = require('mongoose');
|
||||
const { MeiliSearch } = require('meilisearch');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { FlowStateManager } = require('@librechat/api');
|
||||
const { CacheKeys } = require('librechat-data-provider');
|
||||
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
const Conversation = mongoose.models.Conversation;
|
||||
const Message = mongoose.models.Message;
|
||||
@@ -28,43 +31,123 @@ class MeiliSearchClient {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs the actual sync operations for messages and conversations
|
||||
*/
|
||||
async function performSync() {
|
||||
const client = MeiliSearchClient.getInstance();
|
||||
|
||||
const { status } = await client.health();
|
||||
if (status !== 'available') {
|
||||
throw new Error('Meilisearch not available');
|
||||
}
|
||||
|
||||
if (indexingDisabled === true) {
|
||||
logger.info('[indexSync] Indexing is disabled, skipping...');
|
||||
return { messagesSync: false, convosSync: false };
|
||||
}
|
||||
|
||||
let messagesSync = false;
|
||||
let convosSync = false;
|
||||
|
||||
// Check if we need to sync messages
|
||||
const messageProgress = await Message.getSyncProgress();
|
||||
if (!messageProgress.isComplete) {
|
||||
logger.info(
|
||||
`[indexSync] Messages need syncing: ${messageProgress.totalProcessed}/${messageProgress.totalDocuments} indexed`,
|
||||
);
|
||||
|
||||
// Check if we should do a full sync or incremental
|
||||
const messageCount = await Message.countDocuments();
|
||||
const messagesIndexed = messageProgress.totalProcessed;
|
||||
const syncThreshold = parseInt(process.env.MEILI_SYNC_THRESHOLD || '1000', 10);
|
||||
|
||||
if (messageCount - messagesIndexed > syncThreshold) {
|
||||
logger.info('[indexSync] Starting full message sync due to large difference');
|
||||
await Message.syncWithMeili();
|
||||
messagesSync = true;
|
||||
} else if (messageCount !== messagesIndexed) {
|
||||
logger.warn('[indexSync] Messages out of sync, performing incremental sync');
|
||||
await Message.syncWithMeili();
|
||||
messagesSync = true;
|
||||
}
|
||||
} else {
|
||||
logger.info(
|
||||
`[indexSync] Messages are fully synced: ${messageProgress.totalProcessed}/${messageProgress.totalDocuments}`,
|
||||
);
|
||||
}
|
||||
|
||||
// Check if we need to sync conversations
|
||||
const convoProgress = await Conversation.getSyncProgress();
|
||||
if (!convoProgress.isComplete) {
|
||||
logger.info(
|
||||
`[indexSync] Conversations need syncing: ${convoProgress.totalProcessed}/${convoProgress.totalDocuments} indexed`,
|
||||
);
|
||||
|
||||
const convoCount = await Conversation.countDocuments();
|
||||
const convosIndexed = convoProgress.totalProcessed;
|
||||
const syncThreshold = parseInt(process.env.MEILI_SYNC_THRESHOLD || '1000', 10);
|
||||
|
||||
if (convoCount - convosIndexed > syncThreshold) {
|
||||
logger.info('[indexSync] Starting full conversation sync due to large difference');
|
||||
await Conversation.syncWithMeili();
|
||||
convosSync = true;
|
||||
} else if (convoCount !== convosIndexed) {
|
||||
logger.warn('[indexSync] Convos out of sync, performing incremental sync');
|
||||
await Conversation.syncWithMeili();
|
||||
convosSync = true;
|
||||
}
|
||||
} else {
|
||||
logger.info(
|
||||
`[indexSync] Conversations are fully synced: ${convoProgress.totalProcessed}/${convoProgress.totalDocuments}`,
|
||||
);
|
||||
}
|
||||
|
||||
return { messagesSync, convosSync };
|
||||
}
|
||||
|
||||
/**
|
||||
* Main index sync function that uses FlowStateManager to prevent concurrent execution
|
||||
*/
|
||||
async function indexSync() {
|
||||
if (!searchEnabled) {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
const client = MeiliSearchClient.getInstance();
|
||||
|
||||
const { status } = await client.health();
|
||||
if (status !== 'available') {
|
||||
throw new Error('Meilisearch not available');
|
||||
logger.info('[indexSync] Starting index synchronization check...');
|
||||
|
||||
try {
|
||||
// Get or create FlowStateManager instance
|
||||
const flowsCache = getLogStores(CacheKeys.FLOWS);
|
||||
if (!flowsCache) {
|
||||
logger.warn('[indexSync] Flows cache not available, falling back to direct sync');
|
||||
return await performSync();
|
||||
}
|
||||
|
||||
if (indexingDisabled === true) {
|
||||
logger.info('[indexSync] Indexing is disabled, skipping...');
|
||||
const flowManager = new FlowStateManager(flowsCache, {
|
||||
ttl: 60000 * 10, // 10 minutes TTL for sync operations
|
||||
});
|
||||
|
||||
// Use a unique flow ID for the sync operation
|
||||
const flowId = 'meili-index-sync';
|
||||
const flowType = 'MEILI_SYNC';
|
||||
|
||||
// This will only execute the handler if no other instance is running the sync
|
||||
const result = await flowManager.createFlowWithHandler(flowId, flowType, performSync);
|
||||
|
||||
if (result.messagesSync || result.convosSync) {
|
||||
logger.info('[indexSync] Sync completed successfully');
|
||||
} else {
|
||||
logger.debug('[indexSync] No sync was needed');
|
||||
}
|
||||
|
||||
return result;
|
||||
} catch (err) {
|
||||
if (err.message.includes('flow already exists')) {
|
||||
logger.info('[indexSync] Sync already running on another instance');
|
||||
return;
|
||||
}
|
||||
|
||||
const messageCount = await Message.countDocuments();
|
||||
const convoCount = await Conversation.countDocuments();
|
||||
const messages = await client.index('messages').getStats();
|
||||
const convos = await client.index('convos').getStats();
|
||||
const messagesIndexed = messages.numberOfDocuments;
|
||||
const convosIndexed = convos.numberOfDocuments;
|
||||
|
||||
logger.debug(`[indexSync] There are ${messageCount} messages and ${messagesIndexed} indexed`);
|
||||
logger.debug(`[indexSync] There are ${convoCount} convos and ${convosIndexed} indexed`);
|
||||
|
||||
if (messageCount !== messagesIndexed) {
|
||||
logger.debug('[indexSync] Messages out of sync, indexing');
|
||||
Message.syncWithMeili();
|
||||
}
|
||||
|
||||
if (convoCount !== convosIndexed) {
|
||||
logger.debug('[indexSync] Convos out of sync, indexing');
|
||||
Conversation.syncWithMeili();
|
||||
}
|
||||
} catch (err) {
|
||||
if (err.message.includes('not found')) {
|
||||
logger.debug('[indexSync] Creating indices...');
|
||||
currentTimeout = setTimeout(async () => {
|
||||
|
||||
@@ -4,7 +4,6 @@ const { logger } = require('@librechat/data-schemas');
|
||||
const { SystemRoles, Tools, actionDelimiter } = require('librechat-data-provider');
|
||||
const { GLOBAL_PROJECT_NAME, EPHEMERAL_AGENT_ID, mcp_delimiter } =
|
||||
require('librechat-data-provider').Constants;
|
||||
const { CONFIG_STORE, STARTUP_CONFIG } = require('librechat-data-provider').CacheKeys;
|
||||
const {
|
||||
getProjectByName,
|
||||
addAgentIdsToProject,
|
||||
@@ -12,7 +11,6 @@ const {
|
||||
removeAgentFromAllProjects,
|
||||
} = require('./Project');
|
||||
const { getCachedTools } = require('~/server/services/Config');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { getActions } = require('./Action');
|
||||
const { Agent } = require('~/db/models');
|
||||
|
||||
@@ -23,7 +21,7 @@ const { Agent } = require('~/db/models');
|
||||
* @throws {Error} If the agent creation fails.
|
||||
*/
|
||||
const createAgent = async (agentData) => {
|
||||
const { author, ...versionData } = agentData;
|
||||
const { author: _author, ...versionData } = agentData;
|
||||
const timestamp = new Date();
|
||||
const initialAgentData = {
|
||||
...agentData,
|
||||
@@ -70,6 +68,9 @@ const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _
|
||||
if (ephemeralAgent?.execute_code === true) {
|
||||
tools.push(Tools.execute_code);
|
||||
}
|
||||
if (ephemeralAgent?.file_search === true) {
|
||||
tools.push(Tools.file_search);
|
||||
}
|
||||
if (ephemeralAgent?.web_search === true) {
|
||||
tools.push(Tools.web_search);
|
||||
}
|
||||
@@ -123,29 +124,7 @@ const loadAgent = async ({ req, agent_id, endpoint, model_parameters }) => {
|
||||
}
|
||||
|
||||
agent.version = agent.versions ? agent.versions.length : 0;
|
||||
|
||||
if (agent.author.toString() === req.user.id) {
|
||||
return agent;
|
||||
}
|
||||
|
||||
if (!agent.projectIds) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const cache = getLogStores(CONFIG_STORE);
|
||||
/** @type {TStartupConfig} */
|
||||
const cachedStartupConfig = await cache.get(STARTUP_CONFIG);
|
||||
let { instanceProjectId } = cachedStartupConfig ?? {};
|
||||
if (!instanceProjectId) {
|
||||
instanceProjectId = (await getProjectByName(GLOBAL_PROJECT_NAME, '_id'))._id.toString();
|
||||
}
|
||||
|
||||
for (const projectObjectId of agent.projectIds) {
|
||||
const projectId = projectObjectId.toString();
|
||||
if (projectId === instanceProjectId) {
|
||||
return agent;
|
||||
}
|
||||
}
|
||||
return agent;
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -175,7 +154,7 @@ const isDuplicateVersion = (updateData, currentData, versions, actionsHash = nul
|
||||
'actionsHash', // Exclude actionsHash from direct comparison
|
||||
];
|
||||
|
||||
const { $push, $pull, $addToSet, ...directUpdates } = updateData;
|
||||
const { $push: _$push, $pull: _$pull, $addToSet: _$addToSet, ...directUpdates } = updateData;
|
||||
|
||||
if (Object.keys(directUpdates).length === 0 && !actionsHash) {
|
||||
return null;
|
||||
@@ -270,7 +249,14 @@ const updateAgent = async (searchParameter, updateData, options = {}) => {
|
||||
|
||||
const currentAgent = await Agent.findOne(searchParameter);
|
||||
if (currentAgent) {
|
||||
const { __v, _id, id, versions, author, ...versionData } = currentAgent.toObject();
|
||||
const {
|
||||
__v,
|
||||
_id,
|
||||
id: __id,
|
||||
versions,
|
||||
author: _author,
|
||||
...versionData
|
||||
} = currentAgent.toObject();
|
||||
const { $push, $pull, $addToSet, ...directUpdates } = updateData;
|
||||
|
||||
let actionsHash = null;
|
||||
@@ -461,8 +447,110 @@ const deleteAgent = async (searchParameter) => {
|
||||
return agent;
|
||||
};
|
||||
|
||||
/**
|
||||
* Get agents by accessible IDs with optional cursor-based pagination.
|
||||
* @param {Object} params - The parameters for getting accessible agents.
|
||||
* @param {Array} [params.accessibleIds] - Array of agent ObjectIds the user has ACL access to.
|
||||
* @param {Object} [params.otherParams] - Additional query parameters (including author filter).
|
||||
* @param {number} [params.limit] - Number of agents to return (max 100). If not provided, returns all agents.
|
||||
* @param {string} [params.after] - Cursor for pagination - get agents after this cursor. // base64 encoded JSON string with updatedAt and _id.
|
||||
* @returns {Promise<Object>} A promise that resolves to an object containing the agents data and pagination info.
|
||||
*/
|
||||
const getListAgentsByAccess = async ({
|
||||
accessibleIds = [],
|
||||
otherParams = {},
|
||||
limit = null,
|
||||
after = null,
|
||||
}) => {
|
||||
const isPaginated = limit !== null && limit !== undefined;
|
||||
const normalizedLimit = isPaginated ? Math.min(Math.max(1, parseInt(limit) || 20), 100) : null;
|
||||
|
||||
// Build base query combining ACL accessible agents with other filters
|
||||
const baseQuery = { ...otherParams };
|
||||
|
||||
if (accessibleIds.length > 0) {
|
||||
baseQuery._id = { $in: accessibleIds };
|
||||
}
|
||||
|
||||
// Add cursor condition
|
||||
if (after) {
|
||||
try {
|
||||
const cursor = JSON.parse(Buffer.from(after, 'base64').toString('utf8'));
|
||||
const { updatedAt, _id } = cursor;
|
||||
|
||||
const cursorCondition = {
|
||||
$or: [
|
||||
{ updatedAt: { $lt: new Date(updatedAt) } },
|
||||
{ updatedAt: new Date(updatedAt), _id: { $gt: mongoose.Types.ObjectId(_id) } },
|
||||
],
|
||||
};
|
||||
|
||||
// Merge cursor condition with base query
|
||||
if (Object.keys(baseQuery).length > 0) {
|
||||
baseQuery.$and = [{ ...baseQuery }, cursorCondition];
|
||||
// Remove the original conditions from baseQuery to avoid duplication
|
||||
Object.keys(baseQuery).forEach((key) => {
|
||||
if (key !== '$and') delete baseQuery[key];
|
||||
});
|
||||
} else {
|
||||
Object.assign(baseQuery, cursorCondition);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.warn('Invalid cursor:', error.message);
|
||||
}
|
||||
}
|
||||
|
||||
let query = Agent.find(baseQuery, {
|
||||
id: 1,
|
||||
_id: 1,
|
||||
name: 1,
|
||||
avatar: 1,
|
||||
author: 1,
|
||||
projectIds: 1,
|
||||
description: 1,
|
||||
updatedAt: 1,
|
||||
}).sort({ updatedAt: -1, _id: 1 });
|
||||
|
||||
// Only apply limit if pagination is requested
|
||||
if (isPaginated) {
|
||||
query = query.limit(normalizedLimit + 1);
|
||||
}
|
||||
|
||||
const agents = await query.lean();
|
||||
|
||||
const hasMore = isPaginated ? agents.length > normalizedLimit : false;
|
||||
const data = (isPaginated ? agents.slice(0, normalizedLimit) : agents).map((agent) => {
|
||||
if (agent.author) {
|
||||
agent.author = agent.author.toString();
|
||||
}
|
||||
return agent;
|
||||
});
|
||||
|
||||
// Generate next cursor only if paginated
|
||||
let nextCursor = null;
|
||||
if (isPaginated && hasMore && data.length > 0) {
|
||||
const lastAgent = agents[normalizedLimit - 1];
|
||||
nextCursor = Buffer.from(
|
||||
JSON.stringify({
|
||||
updatedAt: lastAgent.updatedAt.toISOString(),
|
||||
_id: lastAgent._id.toString(),
|
||||
}),
|
||||
).toString('base64');
|
||||
}
|
||||
|
||||
return {
|
||||
object: 'list',
|
||||
data,
|
||||
first_id: data.length > 0 ? data[0].id : null,
|
||||
last_id: data.length > 0 ? data[data.length - 1].id : null,
|
||||
has_more: hasMore,
|
||||
after: nextCursor,
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* Get all agents.
|
||||
* @deprecated Use getListAgentsByAccess for ACL-aware agent listing
|
||||
* @param {Object} searchParameter - The search parameters to find matching agents.
|
||||
* @param {string} searchParameter.author - The user ID of the agent's author.
|
||||
* @returns {Promise<Object>} A promise that resolves to an object containing the agents data and pagination info.
|
||||
@@ -481,12 +569,13 @@ const getListAgents = async (searchParameter) => {
|
||||
const agents = (
|
||||
await Agent.find(query, {
|
||||
id: 1,
|
||||
_id: 0,
|
||||
_id: 1,
|
||||
name: 1,
|
||||
avatar: 1,
|
||||
author: 1,
|
||||
projectIds: 1,
|
||||
description: 1,
|
||||
// @deprecated - isCollaborative replaced by ACL permissions
|
||||
isCollaborative: 1,
|
||||
}).lean()
|
||||
).map((agent) => {
|
||||
@@ -670,6 +759,7 @@ module.exports = {
|
||||
revertAgentVersion,
|
||||
updateAgentProjects,
|
||||
addAgentResourceFile,
|
||||
getListAgentsByAccess,
|
||||
removeAgentResourceFiles,
|
||||
generateActionMetadataHash,
|
||||
};
|
||||
|
||||
@@ -43,7 +43,7 @@ describe('models/Agent', () => {
|
||||
const mongoUri = mongoServer.getUri();
|
||||
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
|
||||
await mongoose.connect(mongoUri);
|
||||
});
|
||||
}, 20000);
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
@@ -413,7 +413,7 @@ describe('models/Agent', () => {
|
||||
const mongoUri = mongoServer.getUri();
|
||||
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
|
||||
await mongoose.connect(mongoUri);
|
||||
});
|
||||
}, 20000);
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
@@ -670,7 +670,7 @@ describe('models/Agent', () => {
|
||||
const mongoUri = mongoServer.getUri();
|
||||
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
|
||||
await mongoose.connect(mongoUri);
|
||||
});
|
||||
}, 20000);
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
@@ -1332,7 +1332,7 @@ describe('models/Agent', () => {
|
||||
const mongoUri = mongoServer.getUri();
|
||||
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
|
||||
await mongoose.connect(mongoUri);
|
||||
});
|
||||
}, 20000);
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
@@ -1514,7 +1514,7 @@ describe('models/Agent', () => {
|
||||
const mongoUri = mongoServer.getUri();
|
||||
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
|
||||
await mongoose.connect(mongoUri);
|
||||
});
|
||||
}, 20000);
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
@@ -1633,7 +1633,7 @@ describe('models/Agent', () => {
|
||||
expect(result.version).toBe(1);
|
||||
});
|
||||
|
||||
test('should return null when user is not author and agent has no projectIds', async () => {
|
||||
test('should return agent even when user is not author (permissions checked at route level)', async () => {
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const agentId = `agent_${uuidv4()}`;
|
||||
@@ -1654,7 +1654,11 @@ describe('models/Agent', () => {
|
||||
model_parameters: { model: 'gpt-4' },
|
||||
});
|
||||
|
||||
expect(result).toBeFalsy();
|
||||
// With the new permission system, loadAgent returns the agent regardless of permissions
|
||||
// Permission checks are handled at the route level via middleware
|
||||
expect(result).toBeTruthy();
|
||||
expect(result.id).toBe(agentId);
|
||||
expect(result.name).toBe('Test Agent');
|
||||
});
|
||||
|
||||
test('should handle ephemeral agent with no MCP servers', async () => {
|
||||
@@ -1762,7 +1766,7 @@ describe('models/Agent', () => {
|
||||
}
|
||||
});
|
||||
|
||||
test('should handle loadAgent with agent from different project', async () => {
|
||||
test('should return agent from different project (permissions checked at route level)', async () => {
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const agentId = `agent_${uuidv4()}`;
|
||||
@@ -1785,7 +1789,11 @@ describe('models/Agent', () => {
|
||||
model_parameters: { model: 'gpt-4' },
|
||||
});
|
||||
|
||||
expect(result).toBeFalsy();
|
||||
// With the new permission system, loadAgent returns the agent regardless of permissions
|
||||
// Permission checks are handled at the route level via middleware
|
||||
expect(result).toBeTruthy();
|
||||
expect(result.id).toBe(agentId);
|
||||
expect(result.name).toBe('Project Agent');
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1798,7 +1806,7 @@ describe('models/Agent', () => {
|
||||
const mongoUri = mongoServer.getUri();
|
||||
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
|
||||
await mongoose.connect(mongoUri);
|
||||
});
|
||||
}, 20000);
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
@@ -2350,7 +2358,7 @@ describe('models/Agent', () => {
|
||||
const mongoUri = mongoServer.getUri();
|
||||
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
|
||||
await mongoose.connect(mongoUri);
|
||||
});
|
||||
}, 20000);
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
|
||||
@@ -52,6 +52,7 @@
|
||||
"@librechat/api": "*",
|
||||
"@librechat/data-schemas": "*",
|
||||
"@node-saml/passport-saml": "^5.0.0",
|
||||
"@microsoft/microsoft-graph-client": "^3.0.7",
|
||||
"@waylaidwanderer/fetch-event-source": "^3.0.1",
|
||||
"axios": "^1.8.2",
|
||||
"bcryptjs": "^2.4.3",
|
||||
|
||||
@@ -1,282 +0,0 @@
|
||||
const { getResponseSender, Constants } = require('librechat-data-provider');
|
||||
const {
|
||||
handleAbortError,
|
||||
createAbortController,
|
||||
cleanupAbortController,
|
||||
} = require('~/server/middleware');
|
||||
const {
|
||||
disposeClient,
|
||||
processReqData,
|
||||
clientRegistry,
|
||||
requestDataMap,
|
||||
} = require('~/server/cleanup');
|
||||
const { sendMessage, createOnProgress } = require('~/server/utils');
|
||||
const { saveMessage } = require('~/models');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const AskController = async (req, res, next, initializeClient, addTitle) => {
|
||||
let {
|
||||
text,
|
||||
endpointOption,
|
||||
conversationId,
|
||||
modelDisplayLabel,
|
||||
parentMessageId = null,
|
||||
overrideParentMessageId = null,
|
||||
} = req.body;
|
||||
|
||||
let client = null;
|
||||
let abortKey = null;
|
||||
let cleanupHandlers = [];
|
||||
let clientRef = null;
|
||||
|
||||
logger.debug('[AskController]', {
|
||||
text,
|
||||
conversationId,
|
||||
...endpointOption,
|
||||
modelsConfig: endpointOption?.modelsConfig ? 'exists' : '',
|
||||
});
|
||||
|
||||
let userMessage = null;
|
||||
let userMessagePromise = null;
|
||||
let promptTokens = null;
|
||||
let userMessageId = null;
|
||||
let responseMessageId = null;
|
||||
let getAbortData = null;
|
||||
|
||||
const sender = getResponseSender({
|
||||
...endpointOption,
|
||||
model: endpointOption.modelOptions.model,
|
||||
modelDisplayLabel,
|
||||
});
|
||||
const initialConversationId = conversationId;
|
||||
const newConvo = !initialConversationId;
|
||||
const userId = req.user.id;
|
||||
|
||||
let reqDataContext = {
|
||||
userMessage,
|
||||
userMessagePromise,
|
||||
responseMessageId,
|
||||
promptTokens,
|
||||
conversationId,
|
||||
userMessageId,
|
||||
};
|
||||
|
||||
const updateReqData = (data = {}) => {
|
||||
reqDataContext = processReqData(data, reqDataContext);
|
||||
abortKey = reqDataContext.abortKey;
|
||||
userMessage = reqDataContext.userMessage;
|
||||
userMessagePromise = reqDataContext.userMessagePromise;
|
||||
responseMessageId = reqDataContext.responseMessageId;
|
||||
promptTokens = reqDataContext.promptTokens;
|
||||
conversationId = reqDataContext.conversationId;
|
||||
userMessageId = reqDataContext.userMessageId;
|
||||
};
|
||||
|
||||
let { onProgress: progressCallback, getPartialText } = createOnProgress();
|
||||
|
||||
const performCleanup = () => {
|
||||
logger.debug('[AskController] Performing cleanup');
|
||||
if (Array.isArray(cleanupHandlers)) {
|
||||
for (const handler of cleanupHandlers) {
|
||||
try {
|
||||
if (typeof handler === 'function') {
|
||||
handler();
|
||||
}
|
||||
} catch (e) {
|
||||
// Ignore
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (abortKey) {
|
||||
logger.debug('[AskController] Cleaning up abort controller');
|
||||
cleanupAbortController(abortKey);
|
||||
abortKey = null;
|
||||
}
|
||||
|
||||
if (client) {
|
||||
disposeClient(client);
|
||||
client = null;
|
||||
}
|
||||
|
||||
reqDataContext = null;
|
||||
userMessage = null;
|
||||
userMessagePromise = null;
|
||||
promptTokens = null;
|
||||
getAbortData = null;
|
||||
progressCallback = null;
|
||||
endpointOption = null;
|
||||
cleanupHandlers = null;
|
||||
addTitle = null;
|
||||
|
||||
if (requestDataMap.has(req)) {
|
||||
requestDataMap.delete(req);
|
||||
}
|
||||
logger.debug('[AskController] Cleanup completed');
|
||||
};
|
||||
|
||||
try {
|
||||
({ client } = await initializeClient({ req, res, endpointOption }));
|
||||
if (clientRegistry && client) {
|
||||
clientRegistry.register(client, { userId }, client);
|
||||
}
|
||||
|
||||
if (client) {
|
||||
requestDataMap.set(req, { client });
|
||||
}
|
||||
|
||||
clientRef = new WeakRef(client);
|
||||
|
||||
getAbortData = () => {
|
||||
const currentClient = clientRef?.deref();
|
||||
const currentText =
|
||||
currentClient?.getStreamText != null ? currentClient.getStreamText() : getPartialText();
|
||||
|
||||
return {
|
||||
sender,
|
||||
conversationId,
|
||||
messageId: reqDataContext.responseMessageId,
|
||||
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||
text: currentText,
|
||||
userMessage: userMessage,
|
||||
userMessagePromise: userMessagePromise,
|
||||
promptTokens: reqDataContext.promptTokens,
|
||||
};
|
||||
};
|
||||
|
||||
const { onStart, abortController } = createAbortController(
|
||||
req,
|
||||
res,
|
||||
getAbortData,
|
||||
updateReqData,
|
||||
);
|
||||
|
||||
const closeHandler = () => {
|
||||
logger.debug('[AskController] Request closed');
|
||||
if (!abortController || abortController.signal.aborted || abortController.requestCompleted) {
|
||||
return;
|
||||
}
|
||||
abortController.abort();
|
||||
logger.debug('[AskController] Request aborted on close');
|
||||
};
|
||||
|
||||
res.on('close', closeHandler);
|
||||
cleanupHandlers.push(() => {
|
||||
try {
|
||||
res.removeListener('close', closeHandler);
|
||||
} catch (e) {
|
||||
// Ignore
|
||||
}
|
||||
});
|
||||
|
||||
const messageOptions = {
|
||||
user: userId,
|
||||
parentMessageId,
|
||||
conversationId: reqDataContext.conversationId,
|
||||
overrideParentMessageId,
|
||||
getReqData: updateReqData,
|
||||
onStart,
|
||||
abortController,
|
||||
progressCallback,
|
||||
progressOptions: {
|
||||
res,
|
||||
},
|
||||
};
|
||||
|
||||
/** @type {TMessage} */
|
||||
let response = await client.sendMessage(text, messageOptions);
|
||||
response.endpoint = endpointOption.endpoint;
|
||||
|
||||
const databasePromise = response.databasePromise;
|
||||
delete response.databasePromise;
|
||||
|
||||
const { conversation: convoData = {} } = await databasePromise;
|
||||
const conversation = { ...convoData };
|
||||
conversation.title =
|
||||
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
|
||||
|
||||
const latestUserMessage = reqDataContext.userMessage;
|
||||
|
||||
if (client?.options?.attachments && latestUserMessage) {
|
||||
latestUserMessage.files = client.options.attachments;
|
||||
if (endpointOption?.modelOptions?.model) {
|
||||
conversation.model = endpointOption.modelOptions.model;
|
||||
}
|
||||
delete latestUserMessage.image_urls;
|
||||
}
|
||||
|
||||
if (!abortController.signal.aborted) {
|
||||
const finalResponseMessage = { ...response };
|
||||
|
||||
sendMessage(res, {
|
||||
final: true,
|
||||
conversation,
|
||||
title: conversation.title,
|
||||
requestMessage: latestUserMessage,
|
||||
responseMessage: finalResponseMessage,
|
||||
});
|
||||
res.end();
|
||||
|
||||
if (client?.savedMessageIds && !client.savedMessageIds.has(response.messageId)) {
|
||||
await saveMessage(
|
||||
req,
|
||||
{ ...finalResponseMessage, user: userId },
|
||||
{ context: 'api/server/controllers/AskController.js - response end' },
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if (!client?.skipSaveUserMessage && latestUserMessage) {
|
||||
await saveMessage(req, latestUserMessage, {
|
||||
context: "api/server/controllers/AskController.js - don't skip saving user message",
|
||||
});
|
||||
}
|
||||
|
||||
if (typeof addTitle === 'function' && parentMessageId === Constants.NO_PARENT && newConvo) {
|
||||
addTitle(req, {
|
||||
text,
|
||||
response: { ...response },
|
||||
client,
|
||||
})
|
||||
.then(() => {
|
||||
logger.debug('[AskController] Title generation started');
|
||||
})
|
||||
.catch((err) => {
|
||||
logger.error('[AskController] Error in title generation', err);
|
||||
})
|
||||
.finally(() => {
|
||||
logger.debug('[AskController] Title generation completed');
|
||||
performCleanup();
|
||||
});
|
||||
} else {
|
||||
performCleanup();
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('[AskController] Error handling request', error);
|
||||
let partialText = '';
|
||||
try {
|
||||
const currentClient = clientRef?.deref();
|
||||
partialText =
|
||||
currentClient?.getStreamText != null ? currentClient.getStreamText() : getPartialText();
|
||||
} catch (getTextError) {
|
||||
logger.error('[AskController] Error calling getText() during error handling', getTextError);
|
||||
}
|
||||
|
||||
handleAbortError(res, req, error, {
|
||||
sender,
|
||||
partialText,
|
||||
conversationId: reqDataContext.conversationId,
|
||||
messageId: reqDataContext.responseMessageId,
|
||||
parentMessageId: overrideParentMessageId ?? reqDataContext.userMessageId ?? parentMessageId,
|
||||
userMessageId: reqDataContext.userMessageId,
|
||||
})
|
||||
.catch((err) => {
|
||||
logger.error('[AskController] Error in `handleAbortError` during catch block', err);
|
||||
})
|
||||
.finally(() => {
|
||||
performCleanup();
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = AskController;
|
||||
@@ -84,7 +84,7 @@ const EditController = async (req, res, next, initializeClient) => {
|
||||
}
|
||||
|
||||
if (abortKey) {
|
||||
logger.debug('[AskController] Cleaning up abort controller');
|
||||
logger.debug('[EditController] Cleaning up abort controller');
|
||||
cleanupAbortController(abortKey);
|
||||
abortKey = null;
|
||||
}
|
||||
|
||||
437
api/server/controllers/PermissionsController.js
Normal file
437
api/server/controllers/PermissionsController.js
Normal file
@@ -0,0 +1,437 @@
|
||||
/**
|
||||
* @import { TUpdateResourcePermissionsRequest, TUpdateResourcePermissionsResponse } from 'librechat-data-provider'
|
||||
*/
|
||||
|
||||
const mongoose = require('mongoose');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const {
|
||||
getAvailableRoles,
|
||||
ensurePrincipalExists,
|
||||
getEffectivePermissions,
|
||||
ensureGroupPrincipalExists,
|
||||
bulkUpdateResourcePermissions,
|
||||
} = require('~/server/services/PermissionService');
|
||||
const { AclEntry } = require('~/db/models');
|
||||
const {
|
||||
searchPrincipals: searchLocalPrincipals,
|
||||
sortPrincipalsByRelevance,
|
||||
calculateRelevanceScore,
|
||||
} = require('~/models');
|
||||
const {
|
||||
searchEntraIdPrincipals,
|
||||
entraIdPrincipalFeatureEnabled,
|
||||
} = require('~/server/services/GraphApiService');
|
||||
|
||||
/**
|
||||
* Generic controller for resource permission endpoints
|
||||
* Delegates validation and logic to PermissionService
|
||||
*/
|
||||
|
||||
/**
|
||||
* Bulk update permissions for a resource (grant, update, remove)
|
||||
* @route PUT /api/{resourceType}/{resourceId}/permissions
|
||||
* @param {Object} req - Express request object
|
||||
* @param {Object} req.params - Route parameters
|
||||
* @param {string} req.params.resourceType - Resource type (e.g., 'agent')
|
||||
* @param {string} req.params.resourceId - Resource ID
|
||||
* @param {TUpdateResourcePermissionsRequest} req.body - Request body
|
||||
* @param {Object} res - Express response object
|
||||
* @returns {Promise<TUpdateResourcePermissionsResponse>} Updated permissions response
|
||||
*/
|
||||
const updateResourcePermissions = async (req, res) => {
|
||||
try {
|
||||
const { resourceType, resourceId } = req.params;
|
||||
/** @type {TUpdateResourcePermissionsRequest} */
|
||||
const { updated, removed, public: isPublic, publicAccessRoleId } = req.body;
|
||||
const { id: userId } = req.user;
|
||||
|
||||
// Prepare principals for the service call
|
||||
const updatedPrincipals = [];
|
||||
const revokedPrincipals = [];
|
||||
|
||||
// Add updated principals
|
||||
if (updated && Array.isArray(updated)) {
|
||||
updatedPrincipals.push(...updated);
|
||||
}
|
||||
|
||||
// Add public permission if enabled
|
||||
if (isPublic && publicAccessRoleId) {
|
||||
updatedPrincipals.push({
|
||||
type: 'public',
|
||||
id: null,
|
||||
accessRoleId: publicAccessRoleId,
|
||||
});
|
||||
}
|
||||
|
||||
// Prepare authentication context for enhanced group member fetching
|
||||
const useEntraId = entraIdPrincipalFeatureEnabled(req.user);
|
||||
const authHeader = req.headers.authorization;
|
||||
const accessToken =
|
||||
authHeader && authHeader.startsWith('Bearer ') ? authHeader.substring(7) : null;
|
||||
const authContext =
|
||||
useEntraId && accessToken
|
||||
? {
|
||||
accessToken,
|
||||
sub: req.user.openidId,
|
||||
}
|
||||
: null;
|
||||
|
||||
// Ensure updated principals exist in the database before processing permissions
|
||||
const validatedPrincipals = [];
|
||||
for (const principal of updatedPrincipals) {
|
||||
try {
|
||||
let principalId;
|
||||
|
||||
if (principal.type === 'public') {
|
||||
principalId = null; // Public principals don't need database records
|
||||
} else if (principal.type === 'user') {
|
||||
principalId = await ensurePrincipalExists(principal);
|
||||
} else if (principal.type === 'group') {
|
||||
// Pass authContext to enable member fetching for Entra ID groups when available
|
||||
principalId = await ensureGroupPrincipalExists(principal, authContext);
|
||||
} else {
|
||||
logger.error(`Unsupported principal type: ${principal.type}`);
|
||||
continue; // Skip invalid principal types
|
||||
}
|
||||
|
||||
// Update the principal with the validated ID for ACL operations
|
||||
validatedPrincipals.push({
|
||||
...principal,
|
||||
id: principalId,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('Error ensuring principal exists:', {
|
||||
principal: {
|
||||
type: principal.type,
|
||||
id: principal.id,
|
||||
name: principal.name,
|
||||
source: principal.source,
|
||||
},
|
||||
error: error.message,
|
||||
});
|
||||
// Continue with other principals instead of failing the entire operation
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Add removed principals
|
||||
if (removed && Array.isArray(removed)) {
|
||||
revokedPrincipals.push(...removed);
|
||||
}
|
||||
|
||||
// If public is disabled, add public to revoked list
|
||||
if (!isPublic) {
|
||||
revokedPrincipals.push({
|
||||
type: 'public',
|
||||
id: null,
|
||||
});
|
||||
}
|
||||
|
||||
const results = await bulkUpdateResourcePermissions({
|
||||
resourceType,
|
||||
resourceId,
|
||||
updatedPrincipals: validatedPrincipals,
|
||||
revokedPrincipals,
|
||||
grantedBy: userId,
|
||||
});
|
||||
|
||||
/** @type {TUpdateResourcePermissionsResponse} */
|
||||
const response = {
|
||||
message: 'Permissions updated successfully',
|
||||
results: {
|
||||
principals: results.granted,
|
||||
public: isPublic || false,
|
||||
publicAccessRoleId: isPublic ? publicAccessRoleId : undefined,
|
||||
},
|
||||
};
|
||||
|
||||
res.status(200).json(response);
|
||||
} catch (error) {
|
||||
logger.error('Error updating resource permissions:', error);
|
||||
res.status(400).json({
|
||||
error: 'Failed to update permissions',
|
||||
details: error.message,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Get principals with their permission roles for a resource (UI-friendly format)
|
||||
* Uses efficient aggregation pipeline to join User/Group data in single query
|
||||
* @route GET /api/permissions/{resourceType}/{resourceId}
|
||||
*/
|
||||
const getResourcePermissions = async (req, res) => {
|
||||
try {
|
||||
const { resourceType, resourceId } = req.params;
|
||||
|
||||
// Use aggregation pipeline for efficient single-query data retrieval
|
||||
const results = await AclEntry.aggregate([
|
||||
// Match ACL entries for this resource
|
||||
{
|
||||
$match: {
|
||||
resourceType,
|
||||
resourceId: mongoose.Types.ObjectId.isValid(resourceId)
|
||||
? mongoose.Types.ObjectId.createFromHexString(resourceId)
|
||||
: resourceId,
|
||||
},
|
||||
},
|
||||
// Lookup AccessRole information
|
||||
{
|
||||
$lookup: {
|
||||
from: 'accessroles',
|
||||
localField: 'roleId',
|
||||
foreignField: '_id',
|
||||
as: 'role',
|
||||
},
|
||||
},
|
||||
// Lookup User information (for user principals)
|
||||
{
|
||||
$lookup: {
|
||||
from: 'users',
|
||||
localField: 'principalId',
|
||||
foreignField: '_id',
|
||||
as: 'userInfo',
|
||||
},
|
||||
},
|
||||
// Lookup Group information (for group principals)
|
||||
{
|
||||
$lookup: {
|
||||
from: 'groups',
|
||||
localField: 'principalId',
|
||||
foreignField: '_id',
|
||||
as: 'groupInfo',
|
||||
},
|
||||
},
|
||||
// Project final structure
|
||||
{
|
||||
$project: {
|
||||
principalType: 1,
|
||||
principalId: 1,
|
||||
accessRoleId: { $arrayElemAt: ['$role.accessRoleId', 0] },
|
||||
userInfo: { $arrayElemAt: ['$userInfo', 0] },
|
||||
groupInfo: { $arrayElemAt: ['$groupInfo', 0] },
|
||||
},
|
||||
},
|
||||
]);
|
||||
|
||||
const principals = [];
|
||||
let publicPermission = null;
|
||||
|
||||
// Process aggregation results
|
||||
for (const result of results) {
|
||||
if (result.principalType === 'public') {
|
||||
publicPermission = {
|
||||
public: true,
|
||||
publicAccessRoleId: result.accessRoleId,
|
||||
};
|
||||
} else if (result.principalType === 'user' && result.userInfo) {
|
||||
principals.push({
|
||||
type: 'user',
|
||||
id: result.userInfo._id.toString(),
|
||||
name: result.userInfo.name || result.userInfo.username,
|
||||
email: result.userInfo.email,
|
||||
avatar: result.userInfo.avatar,
|
||||
source: !result.userInfo._id ? 'entra' : 'local',
|
||||
idOnTheSource: result.userInfo.idOnTheSource || result.userInfo._id.toString(),
|
||||
accessRoleId: result.accessRoleId,
|
||||
});
|
||||
} else if (result.principalType === 'group' && result.groupInfo) {
|
||||
principals.push({
|
||||
type: 'group',
|
||||
id: result.groupInfo._id.toString(),
|
||||
name: result.groupInfo.name,
|
||||
email: result.groupInfo.email,
|
||||
description: result.groupInfo.description,
|
||||
avatar: result.groupInfo.avatar,
|
||||
source: result.groupInfo.source || 'local',
|
||||
idOnTheSource: result.groupInfo.idOnTheSource || result.groupInfo._id.toString(),
|
||||
accessRoleId: result.accessRoleId,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Return response in format expected by frontend
|
||||
const response = {
|
||||
resourceType,
|
||||
resourceId,
|
||||
principals,
|
||||
public: publicPermission?.public || false,
|
||||
...(publicPermission?.publicAccessRoleId && {
|
||||
publicAccessRoleId: publicPermission.publicAccessRoleId,
|
||||
}),
|
||||
};
|
||||
|
||||
res.status(200).json(response);
|
||||
} catch (error) {
|
||||
logger.error('Error getting resource permissions principals:', error);
|
||||
res.status(500).json({
|
||||
error: 'Failed to get permissions principals',
|
||||
details: error.message,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Get available roles for a resource type
|
||||
* @route GET /api/{resourceType}/roles
|
||||
*/
|
||||
const getResourceRoles = async (req, res) => {
|
||||
try {
|
||||
const { resourceType } = req.params;
|
||||
|
||||
const roles = await getAvailableRoles({ resourceType });
|
||||
|
||||
res.status(200).json(
|
||||
roles.map((role) => ({
|
||||
accessRoleId: role.accessRoleId,
|
||||
name: role.name,
|
||||
description: role.description,
|
||||
permBits: role.permBits,
|
||||
})),
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error('Error getting resource roles:', error);
|
||||
res.status(500).json({
|
||||
error: 'Failed to get roles',
|
||||
details: error.message,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Get user's effective permission bitmask for a resource
|
||||
* @route GET /api/{resourceType}/{resourceId}/effective
|
||||
*/
|
||||
const getUserEffectivePermissions = async (req, res) => {
|
||||
try {
|
||||
const { resourceType, resourceId } = req.params;
|
||||
const { id: userId } = req.user;
|
||||
|
||||
const permissionBits = await getEffectivePermissions({
|
||||
userId,
|
||||
resourceType,
|
||||
resourceId,
|
||||
});
|
||||
|
||||
res.status(200).json({
|
||||
permissionBits,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('Error getting user effective permissions:', error);
|
||||
res.status(500).json({
|
||||
error: 'Failed to get effective permissions',
|
||||
details: error.message,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Search for users and groups to grant permissions
|
||||
* Supports hybrid local database + Entra ID search when configured
|
||||
* @route GET /api/permissions/search-principals
|
||||
*/
|
||||
const searchPrincipals = async (req, res) => {
|
||||
try {
|
||||
const { q: query, limit = 20, type } = req.query;
|
||||
|
||||
if (!query || query.trim().length === 0) {
|
||||
return res.status(400).json({
|
||||
error: 'Query parameter "q" is required and must not be empty',
|
||||
});
|
||||
}
|
||||
|
||||
if (query.trim().length < 2) {
|
||||
return res.status(400).json({
|
||||
error: 'Query must be at least 2 characters long',
|
||||
});
|
||||
}
|
||||
|
||||
const searchLimit = Math.min(Math.max(1, parseInt(limit) || 10), 50);
|
||||
const typeFilter = ['user', 'group'].includes(type) ? type : null;
|
||||
|
||||
const localResults = await searchLocalPrincipals(query.trim(), searchLimit, typeFilter);
|
||||
let allPrincipals = [...localResults];
|
||||
|
||||
const useEntraId = entraIdPrincipalFeatureEnabled(req.user);
|
||||
|
||||
if (useEntraId && localResults.length < searchLimit) {
|
||||
try {
|
||||
const graphTypeMap = {
|
||||
user: 'users',
|
||||
group: 'groups',
|
||||
null: 'all',
|
||||
};
|
||||
|
||||
const authHeader = req.headers.authorization;
|
||||
const accessToken =
|
||||
authHeader && authHeader.startsWith('Bearer ') ? authHeader.substring(7) : null;
|
||||
|
||||
if (accessToken) {
|
||||
const graphResults = await searchEntraIdPrincipals(
|
||||
accessToken,
|
||||
req.user.openidId,
|
||||
query.trim(),
|
||||
graphTypeMap[typeFilter],
|
||||
searchLimit - localResults.length,
|
||||
);
|
||||
|
||||
const localEmails = new Set(
|
||||
localResults.map((p) => p.email?.toLowerCase()).filter(Boolean),
|
||||
);
|
||||
const localGroupSourceIds = new Set(
|
||||
localResults.map((p) => p.idOnTheSource).filter(Boolean),
|
||||
);
|
||||
|
||||
for (const principal of graphResults) {
|
||||
const isDuplicateByEmail =
|
||||
principal.email && localEmails.has(principal.email.toLowerCase());
|
||||
const isDuplicateBySourceId =
|
||||
principal.idOnTheSource && localGroupSourceIds.has(principal.idOnTheSource);
|
||||
|
||||
if (!isDuplicateByEmail && !isDuplicateBySourceId) {
|
||||
allPrincipals.push(principal);
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (graphError) {
|
||||
logger.warn('Graph API search failed, falling back to local results:', graphError.message);
|
||||
}
|
||||
}
|
||||
const scoredResults = allPrincipals.map((item) => ({
|
||||
...item,
|
||||
_searchScore: calculateRelevanceScore(item, query.trim()),
|
||||
}));
|
||||
|
||||
allPrincipals = sortPrincipalsByRelevance(scoredResults)
|
||||
.slice(0, searchLimit)
|
||||
.map((result) => {
|
||||
const { _searchScore, ...resultWithoutScore } = result;
|
||||
return resultWithoutScore;
|
||||
});
|
||||
res.status(200).json({
|
||||
query: query.trim(),
|
||||
limit: searchLimit,
|
||||
type: typeFilter,
|
||||
results: allPrincipals,
|
||||
count: allPrincipals.length,
|
||||
sources: {
|
||||
local: allPrincipals.filter((r) => r.source === 'local').length,
|
||||
entra: allPrincipals.filter((r) => r.source === 'entra').length,
|
||||
},
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('Error searching principals:', error);
|
||||
res.status(500).json({
|
||||
error: 'Failed to search principals',
|
||||
details: error.message,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
updateResourcePermissions,
|
||||
getResourcePermissions,
|
||||
getResourceRoles,
|
||||
getUserEffectivePermissions,
|
||||
searchPrincipals,
|
||||
};
|
||||
@@ -1,11 +1,10 @@
|
||||
const fs = require('fs').promises;
|
||||
const { nanoid } = require('nanoid');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { logger, PermissionBits } = require('@librechat/data-schemas');
|
||||
const {
|
||||
Tools,
|
||||
Constants,
|
||||
FileSources,
|
||||
SystemRoles,
|
||||
FileSources,
|
||||
EToolResources,
|
||||
actionDelimiter,
|
||||
} = require('librechat-data-provider');
|
||||
@@ -14,16 +13,20 @@ const {
|
||||
createAgent,
|
||||
updateAgent,
|
||||
deleteAgent,
|
||||
getListAgents,
|
||||
getListAgentsByAccess,
|
||||
} = require('~/models/Agent');
|
||||
const {
|
||||
grantPermission,
|
||||
findAccessibleResources,
|
||||
findPubliclyAccessibleResources,
|
||||
hasPublicPermission,
|
||||
} = require('~/server/services/PermissionService');
|
||||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { resizeAvatar } = require('~/server/services/Files/images/avatar');
|
||||
const { refreshS3Url } = require('~/server/services/Files/S3/crud');
|
||||
const { filterFile } = require('~/server/services/Files/process');
|
||||
const { updateAction, getActions } = require('~/models/Action');
|
||||
const { getCachedTools } = require('~/server/services/Config');
|
||||
const { updateAgentProjects } = require('~/models/Agent');
|
||||
const { getProjectByName } = require('~/models/Project');
|
||||
const { revertAgentVersion } = require('~/models/Agent');
|
||||
const { deleteFileByFilter } = require('~/models/File');
|
||||
|
||||
@@ -69,6 +72,27 @@ const createAgentHandler = async (req, res) => {
|
||||
|
||||
agentData.id = `agent_${nanoid()}`;
|
||||
const agent = await createAgent(agentData);
|
||||
|
||||
// Automatically grant owner permissions to the creator
|
||||
try {
|
||||
await grantPermission({
|
||||
principalType: 'user',
|
||||
principalId: userId,
|
||||
resourceType: 'agent',
|
||||
resourceId: agent._id,
|
||||
accessRoleId: 'agent_owner',
|
||||
grantedBy: userId,
|
||||
});
|
||||
logger.debug(
|
||||
`[createAgent] Granted owner permissions to user ${userId} for agent ${agent.id}`,
|
||||
);
|
||||
} catch (permissionError) {
|
||||
logger.error(
|
||||
`[createAgent] Failed to grant owner permissions for agent ${agent.id}:`,
|
||||
permissionError,
|
||||
);
|
||||
}
|
||||
|
||||
res.status(201).json(agent);
|
||||
} catch (error) {
|
||||
logger.error('[/Agents] Error creating agent', error);
|
||||
@@ -87,21 +111,14 @@ const createAgentHandler = async (req, res) => {
|
||||
* @returns {Promise<Agent>} 200 - success response - application/json
|
||||
* @returns {Error} 404 - Agent not found
|
||||
*/
|
||||
const getAgentHandler = async (req, res) => {
|
||||
const getAgentHandler = async (req, res, expandProperties = false) => {
|
||||
try {
|
||||
const id = req.params.id;
|
||||
const author = req.user.id;
|
||||
|
||||
let query = { id, author };
|
||||
|
||||
const globalProject = await getProjectByName(Constants.GLOBAL_PROJECT_NAME, ['agentIds']);
|
||||
if (globalProject && (globalProject.agentIds?.length ?? 0) > 0) {
|
||||
query = {
|
||||
$or: [{ id, $in: globalProject.agentIds }, query],
|
||||
};
|
||||
}
|
||||
|
||||
const agent = await getAgent(query);
|
||||
// Permissions are validated by middleware before calling this function
|
||||
// Simply load the agent by ID
|
||||
const agent = await getAgent({ id });
|
||||
|
||||
if (!agent) {
|
||||
return res.status(404).json({ error: 'Agent not found' });
|
||||
@@ -118,23 +135,45 @@ const getAgentHandler = async (req, res) => {
|
||||
}
|
||||
|
||||
agent.author = agent.author.toString();
|
||||
|
||||
// @deprecated - isCollaborative replaced by ACL permissions
|
||||
agent.isCollaborative = !!agent.isCollaborative;
|
||||
|
||||
// Check if agent is public
|
||||
const isPublic = await hasPublicPermission({
|
||||
resourceType: 'agent',
|
||||
resourceId: agent._id,
|
||||
requiredPermissions: PermissionBits.VIEW,
|
||||
});
|
||||
agent.isPublic = isPublic;
|
||||
|
||||
if (agent.author !== author) {
|
||||
delete agent.author;
|
||||
}
|
||||
|
||||
if (!agent.isCollaborative && agent.author !== author && req.user.role !== SystemRoles.ADMIN) {
|
||||
if (!expandProperties) {
|
||||
// VIEW permission: Basic agent info only
|
||||
return res.status(200).json({
|
||||
_id: agent._id,
|
||||
id: agent.id,
|
||||
name: agent.name,
|
||||
description: agent.description,
|
||||
avatar: agent.avatar,
|
||||
author: agent.author,
|
||||
provider: agent.provider,
|
||||
model: agent.model,
|
||||
projectIds: agent.projectIds,
|
||||
// @deprecated - isCollaborative replaced by ACL permissions
|
||||
isCollaborative: agent.isCollaborative,
|
||||
isPublic: agent.isPublic,
|
||||
version: agent.version,
|
||||
// Safe metadata
|
||||
createdAt: agent.createdAt,
|
||||
updatedAt: agent.updatedAt,
|
||||
});
|
||||
}
|
||||
|
||||
// EDIT permission: Full agent details including sensitive configuration
|
||||
return res.status(200).json(agent);
|
||||
} catch (error) {
|
||||
logger.error('[/Agents/:id] Error retrieving agent', error);
|
||||
@@ -154,42 +193,20 @@ const getAgentHandler = async (req, res) => {
|
||||
const updateAgentHandler = async (req, res) => {
|
||||
try {
|
||||
const id = req.params.id;
|
||||
const { projectIds, removeProjectIds, ...updateData } = req.body;
|
||||
const isAdmin = req.user.role === SystemRoles.ADMIN;
|
||||
const { _id, ...updateData } = req.body;
|
||||
const existingAgent = await getAgent({ id });
|
||||
const isAuthor = existingAgent.author.toString() === req.user.id;
|
||||
|
||||
if (!existingAgent) {
|
||||
return res.status(404).json({ error: 'Agent not found' });
|
||||
}
|
||||
const hasEditPermission = existingAgent.isCollaborative || isAdmin || isAuthor;
|
||||
|
||||
if (!hasEditPermission) {
|
||||
return res.status(403).json({
|
||||
error: 'You do not have permission to modify this non-collaborative agent',
|
||||
});
|
||||
}
|
||||
|
||||
/** @type {boolean} */
|
||||
const isProjectUpdate = (projectIds?.length ?? 0) > 0 || (removeProjectIds?.length ?? 0) > 0;
|
||||
|
||||
let updatedAgent =
|
||||
Object.keys(updateData).length > 0
|
||||
? await updateAgent({ id }, updateData, {
|
||||
updatingUserId: req.user.id,
|
||||
skipVersioning: isProjectUpdate,
|
||||
})
|
||||
: existingAgent;
|
||||
|
||||
if (isProjectUpdate) {
|
||||
updatedAgent = await updateAgentProjects({
|
||||
user: req.user,
|
||||
agentId: id,
|
||||
projectIds,
|
||||
removeProjectIds,
|
||||
});
|
||||
}
|
||||
|
||||
if (updatedAgent.author) {
|
||||
updatedAgent.author = updatedAgent.author.toString();
|
||||
}
|
||||
@@ -307,6 +324,26 @@ const duplicateAgentHandler = async (req, res) => {
|
||||
newAgentData.actions = agentActions;
|
||||
const newAgent = await createAgent(newAgentData);
|
||||
|
||||
// Automatically grant owner permissions to the duplicator
|
||||
try {
|
||||
await grantPermission({
|
||||
principalType: 'user',
|
||||
principalId: userId,
|
||||
resourceType: 'agent',
|
||||
resourceId: newAgent._id,
|
||||
accessRoleId: 'agent_owner',
|
||||
grantedBy: userId,
|
||||
});
|
||||
logger.debug(
|
||||
`[duplicateAgent] Granted owner permissions to user ${userId} for duplicated agent ${newAgent.id}`,
|
||||
);
|
||||
} catch (permissionError) {
|
||||
logger.error(
|
||||
`[duplicateAgent] Failed to grant owner permissions for duplicated agent ${newAgent.id}:`,
|
||||
permissionError,
|
||||
);
|
||||
}
|
||||
|
||||
return res.status(201).json({
|
||||
agent: newAgent,
|
||||
actions: newActionsList,
|
||||
@@ -333,7 +370,7 @@ const deleteAgentHandler = async (req, res) => {
|
||||
if (!agent) {
|
||||
return res.status(404).json({ error: 'Agent not found' });
|
||||
}
|
||||
await deleteAgent({ id, author: req.user.id });
|
||||
await deleteAgent({ id });
|
||||
return res.json({ message: 'Agent deleted' });
|
||||
} catch (error) {
|
||||
logger.error('[/Agents/:id] Error deleting Agent', error);
|
||||
@@ -342,7 +379,7 @@ const deleteAgentHandler = async (req, res) => {
|
||||
};
|
||||
|
||||
/**
|
||||
*
|
||||
* Lists agents using ACL-aware permissions (ownership + explicit shares).
|
||||
* @route GET /Agents
|
||||
* @param {object} req - Express Request
|
||||
* @param {object} req.query - Request query
|
||||
@@ -351,9 +388,31 @@ const deleteAgentHandler = async (req, res) => {
|
||||
*/
|
||||
const getListAgentsHandler = async (req, res) => {
|
||||
try {
|
||||
const data = await getListAgents({
|
||||
author: req.user.id,
|
||||
const userId = req.user.id;
|
||||
|
||||
// Get agent IDs the user has VIEW access to via ACL
|
||||
const accessibleIds = await findAccessibleResources({
|
||||
userId,
|
||||
resourceType: 'agent',
|
||||
requiredPermissions: PermissionBits.VIEW,
|
||||
});
|
||||
const publiclyAccessibleIds = await findPubliclyAccessibleResources({
|
||||
resourceType: 'agent',
|
||||
requiredPermissions: PermissionBits.VIEW,
|
||||
});
|
||||
// Use the new ACL-aware function
|
||||
const data = await getListAgentsByAccess({
|
||||
accessibleIds,
|
||||
otherParams: {}, // Can add query params here if needed
|
||||
});
|
||||
if (data?.data?.length) {
|
||||
data.data = data.data.map((agent) => {
|
||||
if (publiclyAccessibleIds.some((id) => id.equals(agent._id))) {
|
||||
agent.isPublic = true;
|
||||
}
|
||||
return agent;
|
||||
});
|
||||
}
|
||||
return res.json(data);
|
||||
} catch (error) {
|
||||
logger.error('[/Agents] Error listing Agents', error);
|
||||
@@ -431,7 +490,7 @@ const uploadAgentAvatarHandler = async (req, res) => {
|
||||
};
|
||||
|
||||
promises.push(
|
||||
await updateAgent({ id: agent_id, author: req.user.id }, data, {
|
||||
await updateAgent({ id: agent_id }, data, {
|
||||
updatingUserId: req.user.id,
|
||||
}),
|
||||
);
|
||||
|
||||
@@ -39,7 +39,9 @@ const startServer = async () => {
|
||||
await connectDb();
|
||||
|
||||
logger.info('Connected to MongoDB');
|
||||
await indexSync();
|
||||
indexSync().catch((err) => {
|
||||
logger.error('[indexSync] Background sync failed:', err);
|
||||
});
|
||||
|
||||
app.disable('x-powered-by');
|
||||
app.set('trust proxy', trusted_proxy);
|
||||
@@ -95,7 +97,6 @@ const startServer = async () => {
|
||||
app.use('/api/actions', routes.actions);
|
||||
app.use('/api/keys', routes.keys);
|
||||
app.use('/api/user', routes.user);
|
||||
app.use('/api/ask', routes.ask);
|
||||
app.use('/api/search', routes.search);
|
||||
app.use('/api/edit', routes.edit);
|
||||
app.use('/api/messages', routes.messages);
|
||||
@@ -116,8 +117,9 @@ const startServer = async () => {
|
||||
app.use('/api/roles', routes.roles);
|
||||
app.use('/api/agents', routes.agents);
|
||||
app.use('/api/banner', routes.banner);
|
||||
app.use('/api/bedrock', routes.bedrock);
|
||||
app.use('/api/memories', routes.memories);
|
||||
app.use('/api/permissions', routes.accessPermissions);
|
||||
|
||||
app.use('/api/tags', routes.tags);
|
||||
app.use('/api/mcp', routes.mcp);
|
||||
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { Constants, isAgentsEndpoint } = require('librechat-data-provider');
|
||||
const { canAccessResource } = require('./canAccessResource');
|
||||
const { getAgent } = require('~/models/Agent');
|
||||
|
||||
/**
|
||||
* Agent ID resolver function for agent_id from request body
|
||||
* Resolves custom agent ID (e.g., "agent_abc123") to MongoDB ObjectId
|
||||
* This is used specifically for chat routes where agent_id comes from request body
|
||||
*
|
||||
* @param {string} agentCustomId - Custom agent ID from request body
|
||||
* @returns {Promise<Object|null>} Agent document with _id field, or null if not found
|
||||
*/
|
||||
const resolveAgentIdFromBody = async (agentCustomId) => {
|
||||
// Handle ephemeral agents - they don't need permission checks
|
||||
if (agentCustomId === Constants.EPHEMERAL_AGENT_ID) {
|
||||
return null; // No permission check needed for ephemeral agents
|
||||
}
|
||||
|
||||
return await getAgent({ id: agentCustomId });
|
||||
};
|
||||
|
||||
/**
|
||||
* Middleware factory that creates middleware to check agent access permissions from request body.
|
||||
* This middleware is specifically designed for chat routes where the agent_id comes from req.body
|
||||
* instead of route parameters.
|
||||
*
|
||||
* @param {Object} options - Configuration options
|
||||
* @param {number} options.requiredPermission - The permission bit required (1=view, 2=edit, 4=delete, 8=share)
|
||||
* @returns {Function} Express middleware function
|
||||
*
|
||||
* @example
|
||||
* // Basic usage for agent chat (requires VIEW permission)
|
||||
* router.post('/chat',
|
||||
* canAccessAgentFromBody({ requiredPermission: PermissionBits.VIEW }),
|
||||
* buildEndpointOption,
|
||||
* chatController
|
||||
* );
|
||||
*/
|
||||
const canAccessAgentFromBody = (options) => {
|
||||
const { requiredPermission } = options;
|
||||
|
||||
// Validate required options
|
||||
if (!requiredPermission || typeof requiredPermission !== 'number') {
|
||||
throw new Error('canAccessAgentFromBody: requiredPermission is required and must be a number');
|
||||
}
|
||||
|
||||
return async (req, res, next) => {
|
||||
try {
|
||||
const { endpoint, agent_id } = req.body;
|
||||
let agentId = agent_id;
|
||||
|
||||
if (!isAgentsEndpoint(endpoint)) {
|
||||
agentId = Constants.EPHEMERAL_AGENT_ID;
|
||||
}
|
||||
|
||||
if (!agentId) {
|
||||
return res.status(400).json({
|
||||
error: 'Bad Request',
|
||||
message: 'agent_id is required in request body',
|
||||
});
|
||||
}
|
||||
|
||||
// Skip permission checks for ephemeral agents
|
||||
if (agentId === Constants.EPHEMERAL_AGENT_ID) {
|
||||
return next();
|
||||
}
|
||||
|
||||
const agentAccessMiddleware = canAccessResource({
|
||||
resourceType: 'agent',
|
||||
requiredPermission,
|
||||
resourceIdParam: 'agent_id', // This will be ignored since we use custom resolver
|
||||
idResolver: () => resolveAgentIdFromBody(agentId),
|
||||
});
|
||||
|
||||
const tempReq = {
|
||||
...req,
|
||||
params: {
|
||||
...req.params,
|
||||
agent_id: agentId,
|
||||
},
|
||||
};
|
||||
|
||||
return agentAccessMiddleware(tempReq, res, next);
|
||||
} catch (error) {
|
||||
logger.error('Failed to validate agent access permissions', error);
|
||||
return res.status(500).json({
|
||||
error: 'Internal Server Error',
|
||||
message: 'Failed to validate agent access permissions',
|
||||
});
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
canAccessAgentFromBody,
|
||||
};
|
||||
@@ -0,0 +1,58 @@
|
||||
const { getAgent } = require('~/models/Agent');
|
||||
const { canAccessResource } = require('./canAccessResource');
|
||||
|
||||
/**
|
||||
* Agent ID resolver function
|
||||
* Resolves custom agent ID (e.g., "agent_abc123") to MongoDB ObjectId
|
||||
*
|
||||
* @param {string} agentCustomId - Custom agent ID from route parameter
|
||||
* @returns {Promise<Object|null>} Agent document with _id field, or null if not found
|
||||
*/
|
||||
const resolveAgentId = async (agentCustomId) => {
|
||||
return await getAgent({ id: agentCustomId });
|
||||
};
|
||||
|
||||
/**
|
||||
* Agent-specific middleware factory that creates middleware to check agent access permissions.
|
||||
* This middleware extends the generic canAccessResource to handle agent custom ID resolution.
|
||||
*
|
||||
* @param {Object} options - Configuration options
|
||||
* @param {number} options.requiredPermission - The permission bit required (1=view, 2=edit, 4=delete, 8=share)
|
||||
* @param {string} [options.resourceIdParam='id'] - The name of the route parameter containing the agent custom ID
|
||||
* @returns {Function} Express middleware function
|
||||
*
|
||||
* @example
|
||||
* // Basic usage for viewing agents
|
||||
* router.get('/agents/:id',
|
||||
* canAccessAgentResource({ requiredPermission: 1 }),
|
||||
* getAgent
|
||||
* );
|
||||
*
|
||||
* @example
|
||||
* // Custom resource ID parameter and edit permission
|
||||
* router.patch('/agents/:agent_id',
|
||||
* canAccessAgentResource({
|
||||
* requiredPermission: 2,
|
||||
* resourceIdParam: 'agent_id'
|
||||
* }),
|
||||
* updateAgent
|
||||
* );
|
||||
*/
|
||||
const canAccessAgentResource = (options) => {
|
||||
const { requiredPermission, resourceIdParam = 'id' } = options;
|
||||
|
||||
if (!requiredPermission || typeof requiredPermission !== 'number') {
|
||||
throw new Error('canAccessAgentResource: requiredPermission is required and must be a number');
|
||||
}
|
||||
|
||||
return canAccessResource({
|
||||
resourceType: 'agent',
|
||||
requiredPermission,
|
||||
resourceIdParam,
|
||||
idResolver: resolveAgentId,
|
||||
});
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
canAccessAgentResource,
|
||||
};
|
||||
@@ -0,0 +1,384 @@
|
||||
const mongoose = require('mongoose');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { canAccessAgentResource } = require('./canAccessAgentResource');
|
||||
const { User, Role, AclEntry } = require('~/db/models');
|
||||
const { createAgent } = require('~/models/Agent');
|
||||
|
||||
describe('canAccessAgentResource middleware', () => {
|
||||
let mongoServer;
|
||||
let req, res, next;
|
||||
let testUser;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
await mongoose.connect(mongoUri);
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
await mongoose.connection.dropDatabase();
|
||||
await Role.create({
|
||||
name: 'test-role',
|
||||
permissions: {
|
||||
AGENTS: {
|
||||
USE: true,
|
||||
CREATE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Create a test user
|
||||
testUser = await User.create({
|
||||
email: 'test@example.com',
|
||||
name: 'Test User',
|
||||
username: 'testuser',
|
||||
role: 'test-role',
|
||||
});
|
||||
|
||||
req = {
|
||||
user: { id: testUser._id.toString(), role: 'test-role' },
|
||||
params: {},
|
||||
};
|
||||
res = {
|
||||
status: jest.fn().mockReturnThis(),
|
||||
json: jest.fn(),
|
||||
};
|
||||
next = jest.fn();
|
||||
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('middleware factory', () => {
|
||||
test('should throw error if requiredPermission is not provided', () => {
|
||||
expect(() => canAccessAgentResource({})).toThrow(
|
||||
'canAccessAgentResource: requiredPermission is required and must be a number',
|
||||
);
|
||||
});
|
||||
|
||||
test('should throw error if requiredPermission is not a number', () => {
|
||||
expect(() => canAccessAgentResource({ requiredPermission: '1' })).toThrow(
|
||||
'canAccessAgentResource: requiredPermission is required and must be a number',
|
||||
);
|
||||
});
|
||||
|
||||
test('should create middleware with default resourceIdParam', () => {
|
||||
const middleware = canAccessAgentResource({ requiredPermission: 1 });
|
||||
expect(typeof middleware).toBe('function');
|
||||
expect(middleware.length).toBe(3); // Express middleware signature
|
||||
});
|
||||
|
||||
test('should create middleware with custom resourceIdParam', () => {
|
||||
const middleware = canAccessAgentResource({
|
||||
requiredPermission: 2,
|
||||
resourceIdParam: 'agent_id',
|
||||
});
|
||||
expect(typeof middleware).toBe('function');
|
||||
expect(middleware.length).toBe(3);
|
||||
});
|
||||
});
|
||||
|
||||
describe('permission checking with real agents', () => {
|
||||
test('should allow access when user is the agent author', async () => {
|
||||
// Create an agent owned by the test user
|
||||
const agent = await createAgent({
|
||||
id: `agent_${Date.now()}`,
|
||||
name: 'Test Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: testUser._id,
|
||||
});
|
||||
|
||||
// Create ACL entry for the author (owner permissions)
|
||||
await AclEntry.create({
|
||||
principalType: 'user',
|
||||
principalId: testUser._id,
|
||||
principalModel: 'User',
|
||||
resourceType: 'agent',
|
||||
resourceId: agent._id,
|
||||
permBits: 15, // All permissions (1+2+4+8)
|
||||
grantedBy: testUser._id,
|
||||
});
|
||||
|
||||
req.params.id = agent.id;
|
||||
|
||||
const middleware = canAccessAgentResource({ requiredPermission: 1 }); // VIEW permission
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).toHaveBeenCalled();
|
||||
expect(res.status).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should deny access when user is not the author and has no ACL entry', async () => {
|
||||
// Create an agent owned by a different user
|
||||
const otherUser = await User.create({
|
||||
email: 'other@example.com',
|
||||
name: 'Other User',
|
||||
username: 'otheruser',
|
||||
role: 'test-role',
|
||||
});
|
||||
|
||||
const agent = await createAgent({
|
||||
id: `agent_${Date.now()}`,
|
||||
name: 'Other User Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: otherUser._id,
|
||||
});
|
||||
|
||||
// Create ACL entry for the other user (owner)
|
||||
await AclEntry.create({
|
||||
principalType: 'user',
|
||||
principalId: otherUser._id,
|
||||
principalModel: 'User',
|
||||
resourceType: 'agent',
|
||||
resourceId: agent._id,
|
||||
permBits: 15, // All permissions
|
||||
grantedBy: otherUser._id,
|
||||
});
|
||||
|
||||
req.params.id = agent.id;
|
||||
|
||||
const middleware = canAccessAgentResource({ requiredPermission: 1 }); // VIEW permission
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.json).toHaveBeenCalledWith({
|
||||
error: 'Forbidden',
|
||||
message: 'Insufficient permissions to access this agent',
|
||||
});
|
||||
});
|
||||
|
||||
test('should allow access when user has ACL entry with sufficient permissions', async () => {
|
||||
// Create an agent owned by a different user
|
||||
const otherUser = await User.create({
|
||||
email: 'other2@example.com',
|
||||
name: 'Other User 2',
|
||||
username: 'otheruser2',
|
||||
role: 'test-role',
|
||||
});
|
||||
|
||||
const agent = await createAgent({
|
||||
id: `agent_${Date.now()}`,
|
||||
name: 'Shared Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: otherUser._id,
|
||||
});
|
||||
|
||||
// Create ACL entry granting view permission to test user
|
||||
await AclEntry.create({
|
||||
principalType: 'user',
|
||||
principalId: testUser._id,
|
||||
principalModel: 'User',
|
||||
resourceType: 'agent',
|
||||
resourceId: agent._id,
|
||||
permBits: 1, // VIEW permission
|
||||
grantedBy: otherUser._id,
|
||||
});
|
||||
|
||||
req.params.id = agent.id;
|
||||
|
||||
const middleware = canAccessAgentResource({ requiredPermission: 1 }); // VIEW permission
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).toHaveBeenCalled();
|
||||
expect(res.status).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should deny access when ACL permissions are insufficient', async () => {
|
||||
// Create an agent owned by a different user
|
||||
const otherUser = await User.create({
|
||||
email: 'other3@example.com',
|
||||
name: 'Other User 3',
|
||||
username: 'otheruser3',
|
||||
role: 'test-role',
|
||||
});
|
||||
|
||||
const agent = await createAgent({
|
||||
id: `agent_${Date.now()}`,
|
||||
name: 'Limited Access Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: otherUser._id,
|
||||
});
|
||||
|
||||
// Create ACL entry granting only view permission
|
||||
await AclEntry.create({
|
||||
principalType: 'user',
|
||||
principalId: testUser._id,
|
||||
principalModel: 'User',
|
||||
resourceType: 'agent',
|
||||
resourceId: agent._id,
|
||||
permBits: 1, // VIEW permission only
|
||||
grantedBy: otherUser._id,
|
||||
});
|
||||
|
||||
req.params.id = agent.id;
|
||||
|
||||
const middleware = canAccessAgentResource({ requiredPermission: 2 }); // EDIT permission required
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.json).toHaveBeenCalledWith({
|
||||
error: 'Forbidden',
|
||||
message: 'Insufficient permissions to access this agent',
|
||||
});
|
||||
});
|
||||
|
||||
test('should handle non-existent agent', async () => {
|
||||
req.params.id = 'agent_nonexistent';
|
||||
|
||||
const middleware = canAccessAgentResource({ requiredPermission: 1 });
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
expect(res.status).toHaveBeenCalledWith(404);
|
||||
expect(res.json).toHaveBeenCalledWith({
|
||||
error: 'Not Found',
|
||||
message: 'agent not found',
|
||||
});
|
||||
});
|
||||
|
||||
test('should use custom resourceIdParam', async () => {
|
||||
const agent = await createAgent({
|
||||
id: `agent_${Date.now()}`,
|
||||
name: 'Custom Param Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: testUser._id,
|
||||
});
|
||||
|
||||
// Create ACL entry for the author
|
||||
await AclEntry.create({
|
||||
principalType: 'user',
|
||||
principalId: testUser._id,
|
||||
principalModel: 'User',
|
||||
resourceType: 'agent',
|
||||
resourceId: agent._id,
|
||||
permBits: 15, // All permissions
|
||||
grantedBy: testUser._id,
|
||||
});
|
||||
|
||||
req.params.agent_id = agent.id; // Using custom param name
|
||||
|
||||
const middleware = canAccessAgentResource({
|
||||
requiredPermission: 1,
|
||||
resourceIdParam: 'agent_id',
|
||||
});
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).toHaveBeenCalled();
|
||||
expect(res.status).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('permission levels', () => {
|
||||
let agent;
|
||||
|
||||
beforeEach(async () => {
|
||||
agent = await createAgent({
|
||||
id: `agent_${Date.now()}`,
|
||||
name: 'Permission Test Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: testUser._id,
|
||||
});
|
||||
|
||||
// Create ACL entry with all permissions for the owner
|
||||
await AclEntry.create({
|
||||
principalType: 'user',
|
||||
principalId: testUser._id,
|
||||
principalModel: 'User',
|
||||
resourceType: 'agent',
|
||||
resourceId: agent._id,
|
||||
permBits: 15, // All permissions (1+2+4+8)
|
||||
grantedBy: testUser._id,
|
||||
});
|
||||
|
||||
req.params.id = agent.id;
|
||||
});
|
||||
|
||||
test('should support view permission (1)', async () => {
|
||||
const middleware = canAccessAgentResource({ requiredPermission: 1 });
|
||||
await middleware(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should support edit permission (2)', async () => {
|
||||
const middleware = canAccessAgentResource({ requiredPermission: 2 });
|
||||
await middleware(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should support delete permission (4)', async () => {
|
||||
const middleware = canAccessAgentResource({ requiredPermission: 4 });
|
||||
await middleware(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should support share permission (8)', async () => {
|
||||
const middleware = canAccessAgentResource({ requiredPermission: 8 });
|
||||
await middleware(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should support combined permissions', async () => {
|
||||
const viewAndEdit = 1 | 2; // 3
|
||||
const middleware = canAccessAgentResource({ requiredPermission: viewAndEdit });
|
||||
await middleware(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('integration with agent operations', () => {
|
||||
test('should work with agent CRUD operations', async () => {
|
||||
const agentId = `agent_${Date.now()}`;
|
||||
|
||||
// Create agent
|
||||
const agent = await createAgent({
|
||||
id: agentId,
|
||||
name: 'Integration Test Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: testUser._id,
|
||||
description: 'Testing integration',
|
||||
});
|
||||
|
||||
// Create ACL entry for the author
|
||||
await AclEntry.create({
|
||||
principalType: 'user',
|
||||
principalId: testUser._id,
|
||||
principalModel: 'User',
|
||||
resourceType: 'agent',
|
||||
resourceId: agent._id,
|
||||
permBits: 15, // All permissions
|
||||
grantedBy: testUser._id,
|
||||
});
|
||||
|
||||
req.params.id = agentId;
|
||||
|
||||
// Test view access
|
||||
const viewMiddleware = canAccessAgentResource({ requiredPermission: 1 });
|
||||
await viewMiddleware(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
jest.clearAllMocks();
|
||||
|
||||
// Update the agent
|
||||
const { updateAgent } = require('~/models/Agent');
|
||||
await updateAgent({ id: agentId }, { description: 'Updated description' });
|
||||
|
||||
// Test edit access
|
||||
const editMiddleware = canAccessAgentResource({ requiredPermission: 2 });
|
||||
await editMiddleware(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
157
api/server/middleware/accessResources/canAccessResource.js
Normal file
157
api/server/middleware/accessResources/canAccessResource.js
Normal file
@@ -0,0 +1,157 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { SystemRoles } = require('librechat-data-provider');
|
||||
const { checkPermission } = require('~/server/services/PermissionService');
|
||||
|
||||
/**
|
||||
* Generic base middleware factory that creates middleware to check resource access permissions.
|
||||
* This middleware expects MongoDB ObjectIds as resource identifiers for ACL permission checks.
|
||||
*
|
||||
* @param {Object} options - Configuration options
|
||||
* @param {string} options.resourceType - The type of resource (e.g., 'agent', 'file', 'project')
|
||||
* @param {number} options.requiredPermission - The permission bit required (1=view, 2=edit, 4=delete, 8=share)
|
||||
* @param {string} [options.resourceIdParam='resourceId'] - The name of the route parameter containing the resource ID
|
||||
* @param {Function} [options.idResolver] - Optional function to resolve custom IDs to ObjectIds
|
||||
* @returns {Function} Express middleware function
|
||||
*
|
||||
* @example
|
||||
* // Direct usage with ObjectId (for resources that use MongoDB ObjectId in routes)
|
||||
* router.get('/prompts/:promptId',
|
||||
* canAccessResource({ resourceType: 'prompt', requiredPermission: 1 }),
|
||||
* getPrompt
|
||||
* );
|
||||
*
|
||||
* @example
|
||||
* // Usage with custom ID resolver (for resources that use custom string IDs)
|
||||
* router.get('/agents/:id',
|
||||
* canAccessResource({
|
||||
* resourceType: 'agent',
|
||||
* requiredPermission: 1,
|
||||
* resourceIdParam: 'id',
|
||||
* idResolver: (customId) => resolveAgentId(customId)
|
||||
* }),
|
||||
* getAgent
|
||||
* );
|
||||
*/
|
||||
const canAccessResource = (options) => {
|
||||
const {
|
||||
resourceType,
|
||||
requiredPermission,
|
||||
resourceIdParam = 'resourceId',
|
||||
idResolver = null,
|
||||
} = options;
|
||||
|
||||
if (!resourceType || typeof resourceType !== 'string') {
|
||||
throw new Error('canAccessResource: resourceType is required and must be a string');
|
||||
}
|
||||
|
||||
if (!requiredPermission || typeof requiredPermission !== 'number') {
|
||||
throw new Error('canAccessResource: requiredPermission is required and must be a number');
|
||||
}
|
||||
|
||||
return async (req, res, next) => {
|
||||
try {
|
||||
// Extract resource ID from route parameters
|
||||
const rawResourceId = req.params[resourceIdParam];
|
||||
|
||||
if (!rawResourceId) {
|
||||
logger.warn(`[canAccessResource] Missing ${resourceIdParam} in route parameters`);
|
||||
return res.status(400).json({
|
||||
error: 'Bad Request',
|
||||
message: `${resourceIdParam} is required`,
|
||||
});
|
||||
}
|
||||
|
||||
// Check if user is authenticated
|
||||
if (!req.user || !req.user.id) {
|
||||
logger.warn(
|
||||
`[canAccessResource] Unauthenticated request for ${resourceType} ${rawResourceId}`,
|
||||
);
|
||||
return res.status(401).json({
|
||||
error: 'Unauthorized',
|
||||
message: 'Authentication required',
|
||||
});
|
||||
}
|
||||
// if system admin let through
|
||||
if (req.user.role === SystemRoles.ADMIN) {
|
||||
return next();
|
||||
}
|
||||
const userId = req.user.id;
|
||||
let resourceId = rawResourceId;
|
||||
let resourceInfo = null;
|
||||
|
||||
// Resolve custom ID to ObjectId if resolver is provided
|
||||
if (idResolver) {
|
||||
logger.debug(
|
||||
`[canAccessResource] Resolving ${resourceType} custom ID ${rawResourceId} to ObjectId`,
|
||||
);
|
||||
|
||||
const resolutionResult = await idResolver(rawResourceId);
|
||||
|
||||
if (!resolutionResult) {
|
||||
logger.warn(`[canAccessResource] ${resourceType} not found: ${rawResourceId}`);
|
||||
return res.status(404).json({
|
||||
error: 'Not Found',
|
||||
message: `${resourceType} not found`,
|
||||
});
|
||||
}
|
||||
|
||||
// Handle different resolver return formats
|
||||
if (typeof resolutionResult === 'string' || resolutionResult._id) {
|
||||
resourceId = resolutionResult._id || resolutionResult;
|
||||
resourceInfo = typeof resolutionResult === 'object' ? resolutionResult : null;
|
||||
} else {
|
||||
resourceId = resolutionResult;
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
`[canAccessResource] Resolved ${resourceType} ${rawResourceId} to ObjectId ${resourceId}`,
|
||||
);
|
||||
}
|
||||
|
||||
// Check permissions using PermissionService with ObjectId
|
||||
const hasPermission = await checkPermission({
|
||||
userId,
|
||||
resourceType,
|
||||
resourceId,
|
||||
requiredPermission,
|
||||
});
|
||||
|
||||
if (hasPermission) {
|
||||
logger.debug(
|
||||
`[canAccessResource] User ${userId} has permission ${requiredPermission} on ${resourceType} ${rawResourceId} (${resourceId})`,
|
||||
);
|
||||
|
||||
req.resourceAccess = {
|
||||
resourceType,
|
||||
resourceId, // MongoDB ObjectId for ACL operations
|
||||
customResourceId: rawResourceId, // Original ID from route params
|
||||
permission: requiredPermission,
|
||||
userId,
|
||||
...(resourceInfo && { resourceInfo }),
|
||||
};
|
||||
|
||||
return next();
|
||||
}
|
||||
|
||||
logger.warn(
|
||||
`[canAccessResource] User ${userId} denied access to ${resourceType} ${rawResourceId} ` +
|
||||
`(required permission: ${requiredPermission})`,
|
||||
);
|
||||
|
||||
return res.status(403).json({
|
||||
error: 'Forbidden',
|
||||
message: `Insufficient permissions to access this ${resourceType}`,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(`[canAccessResource] Error checking access for ${resourceType}:`, error);
|
||||
return res.status(500).json({
|
||||
error: 'Internal Server Error',
|
||||
message: 'Failed to check resource access permissions',
|
||||
});
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
canAccessResource,
|
||||
};
|
||||
9
api/server/middleware/accessResources/index.js
Normal file
9
api/server/middleware/accessResources/index.js
Normal file
@@ -0,0 +1,9 @@
|
||||
const { canAccessResource } = require('./canAccessResource');
|
||||
const { canAccessAgentResource } = require('./canAccessAgentResource');
|
||||
const { canAccessAgentFromBody } = require('./canAccessAgentFromBody');
|
||||
|
||||
module.exports = {
|
||||
canAccessResource,
|
||||
canAccessAgentResource,
|
||||
canAccessAgentFromBody,
|
||||
};
|
||||
@@ -1,11 +1,11 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const {
|
||||
parseCompactConvo,
|
||||
EndpointURLs,
|
||||
EModelEndpoint,
|
||||
isAgentsEndpoint,
|
||||
EndpointURLs,
|
||||
parseCompactConvo,
|
||||
} = require('librechat-data-provider');
|
||||
const azureAssistants = require('~/server/services/Endpoints/azureAssistants');
|
||||
const { getModelsConfig } = require('~/server/controllers/ModelController');
|
||||
const assistants = require('~/server/services/Endpoints/assistants');
|
||||
const gptPlugins = require('~/server/services/Endpoints/gptPlugins');
|
||||
const { processFiles } = require('~/server/services/Files/process');
|
||||
@@ -36,6 +36,9 @@ async function buildEndpointOption(req, res, next) {
|
||||
try {
|
||||
parsedBody = parseCompactConvo({ endpoint, endpointType, conversation: req.body });
|
||||
} catch (error) {
|
||||
logger.warn(
|
||||
`Error parsing conversation for endpoint ${endpoint}${error?.message ? `: ${error.message}` : ''}`,
|
||||
);
|
||||
return handleError(res, { text: 'Error parsing conversation' });
|
||||
}
|
||||
|
||||
@@ -77,6 +80,7 @@ async function buildEndpointOption(req, res, next) {
|
||||
conversation: currentModelSpec.preset,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(`Error parsing model spec for endpoint ${endpoint}`, error);
|
||||
return handleError(res, { text: 'Error parsing model spec' });
|
||||
}
|
||||
}
|
||||
@@ -84,20 +88,23 @@ async function buildEndpointOption(req, res, next) {
|
||||
try {
|
||||
const isAgents =
|
||||
isAgentsEndpoint(endpoint) || req.baseUrl.startsWith(EndpointURLs[EModelEndpoint.agents]);
|
||||
const endpointFn = buildFunction[isAgents ? EModelEndpoint.agents : (endpointType ?? endpoint)];
|
||||
const builder = isAgents ? (...args) => endpointFn(req, ...args) : endpointFn;
|
||||
const builder = isAgents
|
||||
? (...args) => buildFunction[EModelEndpoint.agents](req, ...args)
|
||||
: buildFunction[endpointType ?? endpoint];
|
||||
|
||||
// TODO: use object params
|
||||
req.body.endpointOption = await builder(endpoint, parsedBody, endpointType);
|
||||
|
||||
// TODO: use `getModelsConfig` only when necessary
|
||||
const modelsConfig = await getModelsConfig(req);
|
||||
req.body.endpointOption.modelsConfig = modelsConfig;
|
||||
if (req.body.files && !isAgents) {
|
||||
req.body.endpointOption.attachments = processFiles(req.body.files);
|
||||
}
|
||||
|
||||
next();
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`Error building endpoint option for endpoint ${endpoint} with type ${endpointType}`,
|
||||
error,
|
||||
);
|
||||
return handleError(res, { text: 'Error building endpoint option' });
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ const concurrentLimiter = require('./concurrentLimiter');
|
||||
const validateEndpoint = require('./validateEndpoint');
|
||||
const requireLocalAuth = require('./requireLocalAuth');
|
||||
const canDeleteAccount = require('./canDeleteAccount');
|
||||
const accessResources = require('./accessResources');
|
||||
const setBalanceConfig = require('./setBalanceConfig');
|
||||
const requireLdapAuth = require('./requireLdapAuth');
|
||||
const abortMiddleware = require('./abortMiddleware');
|
||||
@@ -29,6 +30,7 @@ module.exports = {
|
||||
...validate,
|
||||
...limiters,
|
||||
...roles,
|
||||
...accessResources,
|
||||
noIndex,
|
||||
checkBan,
|
||||
uaParser,
|
||||
|
||||
251
api/server/middleware/roles/access.spec.js
Normal file
251
api/server/middleware/roles/access.spec.js
Normal file
@@ -0,0 +1,251 @@
|
||||
const mongoose = require('mongoose');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { checkAccess, generateCheckAccess } = require('./access');
|
||||
const { PermissionTypes, Permissions } = require('librechat-data-provider');
|
||||
const { Role } = require('~/db/models');
|
||||
|
||||
// Mock only the logger
|
||||
jest.mock('~/config', () => ({
|
||||
logger: {
|
||||
warn: jest.fn(),
|
||||
error: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
describe('Access Middleware', () => {
|
||||
let mongoServer;
|
||||
let req, res, next;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
await mongoose.connect(mongoUri);
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
await mongoose.connection.dropDatabase();
|
||||
|
||||
// Create test roles
|
||||
await Role.create({
|
||||
name: 'user',
|
||||
permissions: {
|
||||
[PermissionTypes.AGENTS]: {
|
||||
[Permissions.USE]: true,
|
||||
[Permissions.CREATE]: false,
|
||||
[Permissions.SHARED_GLOBAL]: false,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
await Role.create({
|
||||
name: 'admin',
|
||||
permissions: {
|
||||
[PermissionTypes.AGENTS]: {
|
||||
[Permissions.USE]: true,
|
||||
[Permissions.CREATE]: true,
|
||||
[Permissions.SHARED_GLOBAL]: true,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
req = {
|
||||
user: { id: 'user123', role: 'user' },
|
||||
body: {},
|
||||
};
|
||||
res = {
|
||||
status: jest.fn().mockReturnThis(),
|
||||
json: jest.fn(),
|
||||
};
|
||||
next = jest.fn();
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('checkAccess', () => {
|
||||
test('should return false if user is not provided', async () => {
|
||||
const result = await checkAccess(null, PermissionTypes.AGENTS, [Permissions.USE]);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
test('should return true if user has required permission', async () => {
|
||||
const result = await checkAccess(req.user, PermissionTypes.AGENTS, [Permissions.USE]);
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
test('should return false if user lacks required permission', async () => {
|
||||
const result = await checkAccess(req.user, PermissionTypes.AGENTS, [Permissions.CREATE]);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
test('should return true if user has any of multiple permissions', async () => {
|
||||
const result = await checkAccess(req.user, PermissionTypes.AGENTS, [
|
||||
Permissions.USE,
|
||||
Permissions.CREATE,
|
||||
]);
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
test('should check body properties when permission is not directly granted', async () => {
|
||||
// User role doesn't have CREATE permission, but bodyProps allows it
|
||||
const bodyProps = {
|
||||
[Permissions.CREATE]: ['agentId', 'name'],
|
||||
};
|
||||
|
||||
const checkObject = { agentId: 'agent123' };
|
||||
|
||||
const result = await checkAccess(
|
||||
req.user,
|
||||
PermissionTypes.AGENTS,
|
||||
[Permissions.CREATE],
|
||||
bodyProps,
|
||||
checkObject,
|
||||
);
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
test('should return false if role is not found', async () => {
|
||||
req.user.role = 'nonexistent';
|
||||
const result = await checkAccess(req.user, PermissionTypes.AGENTS, [Permissions.USE]);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
test('should return false if role has no permissions for the requested type', async () => {
|
||||
await Role.create({
|
||||
name: 'limited',
|
||||
permissions: {
|
||||
// Explicitly set AGENTS permissions to false
|
||||
[PermissionTypes.AGENTS]: {
|
||||
[Permissions.USE]: false,
|
||||
[Permissions.CREATE]: false,
|
||||
[Permissions.SHARED_GLOBAL]: false,
|
||||
},
|
||||
// Has permissions for other types
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
[Permissions.USE]: true,
|
||||
},
|
||||
},
|
||||
});
|
||||
req.user.role = 'limited';
|
||||
|
||||
const result = await checkAccess(req.user, PermissionTypes.AGENTS, [Permissions.USE]);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
test('should handle admin role with all permissions', async () => {
|
||||
req.user.role = 'admin';
|
||||
|
||||
const createResult = await checkAccess(req.user, PermissionTypes.AGENTS, [
|
||||
Permissions.CREATE,
|
||||
]);
|
||||
expect(createResult).toBe(true);
|
||||
|
||||
const shareResult = await checkAccess(req.user, PermissionTypes.AGENTS, [
|
||||
Permissions.SHARED_GLOBAL,
|
||||
]);
|
||||
expect(shareResult).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('generateCheckAccess', () => {
|
||||
test('should call next() when user has required permission', async () => {
|
||||
const middleware = generateCheckAccess(PermissionTypes.AGENTS, [Permissions.USE]);
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).toHaveBeenCalled();
|
||||
expect(res.status).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should return 403 when user lacks permission', async () => {
|
||||
const middleware = generateCheckAccess(PermissionTypes.AGENTS, [Permissions.CREATE]);
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.json).toHaveBeenCalledWith({ message: 'Forbidden: Insufficient permissions' });
|
||||
});
|
||||
|
||||
test('should check body properties when configured', async () => {
|
||||
req.body = { agentId: 'agent123', description: 'test' };
|
||||
|
||||
const bodyProps = {
|
||||
[Permissions.CREATE]: ['agentId'],
|
||||
};
|
||||
|
||||
const middleware = generateCheckAccess(
|
||||
PermissionTypes.AGENTS,
|
||||
[Permissions.CREATE],
|
||||
bodyProps,
|
||||
);
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).toHaveBeenCalled();
|
||||
expect(res.status).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should handle database errors gracefully', async () => {
|
||||
// Create a user with an invalid role that will cause getRoleByName to fail
|
||||
req.user.role = { invalid: 'object' }; // This will cause an error when querying
|
||||
|
||||
const middleware = generateCheckAccess(PermissionTypes.AGENTS, [Permissions.USE]);
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
expect(res.status).toHaveBeenCalledWith(500);
|
||||
expect(res.json).toHaveBeenCalledWith({
|
||||
message: expect.stringContaining('Server error:'),
|
||||
});
|
||||
});
|
||||
|
||||
test('should work with multiple permission types', async () => {
|
||||
req.user.role = 'admin';
|
||||
|
||||
const middleware = generateCheckAccess(PermissionTypes.AGENTS, [
|
||||
Permissions.USE,
|
||||
Permissions.CREATE,
|
||||
Permissions.SHARED_GLOBAL,
|
||||
]);
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should handle missing user gracefully', async () => {
|
||||
req.user = null;
|
||||
|
||||
const middleware = generateCheckAccess(PermissionTypes.AGENTS, [Permissions.USE]);
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
expect(res.status).toHaveBeenCalledWith(500);
|
||||
expect(res.json).toHaveBeenCalledWith({
|
||||
message: expect.stringContaining('Server error:'),
|
||||
});
|
||||
});
|
||||
|
||||
test('should handle role with no AGENTS permissions', async () => {
|
||||
await Role.create({
|
||||
name: 'noaccess',
|
||||
permissions: {
|
||||
// Explicitly set AGENTS with all permissions false
|
||||
[PermissionTypes.AGENTS]: {
|
||||
[Permissions.USE]: false,
|
||||
[Permissions.CREATE]: false,
|
||||
[Permissions.SHARED_GLOBAL]: false,
|
||||
},
|
||||
},
|
||||
});
|
||||
req.user.role = 'noaccess';
|
||||
|
||||
const middleware = generateCheckAccess(PermissionTypes.AGENTS, [Permissions.USE]);
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.json).toHaveBeenCalledWith({ message: 'Forbidden: Insufficient permissions' });
|
||||
});
|
||||
});
|
||||
});
|
||||
62
api/server/routes/accessPermissions.js
Normal file
62
api/server/routes/accessPermissions.js
Normal file
@@ -0,0 +1,62 @@
|
||||
const express = require('express');
|
||||
const { PermissionBits } = require('@librechat/data-schemas');
|
||||
const {
|
||||
getUserEffectivePermissions,
|
||||
updateResourcePermissions,
|
||||
getResourcePermissions,
|
||||
getResourceRoles,
|
||||
searchPrincipals,
|
||||
} = require('~/server/controllers/PermissionsController');
|
||||
const { requireJwtAuth, checkBan, uaParser, canAccessResource } = require('~/server/middleware');
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
// Apply common middleware
|
||||
router.use(requireJwtAuth);
|
||||
router.use(checkBan);
|
||||
router.use(uaParser);
|
||||
|
||||
/**
|
||||
* Generic routes for resource permissions
|
||||
* Pattern: /api/permissions/{resourceType}/{resourceId}
|
||||
*/
|
||||
|
||||
/**
|
||||
* GET /api/permissions/search-principals
|
||||
* Search for users and groups to grant permissions
|
||||
*/
|
||||
router.get('/search-principals', searchPrincipals);
|
||||
|
||||
/**
|
||||
* GET /api/permissions/{resourceType}/roles
|
||||
* Get available roles for a resource type
|
||||
*/
|
||||
router.get('/:resourceType/roles', getResourceRoles);
|
||||
|
||||
/**
|
||||
* GET /api/permissions/{resourceType}/{resourceId}
|
||||
* Get all permissions for a specific resource
|
||||
*/
|
||||
router.get('/:resourceType/:resourceId', getResourcePermissions);
|
||||
|
||||
/**
|
||||
* PUT /api/permissions/{resourceType}/{resourceId}
|
||||
* Bulk update permissions for a specific resource
|
||||
*/
|
||||
router.put(
|
||||
'/:resourceType/:resourceId',
|
||||
canAccessResource({
|
||||
resourceType: 'agent',
|
||||
requiredPermission: PermissionBits.SHARE,
|
||||
resourceIdParam: 'resourceId',
|
||||
}),
|
||||
updateResourcePermissions,
|
||||
);
|
||||
|
||||
/**
|
||||
* GET /api/permissions/{resourceType}/{resourceId}/effective
|
||||
* Get user's effective permissions for a specific resource
|
||||
*/
|
||||
router.get('/:resourceType/:resourceId/effective', getUserEffectivePermissions);
|
||||
|
||||
module.exports = router;
|
||||
@@ -1,20 +1,15 @@
|
||||
const express = require('express');
|
||||
const { nanoid } = require('nanoid');
|
||||
const { actionDelimiter, SystemRoles, removeNullishValues } = require('librechat-data-provider');
|
||||
const { logger, PermissionBits } = require('@librechat/data-schemas');
|
||||
const { actionDelimiter, removeNullishValues } = require('librechat-data-provider');
|
||||
const { encryptMetadata, domainParser } = require('~/server/services/ActionService');
|
||||
const { updateAction, getActions, deleteAction } = require('~/models/Action');
|
||||
const { isActionDomainAllowed } = require('~/server/services/domains');
|
||||
const { canAccessAgentResource } = require('~/server/middleware');
|
||||
const { getAgent, updateAgent } = require('~/models/Agent');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
// If the user has ADMIN role
|
||||
// then action edition is possible even if not owner of the assistant
|
||||
const isAdmin = (req) => {
|
||||
return req.user.role === SystemRoles.ADMIN;
|
||||
};
|
||||
|
||||
/**
|
||||
* Retrieves all user's actions
|
||||
* @route GET /actions/
|
||||
@@ -23,9 +18,8 @@ const isAdmin = (req) => {
|
||||
*/
|
||||
router.get('/', async (req, res) => {
|
||||
try {
|
||||
const admin = isAdmin(req);
|
||||
// If admin, get all actions, otherwise only user's actions
|
||||
const searchParams = admin ? {} : { user: req.user.id };
|
||||
// Get all actions for the user (admin permissions handled by middleware if needed)
|
||||
const searchParams = { user: req.user.id };
|
||||
res.json(await getActions(searchParams));
|
||||
} catch (error) {
|
||||
res.status(500).json({ error: error.message });
|
||||
@@ -41,106 +35,110 @@ router.get('/', async (req, res) => {
|
||||
* @param {ActionMetadata} req.body.metadata - Metadata for the action.
|
||||
* @returns {Object} 200 - success response - application/json
|
||||
*/
|
||||
router.post('/:agent_id', async (req, res) => {
|
||||
try {
|
||||
const { agent_id } = req.params;
|
||||
router.post(
|
||||
'/:agent_id',
|
||||
canAccessAgentResource({
|
||||
requiredPermission: PermissionBits.EDIT,
|
||||
resourceIdParam: 'agent_id',
|
||||
}),
|
||||
async (req, res) => {
|
||||
try {
|
||||
const { agent_id } = req.params;
|
||||
|
||||
/** @type {{ functions: FunctionTool[], action_id: string, metadata: ActionMetadata }} */
|
||||
const { functions, action_id: _action_id, metadata: _metadata } = req.body;
|
||||
if (!functions.length) {
|
||||
return res.status(400).json({ message: 'No functions provided' });
|
||||
}
|
||||
|
||||
let metadata = await encryptMetadata(removeNullishValues(_metadata, true));
|
||||
const isDomainAllowed = await isActionDomainAllowed(metadata.domain);
|
||||
if (!isDomainAllowed) {
|
||||
return res.status(400).json({ message: 'Domain not allowed' });
|
||||
}
|
||||
|
||||
let { domain } = metadata;
|
||||
domain = await domainParser(domain, true);
|
||||
|
||||
if (!domain) {
|
||||
return res.status(400).json({ message: 'No domain provided' });
|
||||
}
|
||||
|
||||
const action_id = _action_id ?? nanoid();
|
||||
const initialPromises = [];
|
||||
const admin = isAdmin(req);
|
||||
|
||||
// If admin, can edit any agent, otherwise only user's agents
|
||||
const agentQuery = admin ? { id: agent_id } : { id: agent_id, author: req.user.id };
|
||||
// TODO: share agents
|
||||
initialPromises.push(getAgent(agentQuery));
|
||||
if (_action_id) {
|
||||
initialPromises.push(getActions({ action_id }, true));
|
||||
}
|
||||
|
||||
/** @type {[Agent, [Action|undefined]]} */
|
||||
const [agent, actions_result] = await Promise.all(initialPromises);
|
||||
if (!agent) {
|
||||
return res.status(404).json({ message: 'Agent not found for adding action' });
|
||||
}
|
||||
|
||||
if (actions_result && actions_result.length) {
|
||||
const action = actions_result[0];
|
||||
metadata = { ...action.metadata, ...metadata };
|
||||
}
|
||||
|
||||
const { actions: _actions = [], author: agent_author } = agent ?? {};
|
||||
const actions = [];
|
||||
for (const action of _actions) {
|
||||
const [_action_domain, current_action_id] = action.split(actionDelimiter);
|
||||
if (current_action_id === action_id) {
|
||||
continue;
|
||||
/** @type {{ functions: FunctionTool[], action_id: string, metadata: ActionMetadata }} */
|
||||
const { functions, action_id: _action_id, metadata: _metadata } = req.body;
|
||||
if (!functions.length) {
|
||||
return res.status(400).json({ message: 'No functions provided' });
|
||||
}
|
||||
|
||||
actions.push(action);
|
||||
}
|
||||
|
||||
actions.push(`${domain}${actionDelimiter}${action_id}`);
|
||||
|
||||
/** @type {string[]}} */
|
||||
const { tools: _tools = [] } = agent;
|
||||
|
||||
const tools = _tools
|
||||
.filter((tool) => !(tool && (tool.includes(domain) || tool.includes(action_id))))
|
||||
.concat(functions.map((tool) => `${tool.function.name}${actionDelimiter}${domain}`));
|
||||
|
||||
// Force version update since actions are changing
|
||||
const updatedAgent = await updateAgent(
|
||||
agentQuery,
|
||||
{ tools, actions },
|
||||
{
|
||||
updatingUserId: req.user.id,
|
||||
forceVersion: true,
|
||||
},
|
||||
);
|
||||
|
||||
// Only update user field for new actions
|
||||
const actionUpdateData = { metadata, agent_id };
|
||||
if (!actions_result || !actions_result.length) {
|
||||
// For new actions, use the agent owner's user ID
|
||||
actionUpdateData.user = agent_author || req.user.id;
|
||||
}
|
||||
|
||||
/** @type {[Action]} */
|
||||
const updatedAction = await updateAction({ action_id }, actionUpdateData);
|
||||
|
||||
const sensitiveFields = ['api_key', 'oauth_client_id', 'oauth_client_secret'];
|
||||
for (let field of sensitiveFields) {
|
||||
if (updatedAction.metadata[field]) {
|
||||
delete updatedAction.metadata[field];
|
||||
let metadata = await encryptMetadata(removeNullishValues(_metadata, true));
|
||||
const isDomainAllowed = await isActionDomainAllowed(metadata.domain);
|
||||
if (!isDomainAllowed) {
|
||||
return res.status(400).json({ message: 'Domain not allowed' });
|
||||
}
|
||||
}
|
||||
|
||||
res.json([updatedAgent, updatedAction]);
|
||||
} catch (error) {
|
||||
const message = 'Trouble updating the Agent Action';
|
||||
logger.error(message, error);
|
||||
res.status(500).json({ message });
|
||||
}
|
||||
});
|
||||
let { domain } = metadata;
|
||||
domain = await domainParser(domain, true);
|
||||
|
||||
if (!domain) {
|
||||
return res.status(400).json({ message: 'No domain provided' });
|
||||
}
|
||||
|
||||
const action_id = _action_id ?? nanoid();
|
||||
const initialPromises = [];
|
||||
|
||||
// Permissions already validated by middleware - load agent directly
|
||||
initialPromises.push(getAgent({ id: agent_id }));
|
||||
if (_action_id) {
|
||||
initialPromises.push(getActions({ action_id }, true));
|
||||
}
|
||||
|
||||
/** @type {[Agent, [Action|undefined]]} */
|
||||
const [agent, actions_result] = await Promise.all(initialPromises);
|
||||
if (!agent) {
|
||||
return res.status(404).json({ message: 'Agent not found for adding action' });
|
||||
}
|
||||
|
||||
if (actions_result && actions_result.length) {
|
||||
const action = actions_result[0];
|
||||
metadata = { ...action.metadata, ...metadata };
|
||||
}
|
||||
|
||||
const { actions: _actions = [], author: agent_author } = agent ?? {};
|
||||
const actions = [];
|
||||
for (const action of _actions) {
|
||||
const [_action_domain, current_action_id] = action.split(actionDelimiter);
|
||||
if (current_action_id === action_id) {
|
||||
continue;
|
||||
}
|
||||
|
||||
actions.push(action);
|
||||
}
|
||||
|
||||
actions.push(`${domain}${actionDelimiter}${action_id}`);
|
||||
|
||||
/** @type {string[]}} */
|
||||
const { tools: _tools = [] } = agent;
|
||||
|
||||
const tools = _tools
|
||||
.filter((tool) => !(tool && (tool.includes(domain) || tool.includes(action_id))))
|
||||
.concat(functions.map((tool) => `${tool.function.name}${actionDelimiter}${domain}`));
|
||||
|
||||
// Force version update since actions are changing
|
||||
const updatedAgent = await updateAgent(
|
||||
{ id: agent_id },
|
||||
{ tools, actions },
|
||||
{
|
||||
updatingUserId: req.user.id,
|
||||
forceVersion: true,
|
||||
},
|
||||
);
|
||||
|
||||
// Only update user field for new actions
|
||||
const actionUpdateData = { metadata, agent_id };
|
||||
if (!actions_result || !actions_result.length) {
|
||||
// For new actions, use the agent owner's user ID
|
||||
actionUpdateData.user = agent_author || req.user.id;
|
||||
}
|
||||
|
||||
/** @type {[Action]} */
|
||||
const updatedAction = await updateAction({ action_id }, actionUpdateData);
|
||||
|
||||
const sensitiveFields = ['api_key', 'oauth_client_id', 'oauth_client_secret'];
|
||||
for (let field of sensitiveFields) {
|
||||
if (updatedAction.metadata[field]) {
|
||||
delete updatedAction.metadata[field];
|
||||
}
|
||||
}
|
||||
|
||||
res.json([updatedAgent, updatedAction]);
|
||||
} catch (error) {
|
||||
const message = 'Trouble updating the Agent Action';
|
||||
logger.error(message, error);
|
||||
res.status(500).json({ message });
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
/**
|
||||
* Deletes an action for a specific agent.
|
||||
@@ -149,52 +147,55 @@ router.post('/:agent_id', async (req, res) => {
|
||||
* @param {string} req.params.action_id - The ID of the action to delete.
|
||||
* @returns {Object} 200 - success response - application/json
|
||||
*/
|
||||
router.delete('/:agent_id/:action_id', async (req, res) => {
|
||||
try {
|
||||
const { agent_id, action_id } = req.params;
|
||||
const admin = isAdmin(req);
|
||||
router.delete(
|
||||
'/:agent_id/:action_id',
|
||||
canAccessAgentResource({
|
||||
requiredPermission: PermissionBits.EDIT,
|
||||
resourceIdParam: 'agent_id',
|
||||
}),
|
||||
async (req, res) => {
|
||||
try {
|
||||
const { agent_id, action_id } = req.params;
|
||||
|
||||
// If admin, can delete any agent, otherwise only user's agents
|
||||
const agentQuery = admin ? { id: agent_id } : { id: agent_id, author: req.user.id };
|
||||
const agent = await getAgent(agentQuery);
|
||||
if (!agent) {
|
||||
return res.status(404).json({ message: 'Agent not found for deleting action' });
|
||||
}
|
||||
|
||||
const { tools = [], actions = [] } = agent;
|
||||
|
||||
let domain = '';
|
||||
const updatedActions = actions.filter((action) => {
|
||||
if (action.includes(action_id)) {
|
||||
[domain] = action.split(actionDelimiter);
|
||||
return false;
|
||||
// Permissions already validated by middleware - load agent directly
|
||||
const agent = await getAgent({ id: agent_id });
|
||||
if (!agent) {
|
||||
return res.status(404).json({ message: 'Agent not found for deleting action' });
|
||||
}
|
||||
return true;
|
||||
});
|
||||
|
||||
domain = await domainParser(domain, true);
|
||||
const { tools = [], actions = [] } = agent;
|
||||
|
||||
if (!domain) {
|
||||
return res.status(400).json({ message: 'No domain provided' });
|
||||
let domain = '';
|
||||
const updatedActions = actions.filter((action) => {
|
||||
if (action.includes(action_id)) {
|
||||
[domain] = action.split(actionDelimiter);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
|
||||
domain = await domainParser(domain, true);
|
||||
|
||||
if (!domain) {
|
||||
return res.status(400).json({ message: 'No domain provided' });
|
||||
}
|
||||
|
||||
const updatedTools = tools.filter((tool) => !(tool && tool.includes(domain)));
|
||||
|
||||
// Force version update since actions are being removed
|
||||
await updateAgent(
|
||||
{ id: agent_id },
|
||||
{ tools: updatedTools, actions: updatedActions },
|
||||
{ updatingUserId: req.user.id, forceVersion: true },
|
||||
);
|
||||
await deleteAction({ action_id });
|
||||
res.status(200).json({ message: 'Action deleted successfully' });
|
||||
} catch (error) {
|
||||
const message = 'Trouble deleting the Agent Action';
|
||||
logger.error(message, error);
|
||||
res.status(500).json({ message });
|
||||
}
|
||||
|
||||
const updatedTools = tools.filter((tool) => !(tool && tool.includes(domain)));
|
||||
|
||||
// Force version update since actions are being removed
|
||||
await updateAgent(
|
||||
agentQuery,
|
||||
{ tools: updatedTools, actions: updatedActions },
|
||||
{ updatingUserId: req.user.id, forceVersion: true },
|
||||
);
|
||||
// If admin, can delete any action, otherwise only user's actions
|
||||
const actionQuery = admin ? { action_id } : { action_id, user: req.user.id };
|
||||
await deleteAction(actionQuery);
|
||||
res.status(200).json({ message: 'Action deleted successfully' });
|
||||
} catch (error) {
|
||||
const message = 'Trouble deleting the Agent Action';
|
||||
logger.error(message, error);
|
||||
res.status(500).json({ message });
|
||||
}
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
module.exports = router;
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
const express = require('express');
|
||||
const { PermissionBits } = require('@librechat/data-schemas');
|
||||
const { PermissionTypes, Permissions } = require('librechat-data-provider');
|
||||
const {
|
||||
setHeaders,
|
||||
@@ -7,6 +8,7 @@ const {
|
||||
generateCheckAccess,
|
||||
validateConvoAccess,
|
||||
buildEndpointOption,
|
||||
canAccessAgentFromBody,
|
||||
} = require('~/server/middleware');
|
||||
const { initializeClient } = require('~/server/services/Endpoints/agents');
|
||||
const AgentController = require('~/server/controllers/agents/request');
|
||||
@@ -17,8 +19,12 @@ const router = express.Router();
|
||||
router.use(moderateText);
|
||||
|
||||
const checkAgentAccess = generateCheckAccess(PermissionTypes.AGENTS, [Permissions.USE]);
|
||||
const checkAgentResourceAccess = canAccessAgentFromBody({
|
||||
requiredPermission: PermissionBits.VIEW,
|
||||
});
|
||||
|
||||
router.use(checkAgentAccess);
|
||||
router.use(checkAgentResourceAccess);
|
||||
router.use(validateConvoAccess);
|
||||
router.use(buildEndpointOption);
|
||||
router.use(setHeaders);
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
const express = require('express');
|
||||
const { PermissionBits } = require('@librechat/data-schemas');
|
||||
const { PermissionTypes, Permissions } = require('librechat-data-provider');
|
||||
const { requireJwtAuth, generateCheckAccess } = require('~/server/middleware');
|
||||
const {
|
||||
requireJwtAuth,
|
||||
generateCheckAccess,
|
||||
canAccessAgentResource,
|
||||
} = require('~/server/middleware');
|
||||
const v1 = require('~/server/controllers/agents/v1');
|
||||
const actions = require('./actions');
|
||||
const tools = require('./tools');
|
||||
@@ -46,13 +51,38 @@ router.use('/tools', tools);
|
||||
router.post('/', checkAgentCreate, v1.createAgent);
|
||||
|
||||
/**
|
||||
* Retrieves an agent.
|
||||
* Retrieves basic agent information (VIEW permission required).
|
||||
* Returns safe, non-sensitive agent data for viewing purposes.
|
||||
* @route GET /agents/:id
|
||||
* @param {string} req.params.id - Agent identifier.
|
||||
* @returns {Agent} 200 - Success response - application/json
|
||||
* @returns {Agent} 200 - Basic agent info - application/json
|
||||
*/
|
||||
router.get('/:id', checkAgentAccess, v1.getAgent);
|
||||
router.get(
|
||||
'/:id',
|
||||
checkAgentAccess,
|
||||
canAccessAgentResource({
|
||||
requiredPermission: PermissionBits.VIEW,
|
||||
resourceIdParam: 'id',
|
||||
}),
|
||||
v1.getAgent,
|
||||
);
|
||||
|
||||
/**
|
||||
* Retrieves full agent details including sensitive configuration (EDIT permission required).
|
||||
* Returns complete agent data for editing/configuration purposes.
|
||||
* @route GET /agents/:id/expanded
|
||||
* @param {string} req.params.id - Agent identifier.
|
||||
* @returns {Agent} 200 - Full agent details - application/json
|
||||
*/
|
||||
router.get(
|
||||
'/:id/expanded',
|
||||
checkAgentAccess,
|
||||
canAccessAgentResource({
|
||||
requiredPermission: PermissionBits.EDIT,
|
||||
resourceIdParam: 'id',
|
||||
}),
|
||||
(req, res) => v1.getAgent(req, res, true), // Expanded version
|
||||
);
|
||||
/**
|
||||
* Updates an agent.
|
||||
* @route PATCH /agents/:id
|
||||
@@ -60,7 +90,15 @@ router.get('/:id', checkAgentAccess, v1.getAgent);
|
||||
* @param {AgentUpdateParams} req.body - The agent update parameters.
|
||||
* @returns {Agent} 200 - Success response - application/json
|
||||
*/
|
||||
router.patch('/:id', checkGlobalAgentShare, v1.updateAgent);
|
||||
router.patch(
|
||||
'/:id',
|
||||
checkGlobalAgentShare,
|
||||
canAccessAgentResource({
|
||||
requiredPermission: PermissionBits.EDIT,
|
||||
resourceIdParam: 'id',
|
||||
}),
|
||||
v1.updateAgent,
|
||||
);
|
||||
|
||||
/**
|
||||
* Duplicates an agent.
|
||||
@@ -68,7 +106,15 @@ router.patch('/:id', checkGlobalAgentShare, v1.updateAgent);
|
||||
* @param {string} req.params.id - Agent identifier.
|
||||
* @returns {Agent} 201 - Success response - application/json
|
||||
*/
|
||||
router.post('/:id/duplicate', checkAgentCreate, v1.duplicateAgent);
|
||||
router.post(
|
||||
'/:id/duplicate',
|
||||
checkAgentCreate,
|
||||
canAccessAgentResource({
|
||||
requiredPermission: PermissionBits.VIEW,
|
||||
resourceIdParam: 'id',
|
||||
}),
|
||||
v1.duplicateAgent,
|
||||
);
|
||||
|
||||
/**
|
||||
* Deletes an agent.
|
||||
@@ -76,7 +122,15 @@ router.post('/:id/duplicate', checkAgentCreate, v1.duplicateAgent);
|
||||
* @param {string} req.params.id - Agent identifier.
|
||||
* @returns {Agent} 200 - success response - application/json
|
||||
*/
|
||||
router.delete('/:id', checkAgentCreate, v1.deleteAgent);
|
||||
router.delete(
|
||||
'/:id',
|
||||
checkAgentCreate,
|
||||
canAccessAgentResource({
|
||||
requiredPermission: PermissionBits.DELETE,
|
||||
resourceIdParam: 'id',
|
||||
}),
|
||||
v1.deleteAgent,
|
||||
);
|
||||
|
||||
/**
|
||||
* Reverts an agent to a previous version.
|
||||
@@ -103,6 +157,14 @@ router.get('/', checkAgentAccess, v1.getListAgents);
|
||||
* @param {string} [req.body.metadata] - Optional metadata for the agent's avatar.
|
||||
* @returns {Object} 200 - success response - application/json
|
||||
*/
|
||||
avatar.post('/:agent_id/avatar/', checkAgentAccess, v1.uploadAgentAvatar);
|
||||
avatar.post(
|
||||
'/:agent_id/avatar/',
|
||||
checkAgentAccess,
|
||||
canAccessAgentResource({
|
||||
requiredPermission: PermissionBits.EDIT,
|
||||
resourceIdParam: 'agent_id',
|
||||
}),
|
||||
v1.uploadAgentAvatar,
|
||||
);
|
||||
|
||||
module.exports = { v1: router, avatar };
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
const { Keyv } = require('keyv');
|
||||
const { KeyvFile } = require('keyv-file');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const addToCache = async ({ endpoint, endpointOption, userMessage, responseMessage }) => {
|
||||
try {
|
||||
const conversationsCache = new Keyv({
|
||||
store: new KeyvFile({ filename: './data/cache.json' }),
|
||||
namespace: 'chatgpt', // should be 'bing' for bing/sydney
|
||||
});
|
||||
|
||||
const {
|
||||
conversationId,
|
||||
messageId: userMessageId,
|
||||
parentMessageId: userParentMessageId,
|
||||
text: userText,
|
||||
} = userMessage;
|
||||
const {
|
||||
messageId: responseMessageId,
|
||||
parentMessageId: responseParentMessageId,
|
||||
text: responseText,
|
||||
} = responseMessage;
|
||||
|
||||
let conversation = await conversationsCache.get(conversationId);
|
||||
// used to generate a title for the conversation if none exists
|
||||
// let isNewConversation = false;
|
||||
if (!conversation) {
|
||||
conversation = {
|
||||
messages: [],
|
||||
createdAt: Date.now(),
|
||||
};
|
||||
// isNewConversation = true;
|
||||
}
|
||||
|
||||
const roles = (options) => {
|
||||
if (endpoint === 'openAI') {
|
||||
return options?.chatGptLabel || 'ChatGPT';
|
||||
}
|
||||
};
|
||||
|
||||
let _userMessage = {
|
||||
id: userMessageId,
|
||||
parentMessageId: userParentMessageId,
|
||||
role: 'User',
|
||||
message: userText,
|
||||
};
|
||||
|
||||
let _responseMessage = {
|
||||
id: responseMessageId,
|
||||
parentMessageId: responseParentMessageId,
|
||||
role: roles(endpointOption),
|
||||
message: responseText,
|
||||
};
|
||||
|
||||
conversation.messages.push(_userMessage, _responseMessage);
|
||||
|
||||
await conversationsCache.set(conversationId, conversation);
|
||||
} catch (error) {
|
||||
logger.error('[addToCache] Error adding conversation to cache', error);
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = addToCache;
|
||||
@@ -1,25 +0,0 @@
|
||||
const express = require('express');
|
||||
const AskController = require('~/server/controllers/AskController');
|
||||
const { addTitle, initializeClient } = require('~/server/services/Endpoints/anthropic');
|
||||
const {
|
||||
setHeaders,
|
||||
handleAbort,
|
||||
validateModel,
|
||||
validateEndpoint,
|
||||
buildEndpointOption,
|
||||
} = require('~/server/middleware');
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
router.post(
|
||||
'/',
|
||||
validateEndpoint,
|
||||
validateModel,
|
||||
buildEndpointOption,
|
||||
setHeaders,
|
||||
async (req, res, next) => {
|
||||
await AskController(req, res, next, initializeClient, addTitle);
|
||||
},
|
||||
);
|
||||
|
||||
module.exports = router;
|
||||
@@ -1,25 +0,0 @@
|
||||
const express = require('express');
|
||||
const AskController = require('~/server/controllers/AskController');
|
||||
const { initializeClient } = require('~/server/services/Endpoints/custom');
|
||||
const { addTitle } = require('~/server/services/Endpoints/openAI');
|
||||
const {
|
||||
setHeaders,
|
||||
validateModel,
|
||||
validateEndpoint,
|
||||
buildEndpointOption,
|
||||
} = require('~/server/middleware');
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
router.post(
|
||||
'/',
|
||||
validateEndpoint,
|
||||
validateModel,
|
||||
buildEndpointOption,
|
||||
setHeaders,
|
||||
async (req, res, next) => {
|
||||
await AskController(req, res, next, initializeClient, addTitle);
|
||||
},
|
||||
);
|
||||
|
||||
module.exports = router;
|
||||
@@ -1,24 +0,0 @@
|
||||
const express = require('express');
|
||||
const AskController = require('~/server/controllers/AskController');
|
||||
const { initializeClient, addTitle } = require('~/server/services/Endpoints/google');
|
||||
const {
|
||||
setHeaders,
|
||||
validateModel,
|
||||
validateEndpoint,
|
||||
buildEndpointOption,
|
||||
} = require('~/server/middleware');
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
router.post(
|
||||
'/',
|
||||
validateEndpoint,
|
||||
validateModel,
|
||||
buildEndpointOption,
|
||||
setHeaders,
|
||||
async (req, res, next) => {
|
||||
await AskController(req, res, next, initializeClient, addTitle);
|
||||
},
|
||||
);
|
||||
|
||||
module.exports = router;
|
||||
@@ -1,241 +0,0 @@
|
||||
const express = require('express');
|
||||
const { getResponseSender, Constants } = require('librechat-data-provider');
|
||||
const { initializeClient } = require('~/server/services/Endpoints/gptPlugins');
|
||||
const { sendMessage, createOnProgress } = require('~/server/utils');
|
||||
const { addTitle } = require('~/server/services/Endpoints/openAI');
|
||||
const { saveMessage, updateMessage } = require('~/models');
|
||||
const {
|
||||
handleAbort,
|
||||
createAbortController,
|
||||
handleAbortError,
|
||||
setHeaders,
|
||||
validateModel,
|
||||
validateEndpoint,
|
||||
buildEndpointOption,
|
||||
moderateText,
|
||||
} = require('~/server/middleware');
|
||||
const { validateTools } = require('~/app');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
router.use(moderateText);
|
||||
|
||||
router.post(
|
||||
'/',
|
||||
validateEndpoint,
|
||||
validateModel,
|
||||
buildEndpointOption,
|
||||
setHeaders,
|
||||
async (req, res) => {
|
||||
let {
|
||||
text,
|
||||
endpointOption,
|
||||
conversationId,
|
||||
parentMessageId = null,
|
||||
overrideParentMessageId = null,
|
||||
} = req.body;
|
||||
|
||||
logger.debug('[/ask/gptPlugins]', { text, conversationId, ...endpointOption });
|
||||
|
||||
let userMessage;
|
||||
let userMessagePromise;
|
||||
let promptTokens;
|
||||
let userMessageId;
|
||||
let responseMessageId;
|
||||
const sender = getResponseSender({
|
||||
...endpointOption,
|
||||
model: endpointOption.modelOptions.model,
|
||||
});
|
||||
const newConvo = !conversationId;
|
||||
const user = req.user.id;
|
||||
|
||||
const plugins = [];
|
||||
|
||||
const getReqData = (data = {}) => {
|
||||
for (let key in data) {
|
||||
if (key === 'userMessage') {
|
||||
userMessage = data[key];
|
||||
userMessageId = data[key].messageId;
|
||||
} else if (key === 'userMessagePromise') {
|
||||
userMessagePromise = data[key];
|
||||
} else if (key === 'responseMessageId') {
|
||||
responseMessageId = data[key];
|
||||
} else if (key === 'promptTokens') {
|
||||
promptTokens = data[key];
|
||||
} else if (!conversationId && key === 'conversationId') {
|
||||
conversationId = data[key];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let streaming = null;
|
||||
let timer = null;
|
||||
|
||||
const {
|
||||
onProgress: progressCallback,
|
||||
sendIntermediateMessage,
|
||||
getPartialText,
|
||||
} = createOnProgress({
|
||||
onProgress: () => {
|
||||
if (timer) {
|
||||
clearTimeout(timer);
|
||||
}
|
||||
|
||||
streaming = new Promise((resolve) => {
|
||||
timer = setTimeout(() => {
|
||||
resolve();
|
||||
}, 250);
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
const pluginMap = new Map();
|
||||
const onAgentAction = async (action, runId) => {
|
||||
pluginMap.set(runId, action.tool);
|
||||
sendIntermediateMessage(res, {
|
||||
plugins,
|
||||
parentMessageId: userMessage.messageId,
|
||||
messageId: responseMessageId,
|
||||
});
|
||||
};
|
||||
|
||||
const onToolStart = async (tool, input, runId, parentRunId) => {
|
||||
const pluginName = pluginMap.get(parentRunId);
|
||||
const latestPlugin = {
|
||||
runId,
|
||||
loading: true,
|
||||
inputs: [input],
|
||||
latest: pluginName,
|
||||
outputs: null,
|
||||
};
|
||||
|
||||
if (streaming) {
|
||||
await streaming;
|
||||
}
|
||||
const extraTokens = ':::plugin:::\n';
|
||||
plugins.push(latestPlugin);
|
||||
sendIntermediateMessage(
|
||||
res,
|
||||
{ plugins, parentMessageId: userMessage.messageId, messageId: responseMessageId },
|
||||
extraTokens,
|
||||
);
|
||||
};
|
||||
|
||||
const onToolEnd = async (output, runId) => {
|
||||
if (streaming) {
|
||||
await streaming;
|
||||
}
|
||||
|
||||
const pluginIndex = plugins.findIndex((plugin) => plugin.runId === runId);
|
||||
|
||||
if (pluginIndex !== -1) {
|
||||
plugins[pluginIndex].loading = false;
|
||||
plugins[pluginIndex].outputs = output;
|
||||
}
|
||||
};
|
||||
|
||||
const getAbortData = () => ({
|
||||
sender,
|
||||
conversationId,
|
||||
userMessagePromise,
|
||||
messageId: responseMessageId,
|
||||
parentMessageId: overrideParentMessageId ?? userMessageId,
|
||||
text: getPartialText(),
|
||||
plugins: plugins.map((p) => ({ ...p, loading: false })),
|
||||
userMessage,
|
||||
promptTokens,
|
||||
});
|
||||
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
|
||||
|
||||
try {
|
||||
endpointOption.tools = await validateTools(user, endpointOption.tools);
|
||||
const { client } = await initializeClient({ req, res, endpointOption });
|
||||
|
||||
const onChainEnd = () => {
|
||||
if (!client.skipSaveUserMessage) {
|
||||
saveMessage(
|
||||
req,
|
||||
{ ...userMessage, user },
|
||||
{ context: 'api/server/routes/ask/gptPlugins.js - onChainEnd' },
|
||||
);
|
||||
}
|
||||
sendIntermediateMessage(res, {
|
||||
plugins,
|
||||
parentMessageId: userMessage.messageId,
|
||||
messageId: responseMessageId,
|
||||
});
|
||||
};
|
||||
|
||||
let response = await client.sendMessage(text, {
|
||||
user,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
overrideParentMessageId,
|
||||
getReqData,
|
||||
onAgentAction,
|
||||
onChainEnd,
|
||||
onToolStart,
|
||||
onToolEnd,
|
||||
onStart,
|
||||
getPartialText,
|
||||
...endpointOption,
|
||||
progressCallback,
|
||||
progressOptions: {
|
||||
res,
|
||||
// parentMessageId: overrideParentMessageId || userMessageId,
|
||||
plugins,
|
||||
},
|
||||
abortController,
|
||||
});
|
||||
|
||||
if (overrideParentMessageId) {
|
||||
response.parentMessageId = overrideParentMessageId;
|
||||
}
|
||||
|
||||
logger.debug('[/ask/gptPlugins]', response);
|
||||
|
||||
const { conversation = {} } = await response.databasePromise;
|
||||
delete response.databasePromise;
|
||||
conversation.title =
|
||||
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
|
||||
|
||||
sendMessage(res, {
|
||||
title: conversation.title,
|
||||
final: true,
|
||||
conversation,
|
||||
requestMessage: userMessage,
|
||||
responseMessage: response,
|
||||
});
|
||||
res.end();
|
||||
|
||||
if (parentMessageId === Constants.NO_PARENT && newConvo) {
|
||||
addTitle(req, {
|
||||
text,
|
||||
response,
|
||||
client,
|
||||
});
|
||||
}
|
||||
|
||||
response.plugins = plugins.map((p) => ({ ...p, loading: false }));
|
||||
if (response.plugins?.length > 0) {
|
||||
await updateMessage(
|
||||
req,
|
||||
{ ...response, user },
|
||||
{ context: 'api/server/routes/ask/gptPlugins.js - save plugins used' },
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
const partialText = getPartialText();
|
||||
handleAbortError(res, req, error, {
|
||||
partialText,
|
||||
conversationId,
|
||||
sender,
|
||||
messageId: responseMessageId,
|
||||
parentMessageId: userMessageId ?? parentMessageId,
|
||||
});
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
module.exports = router;
|
||||
@@ -1,47 +0,0 @@
|
||||
const express = require('express');
|
||||
const { EModelEndpoint } = require('librechat-data-provider');
|
||||
const {
|
||||
uaParser,
|
||||
checkBan,
|
||||
requireJwtAuth,
|
||||
messageIpLimiter,
|
||||
concurrentLimiter,
|
||||
messageUserLimiter,
|
||||
validateConvoAccess,
|
||||
} = require('~/server/middleware');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const gptPlugins = require('./gptPlugins');
|
||||
const anthropic = require('./anthropic');
|
||||
const custom = require('./custom');
|
||||
const google = require('./google');
|
||||
const openAI = require('./openAI');
|
||||
|
||||
const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {};
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
router.use(requireJwtAuth);
|
||||
router.use(checkBan);
|
||||
router.use(uaParser);
|
||||
|
||||
if (isEnabled(LIMIT_CONCURRENT_MESSAGES)) {
|
||||
router.use(concurrentLimiter);
|
||||
}
|
||||
|
||||
if (isEnabled(LIMIT_MESSAGE_IP)) {
|
||||
router.use(messageIpLimiter);
|
||||
}
|
||||
|
||||
if (isEnabled(LIMIT_MESSAGE_USER)) {
|
||||
router.use(messageUserLimiter);
|
||||
}
|
||||
|
||||
router.use(validateConvoAccess);
|
||||
|
||||
router.use([`/${EModelEndpoint.azureOpenAI}`, `/${EModelEndpoint.openAI}`], openAI);
|
||||
router.use(`/${EModelEndpoint.gptPlugins}`, gptPlugins);
|
||||
router.use(`/${EModelEndpoint.anthropic}`, anthropic);
|
||||
router.use(`/${EModelEndpoint.google}`, google);
|
||||
router.use(`/${EModelEndpoint.custom}`, custom);
|
||||
|
||||
module.exports = router;
|
||||
@@ -1,27 +0,0 @@
|
||||
const express = require('express');
|
||||
const AskController = require('~/server/controllers/AskController');
|
||||
const { addTitle, initializeClient } = require('~/server/services/Endpoints/openAI');
|
||||
const {
|
||||
handleAbort,
|
||||
setHeaders,
|
||||
validateModel,
|
||||
validateEndpoint,
|
||||
buildEndpointOption,
|
||||
moderateText,
|
||||
} = require('~/server/middleware');
|
||||
|
||||
const router = express.Router();
|
||||
router.use(moderateText);
|
||||
|
||||
router.post(
|
||||
'/',
|
||||
validateEndpoint,
|
||||
validateModel,
|
||||
buildEndpointOption,
|
||||
setHeaders,
|
||||
async (req, res, next) => {
|
||||
await AskController(req, res, next, initializeClient, addTitle);
|
||||
},
|
||||
);
|
||||
|
||||
module.exports = router;
|
||||
@@ -1,37 +0,0 @@
|
||||
const express = require('express');
|
||||
|
||||
const router = express.Router();
|
||||
const {
|
||||
setHeaders,
|
||||
handleAbort,
|
||||
moderateText,
|
||||
// validateModel,
|
||||
// validateEndpoint,
|
||||
buildEndpointOption,
|
||||
} = require('~/server/middleware');
|
||||
const { initializeClient } = require('~/server/services/Endpoints/bedrock');
|
||||
const AgentController = require('~/server/controllers/agents/request');
|
||||
const addTitle = require('~/server/services/Endpoints/agents/title');
|
||||
|
||||
router.use(moderateText);
|
||||
|
||||
/**
|
||||
* @route POST /
|
||||
* @desc Chat with an assistant
|
||||
* @access Public
|
||||
* @param {express.Request} req - The request object, containing the request data.
|
||||
* @param {express.Response} res - The response object, used to send back a response.
|
||||
* @returns {void}
|
||||
*/
|
||||
router.post(
|
||||
'/',
|
||||
// validateModel,
|
||||
// validateEndpoint,
|
||||
buildEndpointOption,
|
||||
setHeaders,
|
||||
async (req, res, next) => {
|
||||
await AgentController(req, res, next, initializeClient, addTitle);
|
||||
},
|
||||
);
|
||||
|
||||
module.exports = router;
|
||||
@@ -1,35 +0,0 @@
|
||||
const express = require('express');
|
||||
const {
|
||||
uaParser,
|
||||
checkBan,
|
||||
requireJwtAuth,
|
||||
messageIpLimiter,
|
||||
concurrentLimiter,
|
||||
messageUserLimiter,
|
||||
} = require('~/server/middleware');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const chat = require('./chat');
|
||||
|
||||
const { LIMIT_CONCURRENT_MESSAGES, LIMIT_MESSAGE_IP, LIMIT_MESSAGE_USER } = process.env ?? {};
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
router.use(requireJwtAuth);
|
||||
router.use(checkBan);
|
||||
router.use(uaParser);
|
||||
|
||||
if (isEnabled(LIMIT_CONCURRENT_MESSAGES)) {
|
||||
router.use(concurrentLimiter);
|
||||
}
|
||||
|
||||
if (isEnabled(LIMIT_MESSAGE_IP)) {
|
||||
router.use(messageIpLimiter);
|
||||
}
|
||||
|
||||
if (isEnabled(LIMIT_MESSAGE_USER)) {
|
||||
router.use(messageUserLimiter);
|
||||
}
|
||||
|
||||
router.use('/chat', chat);
|
||||
|
||||
module.exports = router;
|
||||
@@ -1,3 +1,4 @@
|
||||
const accessPermissions = require('./accessPermissions');
|
||||
const assistants = require('./assistants');
|
||||
const categories = require('./categories');
|
||||
const tokenizer = require('./tokenizer');
|
||||
@@ -9,7 +10,6 @@ const presets = require('./presets');
|
||||
const prompts = require('./prompts');
|
||||
const balance = require('./balance');
|
||||
const plugins = require('./plugins');
|
||||
const bedrock = require('./bedrock');
|
||||
const actions = require('./actions');
|
||||
const banner = require('./banner');
|
||||
const search = require('./search');
|
||||
@@ -26,11 +26,10 @@ const auth = require('./auth');
|
||||
const edit = require('./edit');
|
||||
const keys = require('./keys');
|
||||
const user = require('./user');
|
||||
const ask = require('./ask');
|
||||
const mcp = require('./mcp');
|
||||
|
||||
module.exports = {
|
||||
ask,
|
||||
mcp,
|
||||
edit,
|
||||
auth,
|
||||
keys,
|
||||
@@ -46,7 +45,6 @@ module.exports = {
|
||||
search,
|
||||
config,
|
||||
models,
|
||||
bedrock,
|
||||
prompts,
|
||||
plugins,
|
||||
actions,
|
||||
@@ -59,5 +57,5 @@ module.exports = {
|
||||
assistants,
|
||||
categories,
|
||||
staticRoute,
|
||||
mcp,
|
||||
accessPermissions,
|
||||
};
|
||||
|
||||
@@ -9,6 +9,7 @@ const {
|
||||
setBalanceConfig,
|
||||
checkDomainAllowed,
|
||||
} = require('~/server/middleware');
|
||||
const { syncUserEntraGroupMemberships } = require('~/server/services/PermissionService');
|
||||
const { setAuthTokens, setOpenIDAuthTokens } = require('~/server/services/AuthService');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { logger } = require('~/config');
|
||||
@@ -35,6 +36,7 @@ const oauthHandler = async (req, res) => {
|
||||
req.user.provider == 'openid' &&
|
||||
isEnabled(process.env.OPENID_REUSE_TOKENS) === true
|
||||
) {
|
||||
await syncUserEntraGroupMemberships(req.user, req.user.tokenset.access_token);
|
||||
setOpenIDAuthTokens(req.user.tokenset, res);
|
||||
} else {
|
||||
await setAuthTokens(req.user._id, res);
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
jest.mock('~/models', () => ({
|
||||
initializeRoles: jest.fn(),
|
||||
seedDefaultRoles: jest.fn(),
|
||||
}));
|
||||
jest.mock('~/models/Role', () => ({
|
||||
updateAccessPermissions: jest.fn(),
|
||||
|
||||
@@ -17,6 +17,7 @@ const {
|
||||
const { azureAssistantsDefaults, assistantsConfigSetup } = require('./start/assistants');
|
||||
const { initializeAzureBlobService } = require('./Files/Azure/initialize');
|
||||
const { initializeFirebase } = require('./Files/Firebase/initialize');
|
||||
const { seedDefaultRoles, initializeRoles } = require('~/models');
|
||||
const loadCustomConfig = require('./Config/loadCustomConfig');
|
||||
const handleRateLimits = require('./Config/handleRateLimits');
|
||||
const { loadDefaultInterface } = require('./start/interface');
|
||||
@@ -26,7 +27,6 @@ const { processModelSpecs } = require('./start/modelSpecs');
|
||||
const { initializeS3 } = require('./Files/S3/initialize');
|
||||
const { loadAndFormatTools } = require('./ToolService');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { initializeRoles } = require('~/models');
|
||||
const { setCachedTools } = require('./Config');
|
||||
const paths = require('~/config/paths');
|
||||
|
||||
@@ -37,6 +37,7 @@ const paths = require('~/config/paths');
|
||||
*/
|
||||
const AppService = async (app) => {
|
||||
await initializeRoles();
|
||||
await seedDefaultRoles();
|
||||
/** @type {TCustomConfig} */
|
||||
const config = (await loadCustomConfig()) ?? {};
|
||||
const configDefaults = getConfigDefaults();
|
||||
|
||||
@@ -28,6 +28,7 @@ jest.mock('./Files/Firebase/initialize', () => ({
|
||||
}));
|
||||
jest.mock('~/models', () => ({
|
||||
initializeRoles: jest.fn(),
|
||||
seedDefaultRoles: jest.fn(),
|
||||
}));
|
||||
jest.mock('~/models/Role', () => ({
|
||||
updateAccessPermissions: jest.fn(),
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
const { Providers } = require('@librechat/agents');
|
||||
const { primeResources, optionalChainWithEmptyCheck } = require('@librechat/api');
|
||||
const {
|
||||
primeResources,
|
||||
extractLibreChatParams,
|
||||
optionalChainWithEmptyCheck,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
ErrorTypes,
|
||||
EModelEndpoint,
|
||||
@@ -15,10 +19,9 @@ const initGoogle = require('~/server/services/Endpoints/google/initialize');
|
||||
const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts');
|
||||
const { getCustomEndpointConfig } = require('~/server/services/Config');
|
||||
const { processFiles } = require('~/server/services/Files/process');
|
||||
const { getFiles, getToolFilesByIds } = require('~/models/File');
|
||||
const { getConvoFiles } = require('~/models/Conversation');
|
||||
const { getToolFilesByIds } = require('~/models/File');
|
||||
const { getModelMaxTokens } = require('~/utils');
|
||||
const { getFiles } = require('~/models/File');
|
||||
|
||||
const providerConfigMap = {
|
||||
[Providers.XAI]: initCustom,
|
||||
@@ -71,7 +74,7 @@ const initializeAgent = async ({
|
||||
),
|
||||
);
|
||||
|
||||
const { resendFiles = true, ...modelOptions } = _modelOptions;
|
||||
const { resendFiles, maxContextTokens, modelOptions } = extractLibreChatParams(_modelOptions);
|
||||
|
||||
if (isInitialAgent && conversationId != null && resendFiles) {
|
||||
const fileIds = (await getConvoFiles(conversationId)) ?? [];
|
||||
@@ -145,9 +148,8 @@ const initializeAgent = async ({
|
||||
modelOptions.maxTokens,
|
||||
0,
|
||||
);
|
||||
const maxContextTokens = optionalChainWithEmptyCheck(
|
||||
modelOptions.maxContextTokens,
|
||||
modelOptions.max_context_tokens,
|
||||
const agentMaxContextTokens = optionalChainWithEmptyCheck(
|
||||
maxContextTokens,
|
||||
getModelMaxTokens(tokensModel, providerEndpointMap[provider]),
|
||||
4096,
|
||||
);
|
||||
@@ -189,7 +191,7 @@ const initializeAgent = async ({
|
||||
attachments,
|
||||
resendFiles,
|
||||
toolContextMap,
|
||||
maxContextTokens: (maxContextTokens - maxTokens) * 0.9,
|
||||
maxContextTokens: (agentMaxContextTokens - maxTokens) * 0.9,
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
const { isAgentsEndpoint, Constants } = require('librechat-data-provider');
|
||||
const { isAgentsEndpoint, removeNullishValues, Constants } = require('librechat-data-provider');
|
||||
const { loadAgent } = require('~/models/Agent');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const buildOptions = (req, endpoint, parsedBody, endpointType) => {
|
||||
const { spec, iconURL, agent_id, instructions, maxContextTokens, ...model_parameters } =
|
||||
parsedBody;
|
||||
const { spec, iconURL, agent_id, instructions, ...model_parameters } = parsedBody;
|
||||
const agentPromise = loadAgent({
|
||||
req,
|
||||
agent_id: isAgentsEndpoint(endpoint) ? agent_id : Constants.EPHEMERAL_AGENT_ID,
|
||||
@@ -15,19 +14,16 @@ const buildOptions = (req, endpoint, parsedBody, endpointType) => {
|
||||
return undefined;
|
||||
});
|
||||
|
||||
const endpointOption = {
|
||||
return removeNullishValues({
|
||||
spec,
|
||||
iconURL,
|
||||
endpoint,
|
||||
agent_id,
|
||||
endpointType,
|
||||
instructions,
|
||||
maxContextTokens,
|
||||
model_parameters,
|
||||
agent: agentPromise,
|
||||
};
|
||||
|
||||
return endpointOption;
|
||||
});
|
||||
};
|
||||
|
||||
module.exports = { buildOptions };
|
||||
|
||||
@@ -1,11 +1,17 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { createContentAggregator } = require('@librechat/agents');
|
||||
const { Constants, EModelEndpoint, getResponseSender } = require('librechat-data-provider');
|
||||
const {
|
||||
getDefaultHandlers,
|
||||
Constants,
|
||||
EModelEndpoint,
|
||||
isAgentsEndpoint,
|
||||
getResponseSender,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
createToolEndCallback,
|
||||
getDefaultHandlers,
|
||||
} = require('~/server/controllers/agents/callbacks');
|
||||
const { initializeAgent } = require('~/server/services/Endpoints/agents/agent');
|
||||
const { getCustomEndpointConfig } = require('~/server/services/Config');
|
||||
const { loadAgentTools } = require('~/server/services/ToolService');
|
||||
const AgentClient = require('~/server/controllers/agents/client');
|
||||
const { getAgent } = require('~/models/Agent');
|
||||
@@ -61,6 +67,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
||||
}
|
||||
|
||||
const primaryAgent = await endpointOption.agent;
|
||||
delete endpointOption.agent;
|
||||
if (!primaryAgent) {
|
||||
throw new Error('Agent not found');
|
||||
}
|
||||
@@ -108,11 +115,25 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
||||
}
|
||||
}
|
||||
|
||||
let endpointConfig = req.app.locals[primaryConfig.endpoint];
|
||||
if (!isAgentsEndpoint(primaryConfig.endpoint) && !endpointConfig) {
|
||||
try {
|
||||
endpointConfig = await getCustomEndpointConfig(primaryConfig.endpoint);
|
||||
} catch (err) {
|
||||
logger.error(
|
||||
'[api/server/controllers/agents/client.js #titleConvo] Error getting custom endpoint config',
|
||||
err,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const sender =
|
||||
primaryAgent.name ??
|
||||
getResponseSender({
|
||||
...endpointOption,
|
||||
model: endpointOption.model_parameters.model,
|
||||
modelDisplayLabel: endpointConfig?.modelDisplayLabel,
|
||||
modelLabel: endpointOption.model_parameters.modelLabel,
|
||||
});
|
||||
|
||||
const client = new AgentClient({
|
||||
|
||||
525
api/server/services/GraphApiService.js
Normal file
525
api/server/services/GraphApiService.js
Normal file
@@ -0,0 +1,525 @@
|
||||
const client = require('openid-client');
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const { Client } = require('@microsoft/microsoft-graph-client');
|
||||
const { getOpenIdConfig } = require('~/strategies/openidStrategy');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
|
||||
/**
|
||||
* @import { TPrincipalSearchResult, TGraphPerson, TGraphUser, TGraphGroup, TGraphPeopleResponse, TGraphUsersResponse, TGraphGroupsResponse } from 'librechat-data-provider'
|
||||
*/
|
||||
|
||||
/**
|
||||
* Checks if Entra ID principal search feature is enabled based on environment variables and user authentication
|
||||
* @param {Object} user - User object from request
|
||||
* @param {string} user.provider - Authentication provider
|
||||
* @param {string} user.openidId - OpenID subject identifier
|
||||
* @returns {boolean} True if Entra ID principal search is enabled and user is authenticated via OpenID
|
||||
*/
|
||||
const entraIdPrincipalFeatureEnabled = (user) => {
|
||||
return (
|
||||
isEnabled(process.env.USE_ENTRA_ID_FOR_PEOPLE_SEARCH) &&
|
||||
isEnabled(process.env.OPENID_REUSE_TOKENS) &&
|
||||
user?.provider === 'openid' &&
|
||||
user?.openidId
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* Creates a Microsoft Graph client with on-behalf-of token exchange
|
||||
* @param {string} accessToken - OpenID Connect access token from user
|
||||
* @param {string} sub - Subject identifier from token claims
|
||||
* @returns {Promise<Client>} Authenticated Graph API client
|
||||
*/
|
||||
const createGraphClient = async (accessToken, sub) => {
|
||||
try {
|
||||
// Reason: Use existing OpenID configuration and token exchange pattern from openidStrategy.js
|
||||
const openidConfig = getOpenIdConfig();
|
||||
const exchangedToken = await exchangeTokenForGraphAccess(openidConfig, accessToken, sub);
|
||||
|
||||
const graphClient = Client.init({
|
||||
authProvider: (done) => {
|
||||
done(null, exchangedToken);
|
||||
},
|
||||
});
|
||||
|
||||
return graphClient;
|
||||
} catch (error) {
|
||||
logger.error('[createGraphClient] Error creating Graph client:', error);
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Exchange OpenID token for Graph API access using on-behalf-of flow
|
||||
* Similar to exchangeAccessTokenIfNeeded in openidStrategy.js but for Graph scopes
|
||||
* @param {Configuration} config - OpenID configuration
|
||||
* @param {string} accessToken - Original access token
|
||||
* @param {string} sub - Subject identifier
|
||||
* @returns {Promise<string>} Graph API access token
|
||||
*/
|
||||
const exchangeTokenForGraphAccess = async (config, accessToken, sub) => {
|
||||
try {
|
||||
const tokensCache = getLogStores(CacheKeys.OPENID_EXCHANGED_TOKENS);
|
||||
const cacheKey = `${sub}:graph`;
|
||||
|
||||
const cachedToken = await tokensCache.get(cacheKey);
|
||||
if (cachedToken) {
|
||||
return cachedToken.access_token;
|
||||
}
|
||||
|
||||
const graphScopes = process.env.OPENID_GRAPH_SCOPES || 'User.Read,People.Read,Group.Read.All';
|
||||
const scopeString = graphScopes
|
||||
.split(',')
|
||||
.map((scope) => `https://graph.microsoft.com/${scope}`)
|
||||
.join(' ');
|
||||
|
||||
const grantResponse = await client.genericGrantRequest(
|
||||
config,
|
||||
'urn:ietf:params:oauth:grant-type:jwt-bearer',
|
||||
{
|
||||
scope: scopeString,
|
||||
assertion: accessToken,
|
||||
requested_token_use: 'on_behalf_of',
|
||||
},
|
||||
);
|
||||
|
||||
await tokensCache.set(
|
||||
cacheKey,
|
||||
{
|
||||
access_token: grantResponse.access_token,
|
||||
},
|
||||
grantResponse.expires_in * 1000,
|
||||
);
|
||||
|
||||
return grantResponse.access_token;
|
||||
} catch (error) {
|
||||
logger.error('[exchangeTokenForGraphAccess] Token exchange failed:', error);
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Search for principals (people and groups) using Microsoft Graph API
|
||||
* Uses searchContacts first, then searchUsers and searchGroups to fill remaining slots
|
||||
* @param {string} accessToken - OpenID Connect access token
|
||||
* @param {string} sub - Subject identifier
|
||||
* @param {string} query - Search query string
|
||||
* @param {string} type - Type filter ('users', 'groups', or 'all')
|
||||
* @param {number} limit - Maximum number of results
|
||||
* @returns {Promise<TPrincipalSearchResult[]>} Array of principal search results
|
||||
*/
|
||||
const searchEntraIdPrincipals = async (accessToken, sub, query, type = 'all', limit = 10) => {
|
||||
try {
|
||||
if (!query || query.trim().length < 2) {
|
||||
return [];
|
||||
}
|
||||
const graphClient = await createGraphClient(accessToken, sub);
|
||||
let allResults = [];
|
||||
|
||||
if (type === 'users' || type === 'all') {
|
||||
const contactResults = await searchContacts(graphClient, query, limit);
|
||||
allResults.push(...contactResults);
|
||||
}
|
||||
if (allResults.length >= limit) {
|
||||
return allResults.slice(0, limit);
|
||||
}
|
||||
|
||||
if (type === 'users') {
|
||||
const userResults = await searchUsers(graphClient, query, limit);
|
||||
allResults.push(...userResults);
|
||||
} else if (type === 'groups') {
|
||||
const groupResults = await searchGroups(graphClient, query, limit);
|
||||
allResults.push(...groupResults);
|
||||
} else if (type === 'all') {
|
||||
const [userResults, groupResults] = await Promise.all([
|
||||
searchUsers(graphClient, query, limit),
|
||||
searchGroups(graphClient, query, limit),
|
||||
]);
|
||||
|
||||
allResults.push(...userResults, ...groupResults);
|
||||
}
|
||||
|
||||
const seenIds = new Set();
|
||||
const uniqueResults = allResults.filter((result) => {
|
||||
if (seenIds.has(result.idOnTheSource)) {
|
||||
return false;
|
||||
}
|
||||
seenIds.add(result.idOnTheSource);
|
||||
return true;
|
||||
});
|
||||
|
||||
return uniqueResults.slice(0, limit);
|
||||
} catch (error) {
|
||||
logger.error('[searchEntraIdPrincipals] Error searching principals:', error);
|
||||
return [];
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Get current user's Entra ID group memberships from Microsoft Graph
|
||||
* Uses /me/memberOf endpoint to get groups the user is a member of
|
||||
* @param {string} accessToken - OpenID Connect access token
|
||||
* @param {string} sub - Subject identifier
|
||||
* @returns {Promise<Array<string>>} Array of group ID strings (GUIDs)
|
||||
*/
|
||||
const getUserEntraGroups = async (accessToken, sub) => {
|
||||
try {
|
||||
const graphClient = await createGraphClient(accessToken, sub);
|
||||
|
||||
const groupsResponse = await graphClient.api('/me/memberOf').select('id').get();
|
||||
|
||||
return (groupsResponse.value || []).map((group) => group.id);
|
||||
} catch (error) {
|
||||
logger.error('[getUserEntraGroups] Error fetching user groups:', error);
|
||||
return [];
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Get current user's owned Entra ID groups from Microsoft Graph
|
||||
* Uses /me/ownedObjects/microsoft.graph.group endpoint to get groups the user owns
|
||||
* @param {string} accessToken - OpenID Connect access token
|
||||
* @param {string} sub - Subject identifier
|
||||
* @returns {Promise<Array<string>>} Array of group ID strings (GUIDs)
|
||||
*/
|
||||
const getUserOwnedEntraGroups = async (accessToken, sub) => {
|
||||
try {
|
||||
const graphClient = await createGraphClient(accessToken, sub);
|
||||
|
||||
const groupsResponse = await graphClient
|
||||
.api('/me/ownedObjects/microsoft.graph.group')
|
||||
.select('id')
|
||||
.get();
|
||||
|
||||
return (groupsResponse.value || []).map((group) => group.id);
|
||||
} catch (error) {
|
||||
logger.error('[getUserOwnedEntraGroups] Error fetching user owned groups:', error);
|
||||
return [];
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Get group members from Microsoft Graph API
|
||||
* Recursively fetches all members using pagination (@odata.nextLink)
|
||||
* @param {string} accessToken - OpenID Connect access token
|
||||
* @param {string} sub - Subject identifier
|
||||
* @param {string} groupId - Entra ID group object ID
|
||||
* @returns {Promise<Array>} Array of member IDs (idOnTheSource values)
|
||||
*/
|
||||
const getGroupMembers = async (accessToken, sub, groupId) => {
|
||||
try {
|
||||
const graphClient = await createGraphClient(accessToken, sub);
|
||||
const allMembers = [];
|
||||
let nextLink = `/groups/${groupId}/members`;
|
||||
|
||||
while (nextLink) {
|
||||
const membersResponse = await graphClient.api(nextLink).select('id').top(999).get();
|
||||
|
||||
const members = membersResponse.value || [];
|
||||
allMembers.push(...members.map((member) => member.id));
|
||||
|
||||
nextLink = membersResponse['@odata.nextLink']
|
||||
? membersResponse['@odata.nextLink'].split('/v1.0')[1]
|
||||
: null;
|
||||
}
|
||||
|
||||
return allMembers;
|
||||
} catch (error) {
|
||||
logger.error('[getGroupMembers] Error fetching group members:', error);
|
||||
return [];
|
||||
}
|
||||
};
|
||||
/**
|
||||
* Get group owners from Microsoft Graph API
|
||||
* Recursively fetches all owners using pagination (@odata.nextLink)
|
||||
* @param {string} accessToken - OpenID Connect access token
|
||||
* @param {string} sub - Subject identifier
|
||||
* @param {string} groupId - Entra ID group object ID
|
||||
* @returns {Promise<Array>} Array of owner IDs (idOnTheSource values)
|
||||
*/
|
||||
const getGroupOwners = async (accessToken, sub, groupId) => {
|
||||
try {
|
||||
const graphClient = await createGraphClient(accessToken, sub);
|
||||
const allOwners = [];
|
||||
let nextLink = `/groups/${groupId}/owners`;
|
||||
|
||||
while (nextLink) {
|
||||
const ownersResponse = await graphClient.api(nextLink).select('id').top(999).get();
|
||||
|
||||
const owners = ownersResponse.value || [];
|
||||
allOwners.push(...owners.map((member) => member.id));
|
||||
|
||||
nextLink = ownersResponse['@odata.nextLink']
|
||||
? ownersResponse['@odata.nextLink'].split('/v1.0')[1]
|
||||
: null;
|
||||
}
|
||||
|
||||
return allOwners;
|
||||
} catch (error) {
|
||||
logger.error('[getGroupOwners] Error fetching group owners:', error);
|
||||
return [];
|
||||
}
|
||||
};
|
||||
/**
|
||||
* Search for contacts (users only) using Microsoft Graph /me/people endpoint
|
||||
* Returns mapped TPrincipalSearchResult objects for users only
|
||||
* @param {Client} graphClient - Authenticated Microsoft Graph client
|
||||
* @param {string} query - Search query string
|
||||
* @param {number} limit - Maximum number of results (default: 10)
|
||||
* @returns {Promise<TPrincipalSearchResult[]>} Array of mapped user contact results
|
||||
*/
|
||||
const searchContacts = async (graphClient, query, limit = 10) => {
|
||||
try {
|
||||
if (!query || query.trim().length < 2) {
|
||||
return [];
|
||||
}
|
||||
if (
|
||||
process.env.OPENID_GRAPH_SCOPES &&
|
||||
!process.env.OPENID_GRAPH_SCOPES.toLowerCase().includes('people.read')
|
||||
) {
|
||||
logger.warn('[searchContacts] People.Read scope is not enabled, skipping contact search');
|
||||
return [];
|
||||
}
|
||||
// Reason: Search only for OrganizationUser (person) type, not groups
|
||||
const filter = "personType/subclass eq 'OrganizationUser'";
|
||||
|
||||
let apiCall = graphClient
|
||||
.api('/me/people')
|
||||
.search(`"${query}"`)
|
||||
.select(
|
||||
'id,displayName,givenName,surname,userPrincipalName,jobTitle,department,companyName,scoredEmailAddresses,personType,phones',
|
||||
)
|
||||
.header('ConsistencyLevel', 'eventual')
|
||||
.filter(filter)
|
||||
.top(limit);
|
||||
|
||||
const contactsResponse = await apiCall.get();
|
||||
return (contactsResponse.value || []).map(mapContactToTPrincipalSearchResult);
|
||||
} catch (error) {
|
||||
logger.error('[searchContacts] Error searching contacts:', error);
|
||||
return [];
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Search for users using Microsoft Graph /users endpoint
|
||||
* Returns mapped TPrincipalSearchResult objects
|
||||
* @param {Client} graphClient - Authenticated Microsoft Graph client
|
||||
* @param {string} query - Search query string
|
||||
* @param {number} limit - Maximum number of results (default: 10)
|
||||
* @returns {Promise<TPrincipalSearchResult[]>} Array of mapped user results
|
||||
*/
|
||||
const searchUsers = async (graphClient, query, limit = 10) => {
|
||||
try {
|
||||
if (!query || query.trim().length < 2) {
|
||||
return [];
|
||||
}
|
||||
|
||||
// Reason: Search users by display name, email, and user principal name
|
||||
const usersResponse = await graphClient
|
||||
.api('/users')
|
||||
.search(
|
||||
`"displayName:${query}" OR "userPrincipalName:${query}" OR "mail:${query}" OR "givenName:${query}" OR "surname:${query}"`,
|
||||
)
|
||||
.select(
|
||||
'id,displayName,givenName,surname,userPrincipalName,jobTitle,department,companyName,mail,phones',
|
||||
)
|
||||
.header('ConsistencyLevel', 'eventual')
|
||||
.top(limit)
|
||||
.get();
|
||||
|
||||
return (usersResponse.value || []).map(mapUserToTPrincipalSearchResult);
|
||||
} catch (error) {
|
||||
logger.error('[searchUsers] Error searching users:', error);
|
||||
return [];
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Search for groups using Microsoft Graph /groups endpoint
|
||||
* Returns mapped TPrincipalSearchResult objects, includes all group types
|
||||
* @param {Client} graphClient - Authenticated Microsoft Graph client
|
||||
* @param {string} query - Search query string
|
||||
* @param {number} limit - Maximum number of results (default: 10)
|
||||
* @returns {Promise<TPrincipalSearchResult[]>} Array of mapped group results
|
||||
*/
|
||||
const searchGroups = async (graphClient, query, limit = 10) => {
|
||||
try {
|
||||
if (!query || query.trim().length < 2) {
|
||||
return [];
|
||||
}
|
||||
|
||||
// Reason: Search all groups by display name and email without filtering group types
|
||||
const groupsResponse = await graphClient
|
||||
.api('/groups')
|
||||
.search(`"displayName:${query}" OR "mail:${query}" OR "mailNickname:${query}"`)
|
||||
.select('id,displayName,mail,mailNickname,description,groupTypes,resourceProvisioningOptions')
|
||||
.header('ConsistencyLevel', 'eventual')
|
||||
.top(limit)
|
||||
.get();
|
||||
|
||||
return (groupsResponse.value || []).map(mapGroupToTPrincipalSearchResult);
|
||||
} catch (error) {
|
||||
logger.error('[searchGroups] Error searching groups:', error);
|
||||
return [];
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Test Graph API connectivity and permissions
|
||||
* @param {string} accessToken - OpenID Connect access token
|
||||
* @param {string} sub - Subject identifier
|
||||
* @returns {Promise<Object>} Test results with available permissions
|
||||
*/
|
||||
const testGraphApiAccess = async (accessToken, sub) => {
|
||||
try {
|
||||
const graphClient = await createGraphClient(accessToken, sub);
|
||||
const results = {
|
||||
userAccess: false,
|
||||
peopleAccess: false,
|
||||
groupsAccess: false,
|
||||
usersEndpointAccess: false,
|
||||
groupsEndpointAccess: false,
|
||||
errors: [],
|
||||
};
|
||||
|
||||
// Test User.Read permission
|
||||
try {
|
||||
await graphClient.api('/me').select('id,displayName').get();
|
||||
results.userAccess = true;
|
||||
} catch (error) {
|
||||
results.errors.push(`User.Read: ${error.message}`);
|
||||
}
|
||||
|
||||
// Test People.Read permission with OrganizationUser filter
|
||||
try {
|
||||
await graphClient
|
||||
.api('/me/people')
|
||||
.filter("personType/subclass eq 'OrganizationUser'")
|
||||
.top(1)
|
||||
.get();
|
||||
results.peopleAccess = true;
|
||||
} catch (error) {
|
||||
results.errors.push(`People.Read (OrganizationUser): ${error.message}`);
|
||||
}
|
||||
|
||||
// Test People.Read permission with UnifiedGroup filter
|
||||
try {
|
||||
await graphClient
|
||||
.api('/me/people')
|
||||
.filter("personType/subclass eq 'UnifiedGroup'")
|
||||
.top(1)
|
||||
.get();
|
||||
results.groupsAccess = true;
|
||||
} catch (error) {
|
||||
results.errors.push(`People.Read (UnifiedGroup): ${error.message}`);
|
||||
}
|
||||
|
||||
// Test /users endpoint access (requires User.Read.All or similar)
|
||||
try {
|
||||
await graphClient
|
||||
.api('/users')
|
||||
.search('"displayName:test"')
|
||||
.select('id,displayName,userPrincipalName')
|
||||
.top(1)
|
||||
.get();
|
||||
results.usersEndpointAccess = true;
|
||||
} catch (error) {
|
||||
results.errors.push(`Users endpoint: ${error.message}`);
|
||||
}
|
||||
|
||||
// Test /groups endpoint access (requires Group.Read.All or similar)
|
||||
try {
|
||||
await graphClient
|
||||
.api('/groups')
|
||||
.search('"displayName:test"')
|
||||
.select('id,displayName,mail')
|
||||
.top(1)
|
||||
.get();
|
||||
results.groupsEndpointAccess = true;
|
||||
} catch (error) {
|
||||
results.errors.push(`Groups endpoint: ${error.message}`);
|
||||
}
|
||||
|
||||
return results;
|
||||
} catch (error) {
|
||||
logger.error('[testGraphApiAccess] Error testing Graph API access:', error);
|
||||
return {
|
||||
userAccess: false,
|
||||
peopleAccess: false,
|
||||
groupsAccess: false,
|
||||
usersEndpointAccess: false,
|
||||
groupsEndpointAccess: false,
|
||||
errors: [error.message],
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Map Graph API user object to TPrincipalSearchResult format
|
||||
* @param {TGraphUser} user - Raw user object from Graph API
|
||||
* @returns {TPrincipalSearchResult} Mapped user result
|
||||
*/
|
||||
const mapUserToTPrincipalSearchResult = (user) => {
|
||||
return {
|
||||
id: null,
|
||||
type: 'user',
|
||||
name: user.displayName,
|
||||
email: user.mail || user.userPrincipalName,
|
||||
username: user.userPrincipalName,
|
||||
source: 'entra',
|
||||
idOnTheSource: user.id,
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* Map Graph API group object to TPrincipalSearchResult format
|
||||
* @param {TGraphGroup} group - Raw group object from Graph API
|
||||
* @returns {TPrincipalSearchResult} Mapped group result
|
||||
*/
|
||||
const mapGroupToTPrincipalSearchResult = (group) => {
|
||||
return {
|
||||
id: null,
|
||||
type: 'group',
|
||||
name: group.displayName,
|
||||
email: group.mail || group.userPrincipalName,
|
||||
description: group.description,
|
||||
source: 'entra',
|
||||
idOnTheSource: group.id,
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* Map Graph API /me/people contact object to TPrincipalSearchResult format
|
||||
* Handles both user and group contacts from the people endpoint
|
||||
* @param {TGraphPerson} contact - Raw contact object from Graph API /me/people
|
||||
* @returns {TPrincipalSearchResult} Mapped contact result
|
||||
*/
|
||||
const mapContactToTPrincipalSearchResult = (contact) => {
|
||||
const isGroup = contact.personType?.class === 'Group';
|
||||
const primaryEmail = contact.scoredEmailAddresses?.[0]?.address;
|
||||
|
||||
return {
|
||||
id: null,
|
||||
type: isGroup ? 'group' : 'user',
|
||||
name: contact.displayName,
|
||||
email: primaryEmail,
|
||||
username: !isGroup ? contact.userPrincipalName : undefined,
|
||||
source: 'entra',
|
||||
idOnTheSource: contact.id,
|
||||
};
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
getGroupMembers,
|
||||
getGroupOwners,
|
||||
createGraphClient,
|
||||
getUserEntraGroups,
|
||||
getUserOwnedEntraGroups,
|
||||
testGraphApiAccess,
|
||||
searchEntraIdPrincipals,
|
||||
exchangeTokenForGraphAccess,
|
||||
entraIdPrincipalFeatureEnabled,
|
||||
};
|
||||
720
api/server/services/GraphApiService.spec.js
Normal file
720
api/server/services/GraphApiService.spec.js
Normal file
@@ -0,0 +1,720 @@
|
||||
jest.mock('@microsoft/microsoft-graph-client');
|
||||
jest.mock('~/strategies/openidStrategy');
|
||||
jest.mock('~/cache/getLogStores');
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
...jest.requireActual('@librechat/data-schemas'),
|
||||
logger: {
|
||||
error: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
},
|
||||
}));
|
||||
jest.mock('~/config', () => ({
|
||||
logger: {
|
||||
error: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
},
|
||||
createAxiosInstance: jest.fn(() => ({
|
||||
create: jest.fn(),
|
||||
defaults: {},
|
||||
})),
|
||||
}));
|
||||
jest.mock('~/utils', () => ({
|
||||
logAxiosError: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Config', () => ({}));
|
||||
jest.mock('~/server/services/Files/strategies', () => ({
|
||||
getStrategyFunctions: jest.fn(),
|
||||
}));
|
||||
|
||||
const mongoose = require('mongoose');
|
||||
const client = require('openid-client');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { Client } = require('@microsoft/microsoft-graph-client');
|
||||
const { getOpenIdConfig } = require('~/strategies/openidStrategy');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const GraphApiService = require('./GraphApiService');
|
||||
|
||||
describe('GraphApiService', () => {
|
||||
let mongoServer;
|
||||
let mockGraphClient;
|
||||
let mockTokensCache;
|
||||
let mockOpenIdConfig;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
await mongoose.connect(mongoUri);
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
// Clean up environment variables
|
||||
delete process.env.OPENID_GRAPH_SCOPES;
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
jest.clearAllMocks();
|
||||
await mongoose.connection.dropDatabase();
|
||||
|
||||
// Set up environment variable for People.Read scope
|
||||
process.env.OPENID_GRAPH_SCOPES = 'User.Read,People.Read,Group.Read.All';
|
||||
|
||||
// Mock Graph client
|
||||
mockGraphClient = {
|
||||
api: jest.fn().mockReturnThis(),
|
||||
search: jest.fn().mockReturnThis(),
|
||||
filter: jest.fn().mockReturnThis(),
|
||||
select: jest.fn().mockReturnThis(),
|
||||
header: jest.fn().mockReturnThis(),
|
||||
top: jest.fn().mockReturnThis(),
|
||||
get: jest.fn(),
|
||||
};
|
||||
|
||||
Client.init.mockReturnValue(mockGraphClient);
|
||||
|
||||
// Mock tokens cache
|
||||
mockTokensCache = {
|
||||
get: jest.fn(),
|
||||
set: jest.fn(),
|
||||
};
|
||||
getLogStores.mockReturnValue(mockTokensCache);
|
||||
|
||||
// Mock OpenID config
|
||||
mockOpenIdConfig = {
|
||||
client_id: 'test-client-id',
|
||||
issuer: 'https://test-issuer.com',
|
||||
};
|
||||
getOpenIdConfig.mockReturnValue(mockOpenIdConfig);
|
||||
|
||||
// Mock openid-client (using the existing jest mock configuration)
|
||||
if (client.genericGrantRequest) {
|
||||
client.genericGrantRequest.mockResolvedValue({
|
||||
access_token: 'mocked-graph-token',
|
||||
expires_in: 3600,
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
describe('Dependency Contract Tests', () => {
|
||||
it('should fail if getOpenIdConfig interface changes', () => {
|
||||
// Reason: Ensure getOpenIdConfig returns expected structure
|
||||
const config = getOpenIdConfig();
|
||||
|
||||
expect(config).toBeDefined();
|
||||
expect(typeof config).toBe('object');
|
||||
// Add specific property checks that GraphApiService depends on
|
||||
expect(config).toHaveProperty('client_id');
|
||||
expect(config).toHaveProperty('issuer');
|
||||
|
||||
// Ensure the function is callable
|
||||
expect(typeof getOpenIdConfig).toBe('function');
|
||||
});
|
||||
|
||||
it('should fail if openid-client.genericGrantRequest interface changes', () => {
|
||||
// Reason: Ensure client.genericGrantRequest maintains expected signature
|
||||
if (client.genericGrantRequest) {
|
||||
expect(typeof client.genericGrantRequest).toBe('function');
|
||||
|
||||
// Test that it accepts the expected parameters
|
||||
const mockCall = client.genericGrantRequest(
|
||||
mockOpenIdConfig,
|
||||
'urn:ietf:params:oauth:grant-type:jwt-bearer',
|
||||
{
|
||||
scope: 'test-scope',
|
||||
assertion: 'test-token',
|
||||
requested_token_use: 'on_behalf_of',
|
||||
},
|
||||
);
|
||||
|
||||
expect(mockCall).toBeDefined();
|
||||
}
|
||||
});
|
||||
|
||||
it('should fail if Microsoft Graph Client interface changes', () => {
|
||||
// Reason: Ensure Graph Client maintains expected fluent API
|
||||
expect(typeof Client.init).toBe('function');
|
||||
|
||||
const client = Client.init({ authProvider: jest.fn() });
|
||||
expect(client).toHaveProperty('api');
|
||||
expect(typeof client.api).toBe('function');
|
||||
});
|
||||
});
|
||||
|
||||
describe('createGraphClient', () => {
|
||||
it('should create graph client with exchanged token', async () => {
|
||||
const accessToken = 'test-access-token';
|
||||
const sub = 'test-user-id';
|
||||
|
||||
const result = await GraphApiService.createGraphClient(accessToken, sub);
|
||||
|
||||
expect(getOpenIdConfig).toHaveBeenCalled();
|
||||
expect(Client.init).toHaveBeenCalledWith({
|
||||
authProvider: expect.any(Function),
|
||||
});
|
||||
expect(result).toBe(mockGraphClient);
|
||||
});
|
||||
|
||||
it('should handle token exchange errors gracefully', async () => {
|
||||
if (client.genericGrantRequest) {
|
||||
client.genericGrantRequest.mockRejectedValue(new Error('Token exchange failed'));
|
||||
}
|
||||
|
||||
await expect(GraphApiService.createGraphClient('invalid-token', 'test-user')).rejects.toThrow(
|
||||
'Token exchange failed',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('exchangeTokenForGraphAccess', () => {
|
||||
it('should return cached token if available', async () => {
|
||||
const cachedToken = { access_token: 'cached-token' };
|
||||
mockTokensCache.get.mockResolvedValue(cachedToken);
|
||||
|
||||
const result = await GraphApiService.exchangeTokenForGraphAccess(
|
||||
mockOpenIdConfig,
|
||||
'test-token',
|
||||
'test-user',
|
||||
);
|
||||
|
||||
expect(result).toBe('cached-token');
|
||||
expect(mockTokensCache.get).toHaveBeenCalledWith('test-user:graph');
|
||||
if (client.genericGrantRequest) {
|
||||
expect(client.genericGrantRequest).not.toHaveBeenCalled();
|
||||
}
|
||||
});
|
||||
|
||||
it('should exchange token and cache result', async () => {
|
||||
mockTokensCache.get.mockResolvedValue(null);
|
||||
|
||||
const result = await GraphApiService.exchangeTokenForGraphAccess(
|
||||
mockOpenIdConfig,
|
||||
'test-token',
|
||||
'test-user',
|
||||
);
|
||||
|
||||
if (client.genericGrantRequest) {
|
||||
expect(client.genericGrantRequest).toHaveBeenCalledWith(
|
||||
mockOpenIdConfig,
|
||||
'urn:ietf:params:oauth:grant-type:jwt-bearer',
|
||||
{
|
||||
scope:
|
||||
'https://graph.microsoft.com/User.Read https://graph.microsoft.com/People.Read https://graph.microsoft.com/Group.Read.All',
|
||||
assertion: 'test-token',
|
||||
requested_token_use: 'on_behalf_of',
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
expect(mockTokensCache.set).toHaveBeenCalledWith(
|
||||
'test-user:graph',
|
||||
{ access_token: 'mocked-graph-token' },
|
||||
3600000,
|
||||
);
|
||||
|
||||
expect(result).toBe('mocked-graph-token');
|
||||
});
|
||||
|
||||
it('should use custom scopes from environment', async () => {
|
||||
const originalEnv = process.env.OPENID_GRAPH_SCOPES;
|
||||
process.env.OPENID_GRAPH_SCOPES = 'Custom.Read,Custom.Write';
|
||||
|
||||
mockTokensCache.get.mockResolvedValue(null);
|
||||
|
||||
await GraphApiService.exchangeTokenForGraphAccess(
|
||||
mockOpenIdConfig,
|
||||
'test-token',
|
||||
'test-user',
|
||||
);
|
||||
|
||||
if (client.genericGrantRequest) {
|
||||
expect(client.genericGrantRequest).toHaveBeenCalledWith(
|
||||
mockOpenIdConfig,
|
||||
'urn:ietf:params:oauth:grant-type:jwt-bearer',
|
||||
{
|
||||
scope:
|
||||
'https://graph.microsoft.com/Custom.Read https://graph.microsoft.com/Custom.Write',
|
||||
assertion: 'test-token',
|
||||
requested_token_use: 'on_behalf_of',
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
process.env.OPENID_GRAPH_SCOPES = originalEnv;
|
||||
});
|
||||
});
|
||||
|
||||
describe('searchEntraIdPrincipals', () => {
|
||||
// Mock data used by multiple tests
|
||||
const mockContactsResponse = {
|
||||
value: [
|
||||
{
|
||||
id: 'contact-user-1',
|
||||
displayName: 'John Doe',
|
||||
userPrincipalName: 'john@company.com',
|
||||
mail: 'john@company.com',
|
||||
personType: { class: 'Person', subclass: 'OrganizationUser' },
|
||||
scoredEmailAddresses: [{ address: 'john@company.com', relevanceScore: 0.9 }],
|
||||
},
|
||||
{
|
||||
id: 'contact-group-1',
|
||||
displayName: 'Marketing Team',
|
||||
mail: 'marketing@company.com',
|
||||
personType: { class: 'Group', subclass: 'UnifiedGroup' },
|
||||
scoredEmailAddresses: [{ address: 'marketing@company.com', relevanceScore: 0.8 }],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const mockUsersResponse = {
|
||||
value: [
|
||||
{
|
||||
id: 'dir-user-1',
|
||||
displayName: 'Jane Smith',
|
||||
userPrincipalName: 'jane@company.com',
|
||||
mail: 'jane@company.com',
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const mockGroupsResponse = {
|
||||
value: [
|
||||
{
|
||||
id: 'dir-group-1',
|
||||
displayName: 'Development Team',
|
||||
mail: 'dev@company.com',
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
// Reset mock call history for each test
|
||||
jest.clearAllMocks();
|
||||
|
||||
// Re-apply the Client.init mock after clearAllMocks
|
||||
Client.init.mockReturnValue(mockGraphClient);
|
||||
|
||||
// Re-apply openid-client mock
|
||||
if (client.genericGrantRequest) {
|
||||
client.genericGrantRequest.mockResolvedValue({
|
||||
access_token: 'mocked-graph-token',
|
||||
expires_in: 3600,
|
||||
});
|
||||
}
|
||||
|
||||
// Re-apply cache mock
|
||||
mockTokensCache.get.mockResolvedValue(null); // Force token exchange
|
||||
mockTokensCache.set.mockResolvedValue();
|
||||
getLogStores.mockReturnValue(mockTokensCache);
|
||||
getOpenIdConfig.mockReturnValue(mockOpenIdConfig);
|
||||
});
|
||||
|
||||
it('should return empty results for short queries', async () => {
|
||||
const result = await GraphApiService.searchEntraIdPrincipals('token', 'user', 'a', 'all', 10);
|
||||
|
||||
expect(result).toEqual([]);
|
||||
expect(mockGraphClient.api).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should search contacts first and additional users for users type', async () => {
|
||||
// Mock responses for this specific test
|
||||
const contactsFilteredResponse = {
|
||||
value: [
|
||||
{
|
||||
id: 'contact-user-1',
|
||||
displayName: 'John Doe',
|
||||
userPrincipalName: 'john@company.com',
|
||||
mail: 'john@company.com',
|
||||
personType: { class: 'Person', subclass: 'OrganizationUser' },
|
||||
scoredEmailAddresses: [{ address: 'john@company.com', relevanceScore: 0.9 }],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
mockGraphClient.get
|
||||
.mockResolvedValueOnce(contactsFilteredResponse) // contacts call
|
||||
.mockResolvedValueOnce(mockUsersResponse); // users call
|
||||
|
||||
const result = await GraphApiService.searchEntraIdPrincipals(
|
||||
'token',
|
||||
'user',
|
||||
'john',
|
||||
'users',
|
||||
10,
|
||||
);
|
||||
|
||||
// Should call contacts first with user filter
|
||||
expect(mockGraphClient.api).toHaveBeenCalledWith('/me/people');
|
||||
expect(mockGraphClient.search).toHaveBeenCalledWith('"john"');
|
||||
expect(mockGraphClient.filter).toHaveBeenCalledWith(
|
||||
"personType/subclass eq 'OrganizationUser'",
|
||||
);
|
||||
|
||||
// Should call users endpoint for additional results
|
||||
expect(mockGraphClient.api).toHaveBeenCalledWith('/users');
|
||||
expect(mockGraphClient.search).toHaveBeenCalledWith(
|
||||
'"displayName:john" OR "userPrincipalName:john" OR "mail:john" OR "givenName:john" OR "surname:john"',
|
||||
);
|
||||
|
||||
// Should return TPrincipalSearchResult array
|
||||
expect(Array.isArray(result)).toBe(true);
|
||||
expect(result).toHaveLength(2); // 1 from contacts + 1 from users
|
||||
expect(result[0]).toMatchObject({
|
||||
id: null,
|
||||
type: 'user',
|
||||
name: 'John Doe',
|
||||
email: 'john@company.com',
|
||||
source: 'entra',
|
||||
idOnTheSource: 'contact-user-1',
|
||||
});
|
||||
});
|
||||
|
||||
it('should search groups endpoint only for groups type', async () => {
|
||||
// Mock responses for this specific test - only groups endpoint called
|
||||
mockGraphClient.get.mockResolvedValueOnce(mockGroupsResponse); // only groups call
|
||||
|
||||
const result = await GraphApiService.searchEntraIdPrincipals(
|
||||
'token',
|
||||
'user',
|
||||
'team',
|
||||
'groups',
|
||||
10,
|
||||
);
|
||||
|
||||
// Should NOT call contacts for groups type
|
||||
expect(mockGraphClient.api).not.toHaveBeenCalledWith('/me/people');
|
||||
|
||||
// Should call groups endpoint only
|
||||
expect(mockGraphClient.api).toHaveBeenCalledWith('/groups');
|
||||
expect(mockGraphClient.search).toHaveBeenCalledWith(
|
||||
'"displayName:team" OR "mail:team" OR "mailNickname:team"',
|
||||
);
|
||||
|
||||
expect(Array.isArray(result)).toBe(true);
|
||||
expect(result).toHaveLength(1); // 1 from groups only
|
||||
});
|
||||
|
||||
it('should search all endpoints for all type', async () => {
|
||||
// Mock responses for this specific test
|
||||
mockGraphClient.get
|
||||
.mockResolvedValueOnce(mockContactsResponse) // contacts call (both user and group)
|
||||
.mockResolvedValueOnce(mockUsersResponse) // users call
|
||||
.mockResolvedValueOnce(mockGroupsResponse); // groups call
|
||||
|
||||
const result = await GraphApiService.searchEntraIdPrincipals(
|
||||
'token',
|
||||
'user',
|
||||
'test',
|
||||
'all',
|
||||
10,
|
||||
);
|
||||
|
||||
// Should call contacts first with user filter
|
||||
expect(mockGraphClient.api).toHaveBeenCalledWith('/me/people');
|
||||
expect(mockGraphClient.search).toHaveBeenCalledWith('"test"');
|
||||
expect(mockGraphClient.filter).toHaveBeenCalledWith(
|
||||
"personType/subclass eq 'OrganizationUser'",
|
||||
);
|
||||
|
||||
// Should call both users and groups endpoints
|
||||
expect(mockGraphClient.api).toHaveBeenCalledWith('/users');
|
||||
expect(mockGraphClient.api).toHaveBeenCalledWith('/groups');
|
||||
|
||||
expect(Array.isArray(result)).toBe(true);
|
||||
expect(result).toHaveLength(4); // 2 from contacts + 1 from users + 1 from groups
|
||||
});
|
||||
|
||||
it('should early exit if contacts reach limit', async () => {
|
||||
// Mock contacts to return exactly the limit
|
||||
const limitedContactsResponse = {
|
||||
value: Array(10).fill({
|
||||
id: 'contact-1',
|
||||
displayName: 'Contact User',
|
||||
mail: 'contact@company.com',
|
||||
personType: { class: 'Person', subclass: 'OrganizationUser' },
|
||||
}),
|
||||
};
|
||||
|
||||
mockGraphClient.get.mockResolvedValueOnce(limitedContactsResponse);
|
||||
|
||||
const result = await GraphApiService.searchEntraIdPrincipals(
|
||||
'token',
|
||||
'user',
|
||||
'test',
|
||||
'all',
|
||||
10,
|
||||
);
|
||||
|
||||
// Should call contacts first
|
||||
expect(mockGraphClient.api).toHaveBeenCalledWith('/me/people');
|
||||
expect(mockGraphClient.search).toHaveBeenCalledWith('"test"');
|
||||
// Should not call users endpoint since limit was reached
|
||||
expect(mockGraphClient.api).not.toHaveBeenCalledWith('/users');
|
||||
|
||||
expect(result).toHaveLength(10);
|
||||
});
|
||||
|
||||
it('should deduplicate results based on idOnTheSource', async () => {
|
||||
// Mock responses with duplicate IDs
|
||||
const duplicateContactsResponse = {
|
||||
value: [
|
||||
{
|
||||
id: 'duplicate-id',
|
||||
displayName: 'John Doe',
|
||||
mail: 'john@company.com',
|
||||
personType: { class: 'Person', subclass: 'OrganizationUser' },
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
const duplicateUsersResponse = {
|
||||
value: [
|
||||
{
|
||||
id: 'duplicate-id', // Same ID as contact
|
||||
displayName: 'John Doe',
|
||||
mail: 'john@company.com',
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
mockGraphClient.get
|
||||
.mockResolvedValueOnce(duplicateContactsResponse)
|
||||
.mockResolvedValueOnce(duplicateUsersResponse);
|
||||
|
||||
const result = await GraphApiService.searchEntraIdPrincipals(
|
||||
'token',
|
||||
'user',
|
||||
'john',
|
||||
'users',
|
||||
10,
|
||||
);
|
||||
|
||||
// Should only return one result despite duplicate IDs
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0].idOnTheSource).toBe('duplicate-id');
|
||||
});
|
||||
|
||||
it('should handle Graph API errors gracefully', async () => {
|
||||
mockGraphClient.get.mockRejectedValue(new Error('Graph API error'));
|
||||
|
||||
const result = await GraphApiService.searchEntraIdPrincipals(
|
||||
'token',
|
||||
'user',
|
||||
'test',
|
||||
'all',
|
||||
10,
|
||||
);
|
||||
|
||||
expect(result).toEqual([]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getUserEntraGroups', () => {
|
||||
it('should fetch user groups from memberOf endpoint', async () => {
|
||||
const mockGroupsResponse = {
|
||||
value: [
|
||||
{
|
||||
id: 'group-1',
|
||||
},
|
||||
{
|
||||
id: 'group-2',
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
mockGraphClient.get.mockResolvedValue(mockGroupsResponse);
|
||||
|
||||
const result = await GraphApiService.getUserEntraGroups('token', 'user');
|
||||
|
||||
expect(mockGraphClient.api).toHaveBeenCalledWith('/me/memberOf');
|
||||
expect(mockGraphClient.select).toHaveBeenCalledWith('id');
|
||||
|
||||
expect(result).toHaveLength(2);
|
||||
expect(result).toEqual(['group-1', 'group-2']);
|
||||
});
|
||||
|
||||
it('should return empty array on error', async () => {
|
||||
mockGraphClient.get.mockRejectedValue(new Error('API error'));
|
||||
|
||||
const result = await GraphApiService.getUserEntraGroups('token', 'user');
|
||||
|
||||
expect(result).toEqual([]);
|
||||
});
|
||||
|
||||
it('should handle empty response', async () => {
|
||||
const mockGroupsResponse = {
|
||||
value: [],
|
||||
};
|
||||
|
||||
mockGraphClient.get.mockResolvedValue(mockGroupsResponse);
|
||||
|
||||
const result = await GraphApiService.getUserEntraGroups('token', 'user');
|
||||
|
||||
expect(result).toEqual([]);
|
||||
});
|
||||
|
||||
it('should handle missing value property', async () => {
|
||||
mockGraphClient.get.mockResolvedValue({});
|
||||
|
||||
const result = await GraphApiService.getUserEntraGroups('token', 'user');
|
||||
|
||||
expect(result).toEqual([]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('testGraphApiAccess', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should test all permissions and return success results', async () => {
|
||||
// Mock successful responses for all tests
|
||||
mockGraphClient.get
|
||||
.mockResolvedValueOnce({ id: 'user-123', displayName: 'Test User' }) // /me test
|
||||
.mockResolvedValueOnce({ value: [] }) // people OrganizationUser test
|
||||
.mockResolvedValueOnce({ value: [] }) // people UnifiedGroup test
|
||||
.mockResolvedValueOnce({ value: [] }) // /users endpoint test
|
||||
.mockResolvedValueOnce({ value: [] }); // /groups endpoint test
|
||||
|
||||
const result = await GraphApiService.testGraphApiAccess('token', 'user');
|
||||
|
||||
expect(result).toEqual({
|
||||
userAccess: true,
|
||||
peopleAccess: true,
|
||||
groupsAccess: true,
|
||||
usersEndpointAccess: true,
|
||||
groupsEndpointAccess: true,
|
||||
errors: [],
|
||||
});
|
||||
|
||||
// Verify all endpoints were tested
|
||||
expect(mockGraphClient.api).toHaveBeenCalledWith('/me');
|
||||
expect(mockGraphClient.api).toHaveBeenCalledWith('/me/people');
|
||||
expect(mockGraphClient.api).toHaveBeenCalledWith('/users');
|
||||
expect(mockGraphClient.api).toHaveBeenCalledWith('/groups');
|
||||
expect(mockGraphClient.filter).toHaveBeenCalledWith(
|
||||
"personType/subclass eq 'OrganizationUser'",
|
||||
);
|
||||
expect(mockGraphClient.filter).toHaveBeenCalledWith("personType/subclass eq 'UnifiedGroup'");
|
||||
expect(mockGraphClient.search).toHaveBeenCalledWith('"displayName:test"');
|
||||
});
|
||||
|
||||
it('should handle partial failures and record errors', async () => {
|
||||
// Mock mixed success/failure responses
|
||||
mockGraphClient.get
|
||||
.mockResolvedValueOnce({ id: 'user-123', displayName: 'Test User' }) // /me success
|
||||
.mockRejectedValueOnce(new Error('People access denied')) // people OrganizationUser fail
|
||||
.mockResolvedValueOnce({ value: [] }) // people UnifiedGroup success
|
||||
.mockRejectedValueOnce(new Error('Users endpoint access denied')) // /users fail
|
||||
.mockResolvedValueOnce({ value: [] }); // /groups success
|
||||
|
||||
const result = await GraphApiService.testGraphApiAccess('token', 'user');
|
||||
|
||||
expect(result).toEqual({
|
||||
userAccess: true,
|
||||
peopleAccess: false,
|
||||
groupsAccess: true,
|
||||
usersEndpointAccess: false,
|
||||
groupsEndpointAccess: true,
|
||||
errors: [
|
||||
'People.Read (OrganizationUser): People access denied',
|
||||
'Users endpoint: Users endpoint access denied',
|
||||
],
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle complete Graph client creation failure', async () => {
|
||||
// Mock token exchange failure to test error handling
|
||||
if (client.genericGrantRequest) {
|
||||
client.genericGrantRequest.mockRejectedValue(new Error('Token exchange failed'));
|
||||
}
|
||||
|
||||
const result = await GraphApiService.testGraphApiAccess('invalid-token', 'user');
|
||||
|
||||
expect(result).toEqual({
|
||||
userAccess: false,
|
||||
peopleAccess: false,
|
||||
groupsAccess: false,
|
||||
usersEndpointAccess: false,
|
||||
groupsEndpointAccess: false,
|
||||
errors: ['Token exchange failed'],
|
||||
});
|
||||
});
|
||||
|
||||
it('should record all permission errors', async () => {
|
||||
// Mock all requests to fail
|
||||
mockGraphClient.get
|
||||
.mockRejectedValueOnce(new Error('User.Read denied'))
|
||||
.mockRejectedValueOnce(new Error('People.Read OrganizationUser denied'))
|
||||
.mockRejectedValueOnce(new Error('People.Read UnifiedGroup denied'))
|
||||
.mockRejectedValueOnce(new Error('Users directory access denied'))
|
||||
.mockRejectedValueOnce(new Error('Groups directory access denied'));
|
||||
|
||||
const result = await GraphApiService.testGraphApiAccess('token', 'user');
|
||||
|
||||
expect(result).toEqual({
|
||||
userAccess: false,
|
||||
peopleAccess: false,
|
||||
groupsAccess: false,
|
||||
usersEndpointAccess: false,
|
||||
groupsEndpointAccess: false,
|
||||
errors: [
|
||||
'User.Read: User.Read denied',
|
||||
'People.Read (OrganizationUser): People.Read OrganizationUser denied',
|
||||
'People.Read (UnifiedGroup): People.Read UnifiedGroup denied',
|
||||
'Users endpoint: Users directory access denied',
|
||||
'Groups endpoint: Groups directory access denied',
|
||||
],
|
||||
});
|
||||
});
|
||||
|
||||
it('should test new endpoints with correct search patterns', async () => {
|
||||
// Mock successful responses for endpoint testing
|
||||
mockGraphClient.get
|
||||
.mockResolvedValueOnce({ id: 'user-123', displayName: 'Test User' }) // /me
|
||||
.mockResolvedValueOnce({ value: [] }) // people OrganizationUser
|
||||
.mockResolvedValueOnce({ value: [] }) // people UnifiedGroup
|
||||
.mockResolvedValueOnce({ value: [] }) // /users
|
||||
.mockResolvedValueOnce({ value: [] }); // /groups
|
||||
|
||||
await GraphApiService.testGraphApiAccess('token', 'user');
|
||||
|
||||
// Verify /users endpoint test
|
||||
expect(mockGraphClient.api).toHaveBeenCalledWith('/users');
|
||||
expect(mockGraphClient.search).toHaveBeenCalledWith('"displayName:test"');
|
||||
expect(mockGraphClient.select).toHaveBeenCalledWith('id,displayName,userPrincipalName');
|
||||
|
||||
// Verify /groups endpoint test
|
||||
expect(mockGraphClient.api).toHaveBeenCalledWith('/groups');
|
||||
expect(mockGraphClient.select).toHaveBeenCalledWith('id,displayName,mail');
|
||||
});
|
||||
|
||||
it('should handle endpoint-specific permission failures', async () => {
|
||||
// Mock specific endpoint failures
|
||||
mockGraphClient.get
|
||||
.mockResolvedValueOnce({ id: 'user-123', displayName: 'Test User' }) // /me success
|
||||
.mockResolvedValueOnce({ value: [] }) // people OrganizationUser success
|
||||
.mockResolvedValueOnce({ value: [] }) // people UnifiedGroup success
|
||||
.mockRejectedValueOnce(new Error('Insufficient privileges')) // /users fail (User.Read.All needed)
|
||||
.mockRejectedValueOnce(new Error('Access denied to groups')); // /groups fail (Group.Read.All needed)
|
||||
|
||||
const result = await GraphApiService.testGraphApiAccess('token', 'user');
|
||||
|
||||
expect(result).toEqual({
|
||||
userAccess: true,
|
||||
peopleAccess: true,
|
||||
groupsAccess: true,
|
||||
usersEndpointAccess: false,
|
||||
groupsEndpointAccess: false,
|
||||
errors: [
|
||||
'Users endpoint: Insufficient privileges',
|
||||
'Groups endpoint: Access denied to groups',
|
||||
],
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
721
api/server/services/PermissionService.js
Normal file
721
api/server/services/PermissionService.js
Normal file
@@ -0,0 +1,721 @@
|
||||
const mongoose = require('mongoose');
|
||||
const { getTransactionSupport, logger } = require('@librechat/data-schemas');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const {
|
||||
entraIdPrincipalFeatureEnabled,
|
||||
getUserEntraGroups,
|
||||
getUserOwnedEntraGroups,
|
||||
getGroupMembers,
|
||||
getGroupOwners,
|
||||
} = require('~/server/services/GraphApiService');
|
||||
const {
|
||||
findGroupByExternalId,
|
||||
findRoleByIdentifier,
|
||||
getUserPrincipals,
|
||||
createGroup,
|
||||
createUser,
|
||||
updateUser,
|
||||
findUser,
|
||||
grantPermission: grantPermissionACL,
|
||||
findAccessibleResources: findAccessibleResourcesACL,
|
||||
hasPermission,
|
||||
getEffectivePermissions: getEffectivePermissionsACL,
|
||||
findEntriesByPrincipalsAndResource,
|
||||
} = require('~/models');
|
||||
const { AclEntry, AccessRole, Group } = require('~/db/models');
|
||||
|
||||
/** @type {boolean|null} */
|
||||
let transactionSupportCache = null;
|
||||
|
||||
/**
|
||||
* @import { TPrincipal } from 'librechat-data-provider'
|
||||
*/
|
||||
/**
|
||||
* Grant a permission to a principal for a resource using a role
|
||||
* @param {Object} params - Parameters for granting role-based permission
|
||||
* @param {string} params.principalType - 'user', 'group', or 'public'
|
||||
* @param {string|mongoose.Types.ObjectId|null} params.principalId - The ID of the principal (null for 'public')
|
||||
* @param {string} params.resourceType - Type of resource (e.g., 'agent')
|
||||
* @param {string|mongoose.Types.ObjectId} params.resourceId - The ID of the resource
|
||||
* @param {string} params.accessRoleId - The ID of the role (e.g., 'agent_viewer', 'agent_editor')
|
||||
* @param {string|mongoose.Types.ObjectId} params.grantedBy - User ID granting the permission
|
||||
* @param {mongoose.ClientSession} [params.session] - Optional MongoDB session for transactions
|
||||
* @returns {Promise<Object>} The created or updated ACL entry
|
||||
*/
|
||||
const grantPermission = async ({
|
||||
principalType,
|
||||
principalId,
|
||||
resourceType,
|
||||
resourceId,
|
||||
accessRoleId,
|
||||
grantedBy,
|
||||
session,
|
||||
}) => {
|
||||
try {
|
||||
if (!['user', 'group', 'public'].includes(principalType)) {
|
||||
throw new Error(`Invalid principal type: ${principalType}`);
|
||||
}
|
||||
|
||||
if (principalType !== 'public' && !principalId) {
|
||||
throw new Error('Principal ID is required for user and group principals');
|
||||
}
|
||||
|
||||
if (principalId && !mongoose.Types.ObjectId.isValid(principalId)) {
|
||||
throw new Error(`Invalid principal ID: ${principalId}`);
|
||||
}
|
||||
|
||||
if (!resourceId || !mongoose.Types.ObjectId.isValid(resourceId)) {
|
||||
throw new Error(`Invalid resource ID: ${resourceId}`);
|
||||
}
|
||||
|
||||
// Get the role to determine permission bits
|
||||
const role = await findRoleByIdentifier(accessRoleId);
|
||||
if (!role) {
|
||||
throw new Error(`Role ${accessRoleId} not found`);
|
||||
}
|
||||
|
||||
// Ensure the role is for the correct resource type
|
||||
if (role.resourceType !== resourceType) {
|
||||
throw new Error(
|
||||
`Role ${accessRoleId} is for ${role.resourceType} resources, not ${resourceType}`,
|
||||
);
|
||||
}
|
||||
return await grantPermissionACL(
|
||||
principalType,
|
||||
principalId,
|
||||
resourceType,
|
||||
resourceId,
|
||||
role.permBits,
|
||||
grantedBy,
|
||||
session,
|
||||
role._id,
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error(`[PermissionService.grantPermission] Error: ${error.message}`);
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Check if a user has specific permission bits on a resource
|
||||
* @param {Object} params - Parameters for checking permissions
|
||||
* @param {string|mongoose.Types.ObjectId} params.userId - The ID of the user
|
||||
* @param {string} params.resourceType - Type of resource (e.g., 'agent')
|
||||
* @param {string|mongoose.Types.ObjectId} params.resourceId - The ID of the resource
|
||||
* @param {number} params.requiredPermissions - The permission bits required (e.g., 1 for VIEW, 3 for VIEW+EDIT)
|
||||
* @returns {Promise<boolean>} Whether the user has the required permission bits
|
||||
*/
|
||||
const checkPermission = async ({ userId, resourceType, resourceId, requiredPermission }) => {
|
||||
try {
|
||||
if (typeof requiredPermission !== 'number' || requiredPermission < 1) {
|
||||
throw new Error('requiredPermission must be a positive number');
|
||||
}
|
||||
|
||||
// Get all principals for the user (user + groups + public)
|
||||
const principals = await getUserPrincipals(userId);
|
||||
|
||||
if (principals.length === 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return await hasPermission(principals, resourceType, resourceId, requiredPermission);
|
||||
} catch (error) {
|
||||
logger.error(`[PermissionService.checkPermission] Error: ${error.message}`);
|
||||
// Re-throw validation errors
|
||||
if (error.message.includes('requiredPermission must be')) {
|
||||
throw error;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Get effective permission bitmask for a user on a resource
|
||||
* @param {Object} params - Parameters for getting effective permissions
|
||||
* @param {string|mongoose.Types.ObjectId} params.userId - The ID of the user
|
||||
* @param {string} params.resourceType - Type of resource (e.g., 'agent')
|
||||
* @param {string|mongoose.Types.ObjectId} params.resourceId - The ID of the resource
|
||||
* @returns {Promise<number>} Effective permission bitmask
|
||||
*/
|
||||
const getEffectivePermissions = async ({ userId, resourceType, resourceId }) => {
|
||||
try {
|
||||
// Get all principals for the user (user + groups + public)
|
||||
const principals = await getUserPrincipals(userId);
|
||||
|
||||
if (principals.length === 0) {
|
||||
return 0;
|
||||
}
|
||||
return await getEffectivePermissionsACL(principals, resourceType, resourceId);
|
||||
} catch (error) {
|
||||
logger.error(`[PermissionService.getEffectivePermissions] Error: ${error.message}`);
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Find all resources of a specific type that a user has access to with specific permission bits
|
||||
* @param {Object} params - Parameters for finding accessible resources
|
||||
* @param {string|mongoose.Types.ObjectId} params.userId - The ID of the user
|
||||
* @param {string} params.resourceType - Type of resource (e.g., 'agent')
|
||||
* @param {number} params.requiredPermissions - The minimum permission bits required (e.g., 1 for VIEW, 3 for VIEW+EDIT)
|
||||
* @returns {Promise<Array>} Array of resource IDs
|
||||
*/
|
||||
const findAccessibleResources = async ({ userId, resourceType, requiredPermissions }) => {
|
||||
try {
|
||||
if (typeof requiredPermissions !== 'number' || requiredPermissions < 1) {
|
||||
throw new Error('requiredPermissions must be a positive number');
|
||||
}
|
||||
|
||||
// Get all principals for the user (user + groups + public)
|
||||
const principalsList = await getUserPrincipals(userId);
|
||||
|
||||
if (principalsList.length === 0) {
|
||||
return [];
|
||||
}
|
||||
return await findAccessibleResourcesACL(principalsList, resourceType, requiredPermissions);
|
||||
} catch (error) {
|
||||
logger.error(`[PermissionService.findAccessibleResources] Error: ${error.message}`);
|
||||
// Re-throw validation errors
|
||||
if (error.message.includes('requiredPermissions must be')) {
|
||||
throw error;
|
||||
}
|
||||
return [];
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Find all publicly accessible resources of a specific type
|
||||
* @param {Object} params - Parameters for finding publicly accessible resources
|
||||
* @param {string} params.resourceType - Type of resource (e.g., 'agent')
|
||||
* @param {number} params.requiredPermissions - The minimum permission bits required (e.g., 1 for VIEW, 3 for VIEW+EDIT)
|
||||
* @returns {Promise<Array>} Array of resource IDs
|
||||
*/
|
||||
const findPubliclyAccessibleResources = async ({ resourceType, requiredPermissions }) => {
|
||||
try {
|
||||
if (typeof requiredPermissions !== 'number' || requiredPermissions < 1) {
|
||||
throw new Error('requiredPermissions must be a positive number');
|
||||
}
|
||||
|
||||
// Find all public ACL entries where the public principal has at least the required permission bits
|
||||
const entries = await AclEntry.find({
|
||||
principalType: 'public',
|
||||
resourceType,
|
||||
permBits: { $bitsAllSet: requiredPermissions },
|
||||
}).distinct('resourceId');
|
||||
|
||||
return entries;
|
||||
} catch (error) {
|
||||
logger.error(`[PermissionService.findPubliclyAccessibleResources] Error: ${error.message}`);
|
||||
// Re-throw validation errors
|
||||
if (error.message.includes('requiredPermissions must be')) {
|
||||
throw error;
|
||||
}
|
||||
return [];
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Get available roles for a resource type
|
||||
* @param {Object} params - Parameters for getting available roles
|
||||
* @param {string} params.resourceType - Type of resource (e.g., 'agent')
|
||||
* @returns {Promise<Array>} Array of role definitions
|
||||
*/
|
||||
const getAvailableRoles = async ({ resourceType }) => {
|
||||
try {
|
||||
return await AccessRole.find({ resourceType }).lean();
|
||||
} catch (error) {
|
||||
logger.error(`[PermissionService.getAvailableRoles] Error: ${error.message}`);
|
||||
return [];
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Ensures a principal exists in the database based on TPrincipal data
|
||||
* Creates user if it doesn't exist locally (for Entra ID users)
|
||||
* @param {Object} principal - TPrincipal object from frontend
|
||||
* @param {string} principal.type - 'user', 'group', or 'public'
|
||||
* @param {string} [principal.id] - Local database ID (null for Entra ID principals not yet synced)
|
||||
* @param {string} principal.name - Display name
|
||||
* @param {string} [principal.email] - Email address
|
||||
* @param {string} [principal.source] - 'local' or 'entra'
|
||||
* @param {string} [principal.idOnTheSource] - Entra ID object ID for external principals
|
||||
* @returns {Promise<string|null>} Returns the principalId for database operations, null for public
|
||||
*/
|
||||
const ensurePrincipalExists = async function (principal) {
|
||||
if (principal.type === 'public') {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (principal.id) {
|
||||
return principal.id;
|
||||
}
|
||||
|
||||
if (principal.type === 'user' && principal.source === 'entra') {
|
||||
if (!principal.email || !principal.idOnTheSource) {
|
||||
throw new Error('Entra ID user principals must have email and idOnTheSource');
|
||||
}
|
||||
|
||||
let existingUser = await findUser({ idOnTheSource: principal.idOnTheSource });
|
||||
|
||||
if (!existingUser) {
|
||||
existingUser = await findUser({ email: principal.email.toLowerCase() });
|
||||
}
|
||||
|
||||
if (existingUser) {
|
||||
if (!existingUser.idOnTheSource && principal.idOnTheSource) {
|
||||
await updateUser(existingUser._id, {
|
||||
idOnTheSource: principal.idOnTheSource,
|
||||
provider: 'openid',
|
||||
});
|
||||
}
|
||||
return existingUser._id.toString();
|
||||
}
|
||||
|
||||
const userData = {
|
||||
name: principal.name,
|
||||
email: principal.email.toLowerCase(),
|
||||
emailVerified: false,
|
||||
provider: 'openid',
|
||||
idOnTheSource: principal.idOnTheSource,
|
||||
};
|
||||
|
||||
const userId = await createUser(userData, true, false);
|
||||
return userId.toString();
|
||||
}
|
||||
|
||||
if (principal.type === 'group') {
|
||||
throw new Error('Group principals should be handled by group-specific methods');
|
||||
}
|
||||
|
||||
throw new Error(`Unsupported principal type: ${principal.type}`);
|
||||
};
|
||||
|
||||
/**
|
||||
* Ensures a group principal exists in the database based on TPrincipal data
|
||||
* Creates group if it doesn't exist locally (for Entra ID groups)
|
||||
* For Entra ID groups, always synchronizes member IDs when authentication context is provided
|
||||
* @param {Object} principal - TPrincipal object from frontend
|
||||
* @param {string} principal.type - Must be 'group'
|
||||
* @param {string} [principal.id] - Local database ID (null for Entra ID principals not yet synced)
|
||||
* @param {string} principal.name - Display name
|
||||
* @param {string} [principal.email] - Email address
|
||||
* @param {string} [principal.description] - Group description
|
||||
* @param {string} [principal.source] - 'local' or 'entra'
|
||||
* @param {string} [principal.idOnTheSource] - Entra ID object ID for external principals
|
||||
* @param {Object} [authContext] - Optional authentication context for fetching member data
|
||||
* @param {string} [authContext.accessToken] - Access token for Graph API calls
|
||||
* @param {string} [authContext.sub] - Subject identifier
|
||||
* @returns {Promise<string>} Returns the groupId for database operations
|
||||
*/
|
||||
const ensureGroupPrincipalExists = async function (principal, authContext = null) {
|
||||
if (principal.type !== 'group') {
|
||||
throw new Error(`Invalid principal type: ${principal.type}. Expected 'group'`);
|
||||
}
|
||||
|
||||
if (principal.source === 'entra') {
|
||||
if (!principal.name || !principal.idOnTheSource) {
|
||||
throw new Error('Entra ID group principals must have name and idOnTheSource');
|
||||
}
|
||||
|
||||
let memberIds = [];
|
||||
if (authContext && authContext.accessToken && authContext.sub) {
|
||||
try {
|
||||
memberIds = await getGroupMembers(
|
||||
authContext.accessToken,
|
||||
authContext.sub,
|
||||
principal.idOnTheSource,
|
||||
);
|
||||
|
||||
// Include group owners as members if feature is enabled
|
||||
if (isEnabled(process.env.ENTRA_ID_INCLUDE_OWNERS_AS_MEMBERS)) {
|
||||
const ownerIds = await getGroupOwners(
|
||||
authContext.accessToken,
|
||||
authContext.sub,
|
||||
principal.idOnTheSource,
|
||||
);
|
||||
if (ownerIds && ownerIds.length > 0) {
|
||||
memberIds.push(...ownerIds);
|
||||
// Remove duplicates
|
||||
memberIds = [...new Set(memberIds)];
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Failed to fetch group members from Graph API:', error);
|
||||
}
|
||||
}
|
||||
|
||||
let existingGroup = await findGroupByExternalId(principal.idOnTheSource, 'entra');
|
||||
|
||||
if (!existingGroup && principal.email) {
|
||||
existingGroup = await Group.findOne({ email: principal.email.toLowerCase() }).lean();
|
||||
}
|
||||
|
||||
if (existingGroup) {
|
||||
const updateData = {};
|
||||
let needsUpdate = false;
|
||||
|
||||
if (!existingGroup.idOnTheSource && principal.idOnTheSource) {
|
||||
updateData.idOnTheSource = principal.idOnTheSource;
|
||||
updateData.source = 'entra';
|
||||
needsUpdate = true;
|
||||
}
|
||||
|
||||
if (principal.description && existingGroup.description !== principal.description) {
|
||||
updateData.description = principal.description;
|
||||
needsUpdate = true;
|
||||
}
|
||||
|
||||
if (principal.email && existingGroup.email !== principal.email.toLowerCase()) {
|
||||
updateData.email = principal.email.toLowerCase();
|
||||
needsUpdate = true;
|
||||
}
|
||||
|
||||
if (authContext && authContext.accessToken && authContext.sub) {
|
||||
updateData.memberIds = memberIds;
|
||||
needsUpdate = true;
|
||||
}
|
||||
|
||||
if (needsUpdate) {
|
||||
await Group.findByIdAndUpdate(existingGroup._id, { $set: updateData }, { new: true });
|
||||
}
|
||||
|
||||
return existingGroup._id.toString();
|
||||
}
|
||||
|
||||
const groupData = {
|
||||
name: principal.name,
|
||||
source: 'entra',
|
||||
idOnTheSource: principal.idOnTheSource,
|
||||
memberIds: memberIds, // Store idOnTheSource values of group members (empty if no auth context)
|
||||
};
|
||||
|
||||
if (principal.email) {
|
||||
groupData.email = principal.email.toLowerCase();
|
||||
}
|
||||
|
||||
if (principal.description) {
|
||||
groupData.description = principal.description;
|
||||
}
|
||||
|
||||
const newGroup = await createGroup(groupData);
|
||||
return newGroup._id.toString();
|
||||
}
|
||||
if (principal.id && authContext == null) {
|
||||
return principal.id;
|
||||
}
|
||||
|
||||
throw new Error(`Unsupported group principal source: ${principal.source}`);
|
||||
};
|
||||
|
||||
/**
|
||||
* Synchronize user's Entra ID group memberships on sign-in
|
||||
* Gets user's group IDs from GraphAPI and updates memberships only for existing groups in database
|
||||
* Optionally includes groups the user owns if ENTRA_ID_INCLUDE_OWNERS_AS_MEMBERS is enabled
|
||||
* @param {Object} user - User object with authentication context
|
||||
* @param {string} user.openidId - User's OpenID subject identifier
|
||||
* @param {string} user.idOnTheSource - User's Entra ID (oid from token claims)
|
||||
* @param {string} user.provider - Authentication provider ('openid')
|
||||
* @param {string} accessToken - Access token for Graph API calls
|
||||
* @param {mongoose.ClientSession} [session] - Optional MongoDB session for transactions
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
const syncUserEntraGroupMemberships = async (user, accessToken, session = null) => {
|
||||
try {
|
||||
if (!entraIdPrincipalFeatureEnabled(user) || !accessToken || !user.idOnTheSource) {
|
||||
return;
|
||||
}
|
||||
|
||||
const memberGroupIds = await getUserEntraGroups(accessToken, user.openidId);
|
||||
let allGroupIds = [...(memberGroupIds || [])];
|
||||
|
||||
// Include owned groups if feature is enabled
|
||||
if (isEnabled(process.env.ENTRA_ID_INCLUDE_OWNERS_AS_MEMBERS)) {
|
||||
const ownedGroupIds = await getUserOwnedEntraGroups(accessToken, user.openidId);
|
||||
if (ownedGroupIds && ownedGroupIds.length > 0) {
|
||||
allGroupIds.push(...ownedGroupIds);
|
||||
// Remove duplicates
|
||||
allGroupIds = [...new Set(allGroupIds)];
|
||||
}
|
||||
}
|
||||
|
||||
if (!allGroupIds || allGroupIds.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const sessionOptions = session ? { session } : {};
|
||||
|
||||
await Group.updateMany(
|
||||
{
|
||||
idOnTheSource: { $in: allGroupIds },
|
||||
source: 'entra',
|
||||
memberIds: { $ne: user.idOnTheSource },
|
||||
},
|
||||
{ $addToSet: { memberIds: user.idOnTheSource } },
|
||||
sessionOptions,
|
||||
);
|
||||
|
||||
await Group.updateMany(
|
||||
{
|
||||
source: 'entra',
|
||||
memberIds: user.idOnTheSource,
|
||||
idOnTheSource: { $nin: allGroupIds },
|
||||
},
|
||||
{ $pull: { memberIds: user.idOnTheSource } },
|
||||
sessionOptions,
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error(`[PermissionService.syncUserEntraGroupMemberships] Error syncing groups:`, error);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Check if public has a specific permission on a resource
|
||||
* @param {Object} params - Parameters for checking public permission
|
||||
* @param {string} params.resourceType - Type of resource (e.g., 'agent')
|
||||
* @param {string|mongoose.Types.ObjectId} params.resourceId - The ID of the resource
|
||||
* @param {number} params.requiredPermissions - The permission bits required (e.g., 1 for VIEW, 3 for VIEW+EDIT)
|
||||
* @returns {Promise<boolean>} Whether public has the required permission bits
|
||||
*/
|
||||
const hasPublicPermission = async ({ resourceType, resourceId, requiredPermissions }) => {
|
||||
try {
|
||||
if (typeof requiredPermissions !== 'number' || requiredPermissions < 1) {
|
||||
throw new Error('requiredPermissions must be a positive number');
|
||||
}
|
||||
|
||||
// Use public principal to check permissions
|
||||
const publicPrincipal = [{ principalType: 'public' }];
|
||||
|
||||
const entries = await findEntriesByPrincipalsAndResource(
|
||||
publicPrincipal,
|
||||
resourceType,
|
||||
resourceId,
|
||||
);
|
||||
|
||||
// Check if any entry has the required permission bits
|
||||
return entries.some((entry) => (entry.permBits & requiredPermissions) === requiredPermissions);
|
||||
} catch (error) {
|
||||
logger.error(`[PermissionService.hasPublicPermission] Error: ${error.message}`);
|
||||
// Re-throw validation errors
|
||||
if (error.message.includes('requiredPermissions must be')) {
|
||||
throw error;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Bulk update permissions for a resource (grant, update, revoke)
|
||||
* Efficiently handles multiple permission changes in a single transaction
|
||||
*
|
||||
* @param {Object} params - Parameters for bulk permission update
|
||||
* @param {string} params.resourceType - Type of resource (e.g., 'agent')
|
||||
* @param {string|mongoose.Types.ObjectId} params.resourceId - The ID of the resource
|
||||
* @param {Array<TPrincipal>} params.updatedPrincipals - Array of principals to grant/update permissions for
|
||||
* @param {Array<TPrincipal>} params.revokedPrincipals - Array of principals to revoke permissions from
|
||||
* @param {string|mongoose.Types.ObjectId} params.grantedBy - User ID making the changes
|
||||
* @param {mongoose.ClientSession} [params.session] - Optional MongoDB session for transactions
|
||||
* @returns {Promise<Object>} Results object with granted, updated, revoked arrays and error details
|
||||
*/
|
||||
const bulkUpdateResourcePermissions = async ({
|
||||
resourceType,
|
||||
resourceId,
|
||||
updatedPrincipals = [],
|
||||
revokedPrincipals = [],
|
||||
grantedBy,
|
||||
session,
|
||||
}) => {
|
||||
const supportsTransactions = await getTransactionSupport(mongoose, transactionSupportCache);
|
||||
transactionSupportCache = supportsTransactions;
|
||||
let localSession = session;
|
||||
let shouldEndSession = false;
|
||||
|
||||
try {
|
||||
if (!Array.isArray(updatedPrincipals)) {
|
||||
throw new Error('updatedPrincipals must be an array');
|
||||
}
|
||||
|
||||
if (!Array.isArray(revokedPrincipals)) {
|
||||
throw new Error('revokedPrincipals must be an array');
|
||||
}
|
||||
|
||||
if (!resourceId || !mongoose.Types.ObjectId.isValid(resourceId)) {
|
||||
throw new Error(`Invalid resource ID: ${resourceId}`);
|
||||
}
|
||||
|
||||
if (!localSession && supportsTransactions) {
|
||||
localSession = await mongoose.startSession();
|
||||
localSession.startTransaction();
|
||||
shouldEndSession = true;
|
||||
}
|
||||
|
||||
const sessionOptions = localSession ? { session: localSession } : {};
|
||||
|
||||
const roles = await AccessRole.find({ resourceType }).lean();
|
||||
const rolesMap = new Map();
|
||||
roles.forEach((role) => {
|
||||
rolesMap.set(role.accessRoleId, role);
|
||||
});
|
||||
|
||||
const results = {
|
||||
granted: [],
|
||||
updated: [],
|
||||
revoked: [],
|
||||
errors: [],
|
||||
};
|
||||
|
||||
const bulkWrites = [];
|
||||
|
||||
for (const principal of updatedPrincipals) {
|
||||
try {
|
||||
if (!principal.accessRoleId) {
|
||||
results.errors.push({
|
||||
principal,
|
||||
error: 'accessRoleId is required for updated principals',
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
const role = rolesMap.get(principal.accessRoleId);
|
||||
if (!role) {
|
||||
results.errors.push({
|
||||
principal,
|
||||
error: `Role ${principal.accessRoleId} not found`,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
const query = {
|
||||
principalType: principal.type,
|
||||
resourceType,
|
||||
resourceId,
|
||||
};
|
||||
|
||||
if (principal.type !== 'public') {
|
||||
query.principalId = principal.id;
|
||||
}
|
||||
|
||||
const update = {
|
||||
$set: {
|
||||
permBits: role.permBits,
|
||||
roleId: role._id,
|
||||
grantedBy,
|
||||
grantedAt: new Date(),
|
||||
},
|
||||
$setOnInsert: {
|
||||
principalType: principal.type,
|
||||
resourceType,
|
||||
resourceId,
|
||||
...(principal.type !== 'public' && {
|
||||
principalId: principal.id,
|
||||
principalModel: principal.type === 'user' ? 'User' : 'Group',
|
||||
}),
|
||||
},
|
||||
};
|
||||
|
||||
bulkWrites.push({
|
||||
updateOne: {
|
||||
filter: query,
|
||||
update: update,
|
||||
upsert: true,
|
||||
},
|
||||
});
|
||||
|
||||
results.granted.push({
|
||||
type: principal.type,
|
||||
id: principal.id,
|
||||
name: principal.name,
|
||||
email: principal.email,
|
||||
source: principal.source,
|
||||
avatar: principal.avatar,
|
||||
description: principal.description,
|
||||
idOnTheSource: principal.idOnTheSource,
|
||||
accessRoleId: principal.accessRoleId,
|
||||
memberCount: principal.memberCount,
|
||||
memberIds: principal.memberIds,
|
||||
});
|
||||
} catch (error) {
|
||||
results.errors.push({
|
||||
principal,
|
||||
error: error.message,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (bulkWrites.length > 0) {
|
||||
await AclEntry.bulkWrite(bulkWrites, sessionOptions);
|
||||
}
|
||||
|
||||
const deleteQueries = [];
|
||||
for (const principal of revokedPrincipals) {
|
||||
try {
|
||||
const query = {
|
||||
principalType: principal.type,
|
||||
resourceType,
|
||||
resourceId,
|
||||
};
|
||||
|
||||
if (principal.type !== 'public') {
|
||||
query.principalId = principal.id;
|
||||
}
|
||||
|
||||
deleteQueries.push(query);
|
||||
|
||||
results.revoked.push({
|
||||
type: principal.type,
|
||||
id: principal.id,
|
||||
name: principal.name,
|
||||
email: principal.email,
|
||||
source: principal.source,
|
||||
avatar: principal.avatar,
|
||||
description: principal.description,
|
||||
idOnTheSource: principal.idOnTheSource,
|
||||
memberCount: principal.memberCount,
|
||||
});
|
||||
} catch (error) {
|
||||
results.errors.push({
|
||||
principal,
|
||||
error: error.message,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (deleteQueries.length > 0) {
|
||||
await AclEntry.deleteMany(
|
||||
{
|
||||
$or: deleteQueries,
|
||||
},
|
||||
sessionOptions,
|
||||
);
|
||||
}
|
||||
|
||||
if (shouldEndSession && supportsTransactions) {
|
||||
await localSession.commitTransaction();
|
||||
}
|
||||
|
||||
return results;
|
||||
} catch (error) {
|
||||
if (shouldEndSession && supportsTransactions) {
|
||||
await localSession.abortTransaction();
|
||||
}
|
||||
logger.error(`[PermissionService.bulkUpdateResourcePermissions] Error: ${error.message}`);
|
||||
throw error;
|
||||
} finally {
|
||||
if (shouldEndSession && localSession) {
|
||||
localSession.endSession();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
grantPermission,
|
||||
checkPermission,
|
||||
getEffectivePermissions,
|
||||
findAccessibleResources,
|
||||
findPubliclyAccessibleResources,
|
||||
hasPublicPermission,
|
||||
getAvailableRoles,
|
||||
bulkUpdateResourcePermissions,
|
||||
ensurePrincipalExists,
|
||||
ensureGroupPrincipalExists,
|
||||
syncUserEntraGroupMemberships,
|
||||
};
|
||||
1058
api/server/services/PermissionService.spec.js
Normal file
1058
api/server/services/PermissionService.spec.js
Normal file
File diff suppressed because it is too large
Load Diff
@@ -365,6 +365,7 @@ async function setupOpenId() {
|
||||
email: userinfo.email || '',
|
||||
emailVerified: userinfo.email_verified || false,
|
||||
name: fullName,
|
||||
idOnTheSource: userinfo.oid,
|
||||
};
|
||||
|
||||
const balanceConfig = await getBalanceConfig();
|
||||
@@ -375,6 +376,7 @@ async function setupOpenId() {
|
||||
user.openidId = userinfo.sub;
|
||||
user.username = username;
|
||||
user.name = fullName;
|
||||
user.idOnTheSource = userinfo.oid;
|
||||
}
|
||||
|
||||
if (!!userinfo && userinfo.picture && !user.avatar?.includes('manual=true')) {
|
||||
|
||||
83
client/src/Providers/BadgeRowContext.tsx
Normal file
83
client/src/Providers/BadgeRowContext.tsx
Normal file
@@ -0,0 +1,83 @@
|
||||
import React, { createContext, useContext } from 'react';
|
||||
import { Tools, LocalStorageKeys } from 'librechat-data-provider';
|
||||
import { useMCPSelect, useToolToggle, useCodeApiKeyForm, useSearchApiKeyForm } from '~/hooks';
|
||||
|
||||
interface BadgeRowContextType {
|
||||
conversationId?: string | null;
|
||||
mcpSelect: ReturnType<typeof useMCPSelect>;
|
||||
webSearch: ReturnType<typeof useToolToggle>;
|
||||
codeInterpreter: ReturnType<typeof useToolToggle>;
|
||||
fileSearch: ReturnType<typeof useToolToggle>;
|
||||
codeApiKeyForm: ReturnType<typeof useCodeApiKeyForm>;
|
||||
searchApiKeyForm: ReturnType<typeof useSearchApiKeyForm>;
|
||||
}
|
||||
|
||||
const BadgeRowContext = createContext<BadgeRowContextType | undefined>(undefined);
|
||||
|
||||
export function useBadgeRowContext() {
|
||||
const context = useContext(BadgeRowContext);
|
||||
if (context === undefined) {
|
||||
throw new Error('useBadgeRowContext must be used within a BadgeRowProvider');
|
||||
}
|
||||
return context;
|
||||
}
|
||||
|
||||
interface BadgeRowProviderProps {
|
||||
children: React.ReactNode;
|
||||
conversationId?: string | null;
|
||||
}
|
||||
|
||||
export default function BadgeRowProvider({ children, conversationId }: BadgeRowProviderProps) {
|
||||
/** MCPSelect hook */
|
||||
const mcpSelect = useMCPSelect({ conversationId });
|
||||
|
||||
/** CodeInterpreter hooks */
|
||||
const codeApiKeyForm = useCodeApiKeyForm({});
|
||||
const { setIsDialogOpen: setCodeDialogOpen } = codeApiKeyForm;
|
||||
|
||||
const codeInterpreter = useToolToggle({
|
||||
conversationId,
|
||||
setIsDialogOpen: setCodeDialogOpen,
|
||||
toolKey: Tools.execute_code,
|
||||
localStorageKey: LocalStorageKeys.LAST_CODE_TOGGLE_,
|
||||
authConfig: {
|
||||
toolId: Tools.execute_code,
|
||||
queryOptions: { retry: 1 },
|
||||
},
|
||||
});
|
||||
|
||||
/** WebSearch hooks */
|
||||
const searchApiKeyForm = useSearchApiKeyForm({});
|
||||
const { setIsDialogOpen: setWebSearchDialogOpen } = searchApiKeyForm;
|
||||
|
||||
const webSearch = useToolToggle({
|
||||
conversationId,
|
||||
toolKey: Tools.web_search,
|
||||
localStorageKey: LocalStorageKeys.LAST_WEB_SEARCH_TOGGLE_,
|
||||
setIsDialogOpen: setWebSearchDialogOpen,
|
||||
authConfig: {
|
||||
toolId: Tools.web_search,
|
||||
queryOptions: { retry: 1 },
|
||||
},
|
||||
});
|
||||
|
||||
/** FileSearch hook */
|
||||
const fileSearch = useToolToggle({
|
||||
conversationId,
|
||||
toolKey: Tools.file_search,
|
||||
localStorageKey: LocalStorageKeys.LAST_FILE_SEARCH_TOGGLE_,
|
||||
isAuthenticated: true,
|
||||
});
|
||||
|
||||
const value: BadgeRowContextType = {
|
||||
mcpSelect,
|
||||
webSearch,
|
||||
fileSearch,
|
||||
conversationId,
|
||||
codeApiKeyForm,
|
||||
codeInterpreter,
|
||||
searchApiKeyForm,
|
||||
};
|
||||
|
||||
return <BadgeRowContext.Provider value={value}>{children}</BadgeRowContext.Provider>;
|
||||
}
|
||||
@@ -22,3 +22,5 @@ export * from './CodeBlockContext';
|
||||
export * from './ToolCallsMapContext';
|
||||
export * from './SetConvoContext';
|
||||
export * from './SearchContext';
|
||||
export * from './BadgeRowContext';
|
||||
export { default as BadgeRowProvider } from './BadgeRowContext';
|
||||
|
||||
@@ -7,6 +7,7 @@ export type TAgentOption = OptionWithIcon &
|
||||
knowledge_files?: Array<[string, ExtendedFile]>;
|
||||
context_files?: Array<[string, ExtendedFile]>;
|
||||
code_files?: Array<[string, ExtendedFile]>;
|
||||
_id?: string;
|
||||
};
|
||||
|
||||
export type TAgentCapabilities = {
|
||||
|
||||
@@ -1,19 +1,23 @@
|
||||
import React, {
|
||||
memo,
|
||||
useState,
|
||||
useRef,
|
||||
useEffect,
|
||||
useCallback,
|
||||
useMemo,
|
||||
useState,
|
||||
useEffect,
|
||||
forwardRef,
|
||||
useReducer,
|
||||
useCallback,
|
||||
} from 'react';
|
||||
import { useRecoilValue, useRecoilCallback } from 'recoil';
|
||||
import type { LucideIcon } from 'lucide-react';
|
||||
import CodeInterpreter from './CodeInterpreter';
|
||||
import { BadgeRowProvider } from '~/Providers';
|
||||
import ToolsDropdown from './ToolsDropdown';
|
||||
import type { BadgeItem } from '~/common';
|
||||
import { useChatBadges } from '~/hooks';
|
||||
import { Badge } from '~/components/ui';
|
||||
import ToolDialogs from './ToolDialogs';
|
||||
import FileSearch from './FileSearch';
|
||||
import MCPSelect from './MCPSelect';
|
||||
import WebSearch from './WebSearch';
|
||||
import store from '~/store';
|
||||
@@ -313,78 +317,83 @@ function BadgeRow({
|
||||
}, [dragState.draggedBadge, handleMouseMove, handleMouseUp]);
|
||||
|
||||
return (
|
||||
<div ref={containerRef} className="relative flex flex-wrap items-center gap-2">
|
||||
{tempBadges.map((badge, index) => (
|
||||
<React.Fragment key={badge.id}>
|
||||
{dragState.draggedBadge && dragState.insertIndex === index && ghostBadge && (
|
||||
<div className="badge-icon h-full">
|
||||
<Badge
|
||||
id={ghostBadge.id}
|
||||
icon={ghostBadge.icon as LucideIcon}
|
||||
label={ghostBadge.label}
|
||||
isActive={dragState.draggedBadgeActive}
|
||||
isEditing={isEditing}
|
||||
isAvailable={ghostBadge.isAvailable}
|
||||
isInChat={isInChat}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
<BadgeWrapper
|
||||
badge={badge}
|
||||
isEditing={isEditing}
|
||||
isInChat={isInChat}
|
||||
onToggle={handleBadgeToggle}
|
||||
onDelete={handleDelete}
|
||||
onMouseDown={handleMouseDown}
|
||||
badgeRefs={badgeRefs}
|
||||
/>
|
||||
</React.Fragment>
|
||||
))}
|
||||
{dragState.draggedBadge && dragState.insertIndex === tempBadges.length && ghostBadge && (
|
||||
<div className="badge-icon h-full">
|
||||
<Badge
|
||||
id={ghostBadge.id}
|
||||
icon={ghostBadge.icon as LucideIcon}
|
||||
label={ghostBadge.label}
|
||||
isActive={dragState.draggedBadgeActive}
|
||||
isEditing={isEditing}
|
||||
isAvailable={ghostBadge.isAvailable}
|
||||
isInChat={isInChat}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
{showEphemeralBadges === true && (
|
||||
<>
|
||||
<WebSearch conversationId={conversationId} />
|
||||
<CodeInterpreter conversationId={conversationId} />
|
||||
<MCPSelect conversationId={conversationId} />
|
||||
</>
|
||||
)}
|
||||
{ghostBadge && (
|
||||
<div
|
||||
className="ghost-badge h-full"
|
||||
style={{
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
left: 0,
|
||||
transform: `translateX(${dragState.mouseX - dragState.offsetX - (containerRectRef.current?.left || 0)}px)`,
|
||||
zIndex: 10,
|
||||
pointerEvents: 'none',
|
||||
}}
|
||||
>
|
||||
<Badge
|
||||
id={ghostBadge.id}
|
||||
icon={ghostBadge.icon as LucideIcon}
|
||||
label={ghostBadge.label}
|
||||
isActive={dragState.draggedBadgeActive}
|
||||
isAvailable={ghostBadge.isAvailable}
|
||||
isInChat={isInChat}
|
||||
isEditing
|
||||
isDragging
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<BadgeRowProvider conversationId={conversationId}>
|
||||
<div ref={containerRef} className="relative flex flex-wrap items-center gap-2">
|
||||
{showEphemeralBadges === true && <ToolsDropdown />}
|
||||
{tempBadges.map((badge, index) => (
|
||||
<React.Fragment key={badge.id}>
|
||||
{dragState.draggedBadge && dragState.insertIndex === index && ghostBadge && (
|
||||
<div className="badge-icon h-full">
|
||||
<Badge
|
||||
id={ghostBadge.id}
|
||||
icon={ghostBadge.icon as LucideIcon}
|
||||
label={ghostBadge.label}
|
||||
isActive={dragState.draggedBadgeActive}
|
||||
isEditing={isEditing}
|
||||
isAvailable={ghostBadge.isAvailable}
|
||||
isInChat={isInChat}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
<BadgeWrapper
|
||||
badge={badge}
|
||||
isEditing={isEditing}
|
||||
isInChat={isInChat}
|
||||
onToggle={handleBadgeToggle}
|
||||
onDelete={handleDelete}
|
||||
onMouseDown={handleMouseDown}
|
||||
badgeRefs={badgeRefs}
|
||||
/>
|
||||
</React.Fragment>
|
||||
))}
|
||||
{dragState.draggedBadge && dragState.insertIndex === tempBadges.length && ghostBadge && (
|
||||
<div className="badge-icon h-full">
|
||||
<Badge
|
||||
id={ghostBadge.id}
|
||||
icon={ghostBadge.icon as LucideIcon}
|
||||
label={ghostBadge.label}
|
||||
isActive={dragState.draggedBadgeActive}
|
||||
isEditing={isEditing}
|
||||
isAvailable={ghostBadge.isAvailable}
|
||||
isInChat={isInChat}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
{showEphemeralBadges === true && (
|
||||
<>
|
||||
<WebSearch />
|
||||
<CodeInterpreter />
|
||||
<FileSearch />
|
||||
<MCPSelect />
|
||||
</>
|
||||
)}
|
||||
{ghostBadge && (
|
||||
<div
|
||||
className="ghost-badge h-full"
|
||||
style={{
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
left: 0,
|
||||
transform: `translateX(${dragState.mouseX - dragState.offsetX - (containerRectRef.current?.left || 0)}px)`,
|
||||
zIndex: 10,
|
||||
pointerEvents: 'none',
|
||||
}}
|
||||
>
|
||||
<Badge
|
||||
id={ghostBadge.id}
|
||||
icon={ghostBadge.icon as LucideIcon}
|
||||
label={ghostBadge.label}
|
||||
isActive={dragState.draggedBadgeActive}
|
||||
isAvailable={ghostBadge.isAvailable}
|
||||
isInChat={isInChat}
|
||||
isEditing
|
||||
isDragging
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<ToolDialogs />
|
||||
</BadgeRowProvider>
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,122 +1,37 @@
|
||||
import debounce from 'lodash/debounce';
|
||||
import React, { memo, useMemo, useCallback, useRef } from 'react';
|
||||
import { useRecoilState } from 'recoil';
|
||||
import React, { memo } from 'react';
|
||||
import { TerminalSquareIcon } from 'lucide-react';
|
||||
import {
|
||||
Tools,
|
||||
AuthType,
|
||||
Constants,
|
||||
LocalStorageKeys,
|
||||
PermissionTypes,
|
||||
Permissions,
|
||||
} from 'librechat-data-provider';
|
||||
import ApiKeyDialog from '~/components/SidePanel/Agents/Code/ApiKeyDialog';
|
||||
import { useLocalize, useHasAccess, useCodeApiKeyForm } from '~/hooks';
|
||||
import { PermissionTypes, Permissions } from 'librechat-data-provider';
|
||||
import CheckboxButton from '~/components/ui/CheckboxButton';
|
||||
import useLocalStorage from '~/hooks/useLocalStorageAlt';
|
||||
import { useVerifyAgentToolAuth } from '~/data-provider';
|
||||
import { ephemeralAgentByConvoId } from '~/store';
|
||||
import { useLocalize, useHasAccess } from '~/hooks';
|
||||
import { useBadgeRowContext } from '~/Providers';
|
||||
|
||||
const storageCondition = (value: unknown, rawCurrentValue?: string | null) => {
|
||||
if (rawCurrentValue) {
|
||||
try {
|
||||
const currentValue = rawCurrentValue?.trim() ?? '';
|
||||
if (currentValue === 'true' && value === false) {
|
||||
return true;
|
||||
}
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
}
|
||||
}
|
||||
return value !== undefined && value !== null && value !== '' && value !== false;
|
||||
};
|
||||
|
||||
function CodeInterpreter({ conversationId }: { conversationId?: string | null }) {
|
||||
const triggerRef = useRef<HTMLInputElement>(null);
|
||||
function CodeInterpreter() {
|
||||
const localize = useLocalize();
|
||||
const key = conversationId ?? Constants.NEW_CONVO;
|
||||
const { codeInterpreter, codeApiKeyForm } = useBadgeRowContext();
|
||||
const { toggleState: runCode, debouncedChange, isPinned } = codeInterpreter;
|
||||
const { badgeTriggerRef } = codeApiKeyForm;
|
||||
|
||||
const canRunCode = useHasAccess({
|
||||
permissionType: PermissionTypes.RUN_CODE,
|
||||
permission: Permissions.USE,
|
||||
});
|
||||
const [ephemeralAgent, setEphemeralAgent] = useRecoilState(ephemeralAgentByConvoId(key));
|
||||
const isCodeToggleEnabled = useMemo(() => {
|
||||
return ephemeralAgent?.execute_code ?? false;
|
||||
}, [ephemeralAgent?.execute_code]);
|
||||
|
||||
const { data } = useVerifyAgentToolAuth(
|
||||
{ toolId: Tools.execute_code },
|
||||
{
|
||||
retry: 1,
|
||||
},
|
||||
);
|
||||
const authType = useMemo(() => data?.message ?? false, [data?.message]);
|
||||
const isAuthenticated = useMemo(() => data?.authenticated ?? false, [data?.authenticated]);
|
||||
const { methods, onSubmit, isDialogOpen, setIsDialogOpen, handleRevokeApiKey } =
|
||||
useCodeApiKeyForm({});
|
||||
|
||||
const setValue = useCallback(
|
||||
(isChecked: boolean) => {
|
||||
setEphemeralAgent((prev) => ({
|
||||
...prev,
|
||||
execute_code: isChecked,
|
||||
}));
|
||||
},
|
||||
[setEphemeralAgent],
|
||||
);
|
||||
|
||||
const [runCode, setRunCode] = useLocalStorage<boolean>(
|
||||
`${LocalStorageKeys.LAST_CODE_TOGGLE_}${key}`,
|
||||
isCodeToggleEnabled,
|
||||
setValue,
|
||||
storageCondition,
|
||||
);
|
||||
|
||||
const handleChange = useCallback(
|
||||
(e: React.ChangeEvent<HTMLInputElement>, isChecked: boolean) => {
|
||||
if (!isAuthenticated) {
|
||||
setIsDialogOpen(true);
|
||||
e.preventDefault();
|
||||
return;
|
||||
}
|
||||
setRunCode(isChecked);
|
||||
},
|
||||
[setRunCode, setIsDialogOpen, isAuthenticated],
|
||||
);
|
||||
|
||||
const debouncedChange = useMemo(
|
||||
() => debounce(handleChange, 50, { leading: true }),
|
||||
[handleChange],
|
||||
);
|
||||
|
||||
if (!canRunCode) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
(runCode || isPinned) && (
|
||||
<CheckboxButton
|
||||
ref={triggerRef}
|
||||
ref={badgeTriggerRef}
|
||||
className="max-w-fit"
|
||||
defaultChecked={runCode}
|
||||
checked={runCode}
|
||||
setValue={debouncedChange}
|
||||
label={localize('com_assistants_code_interpreter')}
|
||||
isCheckedClassName="border-purple-600/40 bg-purple-500/10 hover:bg-purple-700/10"
|
||||
icon={<TerminalSquareIcon className="icon-md" />}
|
||||
/>
|
||||
<ApiKeyDialog
|
||||
onSubmit={onSubmit}
|
||||
isOpen={isDialogOpen}
|
||||
triggerRef={triggerRef}
|
||||
register={methods.register}
|
||||
onRevoke={handleRevokeApiKey}
|
||||
onOpenChange={setIsDialogOpen}
|
||||
handleSubmit={methods.handleSubmit}
|
||||
isToolAuthenticated={isAuthenticated}
|
||||
isUserProvided={authType === AuthType.USER_PROVIDED}
|
||||
/>
|
||||
</>
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
28
client/src/components/Chat/Input/FileSearch.tsx
Normal file
28
client/src/components/Chat/Input/FileSearch.tsx
Normal file
@@ -0,0 +1,28 @@
|
||||
import React, { memo } from 'react';
|
||||
import CheckboxButton from '~/components/ui/CheckboxButton';
|
||||
import { useBadgeRowContext } from '~/Providers';
|
||||
import { VectorIcon } from '~/components/svg';
|
||||
import { useLocalize } from '~/hooks';
|
||||
|
||||
function FileSearch() {
|
||||
const localize = useLocalize();
|
||||
const { fileSearch } = useBadgeRowContext();
|
||||
const { toggleState: fileSearchEnabled, debouncedChange, isPinned } = fileSearch;
|
||||
|
||||
return (
|
||||
<>
|
||||
{(fileSearchEnabled || isPinned) && (
|
||||
<CheckboxButton
|
||||
className="max-w-fit"
|
||||
checked={fileSearchEnabled}
|
||||
setValue={debouncedChange}
|
||||
label={localize('com_assistants_file_search')}
|
||||
isCheckedClassName="border-green-600/40 bg-green-500/10 hover:bg-green-700/10"
|
||||
icon={<VectorIcon className="icon-md" />}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
export default memo(FileSearch);
|
||||
@@ -1,31 +1,21 @@
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useRecoilValue } from 'recoil';
|
||||
import {
|
||||
Constants,
|
||||
supportsFiles,
|
||||
mergeFileConfig,
|
||||
isAgentsEndpoint,
|
||||
isEphemeralAgent,
|
||||
EndpointFileConfig,
|
||||
fileConfig as defaultFileConfig,
|
||||
} from 'librechat-data-provider';
|
||||
import { useChatContext } from '~/Providers';
|
||||
import { useGetFileConfig } from '~/data-provider';
|
||||
import { ephemeralAgentByConvoId } from '~/store';
|
||||
import AttachFileMenu from './AttachFileMenu';
|
||||
import AttachFile from './AttachFile';
|
||||
import { useChatContext } from '~/Providers';
|
||||
|
||||
function AttachFileChat({ disableInputs }: { disableInputs: boolean }) {
|
||||
const { conversation } = useChatContext();
|
||||
|
||||
const conversationId = conversation?.conversationId ?? Constants.NEW_CONVO;
|
||||
const { endpoint: _endpoint, endpointType } = conversation ?? { endpoint: null };
|
||||
|
||||
const key = conversation?.conversationId ?? Constants.NEW_CONVO;
|
||||
const ephemeralAgent = useRecoilValue(ephemeralAgentByConvoId(key));
|
||||
const isAgents = useMemo(
|
||||
() => isAgentsEndpoint(_endpoint) || isEphemeralAgent(_endpoint, ephemeralAgent),
|
||||
[_endpoint, ephemeralAgent],
|
||||
);
|
||||
const isAgents = useMemo(() => isAgentsEndpoint(_endpoint), [_endpoint]);
|
||||
|
||||
const { data: fileConfig = defaultFileConfig } = useGetFileConfig({
|
||||
select: (data) => mergeFileConfig(data),
|
||||
@@ -38,11 +28,8 @@ function AttachFileChat({ disableInputs }: { disableInputs: boolean }) {
|
||||
const endpointSupportsFiles: boolean = supportsFiles[endpointType ?? _endpoint ?? ''] ?? false;
|
||||
const isUploadDisabled = (disableInputs || endpointFileConfig?.disabled) ?? false;
|
||||
|
||||
if (isAgents) {
|
||||
return <AttachFileMenu disabled={disableInputs} />;
|
||||
}
|
||||
if (endpointSupportsFiles && !isUploadDisabled) {
|
||||
return <AttachFile disabled={disableInputs} />;
|
||||
if (isAgents || (endpointSupportsFiles && !isUploadDisabled)) {
|
||||
return <AttachFileMenu disabled={disableInputs} conversationId={conversationId} />;
|
||||
}
|
||||
|
||||
return null;
|
||||
|
||||
@@ -1,21 +1,25 @@
|
||||
import { useSetRecoilState } from 'recoil';
|
||||
import * as Ariakit from '@ariakit/react';
|
||||
import React, { useRef, useState, useMemo } from 'react';
|
||||
import { FileSearch, ImageUpIcon, TerminalSquareIcon, FileType2Icon } from 'lucide-react';
|
||||
import { EToolResources, EModelEndpoint, defaultAgentCapabilities } from 'librechat-data-provider';
|
||||
import { FileUpload, TooltipAnchor, DropdownPopup, AttachmentIcon } from '~/components';
|
||||
import { EToolResources, EModelEndpoint } from 'librechat-data-provider';
|
||||
import { useGetEndpointsQuery } from '~/data-provider';
|
||||
import { useLocalize, useFileHandling } from '~/hooks';
|
||||
import { ephemeralAgentByConvoId } from '~/store';
|
||||
import { cn } from '~/utils';
|
||||
|
||||
interface AttachFileProps {
|
||||
interface AttachFileMenuProps {
|
||||
conversationId: string;
|
||||
disabled?: boolean | null;
|
||||
}
|
||||
|
||||
const AttachFile = ({ disabled }: AttachFileProps) => {
|
||||
const AttachFileMenu = ({ disabled, conversationId }: AttachFileMenuProps) => {
|
||||
const localize = useLocalize();
|
||||
const isUploadDisabled = disabled ?? false;
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
const [isPopoverActive, setIsPopoverActive] = useState(false);
|
||||
const setEphemeralAgent = useSetRecoilState(ephemeralAgentByConvoId(conversationId));
|
||||
const [toolResource, setToolResource] = useState<EToolResources | undefined>();
|
||||
const { data: endpointsConfig } = useGetEndpointsQuery();
|
||||
const { handleFileChange } = useFileHandling({
|
||||
@@ -69,6 +73,7 @@ const AttachFile = ({ disabled }: AttachFileProps) => {
|
||||
label: localize('com_ui_upload_file_search'),
|
||||
onClick: () => {
|
||||
setToolResource(EToolResources.file_search);
|
||||
/** File search is not automatically enabled to simulate legacy behavior */
|
||||
handleUploadClick();
|
||||
},
|
||||
icon: <FileSearch className="icon-md" />,
|
||||
@@ -80,6 +85,10 @@ const AttachFile = ({ disabled }: AttachFileProps) => {
|
||||
label: localize('com_ui_upload_code_files'),
|
||||
onClick: () => {
|
||||
setToolResource(EToolResources.execute_code);
|
||||
setEphemeralAgent((prev) => ({
|
||||
...prev,
|
||||
[EToolResources.execute_code]: true,
|
||||
}));
|
||||
handleUploadClick();
|
||||
},
|
||||
icon: <TerminalSquareIcon className="icon-md" />,
|
||||
@@ -87,7 +96,7 @@ const AttachFile = ({ disabled }: AttachFileProps) => {
|
||||
}
|
||||
|
||||
return items;
|
||||
}, [capabilities, localize, setToolResource]);
|
||||
}, [capabilities, localize, setToolResource, setEphemeralAgent]);
|
||||
|
||||
const menuTrigger = (
|
||||
<TooltipAnchor
|
||||
@@ -132,4 +141,4 @@ const AttachFile = ({ disabled }: AttachFileProps) => {
|
||||
);
|
||||
};
|
||||
|
||||
export default React.memo(AttachFile);
|
||||
export default React.memo(AttachFileMenu);
|
||||
|
||||
@@ -7,7 +7,7 @@ import useLocalize from '~/hooks/useLocalize';
|
||||
import { OGDialog } from '~/components/ui';
|
||||
|
||||
interface DragDropModalProps {
|
||||
onOptionSelect: (option: string | undefined) => void;
|
||||
onOptionSelect: (option: EToolResources | undefined) => void;
|
||||
files: File[];
|
||||
isVisible: boolean;
|
||||
setShowModal: (showModal: boolean) => void;
|
||||
|
||||
@@ -1,75 +1,29 @@
|
||||
import React, { memo, useRef, useMemo, useEffect, useCallback, useState } from 'react';
|
||||
import { useRecoilState } from 'recoil';
|
||||
import { Settings2 } from 'lucide-react';
|
||||
import React, { memo, useCallback, useState } from 'react';
|
||||
import { SettingsIcon } from 'lucide-react';
|
||||
import { Constants } from 'librechat-data-provider';
|
||||
import { useUpdateUserPluginsMutation } from 'librechat-data-provider/react-query';
|
||||
import { Constants, EModelEndpoint, LocalStorageKeys } from 'librechat-data-provider';
|
||||
import type { TPlugin, TPluginAuthConfig, TUpdateUserPlugins } from 'librechat-data-provider';
|
||||
import type { TUpdateUserPlugins } from 'librechat-data-provider';
|
||||
import type { McpServerInfo } from '~/hooks/Plugins/useMCPSelect';
|
||||
import MCPConfigDialog, { type ConfigFieldDetail } from '~/components/ui/MCPConfigDialog';
|
||||
import { useAvailableToolsQuery } from '~/data-provider';
|
||||
import useLocalStorage from '~/hooks/useLocalStorageAlt';
|
||||
import { useToastContext, useBadgeRowContext } from '~/Providers';
|
||||
import MultiSelect from '~/components/ui/MultiSelect';
|
||||
import { ephemeralAgentByConvoId } from '~/store';
|
||||
import { useToastContext } from '~/Providers';
|
||||
import MCPIcon from '~/components/ui/MCPIcon';
|
||||
import { MCPIcon } from '~/components/svg';
|
||||
import { useLocalize } from '~/hooks';
|
||||
|
||||
interface McpServerInfo {
|
||||
name: string;
|
||||
pluginKey: string;
|
||||
authConfig?: TPluginAuthConfig[];
|
||||
authenticated?: boolean;
|
||||
}
|
||||
|
||||
// Helper function to extract mcp_serverName from a full pluginKey like action_mcp_serverName
|
||||
const getBaseMCPPluginKey = (fullPluginKey: string): string => {
|
||||
const parts = fullPluginKey.split(Constants.mcp_delimiter);
|
||||
return Constants.mcp_prefix + parts[parts.length - 1];
|
||||
};
|
||||
|
||||
const storageCondition = (value: unknown, rawCurrentValue?: string | null) => {
|
||||
if (rawCurrentValue) {
|
||||
try {
|
||||
const currentValue = rawCurrentValue?.trim() ?? '';
|
||||
if (currentValue.length > 2) {
|
||||
return true;
|
||||
}
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
}
|
||||
}
|
||||
return Array.isArray(value) && value.length > 0;
|
||||
};
|
||||
|
||||
function MCPSelect({ conversationId }: { conversationId?: string | null }) {
|
||||
function MCPSelect() {
|
||||
const localize = useLocalize();
|
||||
const { showToast } = useToastContext();
|
||||
const key = conversationId ?? Constants.NEW_CONVO;
|
||||
const hasSetFetched = useRef<string | null>(null);
|
||||
const { mcpSelect } = useBadgeRowContext();
|
||||
const { mcpValues, setMCPValues, mcpServerNames, mcpToolDetails, isPinned } = mcpSelect;
|
||||
|
||||
const [isConfigModalOpen, setIsConfigModalOpen] = useState(false);
|
||||
const [selectedToolForConfig, setSelectedToolForConfig] = useState<McpServerInfo | null>(null);
|
||||
|
||||
const { data: mcpToolDetails, isFetched } = useAvailableToolsQuery(EModelEndpoint.agents, {
|
||||
select: (data: TPlugin[]) => {
|
||||
const mcpToolsMap = new Map<string, McpServerInfo>();
|
||||
data.forEach((tool) => {
|
||||
const isMCP = tool.pluginKey.includes(Constants.mcp_delimiter);
|
||||
if (isMCP && tool.chatMenu !== false) {
|
||||
const parts = tool.pluginKey.split(Constants.mcp_delimiter);
|
||||
const serverName = parts[parts.length - 1];
|
||||
if (!mcpToolsMap.has(serverName)) {
|
||||
mcpToolsMap.set(serverName, {
|
||||
name: serverName,
|
||||
pluginKey: tool.pluginKey,
|
||||
authConfig: tool.authConfig,
|
||||
authenticated: tool.authenticated,
|
||||
});
|
||||
}
|
||||
}
|
||||
});
|
||||
return Array.from(mcpToolsMap.values());
|
||||
},
|
||||
});
|
||||
|
||||
const updateUserPluginsMutation = useUpdateUserPluginsMutation({
|
||||
onSuccess: () => {
|
||||
setIsConfigModalOpen(false);
|
||||
@@ -84,48 +38,6 @@ function MCPSelect({ conversationId }: { conversationId?: string | null }) {
|
||||
},
|
||||
});
|
||||
|
||||
const [ephemeralAgent, setEphemeralAgent] = useRecoilState(ephemeralAgentByConvoId(key));
|
||||
const mcpState = useMemo(() => {
|
||||
return ephemeralAgent?.mcp ?? [];
|
||||
}, [ephemeralAgent?.mcp]);
|
||||
|
||||
const setSelectedValues = useCallback(
|
||||
(values: string[] | null | undefined) => {
|
||||
if (!values) {
|
||||
return;
|
||||
}
|
||||
if (!Array.isArray(values)) {
|
||||
return;
|
||||
}
|
||||
setEphemeralAgent((prev) => ({
|
||||
...prev,
|
||||
mcp: values,
|
||||
}));
|
||||
},
|
||||
[setEphemeralAgent],
|
||||
);
|
||||
const [mcpValues, setMCPValues] = useLocalStorage<string[]>(
|
||||
`${LocalStorageKeys.LAST_MCP_}${key}`,
|
||||
mcpState,
|
||||
setSelectedValues,
|
||||
storageCondition,
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (hasSetFetched.current === key) {
|
||||
return;
|
||||
}
|
||||
if (!isFetched) {
|
||||
return;
|
||||
}
|
||||
hasSetFetched.current = key;
|
||||
if ((mcpToolDetails?.length ?? 0) > 0) {
|
||||
setMCPValues(mcpValues.filter((mcp) => mcpToolDetails?.some((tool) => tool.name === mcp)));
|
||||
return;
|
||||
}
|
||||
setMCPValues([]);
|
||||
}, [isFetched, setMCPValues, mcpToolDetails, key, mcpValues]);
|
||||
|
||||
const renderSelectedValues = useCallback(
|
||||
(values: string[], placeholder?: string) => {
|
||||
if (values.length === 0) {
|
||||
@@ -139,10 +51,6 @@ function MCPSelect({ conversationId }: { conversationId?: string | null }) {
|
||||
[localize],
|
||||
);
|
||||
|
||||
const mcpServerNames = useMemo(() => {
|
||||
return (mcpToolDetails ?? []).map((tool) => tool.name);
|
||||
}, [mcpToolDetails]);
|
||||
|
||||
const handleConfigSave = useCallback(
|
||||
(targetName: string, authData: Record<string, string>) => {
|
||||
if (selectedToolForConfig && selectedToolForConfig.name === targetName) {
|
||||
@@ -198,10 +106,10 @@ function MCPSelect({ conversationId }: { conversationId?: string | null }) {
|
||||
setSelectedToolForConfig(tool);
|
||||
setIsConfigModalOpen(true);
|
||||
}}
|
||||
className="ml-2 flex h-6 w-6 items-center justify-center rounded p-1 hover:bg-black/10 dark:hover:bg-white/10"
|
||||
className="ml-2 flex h-6 w-6 items-center justify-center rounded p-1 hover:bg-surface-secondary"
|
||||
aria-label={`Configure ${serverName}`}
|
||||
>
|
||||
<Settings2 className={`h-4 w-4 ${tool.authenticated ? 'text-green-500' : ''}`} />
|
||||
<SettingsIcon className={`h-4 w-4 ${tool.authenticated ? 'text-green-500' : ''}`} />
|
||||
</button>
|
||||
</div>
|
||||
);
|
||||
@@ -212,6 +120,11 @@ function MCPSelect({ conversationId }: { conversationId?: string | null }) {
|
||||
[mcpToolDetails, setSelectedToolForConfig, setIsConfigModalOpen],
|
||||
);
|
||||
|
||||
// Don't render if no servers are selected and not pinned
|
||||
if ((!mcpValues || mcpValues.length === 0) && !isPinned) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (!mcpToolDetails || mcpToolDetails.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
96
client/src/components/Chat/Input/MCPSubMenu.tsx
Normal file
96
client/src/components/Chat/Input/MCPSubMenu.tsx
Normal file
@@ -0,0 +1,96 @@
|
||||
import React from 'react';
|
||||
import * as Ariakit from '@ariakit/react';
|
||||
import { ChevronRight } from 'lucide-react';
|
||||
import { PinIcon, MCPIcon } from '~/components/svg';
|
||||
import { useLocalize } from '~/hooks';
|
||||
import { cn } from '~/utils';
|
||||
|
||||
interface MCPSubMenuProps {
|
||||
isMCPPinned: boolean;
|
||||
setIsMCPPinned: (value: boolean) => void;
|
||||
mcpValues?: string[];
|
||||
mcpServerNames: string[];
|
||||
handleMCPToggle: (serverName: string) => void;
|
||||
}
|
||||
|
||||
const MCPSubMenu = ({
|
||||
mcpValues,
|
||||
isMCPPinned,
|
||||
mcpServerNames,
|
||||
setIsMCPPinned,
|
||||
handleMCPToggle,
|
||||
...props
|
||||
}: MCPSubMenuProps) => {
|
||||
const localize = useLocalize();
|
||||
|
||||
const menuStore = Ariakit.useMenuStore({
|
||||
showTimeout: 100,
|
||||
placement: 'right',
|
||||
});
|
||||
|
||||
return (
|
||||
<Ariakit.MenuProvider store={menuStore}>
|
||||
<Ariakit.MenuItem
|
||||
{...props}
|
||||
render={
|
||||
<Ariakit.MenuButton className="flex w-full cursor-pointer items-center justify-between rounded-lg p-2 hover:bg-surface-hover" />
|
||||
}
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<MCPIcon className="icon-md" />
|
||||
<span>{localize('com_ui_mcp_servers')}</span>
|
||||
<ChevronRight className="ml-auto h-3 w-3" />
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
setIsMCPPinned(!isMCPPinned);
|
||||
}}
|
||||
className={cn(
|
||||
'rounded p-1 transition-all duration-200',
|
||||
'hover:bg-surface-tertiary hover:shadow-sm',
|
||||
!isMCPPinned && 'text-text-secondary hover:text-text-primary',
|
||||
)}
|
||||
aria-label={isMCPPinned ? 'Unpin' : 'Pin'}
|
||||
>
|
||||
<div className="h-4 w-4">
|
||||
<PinIcon unpin={isMCPPinned} />
|
||||
</div>
|
||||
</button>
|
||||
</Ariakit.MenuItem>
|
||||
<Ariakit.Menu
|
||||
gutter={-4}
|
||||
shift={-8}
|
||||
unmountOnHide
|
||||
portal={true}
|
||||
className={cn(
|
||||
'animate-popover-left z-50 ml-3 flex min-w-[200px] flex-col rounded-xl',
|
||||
'border border-border-light bg-surface-secondary p-1 shadow-lg',
|
||||
)}
|
||||
>
|
||||
{mcpServerNames.map((serverName) => (
|
||||
<Ariakit.MenuItem
|
||||
key={serverName}
|
||||
onClick={(event) => {
|
||||
event.preventDefault();
|
||||
handleMCPToggle(serverName);
|
||||
}}
|
||||
className={cn(
|
||||
'flex items-center gap-2 rounded-lg px-2 py-1.5 text-text-primary hover:cursor-pointer',
|
||||
'scroll-m-1 outline-none transition-colors',
|
||||
'hover:bg-black/[0.075] dark:hover:bg-white/10',
|
||||
'data-[active-item]:bg-black/[0.075] dark:data-[active-item]:bg-white/10',
|
||||
'w-full min-w-0 text-sm',
|
||||
)}
|
||||
>
|
||||
<Ariakit.MenuItemCheck checked={mcpValues?.includes(serverName) ?? false} />
|
||||
<span>{serverName}</span>
|
||||
</Ariakit.MenuItem>
|
||||
))}
|
||||
</Ariakit.Menu>
|
||||
</Ariakit.MenuProvider>
|
||||
);
|
||||
};
|
||||
|
||||
export default React.memo(MCPSubMenu);
|
||||
66
client/src/components/Chat/Input/ToolDialogs.tsx
Normal file
66
client/src/components/Chat/Input/ToolDialogs.tsx
Normal file
@@ -0,0 +1,66 @@
|
||||
import React, { useMemo } from 'react';
|
||||
import { AuthType } from 'librechat-data-provider';
|
||||
import SearchApiKeyDialog from '~/components/SidePanel/Agents/Search/ApiKeyDialog';
|
||||
import CodeApiKeyDialog from '~/components/SidePanel/Agents/Code/ApiKeyDialog';
|
||||
import { useBadgeRowContext } from '~/Providers';
|
||||
|
||||
function ToolDialogs() {
|
||||
const { webSearch, codeInterpreter, searchApiKeyForm, codeApiKeyForm } = useBadgeRowContext();
|
||||
const { authData: webSearchAuthData } = webSearch;
|
||||
const { authData: codeAuthData } = codeInterpreter;
|
||||
|
||||
const {
|
||||
methods: searchMethods,
|
||||
onSubmit: searchOnSubmit,
|
||||
isDialogOpen: searchDialogOpen,
|
||||
setIsDialogOpen: setSearchDialogOpen,
|
||||
handleRevokeApiKey: searchHandleRevoke,
|
||||
badgeTriggerRef: searchBadgeTriggerRef,
|
||||
menuTriggerRef: searchMenuTriggerRef,
|
||||
} = searchApiKeyForm;
|
||||
|
||||
const {
|
||||
methods: codeMethods,
|
||||
onSubmit: codeOnSubmit,
|
||||
isDialogOpen: codeDialogOpen,
|
||||
setIsDialogOpen: setCodeDialogOpen,
|
||||
handleRevokeApiKey: codeHandleRevoke,
|
||||
badgeTriggerRef: codeBadgeTriggerRef,
|
||||
menuTriggerRef: codeMenuTriggerRef,
|
||||
} = codeApiKeyForm;
|
||||
|
||||
const searchAuthTypes = useMemo(
|
||||
() => webSearchAuthData?.authTypes ?? [],
|
||||
[webSearchAuthData?.authTypes],
|
||||
);
|
||||
const codeAuthType = useMemo(() => codeAuthData?.message ?? false, [codeAuthData?.message]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<SearchApiKeyDialog
|
||||
onSubmit={searchOnSubmit}
|
||||
authTypes={searchAuthTypes}
|
||||
isOpen={searchDialogOpen}
|
||||
onRevoke={searchHandleRevoke}
|
||||
register={searchMethods.register}
|
||||
onOpenChange={setSearchDialogOpen}
|
||||
handleSubmit={searchMethods.handleSubmit}
|
||||
triggerRefs={[searchMenuTriggerRef, searchBadgeTriggerRef]}
|
||||
isToolAuthenticated={webSearchAuthData?.authenticated ?? false}
|
||||
/>
|
||||
<CodeApiKeyDialog
|
||||
onSubmit={codeOnSubmit}
|
||||
isOpen={codeDialogOpen}
|
||||
onRevoke={codeHandleRevoke}
|
||||
register={codeMethods.register}
|
||||
onOpenChange={setCodeDialogOpen}
|
||||
handleSubmit={codeMethods.handleSubmit}
|
||||
triggerRefs={[codeMenuTriggerRef, codeBadgeTriggerRef]}
|
||||
isUserProvided={codeAuthType === AuthType.USER_PROVIDED}
|
||||
isToolAuthenticated={codeAuthData?.authenticated ?? false}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
export default ToolDialogs;
|
||||
322
client/src/components/Chat/Input/ToolsDropdown.tsx
Normal file
322
client/src/components/Chat/Input/ToolsDropdown.tsx
Normal file
@@ -0,0 +1,322 @@
|
||||
import React, { useState, useMemo, useCallback } from 'react';
|
||||
import * as Ariakit from '@ariakit/react';
|
||||
import { Globe, Settings, Settings2, TerminalSquareIcon } from 'lucide-react';
|
||||
import type { MenuItemProps } from '~/common';
|
||||
import { Permissions, PermissionTypes, AuthType } from 'librechat-data-provider';
|
||||
import { TooltipAnchor, DropdownPopup } from '~/components';
|
||||
import MCPSubMenu from '~/components/Chat/Input/MCPSubMenu';
|
||||
import { PinIcon, VectorIcon } from '~/components/svg';
|
||||
import { useLocalize, useHasAccess } from '~/hooks';
|
||||
import { useBadgeRowContext } from '~/Providers';
|
||||
import { cn } from '~/utils';
|
||||
|
||||
interface ToolsDropdownProps {
|
||||
disabled?: boolean;
|
||||
}
|
||||
|
||||
const ToolsDropdown = ({ disabled }: ToolsDropdownProps) => {
|
||||
const localize = useLocalize();
|
||||
const isDisabled = disabled ?? false;
|
||||
const [isPopoverActive, setIsPopoverActive] = useState(false);
|
||||
const { webSearch, codeInterpreter, fileSearch, mcpSelect, searchApiKeyForm, codeApiKeyForm } =
|
||||
useBadgeRowContext();
|
||||
const { setIsDialogOpen: setIsCodeDialogOpen, menuTriggerRef: codeMenuTriggerRef } =
|
||||
codeApiKeyForm;
|
||||
const { setIsDialogOpen: setIsSearchDialogOpen, menuTriggerRef: searchMenuTriggerRef } =
|
||||
searchApiKeyForm;
|
||||
const {
|
||||
isPinned: isSearchPinned,
|
||||
setIsPinned: setIsSearchPinned,
|
||||
authData: webSearchAuthData,
|
||||
} = webSearch;
|
||||
const {
|
||||
isPinned: isCodePinned,
|
||||
setIsPinned: setIsCodePinned,
|
||||
authData: codeAuthData,
|
||||
} = codeInterpreter;
|
||||
const { isPinned: isFileSearchPinned, setIsPinned: setIsFileSearchPinned } = fileSearch;
|
||||
const {
|
||||
mcpValues,
|
||||
mcpServerNames,
|
||||
isPinned: isMCPPinned,
|
||||
setIsPinned: setIsMCPPinned,
|
||||
} = mcpSelect;
|
||||
|
||||
const canUseWebSearch = useHasAccess({
|
||||
permissionType: PermissionTypes.WEB_SEARCH,
|
||||
permission: Permissions.USE,
|
||||
});
|
||||
|
||||
const canRunCode = useHasAccess({
|
||||
permissionType: PermissionTypes.RUN_CODE,
|
||||
permission: Permissions.USE,
|
||||
});
|
||||
|
||||
const showWebSearchSettings = useMemo(() => {
|
||||
const authTypes = webSearchAuthData?.authTypes ?? [];
|
||||
if (authTypes.length === 0) return true;
|
||||
return !authTypes.every(([, authType]) => authType === AuthType.SYSTEM_DEFINED);
|
||||
}, [webSearchAuthData?.authTypes]);
|
||||
|
||||
const showCodeSettings = useMemo(
|
||||
() => codeAuthData?.message !== AuthType.SYSTEM_DEFINED,
|
||||
[codeAuthData?.message],
|
||||
);
|
||||
|
||||
const handleWebSearchToggle = useCallback(() => {
|
||||
const newValue = !webSearch.toggleState;
|
||||
webSearch.debouncedChange({ isChecked: newValue });
|
||||
}, [webSearch]);
|
||||
|
||||
const handleCodeInterpreterToggle = useCallback(() => {
|
||||
const newValue = !codeInterpreter.toggleState;
|
||||
codeInterpreter.debouncedChange({ isChecked: newValue });
|
||||
}, [codeInterpreter]);
|
||||
|
||||
const handleFileSearchToggle = useCallback(() => {
|
||||
const newValue = !fileSearch.toggleState;
|
||||
fileSearch.debouncedChange({ isChecked: newValue });
|
||||
}, [fileSearch]);
|
||||
|
||||
const handleMCPToggle = useCallback(
|
||||
(serverName: string) => {
|
||||
const currentValues = mcpSelect.mcpValues ?? [];
|
||||
const newValues = currentValues.includes(serverName)
|
||||
? currentValues.filter((v) => v !== serverName)
|
||||
: [...currentValues, serverName];
|
||||
mcpSelect.setMCPValues(newValues);
|
||||
},
|
||||
[mcpSelect],
|
||||
);
|
||||
|
||||
const dropdownItems = useMemo(() => {
|
||||
const items: MenuItemProps[] = [
|
||||
{
|
||||
render: () => (
|
||||
<div className="px-3 py-2 text-xs font-semibold text-text-secondary">
|
||||
{localize('com_ui_tools')}
|
||||
</div>
|
||||
),
|
||||
hideOnClick: false,
|
||||
},
|
||||
];
|
||||
|
||||
items.push({
|
||||
onClick: handleFileSearchToggle,
|
||||
hideOnClick: false,
|
||||
render: (props) => (
|
||||
<div {...props}>
|
||||
<div className="flex items-center gap-2">
|
||||
<VectorIcon className="icon-md" />
|
||||
<span>{localize('com_assistants_file_search')}</span>
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
setIsFileSearchPinned(!isFileSearchPinned);
|
||||
}}
|
||||
className={cn(
|
||||
'rounded p-1 transition-all duration-200',
|
||||
'hover:bg-surface-secondary hover:shadow-sm',
|
||||
!isFileSearchPinned && 'text-text-secondary hover:text-text-primary',
|
||||
)}
|
||||
aria-label={isFileSearchPinned ? 'Unpin' : 'Pin'}
|
||||
>
|
||||
<div className="h-4 w-4">
|
||||
<PinIcon unpin={isFileSearchPinned} />
|
||||
</div>
|
||||
</button>
|
||||
</div>
|
||||
),
|
||||
});
|
||||
|
||||
if (canUseWebSearch) {
|
||||
items.push({
|
||||
onClick: handleWebSearchToggle,
|
||||
hideOnClick: false,
|
||||
render: (props) => (
|
||||
<div {...props}>
|
||||
<div className="flex items-center gap-2">
|
||||
<Globe className="icon-md" />
|
||||
<span>{localize('com_ui_web_search')}</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-1">
|
||||
{showWebSearchSettings && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
setIsSearchDialogOpen(true);
|
||||
}}
|
||||
className={cn(
|
||||
'rounded p-1 transition-all duration-200',
|
||||
'hover:bg-surface-secondary hover:shadow-sm',
|
||||
'text-text-secondary hover:text-text-primary',
|
||||
)}
|
||||
aria-label="Configure web search"
|
||||
ref={searchMenuTriggerRef}
|
||||
>
|
||||
<div className="h-4 w-4">
|
||||
<Settings className="h-4 w-4" />
|
||||
</div>
|
||||
</button>
|
||||
)}
|
||||
<button
|
||||
type="button"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
setIsSearchPinned(!isSearchPinned);
|
||||
}}
|
||||
className={cn(
|
||||
'rounded p-1 transition-all duration-200',
|
||||
'hover:bg-surface-secondary hover:shadow-sm',
|
||||
!isSearchPinned && 'text-text-secondary hover:text-text-primary',
|
||||
)}
|
||||
aria-label={isSearchPinned ? 'Unpin' : 'Pin'}
|
||||
>
|
||||
<div className="h-4 w-4">
|
||||
<PinIcon unpin={isSearchPinned} />
|
||||
</div>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
if (canRunCode) {
|
||||
items.push({
|
||||
onClick: handleCodeInterpreterToggle,
|
||||
hideOnClick: false,
|
||||
render: (props) => (
|
||||
<div {...props}>
|
||||
<div className="flex items-center gap-2">
|
||||
<TerminalSquareIcon className="icon-md" />
|
||||
<span>{localize('com_assistants_code_interpreter')}</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-1">
|
||||
{showCodeSettings && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
setIsCodeDialogOpen(true);
|
||||
}}
|
||||
ref={codeMenuTriggerRef}
|
||||
className={cn(
|
||||
'rounded p-1 transition-all duration-200',
|
||||
'hover:bg-surface-secondary hover:shadow-sm',
|
||||
'text-text-secondary hover:text-text-primary',
|
||||
)}
|
||||
aria-label="Configure code interpreter"
|
||||
>
|
||||
<div className="h-4 w-4">
|
||||
<Settings className="h-4 w-4" />
|
||||
</div>
|
||||
</button>
|
||||
)}
|
||||
<button
|
||||
type="button"
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
setIsCodePinned(!isCodePinned);
|
||||
}}
|
||||
className={cn(
|
||||
'rounded p-1 transition-all duration-200',
|
||||
'hover:bg-surface-secondary hover:shadow-sm',
|
||||
!isCodePinned && 'text-text-primary hover:text-text-primary',
|
||||
)}
|
||||
aria-label={isCodePinned ? 'Unpin' : 'Pin'}
|
||||
>
|
||||
<div className="h-4 w-4">
|
||||
<PinIcon unpin={isCodePinned} />
|
||||
</div>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
if (mcpServerNames && mcpServerNames.length > 0) {
|
||||
items.push({
|
||||
hideOnClick: false,
|
||||
render: (props) => (
|
||||
<MCPSubMenu
|
||||
{...props}
|
||||
mcpValues={mcpValues}
|
||||
mcpServerNames={mcpServerNames}
|
||||
isMCPPinned={isMCPPinned}
|
||||
setIsMCPPinned={setIsMCPPinned}
|
||||
handleMCPToggle={handleMCPToggle}
|
||||
/>
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
return items;
|
||||
}, [
|
||||
localize,
|
||||
mcpValues,
|
||||
canRunCode,
|
||||
isMCPPinned,
|
||||
isCodePinned,
|
||||
mcpServerNames,
|
||||
isSearchPinned,
|
||||
setIsMCPPinned,
|
||||
canUseWebSearch,
|
||||
setIsCodePinned,
|
||||
handleMCPToggle,
|
||||
showCodeSettings,
|
||||
setIsSearchPinned,
|
||||
isFileSearchPinned,
|
||||
codeMenuTriggerRef,
|
||||
setIsCodeDialogOpen,
|
||||
searchMenuTriggerRef,
|
||||
showWebSearchSettings,
|
||||
setIsFileSearchPinned,
|
||||
handleWebSearchToggle,
|
||||
setIsSearchDialogOpen,
|
||||
handleFileSearchToggle,
|
||||
handleCodeInterpreterToggle,
|
||||
]);
|
||||
|
||||
const menuTrigger = (
|
||||
<TooltipAnchor
|
||||
render={
|
||||
<Ariakit.MenuButton
|
||||
disabled={isDisabled}
|
||||
id="tools-dropdown-button"
|
||||
aria-label="Tools Options"
|
||||
className={cn(
|
||||
'flex size-9 items-center justify-center rounded-full p-1 transition-colors hover:bg-surface-hover focus:outline-none focus:ring-2 focus:ring-primary focus:ring-opacity-50',
|
||||
)}
|
||||
>
|
||||
<div className="flex w-full items-center justify-center gap-2">
|
||||
<Settings2 className="icon-md" />
|
||||
</div>
|
||||
</Ariakit.MenuButton>
|
||||
}
|
||||
id="tools-dropdown-button"
|
||||
description={localize('com_ui_tools')}
|
||||
disabled={isDisabled}
|
||||
/>
|
||||
);
|
||||
|
||||
return (
|
||||
<DropdownPopup
|
||||
itemClassName="flex w-full cursor-pointer items-center justify-between hover:bg-surface-hover gap-5"
|
||||
menuId="tools-dropdown-menu"
|
||||
isOpen={isPopoverActive}
|
||||
setIsOpen={setIsPopoverActive}
|
||||
modal={true}
|
||||
unmountOnHide={true}
|
||||
trigger={menuTrigger}
|
||||
items={dropdownItems}
|
||||
iconClassName="mr-0"
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default React.memo(ToolsDropdown);
|
||||
@@ -1,122 +1,37 @@
|
||||
import React, { memo, useRef, useMemo, useCallback } from 'react';
|
||||
import React, { memo } from 'react';
|
||||
import { Globe } from 'lucide-react';
|
||||
import debounce from 'lodash/debounce';
|
||||
import { useRecoilState } from 'recoil';
|
||||
import {
|
||||
Tools,
|
||||
AuthType,
|
||||
Constants,
|
||||
Permissions,
|
||||
PermissionTypes,
|
||||
LocalStorageKeys,
|
||||
} from 'librechat-data-provider';
|
||||
import ApiKeyDialog from '~/components/SidePanel/Agents/Search/ApiKeyDialog';
|
||||
import { useLocalize, useHasAccess, useSearchApiKeyForm } from '~/hooks';
|
||||
import { Permissions, PermissionTypes } from 'librechat-data-provider';
|
||||
import CheckboxButton from '~/components/ui/CheckboxButton';
|
||||
import useLocalStorage from '~/hooks/useLocalStorageAlt';
|
||||
import { useVerifyAgentToolAuth } from '~/data-provider';
|
||||
import { ephemeralAgentByConvoId } from '~/store';
|
||||
import { useLocalize, useHasAccess } from '~/hooks';
|
||||
import { useBadgeRowContext } from '~/Providers';
|
||||
|
||||
const storageCondition = (value: unknown, rawCurrentValue?: string | null) => {
|
||||
if (rawCurrentValue) {
|
||||
try {
|
||||
const currentValue = rawCurrentValue?.trim() ?? '';
|
||||
if (currentValue === 'true' && value === false) {
|
||||
return true;
|
||||
}
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
}
|
||||
}
|
||||
return value !== undefined && value !== null && value !== '' && value !== false;
|
||||
};
|
||||
|
||||
function WebSearch({ conversationId }: { conversationId?: string | null }) {
|
||||
const triggerRef = useRef<HTMLInputElement>(null);
|
||||
function WebSearch() {
|
||||
const localize = useLocalize();
|
||||
const key = conversationId ?? Constants.NEW_CONVO;
|
||||
const { webSearch: webSearchData, searchApiKeyForm } = useBadgeRowContext();
|
||||
const { toggleState: webSearch, debouncedChange, isPinned } = webSearchData;
|
||||
const { badgeTriggerRef } = searchApiKeyForm;
|
||||
|
||||
const canUseWebSearch = useHasAccess({
|
||||
permissionType: PermissionTypes.WEB_SEARCH,
|
||||
permission: Permissions.USE,
|
||||
});
|
||||
const [ephemeralAgent, setEphemeralAgent] = useRecoilState(ephemeralAgentByConvoId(key));
|
||||
const isWebSearchToggleEnabled = useMemo(() => {
|
||||
return ephemeralAgent?.web_search ?? false;
|
||||
}, [ephemeralAgent?.web_search]);
|
||||
|
||||
const { data } = useVerifyAgentToolAuth(
|
||||
{ toolId: Tools.web_search },
|
||||
{
|
||||
retry: 1,
|
||||
},
|
||||
);
|
||||
const authTypes = useMemo(() => data?.authTypes ?? [], [data?.authTypes]);
|
||||
const isAuthenticated = useMemo(() => data?.authenticated ?? false, [data?.authenticated]);
|
||||
const { methods, onSubmit, isDialogOpen, setIsDialogOpen, handleRevokeApiKey } =
|
||||
useSearchApiKeyForm({});
|
||||
|
||||
const setValue = useCallback(
|
||||
(isChecked: boolean) => {
|
||||
setEphemeralAgent((prev) => ({
|
||||
...prev,
|
||||
web_search: isChecked,
|
||||
}));
|
||||
},
|
||||
[setEphemeralAgent],
|
||||
);
|
||||
|
||||
const [webSearch, setWebSearch] = useLocalStorage<boolean>(
|
||||
`${LocalStorageKeys.LAST_WEB_SEARCH_TOGGLE_}${key}`,
|
||||
isWebSearchToggleEnabled,
|
||||
setValue,
|
||||
storageCondition,
|
||||
);
|
||||
|
||||
const handleChange = useCallback(
|
||||
(e: React.ChangeEvent<HTMLInputElement>, isChecked: boolean) => {
|
||||
if (!isAuthenticated) {
|
||||
setIsDialogOpen(true);
|
||||
e.preventDefault();
|
||||
return;
|
||||
}
|
||||
setWebSearch(isChecked);
|
||||
},
|
||||
[setWebSearch, setIsDialogOpen, isAuthenticated],
|
||||
);
|
||||
|
||||
const debouncedChange = useMemo(
|
||||
() => debounce(handleChange, 50, { leading: true }),
|
||||
[handleChange],
|
||||
);
|
||||
|
||||
if (!canUseWebSearch) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
(webSearch || isPinned) && (
|
||||
<CheckboxButton
|
||||
ref={triggerRef}
|
||||
ref={badgeTriggerRef}
|
||||
className="max-w-fit"
|
||||
defaultChecked={webSearch}
|
||||
checked={webSearch}
|
||||
setValue={debouncedChange}
|
||||
label={localize('com_ui_search')}
|
||||
isCheckedClassName="border-blue-600/40 bg-blue-500/10 hover:bg-blue-700/10"
|
||||
icon={<Globe className="icon-md" />}
|
||||
/>
|
||||
<ApiKeyDialog
|
||||
onSubmit={onSubmit}
|
||||
authTypes={authTypes}
|
||||
isOpen={isDialogOpen}
|
||||
triggerRef={triggerRef}
|
||||
register={methods.register}
|
||||
onRevoke={handleRevokeApiKey}
|
||||
onOpenChange={setIsDialogOpen}
|
||||
handleSubmit={methods.handleSubmit}
|
||||
isToolAuthenticated={isAuthenticated}
|
||||
/>
|
||||
</>
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -83,7 +83,7 @@ export function filterModels(
|
||||
let modelName = modelId;
|
||||
|
||||
if (isAgentsEndpoint(endpoint.value) && agentsMap && agentsMap[modelId]) {
|
||||
modelName = agentsMap[modelId].name || modelId;
|
||||
modelName = agentsMap[modelId]?.name || modelId;
|
||||
} else if (
|
||||
isAssistantsEndpoint(endpoint.value) &&
|
||||
assistantsMap &&
|
||||
|
||||
@@ -1,16 +1,21 @@
|
||||
import { useWatch, useFormContext } from 'react-hook-form';
|
||||
import { SystemRoles, Permissions, PermissionTypes } from 'librechat-data-provider';
|
||||
import {
|
||||
SystemRoles,
|
||||
Permissions,
|
||||
PermissionTypes,
|
||||
PERMISSION_BITS,
|
||||
} from 'librechat-data-provider';
|
||||
import type { AgentForm, AgentPanelProps } from '~/common';
|
||||
import { useLocalize, useAuthContext, useHasAccess } from '~/hooks';
|
||||
import { useLocalize, useAuthContext, useHasAccess, useResourcePermissions } from '~/hooks';
|
||||
import GrantAccessDialog from './Sharing/GrantAccessDialog';
|
||||
import { useUpdateAgentMutation } from '~/data-provider';
|
||||
import AdvancedButton from './Advanced/AdvancedButton';
|
||||
import VersionButton from './Version/VersionButton';
|
||||
import DuplicateAgent from './DuplicateAgent';
|
||||
import AdminSettings from './AdminSettings';
|
||||
import DeleteButton from './DeleteButton';
|
||||
import { Spinner } from '~/components';
|
||||
import ShareAgent from './ShareAgent';
|
||||
import { Panel } from '~/common';
|
||||
import VersionButton from './Version/VersionButton';
|
||||
|
||||
export default function AgentFooter({
|
||||
activePanel,
|
||||
@@ -32,12 +37,17 @@ export default function AgentFooter({
|
||||
const { control } = methods;
|
||||
const agent = useWatch({ control, name: 'agent' });
|
||||
const agent_id = useWatch({ control, name: 'id' });
|
||||
|
||||
const hasAccessToShareAgents = useHasAccess({
|
||||
permissionType: PermissionTypes.AGENTS,
|
||||
permission: Permissions.SHARED_GLOBAL,
|
||||
});
|
||||
const { hasPermission, isLoading: permissionsLoading } = useResourcePermissions(
|
||||
'agent',
|
||||
agent?._id || '',
|
||||
);
|
||||
|
||||
const canShareThisAgent = hasPermission(PERMISSION_BITS.SHARE);
|
||||
const canDeleteThisAgent = hasPermission(PERMISSION_BITS.DELETE);
|
||||
const renderSaveButton = () => {
|
||||
if (createMutation.isLoading || updateMutation.isLoading) {
|
||||
return <Spinner className="icon-md" aria-hidden="true" />;
|
||||
@@ -59,18 +69,21 @@ export default function AgentFooter({
|
||||
{user?.role === SystemRoles.ADMIN && showButtons && <AdminSettings />}
|
||||
{/* Context Button */}
|
||||
<div className="flex items-center justify-end gap-2">
|
||||
<DeleteButton
|
||||
agent_id={agent_id}
|
||||
setCurrentAgentId={setCurrentAgentId}
|
||||
createMutation={createMutation}
|
||||
/>
|
||||
{(agent?.author === user?.id || user?.role === SystemRoles.ADMIN) &&
|
||||
hasAccessToShareAgents && (
|
||||
<ShareAgent
|
||||
{(agent?.author === user?.id || user?.role === SystemRoles.ADMIN || canDeleteThisAgent) &&
|
||||
!permissionsLoading && (
|
||||
<DeleteButton
|
||||
agent_id={agent_id}
|
||||
setCurrentAgentId={setCurrentAgentId}
|
||||
createMutation={createMutation}
|
||||
/>
|
||||
)}
|
||||
{(agent?.author === user?.id || user?.role === SystemRoles.ADMIN || canShareThisAgent) &&
|
||||
hasAccessToShareAgents &&
|
||||
!permissionsLoading && (
|
||||
<GrantAccessDialog
|
||||
agentDbId={agent?._id}
|
||||
agentId={agent_id}
|
||||
agentName={agent?.name ?? ''}
|
||||
projectIds={agent?.projectIds ?? []}
|
||||
isCollaborative={agent?.isCollaborative}
|
||||
/>
|
||||
)}
|
||||
{agent && agent.author === user?.id && <DuplicateAgent agent_id={agent_id} />}
|
||||
|
||||
@@ -8,6 +8,7 @@ import {
|
||||
SystemRoles,
|
||||
EModelEndpoint,
|
||||
TAgentsEndpoint,
|
||||
PERMISSION_BITS,
|
||||
TEndpointsConfig,
|
||||
isAssistantsEndpoint,
|
||||
} from 'librechat-data-provider';
|
||||
@@ -16,8 +17,10 @@ import {
|
||||
useCreateAgentMutation,
|
||||
useUpdateAgentMutation,
|
||||
useGetAgentByIdQuery,
|
||||
useGetExpandedAgentByIdQuery,
|
||||
} from '~/data-provider';
|
||||
import { createProviderOption, getDefaultAgentFormValues } from '~/utils';
|
||||
import { useResourcePermissions } from '~/hooks/useResourcePermissions';
|
||||
import { useSelectAgent, useLocalize, useAuthContext } from '~/hooks';
|
||||
import { useAgentPanelContext } from '~/Providers/AgentPanelContext';
|
||||
import AgentPanelSkeleton from './AgentPanelSkeleton';
|
||||
@@ -50,10 +53,29 @@ export default function AgentPanel({
|
||||
const { onSelect: onSelectAgent } = useSelectAgent();
|
||||
|
||||
const modelsQuery = useGetModelsQuery();
|
||||
const agentQuery = useGetAgentByIdQuery(current_agent_id ?? '', {
|
||||
|
||||
// Basic agent query for initial permission check
|
||||
const basicAgentQuery = useGetAgentByIdQuery(current_agent_id ?? '', {
|
||||
enabled: !!(current_agent_id ?? '') && current_agent_id !== Constants.EPHEMERAL_AGENT_ID,
|
||||
});
|
||||
|
||||
const { hasPermission, isLoading: permissionsLoading } = useResourcePermissions(
|
||||
'agent',
|
||||
basicAgentQuery.data?._id || '',
|
||||
);
|
||||
|
||||
const canEdit = hasPermission(PERMISSION_BITS.EDIT);
|
||||
|
||||
const expandedAgentQuery = useGetExpandedAgentByIdQuery(current_agent_id ?? '', {
|
||||
enabled:
|
||||
!!(current_agent_id ?? '') &&
|
||||
current_agent_id !== Constants.EPHEMERAL_AGENT_ID &&
|
||||
canEdit &&
|
||||
!permissionsLoading,
|
||||
});
|
||||
|
||||
const agentQuery = canEdit && expandedAgentQuery.data ? expandedAgentQuery : basicAgentQuery;
|
||||
|
||||
const models = useMemo(() => modelsQuery.data ?? {}, [modelsQuery.data]);
|
||||
const methods = useForm<AgentForm>({
|
||||
defaultValues: getDefaultAgentFormValues(),
|
||||
@@ -242,19 +264,16 @@ export default function AgentPanel({
|
||||
}, [agent_id, onSelectAgent]);
|
||||
|
||||
const canEditAgent = useMemo(() => {
|
||||
const canEdit =
|
||||
(agentQuery.data?.isCollaborative ?? false)
|
||||
? true
|
||||
: agentQuery.data?.author === user?.id || user?.role === SystemRoles.ADMIN;
|
||||
if (!agentQuery.data?.id) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return agentQuery.data?.id != null && agentQuery.data.id ? canEdit : true;
|
||||
}, [
|
||||
agentQuery.data?.isCollaborative,
|
||||
agentQuery.data?.author,
|
||||
agentQuery.data?.id,
|
||||
user?.id,
|
||||
user?.role,
|
||||
]);
|
||||
if (agentQuery.data?.author === user?.id || user?.role === SystemRoles.ADMIN) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return canEdit;
|
||||
}, [agentQuery.data?.author, agentQuery.data?.id, user?.id, user?.role, canEdit]);
|
||||
|
||||
return (
|
||||
<FormProvider {...methods}>
|
||||
|
||||
@@ -43,9 +43,7 @@ export default function AgentSelect({
|
||||
|
||||
const resetAgentForm = useCallback(
|
||||
(fullAgent: Agent) => {
|
||||
const { instanceProjectId } = startupConfig ?? {};
|
||||
const isGlobal =
|
||||
(instanceProjectId != null && fullAgent.projectIds?.includes(instanceProjectId)) ?? false;
|
||||
const isGlobal = fullAgent.isPublic ?? false;
|
||||
const update = {
|
||||
...fullAgent,
|
||||
provider: createProviderOption(fullAgent.provider),
|
||||
|
||||
@@ -15,6 +15,7 @@ export default function ApiKeyDialog({
|
||||
register,
|
||||
handleSubmit,
|
||||
triggerRef,
|
||||
triggerRefs,
|
||||
}: {
|
||||
isOpen: boolean;
|
||||
onOpenChange: (open: boolean) => void;
|
||||
@@ -24,7 +25,8 @@ export default function ApiKeyDialog({
|
||||
isToolAuthenticated: boolean;
|
||||
register: UseFormRegister<ApiKeyFormData>;
|
||||
handleSubmit: UseFormHandleSubmit<ApiKeyFormData>;
|
||||
triggerRef?: RefObject<HTMLInputElement>;
|
||||
triggerRef?: RefObject<HTMLInputElement | HTMLButtonElement>;
|
||||
triggerRefs?: RefObject<HTMLInputElement | HTMLButtonElement>[];
|
||||
}) {
|
||||
const localize = useLocalize();
|
||||
const languageIcons = [
|
||||
@@ -41,7 +43,12 @@ export default function ApiKeyDialog({
|
||||
];
|
||||
|
||||
return (
|
||||
<OGDialog open={isOpen} onOpenChange={onOpenChange} triggerRef={triggerRef}>
|
||||
<OGDialog
|
||||
open={isOpen}
|
||||
onOpenChange={onOpenChange}
|
||||
triggerRef={triggerRef}
|
||||
triggerRefs={triggerRefs}
|
||||
>
|
||||
<OGDialogTemplate
|
||||
className="w-11/12 sm:w-[450px]"
|
||||
title=""
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { useState, useEffect } from 'react';
|
||||
import { useFormContext, Controller } from 'react-hook-form';
|
||||
import { MCP } from 'librechat-data-provider/dist/types/types/assistants';
|
||||
import type { MCP } from 'librechat-data-provider';
|
||||
import MCPAuth from '~/components/SidePanel/Builder/MCPAuth';
|
||||
import MCPIcon from '~/components/SidePanel/Agents/MCPIcon';
|
||||
import { Label, Checkbox } from '~/components/ui';
|
||||
|
||||
@@ -21,6 +21,7 @@ export default function ApiKeyDialog({
|
||||
register,
|
||||
handleSubmit,
|
||||
triggerRef,
|
||||
triggerRefs,
|
||||
}: {
|
||||
isOpen: boolean;
|
||||
onOpenChange: (open: boolean) => void;
|
||||
@@ -30,7 +31,8 @@ export default function ApiKeyDialog({
|
||||
isToolAuthenticated: boolean;
|
||||
register: UseFormRegister<SearchApiKeyFormData>;
|
||||
handleSubmit: UseFormHandleSubmit<SearchApiKeyFormData>;
|
||||
triggerRef?: React.RefObject<HTMLInputElement>;
|
||||
triggerRef?: React.RefObject<HTMLInputElement | HTMLButtonElement>;
|
||||
triggerRefs?: React.RefObject<HTMLInputElement | HTMLButtonElement>[];
|
||||
}) {
|
||||
const localize = useLocalize();
|
||||
const { data: config } = useGetStartupConfig();
|
||||
@@ -181,7 +183,12 @@ export default function ApiKeyDialog({
|
||||
}
|
||||
|
||||
return (
|
||||
<OGDialog open={isOpen} onOpenChange={onOpenChange} triggerRef={triggerRef}>
|
||||
<OGDialog
|
||||
open={isOpen}
|
||||
onOpenChange={onOpenChange}
|
||||
triggerRef={triggerRef}
|
||||
triggerRefs={triggerRefs}
|
||||
>
|
||||
<OGDialogTemplate
|
||||
className="w-11/12 sm:w-[500px]"
|
||||
title=""
|
||||
|
||||
@@ -1,272 +0,0 @@
|
||||
import React, { useEffect, useMemo } from 'react';
|
||||
import { Share2Icon } from 'lucide-react';
|
||||
import { useForm, Controller } from 'react-hook-form';
|
||||
import { Permissions } from 'librechat-data-provider';
|
||||
import type { TStartupConfig, AgentUpdateParams } from 'librechat-data-provider';
|
||||
import {
|
||||
Button,
|
||||
Switch,
|
||||
OGDialog,
|
||||
OGDialogTitle,
|
||||
OGDialogClose,
|
||||
OGDialogContent,
|
||||
OGDialogTrigger,
|
||||
} from '~/components/ui';
|
||||
import { useUpdateAgentMutation, useGetStartupConfig } from '~/data-provider';
|
||||
import { cn, removeFocusOutlines } from '~/utils';
|
||||
import { useToastContext } from '~/Providers';
|
||||
import { useLocalize } from '~/hooks';
|
||||
|
||||
type FormValues = {
|
||||
[Permissions.SHARED_GLOBAL]: boolean;
|
||||
[Permissions.UPDATE]: boolean;
|
||||
};
|
||||
|
||||
export default function ShareAgent({
|
||||
agent_id = '',
|
||||
agentName,
|
||||
projectIds = [],
|
||||
isCollaborative = false,
|
||||
}: {
|
||||
agent_id?: string;
|
||||
agentName?: string;
|
||||
projectIds?: string[];
|
||||
isCollaborative?: boolean;
|
||||
}) {
|
||||
const localize = useLocalize();
|
||||
const { showToast } = useToastContext();
|
||||
const { data: startupConfig = {} as TStartupConfig, isFetching } = useGetStartupConfig();
|
||||
const { instanceProjectId } = startupConfig;
|
||||
const agentIsGlobal = useMemo(
|
||||
() => !!projectIds.includes(instanceProjectId),
|
||||
[projectIds, instanceProjectId],
|
||||
);
|
||||
|
||||
const {
|
||||
watch,
|
||||
control,
|
||||
setValue,
|
||||
getValues,
|
||||
handleSubmit,
|
||||
formState: { isSubmitting },
|
||||
} = useForm<FormValues>({
|
||||
mode: 'onChange',
|
||||
defaultValues: {
|
||||
[Permissions.SHARED_GLOBAL]: agentIsGlobal,
|
||||
[Permissions.UPDATE]: isCollaborative,
|
||||
},
|
||||
});
|
||||
|
||||
const sharedGlobalValue = watch(Permissions.SHARED_GLOBAL);
|
||||
|
||||
useEffect(() => {
|
||||
if (!sharedGlobalValue) {
|
||||
setValue(Permissions.UPDATE, false);
|
||||
}
|
||||
}, [sharedGlobalValue, setValue]);
|
||||
|
||||
useEffect(() => {
|
||||
setValue(Permissions.SHARED_GLOBAL, agentIsGlobal);
|
||||
setValue(Permissions.UPDATE, isCollaborative);
|
||||
}, [agentIsGlobal, isCollaborative, setValue]);
|
||||
|
||||
const updateAgent = useUpdateAgentMutation({
|
||||
onSuccess: (data) => {
|
||||
showToast({
|
||||
message: `${localize('com_assistants_update_success')} ${
|
||||
data.name ?? localize('com_ui_agent')
|
||||
}`,
|
||||
status: 'success',
|
||||
});
|
||||
},
|
||||
onError: (err) => {
|
||||
const error = err as Error;
|
||||
showToast({
|
||||
message: `${localize('com_agents_update_error')}${
|
||||
error.message ? ` ${localize('com_ui_error')}: ${error.message}` : ''
|
||||
}`,
|
||||
status: 'error',
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
if (!agent_id || !instanceProjectId) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const onSubmit = (data: FormValues) => {
|
||||
if (!agent_id || !instanceProjectId) {
|
||||
return;
|
||||
}
|
||||
|
||||
const payload = {} as AgentUpdateParams;
|
||||
|
||||
if (data[Permissions.UPDATE] !== isCollaborative) {
|
||||
payload.isCollaborative = data[Permissions.UPDATE];
|
||||
}
|
||||
|
||||
if (data[Permissions.SHARED_GLOBAL] !== agentIsGlobal) {
|
||||
if (data[Permissions.SHARED_GLOBAL]) {
|
||||
payload.projectIds = [startupConfig.instanceProjectId];
|
||||
} else {
|
||||
payload.removeProjectIds = [startupConfig.instanceProjectId];
|
||||
payload.isCollaborative = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (Object.keys(payload).length > 0) {
|
||||
updateAgent.mutate({
|
||||
agent_id,
|
||||
data: payload,
|
||||
});
|
||||
} else {
|
||||
showToast({
|
||||
message: localize('com_ui_no_changes'),
|
||||
status: 'info',
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<OGDialog>
|
||||
<OGDialogTrigger asChild>
|
||||
<button
|
||||
className={cn(
|
||||
'btn btn-neutral border-token-border-light relative h-9 rounded-lg font-medium',
|
||||
removeFocusOutlines,
|
||||
)}
|
||||
aria-label={localize(
|
||||
'com_ui_share_var',
|
||||
{ 0: agentName != null && agentName !== '' ? `"${agentName}"` : localize('com_ui_agent') },
|
||||
)}
|
||||
type="button"
|
||||
>
|
||||
<div className="flex items-center justify-center gap-2 text-blue-500">
|
||||
<Share2Icon className="icon-md h-4 w-4" />
|
||||
</div>
|
||||
</button>
|
||||
</OGDialogTrigger>
|
||||
<OGDialogContent className="w-11/12 md:max-w-xl">
|
||||
<OGDialogTitle>
|
||||
{localize(
|
||||
'com_ui_share_var',
|
||||
{ 0: agentName != null && agentName !== '' ? `"${agentName}"` : localize('com_ui_agent') },
|
||||
)}
|
||||
</OGDialogTitle>
|
||||
<form
|
||||
className="p-2"
|
||||
onSubmit={(e) => {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
handleSubmit(onSubmit)(e);
|
||||
}}
|
||||
>
|
||||
<div className="flex items-center justify-between gap-2 py-2">
|
||||
<div className="flex items-center">
|
||||
<button
|
||||
type="button"
|
||||
className="mr-2 cursor-pointer"
|
||||
disabled={isFetching || updateAgent.isLoading || !instanceProjectId}
|
||||
onClick={() =>
|
||||
setValue(Permissions.SHARED_GLOBAL, !getValues(Permissions.SHARED_GLOBAL), {
|
||||
shouldDirty: true,
|
||||
})
|
||||
}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === 'Enter' || e.key === ' ') {
|
||||
e.preventDefault();
|
||||
setValue(Permissions.SHARED_GLOBAL, !getValues(Permissions.SHARED_GLOBAL), {
|
||||
shouldDirty: true,
|
||||
});
|
||||
}
|
||||
}}
|
||||
aria-checked={getValues(Permissions.SHARED_GLOBAL)}
|
||||
role="checkbox"
|
||||
>
|
||||
{localize('com_ui_share_to_all_users')}
|
||||
</button>
|
||||
<label htmlFor={Permissions.SHARED_GLOBAL} className="select-none">
|
||||
{agentIsGlobal && (
|
||||
<span className="ml-2 text-xs">{localize('com_ui_agent_shared_to_all')}</span>
|
||||
)}
|
||||
</label>
|
||||
</div>
|
||||
<Controller
|
||||
name={Permissions.SHARED_GLOBAL}
|
||||
control={control}
|
||||
disabled={isFetching || updateAgent.isLoading || !instanceProjectId}
|
||||
render={({ field }) => (
|
||||
<Switch
|
||||
{...field}
|
||||
checked={field.value}
|
||||
onCheckedChange={field.onChange}
|
||||
value={field.value.toString()}
|
||||
/>
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
<div className="mb-4 flex items-center justify-between gap-2 py-2">
|
||||
<div className="flex items-center">
|
||||
<button
|
||||
type="button"
|
||||
className="mr-2 cursor-pointer"
|
||||
disabled={
|
||||
isFetching || updateAgent.isLoading || !instanceProjectId || !sharedGlobalValue
|
||||
}
|
||||
onClick={() =>
|
||||
setValue(Permissions.UPDATE, !getValues(Permissions.UPDATE), {
|
||||
shouldDirty: true,
|
||||
})
|
||||
}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === 'Enter' || e.key === ' ') {
|
||||
e.preventDefault();
|
||||
setValue(Permissions.UPDATE, !getValues(Permissions.UPDATE), {
|
||||
shouldDirty: true,
|
||||
});
|
||||
}
|
||||
}}
|
||||
aria-checked={getValues(Permissions.UPDATE)}
|
||||
role="checkbox"
|
||||
>
|
||||
{localize('com_agents_allow_editing')}
|
||||
</button>
|
||||
{/* <label htmlFor={Permissions.UPDATE} className="select-none">
|
||||
{agentIsGlobal && (
|
||||
<span className="ml-2 text-xs">{localize('com_ui_agent_editing_allowed')}</span>
|
||||
)}
|
||||
</label> */}
|
||||
</div>
|
||||
<Controller
|
||||
name={Permissions.UPDATE}
|
||||
control={control}
|
||||
disabled={
|
||||
isFetching || updateAgent.isLoading || !instanceProjectId || !sharedGlobalValue
|
||||
}
|
||||
render={({ field }) => (
|
||||
<Switch
|
||||
{...field}
|
||||
checked={field.value}
|
||||
onCheckedChange={field.onChange}
|
||||
value={field.value.toString()}
|
||||
/>
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex justify-end">
|
||||
<OGDialogClose asChild>
|
||||
<Button
|
||||
variant="submit"
|
||||
size="sm"
|
||||
type="submit"
|
||||
disabled={isSubmitting || isFetching}
|
||||
>
|
||||
{localize('com_ui_save')}
|
||||
</Button>
|
||||
</OGDialogClose>
|
||||
</div>
|
||||
</form>
|
||||
</OGDialogContent>
|
||||
</OGDialog>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
import React from 'react';
|
||||
import { ACCESS_ROLE_IDS } from 'librechat-data-provider';
|
||||
import type { AccessRole } from 'librechat-data-provider';
|
||||
import { SelectDropDownPop } from '~/components/ui';
|
||||
import { useGetAccessRolesQuery } from 'librechat-data-provider/react-query';
|
||||
import { useLocalize } from '~/hooks';
|
||||
|
||||
interface AccessRolesPickerProps {
|
||||
resourceType?: string;
|
||||
selectedRoleId?: string;
|
||||
onRoleChange: (roleId: string) => void;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export default function AccessRolesPicker({
|
||||
resourceType = 'agent',
|
||||
selectedRoleId = ACCESS_ROLE_IDS.AGENT_VIEWER,
|
||||
onRoleChange,
|
||||
className = '',
|
||||
}: AccessRolesPickerProps) {
|
||||
const localize = useLocalize();
|
||||
|
||||
// Fetch access roles from API
|
||||
const { data: accessRoles, isLoading: rolesLoading } = useGetAccessRolesQuery(resourceType);
|
||||
|
||||
// Helper function to get localized role name and description
|
||||
const getLocalizedRoleInfo = (roleId: string) => {
|
||||
switch (roleId) {
|
||||
case 'agent_viewer':
|
||||
return {
|
||||
name: localize('com_ui_role_viewer'),
|
||||
description: localize('com_ui_role_viewer_desc'),
|
||||
};
|
||||
case 'agent_editor':
|
||||
return {
|
||||
name: localize('com_ui_role_editor'),
|
||||
description: localize('com_ui_role_editor_desc'),
|
||||
};
|
||||
case 'agent_manager':
|
||||
return {
|
||||
name: localize('com_ui_role_manager'),
|
||||
description: localize('com_ui_role_manager_desc'),
|
||||
};
|
||||
case 'agent_owner':
|
||||
return {
|
||||
name: localize('com_ui_role_owner'),
|
||||
description: localize('com_ui_role_owner_desc'),
|
||||
};
|
||||
default:
|
||||
return {
|
||||
name: localize('com_ui_unknown'),
|
||||
description: localize('com_ui_unknown'),
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
// Find the currently selected role
|
||||
const selectedRole = accessRoles?.find((role) => role.accessRoleId === selectedRoleId);
|
||||
|
||||
if (rolesLoading || !accessRoles) {
|
||||
return (
|
||||
<div className={className}>
|
||||
<div className="flex items-center justify-center py-2">
|
||||
<div className="h-4 w-4 animate-spin rounded-full border-2 border-gray-300 border-t-blue-600"></div>
|
||||
<span className="ml-2 text-sm text-gray-500">Loading roles...</span>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={className}>
|
||||
<SelectDropDownPop
|
||||
availableValues={accessRoles.map((role: AccessRole) => {
|
||||
const localizedInfo = getLocalizedRoleInfo(role.accessRoleId);
|
||||
return {
|
||||
value: role.accessRoleId,
|
||||
label: localizedInfo.name,
|
||||
description: localizedInfo.description,
|
||||
};
|
||||
})}
|
||||
showLabel={false}
|
||||
value={
|
||||
selectedRole
|
||||
? (() => {
|
||||
const localizedInfo = getLocalizedRoleInfo(selectedRole.accessRoleId);
|
||||
return {
|
||||
value: selectedRole.accessRoleId,
|
||||
label: localizedInfo.name,
|
||||
description: localizedInfo.description,
|
||||
};
|
||||
})()
|
||||
: null
|
||||
}
|
||||
setValue={onRoleChange}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,266 @@
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import { Share2Icon, Users, Loader, Shield, Link, CopyCheck } from 'lucide-react';
|
||||
import { ACCESS_ROLE_IDS } from 'librechat-data-provider';
|
||||
import type { TPrincipal } from 'librechat-data-provider';
|
||||
import {
|
||||
Button,
|
||||
OGDialog,
|
||||
OGDialogTitle,
|
||||
OGDialogClose,
|
||||
OGDialogContent,
|
||||
OGDialogTrigger,
|
||||
} from '~/components/ui';
|
||||
import { cn, removeFocusOutlines } from '~/utils';
|
||||
import { useToastContext } from '~/Providers';
|
||||
import { useLocalize, useCopyToClipboard } from '~/hooks';
|
||||
import {
|
||||
useGetResourcePermissionsQuery,
|
||||
useUpdateResourcePermissionsMutation,
|
||||
} from 'librechat-data-provider/react-query';
|
||||
|
||||
import PeoplePicker from './PeoplePicker/PeoplePicker';
|
||||
import PublicSharingToggle from './PublicSharingToggle';
|
||||
import ManagePermissionsDialog from './ManagePermissionsDialog';
|
||||
import AccessRolesPicker from './AccessRolesPicker';
|
||||
|
||||
export default function GrantAccessDialog({
|
||||
agentName,
|
||||
onGrantAccess,
|
||||
resourceType = 'agent',
|
||||
agentDbId,
|
||||
agentId,
|
||||
}: {
|
||||
agentDbId?: string | null;
|
||||
agentId?: string | null;
|
||||
agentName?: string;
|
||||
onGrantAccess?: (shares: TPrincipal[], isPublic: boolean, publicRole: string) => void;
|
||||
resourceType?: string;
|
||||
}) {
|
||||
const localize = useLocalize();
|
||||
const { showToast } = useToastContext();
|
||||
|
||||
const {
|
||||
data: permissionsData,
|
||||
// isLoading: isLoadingPermissions,
|
||||
// error: permissionsError,
|
||||
} = useGetResourcePermissionsQuery(resourceType, agentDbId!, {
|
||||
enabled: !!agentDbId,
|
||||
});
|
||||
|
||||
const updatePermissionsMutation = useUpdateResourcePermissionsMutation();
|
||||
|
||||
const [newShares, setNewShares] = useState<TPrincipal[]>([]);
|
||||
const [defaultPermissionId, setDefaultPermissionId] = useState<string>(
|
||||
ACCESS_ROLE_IDS.AGENT_VIEWER,
|
||||
);
|
||||
const [isModalOpen, setIsModalOpen] = useState(false);
|
||||
const [isCopying, setIsCopying] = useState(false);
|
||||
|
||||
const agentUrl = `${window.location.origin}/c/new?agent_id=${agentId}`;
|
||||
const copyAgentUrl = useCopyToClipboard({ text: agentUrl });
|
||||
|
||||
const currentShares: TPrincipal[] =
|
||||
permissionsData?.principals?.map((principal) => ({
|
||||
type: principal.type,
|
||||
id: principal.id,
|
||||
name: principal.name,
|
||||
email: principal.email,
|
||||
source: principal.source,
|
||||
avatar: principal.avatar,
|
||||
description: principal.description,
|
||||
accessRoleId: principal.accessRoleId,
|
||||
})) || [];
|
||||
|
||||
const currentIsPublic = permissionsData?.public ?? false;
|
||||
const currentPublicRole = permissionsData?.publicAccessRoleId || ACCESS_ROLE_IDS.AGENT_VIEWER;
|
||||
|
||||
const [isPublic, setIsPublic] = useState(false);
|
||||
const [publicRole, setPublicRole] = useState<string>(ACCESS_ROLE_IDS.AGENT_VIEWER);
|
||||
|
||||
useEffect(() => {
|
||||
if (permissionsData && isModalOpen) {
|
||||
setIsPublic(currentIsPublic ?? false);
|
||||
setPublicRole(currentPublicRole);
|
||||
}
|
||||
}, [permissionsData, isModalOpen, currentIsPublic, currentPublicRole]);
|
||||
|
||||
if (!agentDbId) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const handleGrantAccess = async () => {
|
||||
try {
|
||||
const sharesToAdd = newShares.map((share) => ({
|
||||
...share,
|
||||
accessRoleId: defaultPermissionId,
|
||||
}));
|
||||
|
||||
const allShares = [...currentShares, ...sharesToAdd];
|
||||
|
||||
await updatePermissionsMutation.mutateAsync({
|
||||
resourceType,
|
||||
resourceId: agentDbId,
|
||||
data: {
|
||||
updated: sharesToAdd,
|
||||
removed: [],
|
||||
public: isPublic,
|
||||
publicAccessRoleId: isPublic ? publicRole : undefined,
|
||||
},
|
||||
});
|
||||
|
||||
if (onGrantAccess) {
|
||||
onGrantAccess(allShares, isPublic, publicRole);
|
||||
}
|
||||
|
||||
showToast({
|
||||
message: `Access granted successfully to ${newShares.length} ${newShares.length === 1 ? 'person' : 'people'}${isPublic ? ' and made public' : ''}`,
|
||||
status: 'success',
|
||||
});
|
||||
|
||||
setNewShares([]);
|
||||
setDefaultPermissionId(ACCESS_ROLE_IDS.AGENT_VIEWER);
|
||||
setIsPublic(false);
|
||||
setPublicRole(ACCESS_ROLE_IDS.AGENT_VIEWER);
|
||||
setIsModalOpen(false);
|
||||
} catch (error) {
|
||||
console.error('Error granting access:', error);
|
||||
showToast({
|
||||
message: 'Failed to grant access. Please try again.',
|
||||
status: 'error',
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
const handleCancel = () => {
|
||||
setNewShares([]);
|
||||
setDefaultPermissionId(ACCESS_ROLE_IDS.AGENT_VIEWER);
|
||||
setIsPublic(false);
|
||||
setPublicRole(ACCESS_ROLE_IDS.AGENT_VIEWER);
|
||||
setIsModalOpen(false);
|
||||
};
|
||||
|
||||
const totalCurrentShares = currentShares.length + (currentIsPublic ? 1 : 0);
|
||||
const submitButtonActive =
|
||||
newShares.length > 0 || isPublic !== currentIsPublic || publicRole !== currentPublicRole;
|
||||
return (
|
||||
<OGDialog open={isModalOpen} onOpenChange={setIsModalOpen} modal>
|
||||
<OGDialogTrigger asChild>
|
||||
<button
|
||||
className={cn(
|
||||
'btn btn-neutral border-token-border-light relative h-9 rounded-lg font-medium',
|
||||
removeFocusOutlines,
|
||||
)}
|
||||
aria-label={localize('com_ui_share_var', {
|
||||
0: agentName != null && agentName !== '' ? `"${agentName}"` : localize('com_ui_agent'),
|
||||
})}
|
||||
type="button"
|
||||
>
|
||||
<div className="flex items-center justify-center gap-2 text-blue-500">
|
||||
<Share2Icon className="icon-md h-4 w-4" />
|
||||
{totalCurrentShares > 0 && (
|
||||
<span className="rounded-full bg-blue-100 px-1.5 py-0.5 text-xs font-medium text-blue-800 dark:bg-blue-900 dark:text-blue-300">
|
||||
{totalCurrentShares}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</button>
|
||||
</OGDialogTrigger>
|
||||
|
||||
<OGDialogContent className="max-h-[90vh] w-11/12 overflow-y-auto md:max-w-3xl">
|
||||
<OGDialogTitle>
|
||||
<div className="flex items-center gap-2">
|
||||
<Users className="h-5 w-5" />
|
||||
{localize('com_ui_share_var', {
|
||||
0:
|
||||
agentName != null && agentName !== '' ? `"${agentName}"` : localize('com_ui_agent'),
|
||||
})}
|
||||
</div>
|
||||
</OGDialogTitle>
|
||||
|
||||
<div className="space-y-6 p-2">
|
||||
<PeoplePicker
|
||||
onSelectionChange={setNewShares}
|
||||
placeholder={localize('com_ui_search_people_placeholder')}
|
||||
/>
|
||||
|
||||
<div className="space-y-3">
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center gap-2">
|
||||
<Shield className="h-4 w-4 text-text-secondary" />
|
||||
<label className="text-sm font-medium text-text-primary">
|
||||
{localize('com_ui_permission_level')}
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
<AccessRolesPicker
|
||||
resourceType={resourceType}
|
||||
selectedRoleId={defaultPermissionId}
|
||||
onRoleChange={setDefaultPermissionId}
|
||||
/>
|
||||
</div>
|
||||
<PublicSharingToggle
|
||||
isPublic={isPublic}
|
||||
publicRole={publicRole}
|
||||
onPublicToggle={setIsPublic}
|
||||
onPublicRoleChange={setPublicRole}
|
||||
resourceType={resourceType}
|
||||
/>
|
||||
<div className="flex justify-between border-t pt-4">
|
||||
<div className="flex gap-2">
|
||||
<ManagePermissionsDialog
|
||||
agentDbId={agentDbId}
|
||||
agentName={agentName}
|
||||
resourceType={resourceType}
|
||||
/>
|
||||
{agentId && (
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() => {
|
||||
if (isCopying) return;
|
||||
copyAgentUrl(setIsCopying);
|
||||
showToast({
|
||||
message: localize('com_ui_agent_url_copied'),
|
||||
status: 'success',
|
||||
});
|
||||
}}
|
||||
disabled={isCopying}
|
||||
className={cn('shrink-0', isCopying ? 'cursor-default' : '')}
|
||||
aria-label={localize('com_ui_copy_url_to_clipboard')}
|
||||
title={
|
||||
isCopying
|
||||
? localize('com_ui_agent_url_copied')
|
||||
: localize('com_ui_copy_url_to_clipboard')
|
||||
}
|
||||
>
|
||||
{isCopying ? <CopyCheck className="h-4 w-4" /> : <Link className="h-4 w-4" />}
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
<div className="flex gap-3">
|
||||
<OGDialogClose asChild>
|
||||
<Button variant="outline" onClick={handleCancel}>
|
||||
{localize('com_ui_cancel')}
|
||||
</Button>
|
||||
</OGDialogClose>
|
||||
<Button
|
||||
onClick={handleGrantAccess}
|
||||
disabled={updatePermissionsMutation.isLoading || !submitButtonActive}
|
||||
className="min-w-[120px]"
|
||||
>
|
||||
{updatePermissionsMutation.isLoading ? (
|
||||
<div className="flex items-center gap-2">
|
||||
<Loader className="h-4 w-4 animate-spin" />
|
||||
{localize('com_ui_granting')}
|
||||
</div>
|
||||
) : (
|
||||
localize('com_ui_grant_access')
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</OGDialogContent>
|
||||
</OGDialog>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,349 @@
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import { Settings, Users, Loader, UserCheck, Trash2, Shield } from 'lucide-react';
|
||||
import { ACCESS_ROLE_IDS, TPrincipal } from 'librechat-data-provider';
|
||||
import {
|
||||
Button,
|
||||
OGDialog,
|
||||
OGDialogTitle,
|
||||
OGDialogClose,
|
||||
OGDialogContent,
|
||||
OGDialogTrigger,
|
||||
} from '~/components/ui';
|
||||
import { cn, removeFocusOutlines } from '~/utils';
|
||||
import { useToastContext } from '~/Providers';
|
||||
import { useLocalize } from '~/hooks';
|
||||
import {
|
||||
useGetAccessRolesQuery,
|
||||
useGetResourcePermissionsQuery,
|
||||
useUpdateResourcePermissionsMutation,
|
||||
} from 'librechat-data-provider/react-query';
|
||||
|
||||
import SelectedPrincipalsList from './PeoplePicker/SelectedPrincipalsList';
|
||||
import PublicSharingToggle from './PublicSharingToggle';
|
||||
|
||||
export default function ManagePermissionsDialog({
|
||||
agentDbId,
|
||||
agentName,
|
||||
resourceType = 'agent',
|
||||
onUpdatePermissions,
|
||||
}: {
|
||||
agentDbId: string;
|
||||
agentName?: string;
|
||||
resourceType?: string;
|
||||
onUpdatePermissions?: (shares: TPrincipal[], isPublic: boolean, publicRole: string) => void;
|
||||
}) {
|
||||
const localize = useLocalize();
|
||||
const { showToast } = useToastContext();
|
||||
|
||||
const {
|
||||
data: permissionsData,
|
||||
isLoading: isLoadingPermissions,
|
||||
error: permissionsError,
|
||||
} = useGetResourcePermissionsQuery(resourceType, agentDbId, {
|
||||
enabled: !!agentDbId,
|
||||
});
|
||||
const {
|
||||
data: accessRoles,
|
||||
// isLoading,
|
||||
} = useGetAccessRolesQuery(resourceType);
|
||||
|
||||
const updatePermissionsMutation = useUpdateResourcePermissionsMutation();
|
||||
|
||||
const [managedShares, setManagedShares] = useState<TPrincipal[]>([]);
|
||||
const [managedIsPublic, setManagedIsPublic] = useState(false);
|
||||
const [managedPublicRole, setManagedPublicRole] = useState<string>(ACCESS_ROLE_IDS.AGENT_VIEWER);
|
||||
const [isModalOpen, setIsModalOpen] = useState(false);
|
||||
const [hasChanges, setHasChanges] = useState(false);
|
||||
|
||||
const currentShares: TPrincipal[] = permissionsData?.principals || [];
|
||||
|
||||
const isPublic = permissionsData?.public || false;
|
||||
const publicRole = permissionsData?.publicAccessRoleId || ACCESS_ROLE_IDS.AGENT_VIEWER;
|
||||
|
||||
useEffect(() => {
|
||||
if (permissionsData) {
|
||||
setManagedShares(currentShares);
|
||||
setManagedIsPublic(isPublic);
|
||||
setManagedPublicRole(publicRole);
|
||||
setHasChanges(false);
|
||||
}
|
||||
}, [permissionsData, isModalOpen]);
|
||||
|
||||
if (!agentDbId) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (permissionsError) {
|
||||
return <div className="text-sm text-red-600">{localize('com_ui_permissions_failed_load')}</div>;
|
||||
}
|
||||
|
||||
const handleRemoveShare = (idOnTheSource: string) => {
|
||||
setManagedShares(managedShares.filter((s) => s.idOnTheSource !== idOnTheSource));
|
||||
setHasChanges(true);
|
||||
};
|
||||
|
||||
const handleRoleChange = (idOnTheSource: string, newRole: string) => {
|
||||
setManagedShares(
|
||||
managedShares.map((s) =>
|
||||
s.idOnTheSource === idOnTheSource ? { ...s, accessRoleId: newRole } : s,
|
||||
),
|
||||
);
|
||||
setHasChanges(true);
|
||||
};
|
||||
|
||||
const handleSaveChanges = async () => {
|
||||
try {
|
||||
const originalSharesMap = new Map(
|
||||
currentShares.map((share) => [`${share.type}-${share.idOnTheSource}`, share]),
|
||||
);
|
||||
const managedSharesMap = new Map(
|
||||
managedShares.map((share) => [`${share.type}-${share.idOnTheSource}`, share]),
|
||||
);
|
||||
|
||||
const updated = managedShares.filter((share) => {
|
||||
const key = `${share.type}-${share.idOnTheSource}`;
|
||||
const original = originalSharesMap.get(key);
|
||||
return !original || original.accessRoleId !== share.accessRoleId;
|
||||
});
|
||||
|
||||
const removed = currentShares.filter((share) => {
|
||||
const key = `${share.type}-${share.idOnTheSource}`;
|
||||
return !managedSharesMap.has(key);
|
||||
});
|
||||
|
||||
await updatePermissionsMutation.mutateAsync({
|
||||
resourceType,
|
||||
resourceId: agentDbId,
|
||||
data: {
|
||||
updated,
|
||||
removed,
|
||||
public: managedIsPublic,
|
||||
publicAccessRoleId: managedIsPublic ? managedPublicRole : undefined,
|
||||
},
|
||||
});
|
||||
|
||||
if (onUpdatePermissions) {
|
||||
onUpdatePermissions(managedShares, managedIsPublic, managedPublicRole);
|
||||
}
|
||||
|
||||
showToast({
|
||||
message: localize('com_ui_permissions_updated_success'),
|
||||
status: 'success',
|
||||
});
|
||||
|
||||
setIsModalOpen(false);
|
||||
} catch (error) {
|
||||
console.error('Error updating permissions:', error);
|
||||
showToast({
|
||||
message: localize('com_ui_permissions_failed_update'),
|
||||
status: 'error',
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
const handleCancel = () => {
|
||||
setManagedShares(currentShares);
|
||||
setManagedIsPublic(isPublic);
|
||||
setManagedPublicRole(publicRole);
|
||||
setIsModalOpen(false);
|
||||
};
|
||||
|
||||
const handleRevokeAll = () => {
|
||||
setManagedShares([]);
|
||||
setManagedIsPublic(false);
|
||||
setHasChanges(true);
|
||||
};
|
||||
const handlePublicToggle = (isPublic: boolean) => {
|
||||
setManagedIsPublic(isPublic);
|
||||
setHasChanges(true);
|
||||
if (!isPublic) {
|
||||
setManagedPublicRole(ACCESS_ROLE_IDS.AGENT_VIEWER);
|
||||
}
|
||||
};
|
||||
const handlePublicRoleChange = (role: string) => {
|
||||
setManagedPublicRole(role);
|
||||
setHasChanges(true);
|
||||
};
|
||||
const totalShares = managedShares.length + (managedIsPublic ? 1 : 0);
|
||||
const originalTotalShares = currentShares.length + (isPublic ? 1 : 0);
|
||||
|
||||
/** Check if there's at least one owner (user, group, or public with owner role) */
|
||||
const hasAtLeastOneOwner =
|
||||
managedShares.some((share) => share.accessRoleId === ACCESS_ROLE_IDS.AGENT_OWNER) ||
|
||||
(managedIsPublic && managedPublicRole === ACCESS_ROLE_IDS.AGENT_OWNER);
|
||||
|
||||
let peopleLabel = localize('com_ui_people');
|
||||
if (managedShares.length === 1) {
|
||||
peopleLabel = localize('com_ui_person');
|
||||
}
|
||||
|
||||
let buttonAriaLabel = localize('com_ui_manage_permissions_for') + ' agent';
|
||||
if (agentName != null && agentName !== '') {
|
||||
buttonAriaLabel = localize('com_ui_manage_permissions_for') + ` "${agentName}"`;
|
||||
}
|
||||
|
||||
let dialogTitle = localize('com_ui_manage_permissions_for') + ' Agent';
|
||||
if (agentName != null && agentName !== '') {
|
||||
dialogTitle = localize('com_ui_manage_permissions_for') + ` "${agentName}"`;
|
||||
}
|
||||
|
||||
let publicSuffix = '';
|
||||
if (managedIsPublic) {
|
||||
publicSuffix = localize('com_ui_and_public');
|
||||
}
|
||||
|
||||
return (
|
||||
<OGDialog open={isModalOpen} onOpenChange={setIsModalOpen}>
|
||||
<OGDialogTrigger asChild>
|
||||
<button
|
||||
className={cn(
|
||||
'btn btn-neutral border-token-border-light relative h-9 rounded-lg font-medium',
|
||||
removeFocusOutlines,
|
||||
)}
|
||||
aria-label={buttonAriaLabel}
|
||||
type="button"
|
||||
>
|
||||
<div className="flex items-center justify-center gap-2 text-blue-500">
|
||||
<Settings className="icon-md h-4 w-4" />
|
||||
<span className="hidden sm:inline">{localize('com_ui_manage')}</span>
|
||||
{originalTotalShares > 0 && `(${originalTotalShares})`}
|
||||
</div>
|
||||
</button>
|
||||
</OGDialogTrigger>
|
||||
|
||||
<OGDialogContent className="max-h-[90vh] w-11/12 overflow-y-auto md:max-w-3xl">
|
||||
<OGDialogTitle>
|
||||
<div className="flex items-center gap-2">
|
||||
<Shield className="h-5 w-5 text-blue-500" />
|
||||
{dialogTitle}
|
||||
</div>
|
||||
</OGDialogTitle>
|
||||
|
||||
<div className="space-y-6 p-2">
|
||||
<div className="rounded-lg bg-surface-tertiary p-4">
|
||||
<div className="flex items-center justify-between">
|
||||
<div>
|
||||
<h3 className="text-sm font-medium text-text-primary">
|
||||
{localize('com_ui_current_access')}
|
||||
</h3>
|
||||
<p className="text-xs text-text-secondary">
|
||||
{(() => {
|
||||
if (totalShares === 0) {
|
||||
return localize('com_ui_no_users_groups_access');
|
||||
}
|
||||
return localize('com_ui_shared_with_count', {
|
||||
0: managedShares.length,
|
||||
1: peopleLabel,
|
||||
2: publicSuffix,
|
||||
});
|
||||
})()}
|
||||
</p>
|
||||
</div>
|
||||
{(managedShares.length > 0 || managedIsPublic) && (
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={handleRevokeAll}
|
||||
className="text-red-600 hover:text-red-700"
|
||||
>
|
||||
<Trash2 className="mr-2 h-4 w-4" />
|
||||
{localize('com_ui_revoke_all')}
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{(() => {
|
||||
if (isLoadingPermissions) {
|
||||
return (
|
||||
<div className="flex items-center justify-center p-8">
|
||||
<Loader className="h-6 w-6 animate-spin" />
|
||||
<span className="ml-2 text-sm text-text-secondary">
|
||||
{localize('com_ui_loading_permissions')}
|
||||
</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (managedShares.length > 0) {
|
||||
return (
|
||||
<div>
|
||||
<h3 className="mb-3 flex items-center gap-2 text-sm font-medium text-text-primary">
|
||||
<UserCheck className="h-4 w-4" />
|
||||
{localize('com_ui_user_group_permissions')} ({managedShares.length})
|
||||
</h3>
|
||||
<SelectedPrincipalsList
|
||||
principles={managedShares}
|
||||
onRemoveHandler={handleRemoveShare}
|
||||
availableRoles={accessRoles || []}
|
||||
onRoleChange={(id, newRole) => handleRoleChange(id, newRole)}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="rounded-lg border-2 border-dashed border-border-light p-8 text-center">
|
||||
<Users className="mx-auto h-8 w-8 text-text-secondary" />
|
||||
<p className="mt-2 text-sm text-text-secondary">
|
||||
{localize('com_ui_no_individual_access')}
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
})()}
|
||||
|
||||
<div>
|
||||
<h3 className="mb-3 text-sm font-medium text-text-primary">
|
||||
{localize('com_ui_public_access')}
|
||||
</h3>
|
||||
<PublicSharingToggle
|
||||
isPublic={managedIsPublic}
|
||||
publicRole={managedPublicRole}
|
||||
onPublicToggle={handlePublicToggle}
|
||||
onPublicRoleChange={handlePublicRoleChange}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="flex justify-end gap-3 border-t pt-4">
|
||||
<OGDialogClose asChild>
|
||||
<Button variant="outline" onClick={handleCancel}>
|
||||
{localize('com_ui_cancel')}
|
||||
</Button>
|
||||
</OGDialogClose>
|
||||
<Button
|
||||
onClick={handleSaveChanges}
|
||||
disabled={
|
||||
updatePermissionsMutation.isLoading ||
|
||||
!hasChanges ||
|
||||
isLoadingPermissions ||
|
||||
!hasAtLeastOneOwner
|
||||
}
|
||||
className="min-w-[120px]"
|
||||
>
|
||||
{updatePermissionsMutation.isLoading ? (
|
||||
<div className="flex items-center gap-2">
|
||||
<Loader className="h-4 w-4 animate-spin" />
|
||||
{localize('com_ui_saving')}
|
||||
</div>
|
||||
) : (
|
||||
localize('com_ui_save_changes')
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
{hasChanges && (
|
||||
<div className="text-xs text-orange-600 dark:text-orange-400">
|
||||
* {localize('com_ui_unsaved_changes')}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{!hasAtLeastOneOwner && hasChanges && (
|
||||
<div className="text-xs text-red-600 dark:text-red-400">
|
||||
* {localize('com_ui_at_least_one_owner_required')}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</OGDialogContent>
|
||||
</OGDialog>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,101 @@
|
||||
import React, { useState, useMemo } from 'react';
|
||||
import type { TPrincipal, PrincipalSearchParams } from 'librechat-data-provider';
|
||||
import { useSearchPrincipalsQuery } from 'librechat-data-provider/react-query';
|
||||
|
||||
import { SearchPicker } from '~/components/ui/SearchPicker';
|
||||
import { useLocalize } from '~/hooks';
|
||||
import PeoplePickerSearchItem from './PeoplePickerSearchItem';
|
||||
import SelectedPrincipalsList from './SelectedPrincipalsList';
|
||||
|
||||
interface PeoplePickerProps {
|
||||
onSelectionChange: (principals: TPrincipal[]) => void;
|
||||
placeholder?: string;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export default function PeoplePicker({
|
||||
onSelectionChange,
|
||||
placeholder,
|
||||
className = '',
|
||||
}: PeoplePickerProps) {
|
||||
const localize = useLocalize();
|
||||
const [searchQuery, setSearchQuery] = useState('');
|
||||
const [selectedShares, setSelectedShares] = useState<TPrincipal[]>([]);
|
||||
|
||||
const searchParams: PrincipalSearchParams = useMemo(
|
||||
() => ({
|
||||
q: searchQuery,
|
||||
limit: 30,
|
||||
}),
|
||||
[searchQuery],
|
||||
);
|
||||
|
||||
const {
|
||||
data: searchResponse,
|
||||
isLoading: queryIsLoading,
|
||||
error,
|
||||
} = useSearchPrincipalsQuery(searchParams, {
|
||||
enabled: searchQuery.length >= 2,
|
||||
});
|
||||
|
||||
const isLoading = searchQuery.length >= 2 && queryIsLoading;
|
||||
|
||||
const selectableResults = useMemo(() => {
|
||||
const results = searchResponse?.results || [];
|
||||
|
||||
return results.filter(
|
||||
(result) => !selectedShares.some((share) => share.idOnTheSource === result.idOnTheSource),
|
||||
);
|
||||
}, [searchResponse?.results, selectedShares]);
|
||||
|
||||
if (error) {
|
||||
console.error('Principal search error:', error);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={`space-y-3 ${className}`}>
|
||||
<div className="relative">
|
||||
<SearchPicker<TPrincipal & { key: string; value: string }>
|
||||
options={selectableResults.map((s) => {
|
||||
const key = s.idOnTheSource || 'unknown' + 'picker_key';
|
||||
const value = s.idOnTheSource || 'Unknown';
|
||||
return {
|
||||
...s,
|
||||
id: s.id ?? undefined,
|
||||
key,
|
||||
value,
|
||||
};
|
||||
})}
|
||||
renderOptions={(o) => <PeoplePickerSearchItem principal={o} />}
|
||||
placeholder={placeholder || localize('com_ui_search_default_placeholder')}
|
||||
query={searchQuery}
|
||||
onQueryChange={(query: string) => {
|
||||
setSearchQuery(query);
|
||||
}}
|
||||
onPick={(principal) => {
|
||||
console.log('Selected Principal:', principal);
|
||||
setSelectedShares((prev) => {
|
||||
const newArray = [...prev, principal];
|
||||
onSelectionChange([...newArray]);
|
||||
return newArray;
|
||||
});
|
||||
setSearchQuery('');
|
||||
}}
|
||||
label={localize('com_ui_search_users_groups')}
|
||||
isLoading={isLoading}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<SelectedPrincipalsList
|
||||
principles={selectedShares}
|
||||
onRemoveHandler={(idOnTheSource: string) => {
|
||||
setSelectedShares((prev) => {
|
||||
const newArray = prev.filter((share) => share.idOnTheSource !== idOnTheSource);
|
||||
onSelectionChange(newArray);
|
||||
return newArray;
|
||||
});
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
import React, { forwardRef } from 'react';
|
||||
import type { TPrincipal } from 'librechat-data-provider';
|
||||
import { cn } from '~/utils';
|
||||
import { useLocalize } from '~/hooks';
|
||||
import PrincipalAvatar from '../PrincipalAvatar';
|
||||
|
||||
interface PeoplePickerSearchItemProps extends React.HTMLAttributes<HTMLDivElement> {
|
||||
principal: TPrincipal;
|
||||
}
|
||||
|
||||
const PeoplePickerSearchItem = forwardRef<HTMLDivElement, PeoplePickerSearchItemProps>(
|
||||
function PeoplePickerSearchItem(
|
||||
{ principal, className, style, onClick, ...props },
|
||||
forwardedRef,
|
||||
) {
|
||||
const localize = useLocalize();
|
||||
const { name, email, type } = principal;
|
||||
|
||||
// Display name with fallback
|
||||
const displayName = name || localize('com_ui_unknown');
|
||||
const subtitle = email || `${type} (${principal.source || 'local'})`;
|
||||
|
||||
return (
|
||||
<div
|
||||
{...props}
|
||||
ref={forwardedRef}
|
||||
className={cn('flex items-center gap-3 p-2', className)}
|
||||
style={style}
|
||||
onClick={(event) => {
|
||||
onClick?.(event);
|
||||
}}
|
||||
>
|
||||
<PrincipalAvatar principal={principal} size="md" />
|
||||
|
||||
<div className="min-w-0 flex-1">
|
||||
<div className="truncate text-sm font-medium text-text-primary">{displayName}</div>
|
||||
<div className="truncate text-xs text-text-secondary">{subtitle}</div>
|
||||
</div>
|
||||
|
||||
<div className="flex-shrink-0">
|
||||
<span
|
||||
className={cn(
|
||||
'inline-flex items-center rounded-full px-2 py-1 text-xs font-medium',
|
||||
type === 'user'
|
||||
? 'bg-blue-100 text-blue-800 dark:bg-blue-900 dark:text-blue-300'
|
||||
: 'bg-green-100 text-green-800 dark:bg-green-900 dark:text-green-300',
|
||||
)}
|
||||
>
|
||||
{type === 'user' ? localize('com_ui_user') : localize('com_ui_group')}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
},
|
||||
);
|
||||
|
||||
export default PeoplePickerSearchItem;
|
||||
@@ -0,0 +1,149 @@
|
||||
import React, { useState, useId } from 'react';
|
||||
import { Users, X, ExternalLink, ChevronDown } from 'lucide-react';
|
||||
import * as Menu from '@ariakit/react/menu';
|
||||
import type { TPrincipal, TAccessRole } from 'librechat-data-provider';
|
||||
import { Button, DropdownPopup } from '~/components/ui';
|
||||
import { useLocalize } from '~/hooks';
|
||||
import PrincipalAvatar from '../PrincipalAvatar';
|
||||
|
||||
interface SelectedPrincipalsListProps {
|
||||
principles: TPrincipal[];
|
||||
onRemoveHandler: (idOnTheSource: string) => void;
|
||||
onRoleChange?: (idOnTheSource: string, newRoleId: string) => void;
|
||||
availableRoles?: Omit<TAccessRole, 'resourceType'>[];
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export default function SelectedPrincipalsList({
|
||||
principles,
|
||||
onRemoveHandler,
|
||||
className = '',
|
||||
onRoleChange,
|
||||
availableRoles,
|
||||
}: SelectedPrincipalsListProps) {
|
||||
const localize = useLocalize();
|
||||
|
||||
const getPrincipalDisplayInfo = (principal: TPrincipal) => {
|
||||
const displayName = principal.name || localize('com_ui_unknown');
|
||||
const subtitle = principal.email || `${principal.type} (${principal.source || 'local'})`;
|
||||
|
||||
return { displayName, subtitle };
|
||||
};
|
||||
|
||||
if (principles.length === 0) {
|
||||
return (
|
||||
<div className={`space-y-3 ${className}`}>
|
||||
<div className="rounded-lg border border-dashed border-border py-8 text-center text-muted-foreground">
|
||||
<Users className="mx-auto mb-2 h-8 w-8 opacity-50" />
|
||||
<p className="mt-1 text-xs">{localize('com_ui_search_above_to_add')}</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={`space-y-3 ${className}`}>
|
||||
<div className="space-y-2">
|
||||
{principles.map((share) => {
|
||||
const { displayName, subtitle } = getPrincipalDisplayInfo(share);
|
||||
return (
|
||||
<div
|
||||
key={share.idOnTheSource + '-principalList'}
|
||||
className="bg-surface flex items-center justify-between rounded-lg border border-border p-3"
|
||||
>
|
||||
<div className="flex min-w-0 flex-1 items-center gap-3">
|
||||
<PrincipalAvatar principal={share} size="md" />
|
||||
|
||||
<div className="min-w-0 flex-1">
|
||||
<div className="truncate text-sm font-medium">{displayName}</div>
|
||||
<div className="flex items-center gap-1 text-xs text-muted-foreground">
|
||||
<span>{subtitle}</span>
|
||||
{share.source === 'entra' && (
|
||||
<>
|
||||
<ExternalLink className="h-3 w-3" />
|
||||
<span>{localize('com_ui_azure_ad')}</span>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex flex-shrink-0 items-center gap-2">
|
||||
{!!share.accessRoleId && !!onRoleChange && (
|
||||
<RoleSelector
|
||||
currentRole={share.accessRoleId}
|
||||
onRoleChange={(newRole) => {
|
||||
onRoleChange?.(share.idOnTheSource!, newRole);
|
||||
}}
|
||||
availableRoles={availableRoles ?? []}
|
||||
/>
|
||||
)}
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
onClick={() => onRemoveHandler(share.idOnTheSource!)}
|
||||
className="h-8 w-8 p-0 hover:bg-destructive/10 hover:text-destructive"
|
||||
aria-label={localize('com_ui_remove_user', { 0: displayName })}
|
||||
>
|
||||
<X className="h-4 w-4" />
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
interface RoleSelectorProps {
|
||||
currentRole: string;
|
||||
onRoleChange: (newRole: string) => void;
|
||||
availableRoles: Omit<TAccessRole, 'resourceType'>[];
|
||||
}
|
||||
|
||||
function RoleSelector({ currentRole, onRoleChange, availableRoles }: RoleSelectorProps) {
|
||||
const menuId = useId();
|
||||
const [isMenuOpen, setIsMenuOpen] = useState(false);
|
||||
const localize = useLocalize();
|
||||
|
||||
const getLocalizedRoleName = (roleId: string) => {
|
||||
switch (roleId) {
|
||||
case 'agent_viewer':
|
||||
return localize('com_ui_role_viewer');
|
||||
case 'agent_editor':
|
||||
return localize('com_ui_role_editor');
|
||||
case 'agent_manager':
|
||||
return localize('com_ui_role_manager');
|
||||
case 'agent_owner':
|
||||
return localize('com_ui_role_owner');
|
||||
default:
|
||||
return localize('com_ui_unknown');
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<DropdownPopup
|
||||
portal={true}
|
||||
mountByState={true}
|
||||
unmountOnHide={true}
|
||||
preserveTabOrder={true}
|
||||
isOpen={isMenuOpen}
|
||||
setIsOpen={setIsMenuOpen}
|
||||
trigger={
|
||||
<Menu.MenuButton className="flex h-8 items-center gap-2 rounded-md border border-border-medium bg-surface-secondary px-2 py-1 text-sm font-medium transition-colors duration-200 hover:bg-surface-tertiary">
|
||||
<span className="hidden sm:inline">{getLocalizedRoleName(currentRole)}</span>
|
||||
<ChevronDown className="h-3 w-3" />
|
||||
</Menu.MenuButton>
|
||||
}
|
||||
items={availableRoles?.map((role) => ({
|
||||
id: role.accessRoleId,
|
||||
label: getLocalizedRoleName(role.accessRoleId),
|
||||
|
||||
onClick: () => onRoleChange(role.accessRoleId),
|
||||
}))}
|
||||
menuId={menuId}
|
||||
className="z-50 [pointer-events:auto]"
|
||||
/>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,101 @@
|
||||
import React from 'react';
|
||||
import { Users, User } from 'lucide-react';
|
||||
import type { TPrincipal } from 'librechat-data-provider';
|
||||
import { cn } from '~/utils';
|
||||
|
||||
interface PrincipalAvatarProps {
|
||||
principal: TPrincipal;
|
||||
size?: 'sm' | 'md' | 'lg';
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export default function PrincipalAvatar({
|
||||
principal,
|
||||
size = 'md',
|
||||
className,
|
||||
}: PrincipalAvatarProps) {
|
||||
const { avatar, type, name } = principal;
|
||||
const displayName = name || 'Unknown';
|
||||
|
||||
// Size variants
|
||||
const sizeClasses = {
|
||||
sm: 'h-6 w-6',
|
||||
md: 'h-8 w-8',
|
||||
lg: 'h-10 w-10',
|
||||
};
|
||||
|
||||
const iconSizeClasses = {
|
||||
sm: 'h-3 w-3',
|
||||
md: 'h-4 w-4',
|
||||
lg: 'h-5 w-5',
|
||||
};
|
||||
|
||||
const avatarSizeClass = sizeClasses[size];
|
||||
const iconSizeClass = iconSizeClasses[size];
|
||||
|
||||
// Avatar or icon logic
|
||||
if (avatar) {
|
||||
return (
|
||||
<div className={cn('flex-shrink-0', className)}>
|
||||
<img
|
||||
src={avatar}
|
||||
alt={`${displayName} avatar`}
|
||||
className={cn(avatarSizeClass, 'rounded-full object-cover')}
|
||||
onError={(e) => {
|
||||
// Fallback to icon if image fails to load
|
||||
const target = e.target as HTMLImageElement;
|
||||
target.style.display = 'none';
|
||||
target.nextElementSibling?.classList.remove('hidden');
|
||||
}}
|
||||
/>
|
||||
{/* Hidden fallback icon that shows if image fails */}
|
||||
<div className={cn('hidden', avatarSizeClass)}>
|
||||
{type === 'user' ? (
|
||||
<div
|
||||
className={cn(
|
||||
avatarSizeClass,
|
||||
'flex items-center justify-center rounded-full bg-blue-100 dark:bg-blue-900',
|
||||
)}
|
||||
>
|
||||
<User className={cn(iconSizeClass, 'text-blue-600 dark:text-blue-400')} />
|
||||
</div>
|
||||
) : (
|
||||
<div
|
||||
className={cn(
|
||||
avatarSizeClass,
|
||||
'flex items-center justify-center rounded-full bg-green-100 dark:bg-green-900',
|
||||
)}
|
||||
>
|
||||
<Users className={cn(iconSizeClass, 'text-green-600 dark:text-green-400')} />
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Fallback icon based on type
|
||||
return (
|
||||
<div className={cn('flex-shrink-0', className)}>
|
||||
{type === 'user' ? (
|
||||
<div
|
||||
className={cn(
|
||||
avatarSizeClass,
|
||||
'flex items-center justify-center rounded-full bg-blue-100 dark:bg-blue-900',
|
||||
)}
|
||||
>
|
||||
<User className={cn(iconSizeClass, 'text-blue-600 dark:text-blue-400')} />
|
||||
</div>
|
||||
) : (
|
||||
<div
|
||||
className={cn(
|
||||
avatarSizeClass,
|
||||
'flex items-center justify-center rounded-full bg-green-100 dark:bg-green-900',
|
||||
)}
|
||||
>
|
||||
<Users className={cn(iconSizeClass, 'text-green-600 dark:text-green-400')} />
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
import React from 'react';
|
||||
import { Globe } from 'lucide-react';
|
||||
import { Switch } from '~/components/ui';
|
||||
import { useLocalize } from '~/hooks';
|
||||
import AccessRolesPicker from './AccessRolesPicker';
|
||||
|
||||
interface PublicSharingToggleProps {
|
||||
isPublic: boolean;
|
||||
publicRole: string;
|
||||
onPublicToggle: (isPublic: boolean) => void;
|
||||
onPublicRoleChange: (role: string) => void;
|
||||
className?: string;
|
||||
resourceType?: string;
|
||||
}
|
||||
|
||||
export default function PublicSharingToggle({
|
||||
isPublic,
|
||||
publicRole,
|
||||
onPublicToggle,
|
||||
onPublicRoleChange,
|
||||
className = '',
|
||||
resourceType = 'agent',
|
||||
}: PublicSharingToggleProps) {
|
||||
const localize = useLocalize();
|
||||
|
||||
return (
|
||||
<div className={`space-y-3 border-t pt-4 ${className}`}>
|
||||
<div className="flex items-center justify-between">
|
||||
<div>
|
||||
<h3 className="flex items-center gap-2 text-sm font-medium">
|
||||
<Globe className="h-4 w-4" />
|
||||
{localize('com_ui_share_with_everyone')}
|
||||
</h3>
|
||||
<p className="mt-1 text-xs text-muted-foreground">
|
||||
{localize('com_ui_make_agent_available_all_users')}
|
||||
</p>
|
||||
</div>
|
||||
<Switch
|
||||
checked={isPublic}
|
||||
onCheckedChange={onPublicToggle}
|
||||
aria-label={localize('com_ui_share_with_everyone')}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{isPublic && (
|
||||
<div>
|
||||
<label className="mb-2 block text-sm font-medium">
|
||||
{localize('com_ui_public_access_level')}
|
||||
</label>
|
||||
<AccessRolesPicker
|
||||
resourceType={resourceType}
|
||||
selectedRoleId={publicRole}
|
||||
onRoleChange={onPublicRoleChange}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,31 +1,41 @@
|
||||
import React from 'react';
|
||||
import { SystemRoles } from 'librechat-data-provider';
|
||||
import { render, screen } from '@testing-library/react';
|
||||
import type { UseMutationResult } from '@tanstack/react-query';
|
||||
import '@testing-library/jest-dom/extend-expect';
|
||||
import type { Agent, AgentCreateParams, TUser } from 'librechat-data-provider';
|
||||
import AgentFooter from '../AgentFooter';
|
||||
import { Panel } from '~/common';
|
||||
import type { Agent, AgentCreateParams, TUser } from 'librechat-data-provider';
|
||||
import { SystemRoles } from 'librechat-data-provider';
|
||||
import * as reactHookForm from 'react-hook-form';
|
||||
import * as hooks from '~/hooks';
|
||||
import type { UseMutationResult } from '@tanstack/react-query';
|
||||
|
||||
const mockUseWatch = jest.fn();
|
||||
const mockUseAuthContext = jest.fn();
|
||||
const mockUseHasAccess = jest.fn();
|
||||
const mockUseResourcePermissions = jest.fn();
|
||||
|
||||
jest.mock('react-hook-form', () => ({
|
||||
useFormContext: () => ({
|
||||
control: {},
|
||||
}),
|
||||
useWatch: () => {
|
||||
return {
|
||||
agent: {
|
||||
name: 'Test Agent',
|
||||
author: 'user-123',
|
||||
projectIds: ['project-1'],
|
||||
isCollaborative: false,
|
||||
},
|
||||
id: 'agent-123',
|
||||
};
|
||||
},
|
||||
useWatch: (params) => mockUseWatch(params),
|
||||
}));
|
||||
|
||||
// Default mock implementations
|
||||
mockUseWatch.mockImplementation(({ name }) => {
|
||||
if (name === 'agent') {
|
||||
return {
|
||||
_id: 'agent-db-123',
|
||||
name: 'Test Agent',
|
||||
author: 'user-123',
|
||||
projectIds: ['project-1'],
|
||||
isCollaborative: false,
|
||||
};
|
||||
}
|
||||
if (name === 'id') {
|
||||
return 'agent-123';
|
||||
}
|
||||
return undefined;
|
||||
});
|
||||
|
||||
const mockUser = {
|
||||
id: 'user-123',
|
||||
username: 'testuser',
|
||||
@@ -39,6 +49,26 @@ const mockUser = {
|
||||
updatedAt: '2023-01-01T00:00:00.000Z',
|
||||
} as TUser;
|
||||
|
||||
// Default auth context
|
||||
mockUseAuthContext.mockReturnValue({
|
||||
user: mockUser,
|
||||
token: 'mock-token',
|
||||
isAuthenticated: true,
|
||||
error: undefined,
|
||||
login: jest.fn(),
|
||||
logout: jest.fn(),
|
||||
setError: jest.fn(),
|
||||
roles: {},
|
||||
});
|
||||
|
||||
// Default access and permissions
|
||||
mockUseHasAccess.mockReturnValue(true);
|
||||
mockUseResourcePermissions.mockReturnValue({
|
||||
hasPermission: () => true,
|
||||
isLoading: false,
|
||||
permissionBits: 0,
|
||||
});
|
||||
|
||||
jest.mock('~/hooks', () => ({
|
||||
useLocalize: () => (key) => {
|
||||
const translations = {
|
||||
@@ -47,17 +77,9 @@ jest.mock('~/hooks', () => ({
|
||||
};
|
||||
return translations[key] || key;
|
||||
},
|
||||
useAuthContext: () => ({
|
||||
user: mockUser,
|
||||
token: 'mock-token',
|
||||
isAuthenticated: true,
|
||||
error: undefined,
|
||||
login: jest.fn(),
|
||||
logout: jest.fn(),
|
||||
setError: jest.fn(),
|
||||
roles: {},
|
||||
}),
|
||||
useHasAccess: () => true,
|
||||
useAuthContext: () => mockUseAuthContext(),
|
||||
useHasAccess: () => mockUseHasAccess(),
|
||||
useResourcePermissions: () => mockUseResourcePermissions(),
|
||||
}));
|
||||
|
||||
const createBaseMutation = <T = Agent, P = any>(
|
||||
@@ -126,9 +148,9 @@ jest.mock('../DeleteButton', () => ({
|
||||
default: jest.fn(() => <div data-testid="delete-button" />),
|
||||
}));
|
||||
|
||||
jest.mock('../ShareAgent', () => ({
|
||||
jest.mock('../Sharing/GrantAccessDialog', () => ({
|
||||
__esModule: true,
|
||||
default: jest.fn(() => <div data-testid="share-agent" />),
|
||||
default: jest.fn(() => <div data-testid="grant-access-dialog" />),
|
||||
}));
|
||||
|
||||
jest.mock('../DuplicateAgent', () => ({
|
||||
@@ -186,6 +208,40 @@ describe('AgentFooter', () => {
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
// Reset to default mock implementations
|
||||
mockUseWatch.mockImplementation(({ name }) => {
|
||||
if (name === 'agent') {
|
||||
return {
|
||||
_id: 'agent-db-123',
|
||||
name: 'Test Agent',
|
||||
author: 'user-123',
|
||||
projectIds: ['project-1'],
|
||||
isCollaborative: false,
|
||||
};
|
||||
}
|
||||
if (name === 'id') {
|
||||
return 'agent-123';
|
||||
}
|
||||
return undefined;
|
||||
});
|
||||
// Reset auth context to default user
|
||||
mockUseAuthContext.mockReturnValue({
|
||||
user: mockUser,
|
||||
token: 'mock-token',
|
||||
isAuthenticated: true,
|
||||
error: undefined,
|
||||
login: jest.fn(),
|
||||
logout: jest.fn(),
|
||||
setError: jest.fn(),
|
||||
roles: {},
|
||||
});
|
||||
// Reset access and permissions to defaults
|
||||
mockUseHasAccess.mockReturnValue(true);
|
||||
mockUseResourcePermissions.mockReturnValue({
|
||||
hasPermission: () => true,
|
||||
isLoading: false,
|
||||
permissionBits: 0,
|
||||
});
|
||||
});
|
||||
|
||||
describe('Main Functionality', () => {
|
||||
@@ -196,8 +252,8 @@ describe('AgentFooter', () => {
|
||||
expect(screen.getByTestId('version-button')).toBeInTheDocument();
|
||||
expect(screen.getByTestId('delete-button')).toBeInTheDocument();
|
||||
expect(screen.queryByTestId('admin-settings')).not.toBeInTheDocument();
|
||||
expect(screen.queryByTestId('share-agent')).not.toBeInTheDocument();
|
||||
expect(screen.queryByTestId('duplicate-agent')).not.toBeInTheDocument();
|
||||
expect(screen.getByTestId('grant-access-dialog')).toBeInTheDocument();
|
||||
expect(screen.getByTestId('duplicate-agent')).toBeInTheDocument();
|
||||
expect(screen.queryByTestId('spinner')).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
@@ -227,42 +283,125 @@ describe('AgentFooter', () => {
|
||||
});
|
||||
|
||||
test('adjusts UI based on agent ID existence', () => {
|
||||
jest.spyOn(reactHookForm, 'useWatch').mockImplementation(() => ({
|
||||
agent: { name: 'Test Agent', author: 'user-123' },
|
||||
id: undefined,
|
||||
}));
|
||||
mockUseWatch.mockImplementation(({ name }) => {
|
||||
if (name === 'agent') {
|
||||
return null; // No agent means no delete/share/duplicate buttons
|
||||
}
|
||||
if (name === 'id') {
|
||||
return undefined; // No ID means create mode
|
||||
}
|
||||
return undefined;
|
||||
});
|
||||
|
||||
// When there's no agent, permissions should also return false
|
||||
mockUseResourcePermissions.mockReturnValue({
|
||||
hasPermission: () => false,
|
||||
isLoading: false,
|
||||
permissionBits: 0,
|
||||
});
|
||||
|
||||
render(<AgentFooter {...defaultProps} />);
|
||||
expect(screen.getByText('Save')).toBeInTheDocument();
|
||||
expect(screen.getByTestId('version-button')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test('adjusts UI based on user role', () => {
|
||||
jest.spyOn(hooks, 'useAuthContext').mockReturnValue(createAuthContext(mockUsers.admin));
|
||||
render(<AgentFooter {...defaultProps} />);
|
||||
expect(screen.queryByTestId('admin-settings')).not.toBeInTheDocument();
|
||||
expect(screen.queryByTestId('share-agent')).not.toBeInTheDocument();
|
||||
|
||||
jest.clearAllMocks();
|
||||
jest.spyOn(hooks, 'useAuthContext').mockReturnValue(createAuthContext(mockUsers.different));
|
||||
render(<AgentFooter {...defaultProps} />);
|
||||
expect(screen.queryByTestId('share-agent')).not.toBeInTheDocument();
|
||||
expect(screen.getByText('Create')).toBeInTheDocument();
|
||||
expect(screen.queryByTestId('version-button')).not.toBeInTheDocument();
|
||||
expect(screen.queryByTestId('delete-button')).not.toBeInTheDocument();
|
||||
expect(screen.queryByTestId('grant-access-dialog')).not.toBeInTheDocument();
|
||||
expect(screen.queryByTestId('duplicate-agent')).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test('adjusts UI based on permissions', () => {
|
||||
jest.spyOn(hooks, 'useHasAccess').mockReturnValue(false);
|
||||
test('adjusts UI based on user role', () => {
|
||||
mockUseAuthContext.mockReturnValue(createAuthContext(mockUsers.admin));
|
||||
const { unmount } = render(<AgentFooter {...defaultProps} />);
|
||||
expect(screen.getByTestId('admin-settings')).toBeInTheDocument();
|
||||
expect(screen.getByTestId('grant-access-dialog')).toBeInTheDocument();
|
||||
|
||||
// Clean up the first render
|
||||
unmount();
|
||||
|
||||
jest.clearAllMocks();
|
||||
mockUseAuthContext.mockReturnValue(createAuthContext(mockUsers.different));
|
||||
mockUseWatch.mockImplementation(({ name }) => {
|
||||
if (name === 'agent') {
|
||||
return { name: 'Test Agent', author: 'different-author', _id: 'agent-123' };
|
||||
}
|
||||
if (name === 'id') {
|
||||
return 'agent-123';
|
||||
}
|
||||
return undefined;
|
||||
});
|
||||
render(<AgentFooter {...defaultProps} />);
|
||||
expect(screen.queryByTestId('share-agent')).not.toBeInTheDocument();
|
||||
expect(screen.queryByTestId('grant-access-dialog')).toBeInTheDocument(); // Still shows because hasAccess is true
|
||||
expect(screen.queryByTestId('duplicate-agent')).not.toBeInTheDocument(); // Should not show for different author
|
||||
});
|
||||
|
||||
test('adjusts UI based on permissions', () => {
|
||||
mockUseHasAccess.mockReturnValue(false);
|
||||
// Also need to ensure the agent is not owned by the user and user is not admin
|
||||
mockUseWatch.mockImplementation(({ name }) => {
|
||||
if (name === 'agent') {
|
||||
return {
|
||||
_id: 'agent-db-123',
|
||||
name: 'Test Agent',
|
||||
author: 'different-user', // Different author
|
||||
projectIds: ['project-1'],
|
||||
isCollaborative: false,
|
||||
};
|
||||
}
|
||||
if (name === 'id') {
|
||||
return 'agent-123';
|
||||
}
|
||||
return undefined;
|
||||
});
|
||||
// Mock permissions to not allow sharing
|
||||
mockUseResourcePermissions.mockReturnValue({
|
||||
hasPermission: () => false, // No permissions
|
||||
isLoading: false,
|
||||
permissionBits: 0,
|
||||
});
|
||||
render(<AgentFooter {...defaultProps} />);
|
||||
expect(screen.queryByTestId('grant-access-dialog')).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test('hides action buttons when permissions are loading', () => {
|
||||
// Ensure we have an agent that would normally show buttons
|
||||
mockUseWatch.mockImplementation(({ name }) => {
|
||||
if (name === 'agent') {
|
||||
return {
|
||||
_id: 'agent-db-123',
|
||||
name: 'Test Agent',
|
||||
author: 'user-123', // Same as current user
|
||||
projectIds: ['project-1'],
|
||||
isCollaborative: false,
|
||||
};
|
||||
}
|
||||
if (name === 'id') {
|
||||
return 'agent-123';
|
||||
}
|
||||
return undefined;
|
||||
});
|
||||
mockUseResourcePermissions.mockReturnValue({
|
||||
hasPermission: () => true,
|
||||
isLoading: true, // This should hide the buttons
|
||||
permissionBits: 0,
|
||||
});
|
||||
render(<AgentFooter {...defaultProps} />);
|
||||
expect(screen.queryByTestId('delete-button')).not.toBeInTheDocument();
|
||||
expect(screen.queryByTestId('grant-access-dialog')).not.toBeInTheDocument();
|
||||
// Duplicate button should still show as it doesn't depend on permissions loading
|
||||
expect(screen.getByTestId('duplicate-agent')).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
test('handles null agent data', () => {
|
||||
jest.spyOn(reactHookForm, 'useWatch').mockImplementation(() => ({
|
||||
agent: null,
|
||||
id: 'agent-123',
|
||||
}));
|
||||
mockUseWatch.mockImplementation(({ name }) => {
|
||||
if (name === 'agent') {
|
||||
return null;
|
||||
}
|
||||
if (name === 'id') {
|
||||
return 'agent-123';
|
||||
}
|
||||
return undefined;
|
||||
});
|
||||
|
||||
render(<AgentFooter {...defaultProps} />);
|
||||
expect(screen.getByText('Save')).toBeInTheDocument();
|
||||
|
||||
@@ -1,15 +1,31 @@
|
||||
export default function MCPIcon() {
|
||||
export default function MCPIcon({ className }: { className?: string }) {
|
||||
return (
|
||||
<svg
|
||||
width="195"
|
||||
height="195"
|
||||
viewBox="0 2 195 195"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
width="24"
|
||||
height="24"
|
||||
fill="currentColor"
|
||||
viewBox="0 0 24 24"
|
||||
className="h-4 w-4"
|
||||
className={className}
|
||||
>
|
||||
<path d="M11.016 2.099a3.998 3.998 0 0 1 5.58.072l.073.074a3.991 3.991 0 0 1 1.058 3.318 3.994 3.994 0 0 1 3.32 1.06l.073.071.048.047.071.075a3.998 3.998 0 0 1 0 5.506l-.071.074-8.183 8.182-.034.042a.267.267 0 0 0 .034.335l1.68 1.68a.8.8 0 0 1-1.131 1.13l-1.68-1.679a1.866 1.866 0 0 1-.034-2.604l8.26-8.261a2.4 2.4 0 0 0-.044-3.349l-.047-.047-.044-.043a2.4 2.4 0 0 0-3.349.043l-6.832 6.832-.03.029a.8.8 0 0 1-1.1-1.16l6.876-6.875a2.4 2.4 0 0 0-.044-3.35l-.179-.161a2.399 2.399 0 0 0-3.169.119l-.045.043-9.047 9.047-.03.028a.8.8 0 0 1-1.1-1.16l9.046-9.046.074-.072Z" />
|
||||
<path d="M13.234 4.404a.8.8 0 0 1 1.1 1.16l-6.69 6.691a2.399 2.399 0 1 0 3.393 3.393l6.691-6.692a.8.8 0 0 1 1.131 1.131l-6.691 6.692a4 4 0 0 1-5.581.07l-.073-.07a3.998 3.998 0 0 1 0-5.655l6.69-6.691.03-.029Z" />
|
||||
<path
|
||||
d="M25 97.8528L92.8823 29.9706C102.255 20.598 117.451 20.598 126.823 29.9706V29.9706C136.196 39.3431 136.196 54.5391 126.823 63.9117L75.5581 115.177"
|
||||
stroke="currentColor"
|
||||
strokeWidth="12"
|
||||
strokeLinecap="round"
|
||||
/>
|
||||
<path
|
||||
d="M76.2653 114.47L126.823 63.9117C136.196 54.5391 151.392 54.5391 160.765 63.9117L161.118 64.2652C170.491 73.6378 170.491 88.8338 161.118 98.2063L99.7248 159.6C96.6006 162.724 96.6006 167.789 99.7248 170.913L112.331 183.52"
|
||||
stroke="currentColor"
|
||||
strokeWidth="12"
|
||||
strokeLinecap="round"
|
||||
/>
|
||||
<path
|
||||
d="M109.853 46.9411L59.6482 97.1457C50.2757 106.518 50.2757 121.714 59.6482 131.087V131.087C69.0208 140.459 84.2168 140.459 93.5894 131.087L143.794 80.8822"
|
||||
stroke="currentColor"
|
||||
strokeWidth="12"
|
||||
strokeLinecap="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
}
|
||||
|
||||
15
client/src/components/svg/VectorIcon.tsx
Normal file
15
client/src/components/svg/VectorIcon.tsx
Normal file
@@ -0,0 +1,15 @@
|
||||
export default function VectorIcon({ className }: { className?: string }) {
|
||||
return (
|
||||
<svg
|
||||
width="20"
|
||||
height="20"
|
||||
viewBox="0 0 20 20"
|
||||
fill="currentColor"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
className={className}
|
||||
aria-label=""
|
||||
>
|
||||
<path d="M7.45996 14.375C7.45996 13.3616 6.63844 12.54 5.625 12.54C4.61156 12.54 3.79004 13.3616 3.79004 14.375C3.79004 15.3884 4.61156 16.21 5.625 16.21C6.63844 16.21 7.45996 15.3884 7.45996 14.375ZM16.21 14.375C16.21 13.3616 15.3884 12.54 14.375 12.54C13.3616 12.54 12.54 13.3616 12.54 14.375C12.54 15.3884 13.3616 16.21 14.375 16.21C15.3884 16.21 16.21 15.3884 16.21 14.375ZM7.45996 5.625C7.45996 4.61156 6.63844 3.79004 5.625 3.79004C4.61156 3.79004 3.79004 4.61156 3.79004 5.625C3.79004 6.63844 4.61156 7.45996 5.625 7.45996C6.63844 7.45996 7.45996 6.63844 7.45996 5.625ZM16.21 5.625C16.21 4.61156 15.3884 3.79004 14.375 3.79004C13.3616 3.79004 12.54 4.61156 12.54 5.625C12.54 6.63844 13.3616 7.45996 14.375 7.45996C15.3884 7.45996 16.21 6.63844 16.21 5.625ZM17.54 14.375C17.54 16.123 16.123 17.54 14.375 17.54C12.627 17.54 11.21 16.123 11.21 14.375C11.21 12.627 12.627 11.21 14.375 11.21C16.123 11.21 17.54 12.627 17.54 14.375ZM8.79004 5.625C8.79004 7.37298 7.37298 8.79004 5.625 8.79004C3.87702 8.79004 2.45996 7.37298 2.45996 5.625C2.45996 3.87702 3.87702 2.45996 5.625 2.45996C7.37298 2.45996 8.79004 3.87702 8.79004 5.625ZM17.54 5.625C17.54 7.37298 16.123 8.79004 14.375 8.79004C13.7416 8.79004 13.153 8.60173 12.6582 8.28125L8.28125 12.6582C8.60173 13.153 8.79004 13.7416 8.79004 14.375C8.79004 16.123 7.37298 17.54 5.625 17.54C3.87702 17.54 2.45996 16.123 2.45996 14.375C2.45996 12.627 3.87702 11.21 5.625 11.21C6.25794 11.21 6.84623 11.3977 7.34082 11.7178L11.7178 7.34082C11.3977 6.84623 11.21 6.25794 11.21 5.625C11.21 3.87702 12.627 2.45996 14.375 2.45996C16.123 2.45996 17.54 3.87702 17.54 5.625Z" />
|
||||
</svg>
|
||||
);
|
||||
}
|
||||
@@ -62,3 +62,5 @@ export { default as ThumbUpIcon } from './ThumbUpIcon';
|
||||
export { default as ThumbDownIcon } from './ThumbDownIcon';
|
||||
export { default as XAIcon } from './XAIcon';
|
||||
export { default as PersonalizationIcon } from './PersonalizationIcon';
|
||||
export { default as MCPIcon } from './MCPIcon';
|
||||
export { default as VectorIcon } from './VectorIcon';
|
||||
|
||||
@@ -9,11 +9,12 @@ const CheckboxButton = React.forwardRef<
|
||||
icon?: React.ReactNode;
|
||||
label: string;
|
||||
className?: string;
|
||||
checked?: boolean;
|
||||
defaultChecked?: boolean;
|
||||
isCheckedClassName?: string;
|
||||
setValue?: (e: React.ChangeEvent<HTMLInputElement>, isChecked: boolean) => void;
|
||||
setValue?: (values: { e?: React.ChangeEvent<HTMLInputElement>; isChecked: boolean }) => void;
|
||||
}
|
||||
>(({ icon, label, setValue, className, defaultChecked, isCheckedClassName }, ref) => {
|
||||
>(({ icon, label, setValue, className, checked, defaultChecked, isCheckedClassName }, ref) => {
|
||||
const checkbox = useCheckboxStore();
|
||||
const isChecked = useStoreState(checkbox, (state) => state?.value);
|
||||
const onChange = (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
@@ -21,20 +22,28 @@ const CheckboxButton = React.forwardRef<
|
||||
if (typeof isChecked !== 'boolean') {
|
||||
return;
|
||||
}
|
||||
setValue?.(e, !isChecked);
|
||||
setValue?.({ e, isChecked: !isChecked });
|
||||
};
|
||||
|
||||
// Sync with controlled checked prop
|
||||
useEffect(() => {
|
||||
if (defaultChecked) {
|
||||
if (checked !== undefined) {
|
||||
checkbox.setValue(checked);
|
||||
}
|
||||
}, [checked, checkbox]);
|
||||
|
||||
// Set initial value from defaultChecked
|
||||
useEffect(() => {
|
||||
if (defaultChecked !== undefined && checked === undefined) {
|
||||
checkbox.setValue(defaultChecked);
|
||||
}
|
||||
}, [defaultChecked, checkbox]);
|
||||
}, [defaultChecked, checked, checkbox]);
|
||||
|
||||
return (
|
||||
<Checkbox
|
||||
ref={ref}
|
||||
store={checkbox}
|
||||
onChange={onChange}
|
||||
defaultChecked={defaultChecked}
|
||||
className={cn(
|
||||
// Base styling from MultiSelect's selectClassName
|
||||
'group relative inline-flex items-center justify-center gap-1.5',
|
||||
|
||||
@@ -97,7 +97,13 @@ const Dropdown: React.FC<DropdownProps> = ({
|
||||
<Select.SelectPopover
|
||||
portal={portal}
|
||||
store={selectProps}
|
||||
className={cn('popover-ui', sizeClasses, className, 'max-h-[80vh] overflow-y-auto')}
|
||||
className={cn(
|
||||
'popover-ui',
|
||||
sizeClasses,
|
||||
className,
|
||||
'max-h-[80vh] overflow-y-auto',
|
||||
'[pointer-events:auto]', // Override body's pointer-events:none when in modal
|
||||
)}
|
||||
>
|
||||
{options.map((item, index) => {
|
||||
if (isDivider(item)) {
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
export default function MCPIcon({ className }: { className?: string }) {
|
||||
return (
|
||||
<svg
|
||||
width="195"
|
||||
height="195"
|
||||
viewBox="0 2 195 195"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
className={className}
|
||||
>
|
||||
<path
|
||||
d="M25 97.8528L92.8823 29.9706C102.255 20.598 117.451 20.598 126.823 29.9706V29.9706C136.196 39.3431 136.196 54.5391 126.823 63.9117L75.5581 115.177"
|
||||
stroke="currentColor"
|
||||
strokeWidth="12"
|
||||
strokeLinecap="round"
|
||||
/>
|
||||
<path
|
||||
d="M76.2653 114.47L126.823 63.9117C136.196 54.5391 151.392 54.5391 160.765 63.9117L161.118 64.2652C170.491 73.6378 170.491 88.8338 161.118 98.2063L99.7248 159.6C96.6006 162.724 96.6006 167.789 99.7248 170.913L112.331 183.52"
|
||||
stroke="currentColor"
|
||||
strokeWidth="12"
|
||||
strokeLinecap="round"
|
||||
/>
|
||||
<path
|
||||
d="M109.853 46.9411L59.6482 97.1457C50.2757 106.518 50.2757 121.714 59.6482 131.087V131.087C69.0208 140.459 84.2168 140.459 93.5894 131.087L143.794 80.8822"
|
||||
stroke="currentColor"
|
||||
strokeWidth="12"
|
||||
strokeLinecap="round"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
}
|
||||
@@ -5,16 +5,26 @@ import { cn } from '~/utils';
|
||||
|
||||
interface OGDialogProps extends DialogPrimitive.DialogProps {
|
||||
triggerRef?: React.RefObject<HTMLButtonElement | HTMLInputElement | null>;
|
||||
triggerRefs?: React.RefObject<HTMLButtonElement | HTMLInputElement | null>[];
|
||||
}
|
||||
|
||||
const Dialog = React.forwardRef<HTMLDivElement, OGDialogProps>(
|
||||
({ children, triggerRef, onOpenChange, ...props }, _ref) => {
|
||||
({ children, triggerRef, triggerRefs, onOpenChange, ...props }, _ref) => {
|
||||
const handleOpenChange = (open: boolean) => {
|
||||
if (!open && triggerRef?.current) {
|
||||
setTimeout(() => {
|
||||
triggerRef.current?.focus();
|
||||
}, 0);
|
||||
}
|
||||
if (triggerRefs?.length) {
|
||||
triggerRefs.forEach((ref) => {
|
||||
if (ref?.current) {
|
||||
setTimeout(() => {
|
||||
ref.current?.focus();
|
||||
}, 0);
|
||||
}
|
||||
});
|
||||
}
|
||||
onOpenChange?.(open);
|
||||
};
|
||||
|
||||
|
||||
192
client/src/components/ui/SearchPicker.tsx
Normal file
192
client/src/components/ui/SearchPicker.tsx
Normal file
@@ -0,0 +1,192 @@
|
||||
'use client';
|
||||
|
||||
import * as React from 'react';
|
||||
import * as Ariakit from '@ariakit/react';
|
||||
import { Search, X } from 'lucide-react';
|
||||
import { cn } from '~/utils';
|
||||
import { Spinner } from '~/components/svg';
|
||||
import { Skeleton } from '~/components/ui';
|
||||
import { useLocalize } from '~/hooks';
|
||||
|
||||
type SearchPickerProps<TOption extends { key: string }> = {
|
||||
options: TOption[];
|
||||
renderOptions: (option: TOption) => React.ReactElement;
|
||||
query: string;
|
||||
onQueryChange: (query: string) => void;
|
||||
onPick: (pickedOption: TOption) => void;
|
||||
placeholder?: string;
|
||||
inputClassName?: string;
|
||||
label: string;
|
||||
resetValueOnHide?: boolean;
|
||||
isSmallScreen?: boolean;
|
||||
isLoading?: boolean;
|
||||
minQueryLengthForNoResults?: number;
|
||||
};
|
||||
|
||||
export function SearchPicker<TOption extends { key: string; value: string }>({
|
||||
options,
|
||||
renderOptions,
|
||||
onPick,
|
||||
onQueryChange,
|
||||
query,
|
||||
label,
|
||||
isSmallScreen = false,
|
||||
placeholder,
|
||||
resetValueOnHide = false,
|
||||
isLoading = false,
|
||||
minQueryLengthForNoResults = 2,
|
||||
}: SearchPickerProps<TOption>) {
|
||||
const localize = useLocalize();
|
||||
const [_open, setOpen] = React.useState(false);
|
||||
const inputRef = React.useRef<HTMLInputElement>(null);
|
||||
const combobox = Ariakit.useComboboxStore({
|
||||
resetValueOnHide,
|
||||
});
|
||||
const onPickHandler = (option: TOption) => {
|
||||
onQueryChange('');
|
||||
onPick(option);
|
||||
setOpen(false);
|
||||
if (inputRef.current) {
|
||||
inputRef.current.focus();
|
||||
}
|
||||
};
|
||||
const showClearIcon = query.trim().length > 0;
|
||||
const clearText = () => {
|
||||
onQueryChange('');
|
||||
if (inputRef.current) {
|
||||
inputRef.current.focus();
|
||||
}
|
||||
};
|
||||
return (
|
||||
<Ariakit.ComboboxProvider store={combobox}>
|
||||
<Ariakit.ComboboxLabel className="text-token-text-primary mb-2 block font-medium">
|
||||
{label}
|
||||
</Ariakit.ComboboxLabel>
|
||||
<div className="py-1.5">
|
||||
<div
|
||||
className={cn(
|
||||
'group relative mt-1 flex h-10 cursor-pointer items-center gap-3 rounded-lg border-border-medium px-3 py-2 text-text-primary transition-colors duration-200 focus-within:bg-surface-hover hover:bg-surface-hover',
|
||||
isSmallScreen === true ? 'mb-2 h-14 rounded-2xl' : '',
|
||||
)}
|
||||
>
|
||||
{isLoading ? (
|
||||
<Spinner className="absolute left-3 h-4 w-4 text-text-primary" />
|
||||
) : (
|
||||
<Search className="absolute left-3 h-4 w-4 text-text-secondary group-focus-within:text-text-primary group-hover:text-text-primary" />
|
||||
)}
|
||||
<Ariakit.Combobox
|
||||
ref={inputRef}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === 'Escape' && combobox.getState().open) {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
onQueryChange('');
|
||||
setOpen(false);
|
||||
}
|
||||
}}
|
||||
store={combobox}
|
||||
setValueOnClick={false}
|
||||
setValueOnChange={false}
|
||||
onChange={(e) => {
|
||||
onQueryChange(e.target.value);
|
||||
}}
|
||||
value={query}
|
||||
// autoSelect
|
||||
placeholder={placeholder || localize('com_ui_select_options')}
|
||||
className="m-0 mr-0 w-full rounded-md border-none bg-transparent p-0 py-2 pl-9 pr-3 text-sm leading-tight text-text-primary placeholder-text-secondary placeholder-opacity-100 focus:outline-none focus-visible:outline-none group-focus-within:placeholder-text-primary group-hover:placeholder-text-primary"
|
||||
/>
|
||||
<button
|
||||
type="button"
|
||||
aria-label={`${localize('com_ui_clear')} ${localize('com_ui_search')}`}
|
||||
className={cn(
|
||||
'absolute right-[7px] flex h-5 w-5 items-center justify-center rounded-full border-none bg-transparent p-0 transition-opacity duration-200',
|
||||
showClearIcon ? 'opacity-100' : 'opacity-0',
|
||||
isSmallScreen === true ? 'right-[16px]' : '',
|
||||
)}
|
||||
onClick={clearText}
|
||||
tabIndex={showClearIcon ? 0 : -1}
|
||||
disabled={!showClearIcon}
|
||||
>
|
||||
<X className="h-5 w-5 cursor-pointer" />
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<Ariakit.ComboboxPopover
|
||||
portal={false} //todo fix focus when set to true
|
||||
gutter={10}
|
||||
// sameWidth
|
||||
open={
|
||||
isLoading ||
|
||||
options.length > 0 ||
|
||||
(query.trim().length >= minQueryLengthForNoResults && !isLoading)
|
||||
}
|
||||
store={combobox}
|
||||
unmountOnHide
|
||||
autoFocusOnShow={false}
|
||||
modal={false}
|
||||
className={cn(
|
||||
'animate-popover z-[9999] min-w-64 overflow-hidden rounded-xl border border-border-light bg-surface-secondary shadow-lg',
|
||||
'[pointer-events:auto]', // Override body's pointer-events:none when in modal
|
||||
)}
|
||||
>
|
||||
{(() => {
|
||||
if (isLoading) {
|
||||
return (
|
||||
<div className="space-y-2 p-2">
|
||||
{Array.from({ length: 3 }).map((_, index) => (
|
||||
<div key={index} className="flex items-center gap-3 px-3 py-2">
|
||||
<Skeleton className="h-8 w-8 rounded-full" />
|
||||
<div className="flex-1 space-y-1">
|
||||
<Skeleton className="h-4 w-3/4" />
|
||||
<Skeleton className="h-3 w-1/2" />
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (options.length > 0) {
|
||||
return options.map((o) => (
|
||||
<Ariakit.ComboboxItem
|
||||
key={o.key}
|
||||
focusOnHover
|
||||
// hideOnClick
|
||||
value={o.value}
|
||||
selectValueOnClick={false}
|
||||
onClick={() => onPickHandler(o)}
|
||||
className={cn(
|
||||
'flex w-full cursor-pointer items-center px-3 text-sm',
|
||||
'text-text-primary hover:bg-surface-tertiary',
|
||||
'data-[active-item]:bg-surface-tertiary',
|
||||
)}
|
||||
render={renderOptions(o)}
|
||||
></Ariakit.ComboboxItem>
|
||||
));
|
||||
}
|
||||
|
||||
if (query.trim().length >= minQueryLengthForNoResults) {
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
'flex items-center justify-center px-4 py-8 text-center',
|
||||
'text-sm text-text-secondary',
|
||||
)}
|
||||
>
|
||||
<div className="flex flex-col items-center gap-2">
|
||||
<Search className="h-8 w-8 text-text-tertiary opacity-50" />
|
||||
<div className="font-medium">{localize('com_ui_no_results_found')}</div>
|
||||
<div className="text-xs text-text-tertiary">
|
||||
{localize('com_ui_try_adjusting_search')}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return null;
|
||||
})()}
|
||||
</Ariakit.ComboboxPopover>
|
||||
</Ariakit.ComboboxProvider>
|
||||
);
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
import React from 'react';
|
||||
import React, { useState } from 'react';
|
||||
import { Root, Trigger, Content, Portal } from '@radix-ui/react-popover';
|
||||
import MenuItem from '~/components/Chat/Menus/UI/MenuItem';
|
||||
import type { Option } from '~/common';
|
||||
@@ -32,6 +32,7 @@ function SelectDropDownPop({
|
||||
footer,
|
||||
}: SelectDropDownProps) {
|
||||
const localize = useLocalize();
|
||||
const [open, setOpen] = useState(false);
|
||||
const transitionProps = { className: 'top-full mt-3' };
|
||||
if (showAbove) {
|
||||
transitionProps.className = 'bottom-full mb-3';
|
||||
@@ -54,8 +55,13 @@ function SelectDropDownPop({
|
||||
const hasSearchRender = Boolean(searchRender);
|
||||
const options = hasSearchRender ? filteredValues : availableValues;
|
||||
|
||||
const handleSelect = (selectedValue: string) => {
|
||||
setValue(selectedValue);
|
||||
setOpen(false);
|
||||
};
|
||||
|
||||
return (
|
||||
<Root>
|
||||
<Root open={open} onOpenChange={setOpen}>
|
||||
<div className={'flex items-center justify-center gap-2'}>
|
||||
<div className={'relative w-full'}>
|
||||
<Trigger asChild>
|
||||
@@ -108,19 +114,32 @@ function SelectDropDownPop({
|
||||
side="bottom"
|
||||
align="start"
|
||||
className={cn(
|
||||
'mr-3 mt-2 max-h-[52vh] w-full max-w-[85vw] overflow-hidden overflow-y-auto rounded-lg border border-gray-200 bg-white shadow-lg dark:border-gray-700 dark:bg-gray-700 dark:text-white sm:max-w-full lg:max-h-[52vh]',
|
||||
'z-50 mr-3 mt-2 max-h-[52vh] w-full max-w-[85vw] overflow-hidden overflow-y-auto rounded-lg border border-gray-200 bg-white shadow-lg dark:border-gray-700 dark:bg-gray-700 dark:text-white sm:max-w-full lg:max-h-[52vh]',
|
||||
hasSearchRender && 'relative',
|
||||
)}
|
||||
>
|
||||
{searchRender}
|
||||
{options.map((option) => {
|
||||
if (typeof option === 'string') {
|
||||
return (
|
||||
<MenuItem
|
||||
key={option}
|
||||
title={option}
|
||||
value={option}
|
||||
selected={!!(value && value === option)}
|
||||
onClick={() => handleSelect(option)}
|
||||
/>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<MenuItem
|
||||
key={option}
|
||||
title={option}
|
||||
value={option}
|
||||
selected={!!(value && value === option)}
|
||||
onClick={() => setValue(option)}
|
||||
key={option.value}
|
||||
title={option.label}
|
||||
description={option.description}
|
||||
value={option.value}
|
||||
icon={option.icon}
|
||||
selected={!!(value && value === option.value)}
|
||||
onClick={() => handleSelect(option.value)}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
|
||||
@@ -28,7 +28,6 @@ export * from './Pagination';
|
||||
export * from './Progress';
|
||||
export * from './InputOTP';
|
||||
export { default as Badge } from './Badge';
|
||||
export { default as MCPIcon } from './MCPIcon';
|
||||
export { default as Combobox } from './Combobox';
|
||||
export { default as Dropdown } from './Dropdown';
|
||||
export { default as SplitText } from './SplitText';
|
||||
|
||||
@@ -83,6 +83,10 @@ export const useUpdateAgentMutation = (
|
||||
});
|
||||
|
||||
queryClient.setQueryData<t.Agent>([QueryKeys.agent, variables.agent_id], updatedAgent);
|
||||
queryClient.setQueryData<t.Agent>(
|
||||
[QueryKeys.agent, variables.agent_id, 'expanded'],
|
||||
updatedAgent,
|
||||
);
|
||||
return options?.onSuccess?.(updatedAgent, variables, context);
|
||||
},
|
||||
},
|
||||
@@ -121,6 +125,7 @@ export const useDeleteAgentMutation = (
|
||||
});
|
||||
|
||||
queryClient.removeQueries([QueryKeys.agent, variables.agent_id]);
|
||||
queryClient.removeQueries([QueryKeys.agent, variables.agent_id, 'expanded']);
|
||||
|
||||
return options?.onSuccess?.(_data, variables, data);
|
||||
},
|
||||
@@ -241,6 +246,10 @@ export const useUpdateAgentAction = (
|
||||
});
|
||||
|
||||
queryClient.setQueryData<t.Agent>([QueryKeys.agent, variables.agent_id], updatedAgent);
|
||||
queryClient.setQueryData<t.Agent>(
|
||||
[QueryKeys.agent, variables.agent_id, 'expanded'],
|
||||
updatedAgent,
|
||||
);
|
||||
return options?.onSuccess?.(updateAgentActionResponse, variables, context);
|
||||
},
|
||||
});
|
||||
@@ -293,8 +302,7 @@ export const useDeleteAgentAction = (
|
||||
};
|
||||
},
|
||||
);
|
||||
|
||||
queryClient.setQueryData<t.Agent>([QueryKeys.agent, variables.agent_id], (prev) => {
|
||||
const updaterFn = (prev) => {
|
||||
if (!prev) {
|
||||
return prev;
|
||||
}
|
||||
@@ -303,7 +311,12 @@ export const useDeleteAgentAction = (
|
||||
...prev,
|
||||
tools: prev.tools?.filter((tool) => !tool.includes(domain ?? '')),
|
||||
};
|
||||
});
|
||||
};
|
||||
queryClient.setQueryData<t.Agent>([QueryKeys.agent, variables.agent_id], updaterFn);
|
||||
queryClient.setQueryData<t.Agent>(
|
||||
[QueryKeys.agent, variables.agent_id, 'expanded'],
|
||||
updaterFn,
|
||||
);
|
||||
return options?.onSuccess?.(_data, variables, context);
|
||||
},
|
||||
});
|
||||
|
||||
@@ -54,7 +54,7 @@ export const useListAgentsQuery = <TData = t.AgentListResponse>(
|
||||
};
|
||||
|
||||
/**
|
||||
* Hook for retrieving details about a single agent
|
||||
* Hook for retrieving basic details about a single agent (VIEW permission)
|
||||
*/
|
||||
export const useGetAgentByIdQuery = (
|
||||
agent_id: string,
|
||||
@@ -75,3 +75,26 @@ export const useGetAgentByIdQuery = (
|
||||
},
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* Hook for retrieving full agent details including sensitive configuration (EDIT permission)
|
||||
*/
|
||||
export const useGetExpandedAgentByIdQuery = (
|
||||
agent_id: string,
|
||||
config?: UseQueryOptions<t.Agent>,
|
||||
): QueryObserverResult<t.Agent> => {
|
||||
return useQuery<t.Agent>(
|
||||
[QueryKeys.agent, agent_id, 'expanded'],
|
||||
() =>
|
||||
dataService.getExpandedAgentById({
|
||||
agent_id,
|
||||
}),
|
||||
{
|
||||
refetchOnWindowFocus: false,
|
||||
refetchOnReconnect: false,
|
||||
refetchOnMount: false,
|
||||
retry: false,
|
||||
...config,
|
||||
},
|
||||
);
|
||||
};
|
||||
|
||||
@@ -125,8 +125,7 @@ export const useEndpoints = ({
|
||||
if (ep === EModelEndpoint.agents && agents.length > 0) {
|
||||
result.models = agents.map((agent) => ({
|
||||
name: agent.id,
|
||||
isGlobal:
|
||||
(instanceProjectId != null && agent.projectIds?.includes(instanceProjectId)) ?? false,
|
||||
isGlobal: agent.isPublic ?? false,
|
||||
}));
|
||||
result.agentNames = agents.reduce((acc, agent) => {
|
||||
acc[agent.id] = agent.name || '';
|
||||
|
||||
@@ -1,43 +1,46 @@
|
||||
import { useState, useMemo } from 'react';
|
||||
import { useDrop } from 'react-dnd';
|
||||
import { useRecoilValue } from 'recoil';
|
||||
import { NativeTypes } from 'react-dnd-html5-backend';
|
||||
import { useQueryClient } from '@tanstack/react-query';
|
||||
import { useRecoilValue, useSetRecoilState } from 'recoil';
|
||||
import {
|
||||
Constants,
|
||||
QueryKeys,
|
||||
Constants,
|
||||
EModelEndpoint,
|
||||
isAgentsEndpoint,
|
||||
isEphemeralAgent,
|
||||
EToolResources,
|
||||
AgentCapabilities,
|
||||
isAssistantsEndpoint,
|
||||
} from 'librechat-data-provider';
|
||||
import type * as t from 'librechat-data-provider';
|
||||
import type { DropTargetMonitor } from 'react-dnd';
|
||||
import useFileHandling from './useFileHandling';
|
||||
import type * as t from 'librechat-data-provider';
|
||||
import store, { ephemeralAgentByConvoId } from '~/store';
|
||||
import useFileHandling from './useFileHandling';
|
||||
|
||||
export default function useDragHelpers() {
|
||||
const queryClient = useQueryClient();
|
||||
const [showModal, setShowModal] = useState(false);
|
||||
const [draggedFiles, setDraggedFiles] = useState<File[]>([]);
|
||||
const conversation = useRecoilValue(store.conversationByIndex(0)) || undefined;
|
||||
const key = useMemo(
|
||||
() => conversation?.conversationId ?? Constants.NEW_CONVO,
|
||||
[conversation?.conversationId],
|
||||
const setEphemeralAgent = useSetRecoilState(
|
||||
ephemeralAgentByConvoId(conversation?.conversationId ?? Constants.NEW_CONVO),
|
||||
);
|
||||
const ephemeralAgent = useRecoilValue(ephemeralAgentByConvoId(key));
|
||||
|
||||
const handleOptionSelect = (toolResource: string | undefined) => {
|
||||
const handleOptionSelect = (toolResource: EToolResources | undefined) => {
|
||||
/** File search is not automatically enabled to simulate legacy behavior */
|
||||
if (toolResource && toolResource !== EToolResources.file_search) {
|
||||
setEphemeralAgent((prev) => ({
|
||||
...prev,
|
||||
[toolResource]: true,
|
||||
}));
|
||||
}
|
||||
handleFiles(draggedFiles, toolResource);
|
||||
setShowModal(false);
|
||||
setDraggedFiles([]);
|
||||
};
|
||||
|
||||
const isAgents = useMemo(
|
||||
() =>
|
||||
isAgentsEndpoint(conversation?.endpoint) ||
|
||||
isEphemeralAgent(conversation?.endpoint, ephemeralAgent),
|
||||
[conversation?.endpoint, ephemeralAgent],
|
||||
() => !isAssistantsEndpoint(conversation?.endpoint),
|
||||
[conversation?.endpoint],
|
||||
);
|
||||
|
||||
const { handleFiles } = useFileHandling({
|
||||
|
||||
@@ -15,12 +15,11 @@ import BookmarkPanel from '~/components/SidePanel/Bookmarks/BookmarkPanel';
|
||||
import MemoryViewer from '~/components/SidePanel/Memories/MemoryViewer';
|
||||
import PanelSwitch from '~/components/SidePanel/Builder/PanelSwitch';
|
||||
import PromptsAccordion from '~/components/Prompts/PromptsAccordion';
|
||||
import { Blocks, MCPIcon, AttachmentIcon } from '~/components/svg';
|
||||
import Parameters from '~/components/SidePanel/Parameters/Panel';
|
||||
import FilesPanel from '~/components/SidePanel/Files/Panel';
|
||||
import MCPPanel from '~/components/SidePanel/MCP/MCPPanel';
|
||||
import { Blocks, AttachmentIcon } from '~/components/svg';
|
||||
import { useGetStartupConfig } from '~/data-provider';
|
||||
import MCPIcon from '~/components/ui/MCPIcon';
|
||||
import { useHasAccess } from '~/hooks';
|
||||
|
||||
export default function useSideNavLinks({
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
export * from './useMCPSelect';
|
||||
export * from './useToolToggle';
|
||||
export { default as useAuthCodeTool } from './useAuthCodeTool';
|
||||
export { default as usePluginInstall } from './usePluginInstall';
|
||||
export { default as useCodeApiKeyForm } from './useCodeApiKeyForm';
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// client/src/hooks/Plugins/useCodeApiKeyForm.ts
|
||||
import { useState, useCallback } from 'react';
|
||||
import { useRef, useState, useCallback } from 'react';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import type { ApiKeyFormData } from '~/common';
|
||||
import useAuthCodeTool from '~/hooks/Plugins/useAuthCodeTool';
|
||||
@@ -12,6 +12,8 @@ export default function useCodeApiKeyForm({
|
||||
onRevoke?: () => void;
|
||||
}) {
|
||||
const methods = useForm<ApiKeyFormData>();
|
||||
const menuTriggerRef = useRef<HTMLButtonElement>(null);
|
||||
const badgeTriggerRef = useRef<HTMLInputElement>(null);
|
||||
const [isDialogOpen, setIsDialogOpen] = useState(false);
|
||||
const { installTool, removeTool } = useAuthCodeTool({ isEntityTool: true });
|
||||
const { reset } = methods;
|
||||
@@ -39,5 +41,7 @@ export default function useCodeApiKeyForm({
|
||||
setIsDialogOpen,
|
||||
handleRevokeApiKey,
|
||||
onSubmit: onSubmitHandler,
|
||||
badgeTriggerRef,
|
||||
menuTriggerRef,
|
||||
};
|
||||
}
|
||||
|
||||
121
client/src/hooks/Plugins/useMCPSelect.ts
Normal file
121
client/src/hooks/Plugins/useMCPSelect.ts
Normal file
@@ -0,0 +1,121 @@
|
||||
import { useRef, useEffect, useCallback, useMemo } from 'react';
|
||||
import { useRecoilState } from 'recoil';
|
||||
import { Constants, LocalStorageKeys, EModelEndpoint } from 'librechat-data-provider';
|
||||
import type { TPlugin, TPluginAuthConfig } from 'librechat-data-provider';
|
||||
import { useAvailableToolsQuery } from '~/data-provider';
|
||||
import useLocalStorage from '~/hooks/useLocalStorageAlt';
|
||||
import { ephemeralAgentByConvoId } from '~/store';
|
||||
|
||||
const storageCondition = (value: unknown, rawCurrentValue?: string | null) => {
|
||||
if (rawCurrentValue) {
|
||||
try {
|
||||
const currentValue = rawCurrentValue?.trim() ?? '';
|
||||
if (currentValue.length > 2) {
|
||||
return true;
|
||||
}
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
}
|
||||
}
|
||||
return Array.isArray(value) && value.length > 0;
|
||||
};
|
||||
|
||||
interface UseMCPSelectOptions {
|
||||
conversationId?: string | null;
|
||||
}
|
||||
|
||||
export interface McpServerInfo {
|
||||
name: string;
|
||||
pluginKey: string;
|
||||
authConfig?: TPluginAuthConfig[];
|
||||
authenticated?: boolean;
|
||||
}
|
||||
|
||||
export function useMCPSelect({ conversationId }: UseMCPSelectOptions) {
|
||||
const key = conversationId ?? Constants.NEW_CONVO;
|
||||
const hasSetFetched = useRef<string | null>(null);
|
||||
const [ephemeralAgent, setEphemeralAgent] = useRecoilState(ephemeralAgentByConvoId(key));
|
||||
const { data: mcpToolDetails, isFetched } = useAvailableToolsQuery(EModelEndpoint.agents, {
|
||||
select: (data: TPlugin[]) => {
|
||||
const mcpToolsMap = new Map<string, McpServerInfo>();
|
||||
data.forEach((tool) => {
|
||||
const isMCP = tool.pluginKey.includes(Constants.mcp_delimiter);
|
||||
if (isMCP && tool.chatMenu !== false) {
|
||||
const parts = tool.pluginKey.split(Constants.mcp_delimiter);
|
||||
const serverName = parts[parts.length - 1];
|
||||
if (!mcpToolsMap.has(serverName)) {
|
||||
mcpToolsMap.set(serverName, {
|
||||
name: serverName,
|
||||
pluginKey: tool.pluginKey,
|
||||
authConfig: tool.authConfig,
|
||||
authenticated: tool.authenticated,
|
||||
});
|
||||
}
|
||||
}
|
||||
});
|
||||
return Array.from(mcpToolsMap.values());
|
||||
},
|
||||
});
|
||||
|
||||
const mcpState = useMemo(() => {
|
||||
return ephemeralAgent?.mcp ?? [];
|
||||
}, [ephemeralAgent?.mcp]);
|
||||
|
||||
const setSelectedValues = useCallback(
|
||||
(values: string[] | null | undefined) => {
|
||||
if (!values) {
|
||||
return;
|
||||
}
|
||||
if (!Array.isArray(values)) {
|
||||
return;
|
||||
}
|
||||
setEphemeralAgent((prev) => ({
|
||||
...prev,
|
||||
mcp: values,
|
||||
}));
|
||||
},
|
||||
[setEphemeralAgent],
|
||||
);
|
||||
|
||||
const [mcpValues, setMCPValues] = useLocalStorage<string[]>(
|
||||
`${LocalStorageKeys.LAST_MCP_}${key}`,
|
||||
mcpState,
|
||||
setSelectedValues,
|
||||
storageCondition,
|
||||
);
|
||||
|
||||
const [isPinned, setIsPinned] = useLocalStorage<boolean>(
|
||||
`${LocalStorageKeys.PIN_MCP_}${key}`,
|
||||
true,
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (hasSetFetched.current === key) {
|
||||
return;
|
||||
}
|
||||
if (!isFetched) {
|
||||
return;
|
||||
}
|
||||
hasSetFetched.current = key;
|
||||
if ((mcpToolDetails?.length ?? 0) > 0) {
|
||||
setMCPValues(mcpValues.filter((mcp) => mcpToolDetails?.some((tool) => tool.name === mcp)));
|
||||
return;
|
||||
}
|
||||
setMCPValues([]);
|
||||
}, [isFetched, setMCPValues, mcpToolDetails, key, mcpValues]);
|
||||
|
||||
const mcpServerNames = useMemo(() => {
|
||||
return (mcpToolDetails ?? []).map((tool) => tool.name);
|
||||
}, [mcpToolDetails]);
|
||||
|
||||
return {
|
||||
mcpValues,
|
||||
setMCPValues,
|
||||
mcpServerNames,
|
||||
ephemeralAgent,
|
||||
mcpToolDetails,
|
||||
setEphemeralAgent,
|
||||
isPinned,
|
||||
setIsPinned,
|
||||
};
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
import { useState, useCallback } from 'react';
|
||||
import { useRef, useState, useCallback } from 'react';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import useAuthSearchTool from '~/hooks/Plugins/useAuthSearchTool';
|
||||
import type { SearchApiKeyFormData } from '~/hooks/Plugins/useAuthSearchTool';
|
||||
@@ -11,6 +11,8 @@ export default function useSearchApiKeyForm({
|
||||
onRevoke?: () => void;
|
||||
}) {
|
||||
const methods = useForm<SearchApiKeyFormData>();
|
||||
const menuTriggerRef = useRef<HTMLButtonElement>(null);
|
||||
const badgeTriggerRef = useRef<HTMLInputElement>(null);
|
||||
const [isDialogOpen, setIsDialogOpen] = useState(false);
|
||||
const { installTool, removeTool } = useAuthSearchTool({ isEntityTool: true });
|
||||
const { reset } = methods;
|
||||
@@ -38,5 +40,7 @@ export default function useSearchApiKeyForm({
|
||||
setIsDialogOpen,
|
||||
handleRevokeApiKey,
|
||||
onSubmit: onSubmitHandler,
|
||||
badgeTriggerRef,
|
||||
menuTriggerRef,
|
||||
};
|
||||
}
|
||||
|
||||
119
client/src/hooks/Plugins/useToolToggle.ts
Normal file
119
client/src/hooks/Plugins/useToolToggle.ts
Normal file
@@ -0,0 +1,119 @@
|
||||
import { useRef, useEffect, useCallback, useMemo } from 'react';
|
||||
import { useRecoilState } from 'recoil';
|
||||
import debounce from 'lodash/debounce';
|
||||
import { Constants, LocalStorageKeys } from 'librechat-data-provider';
|
||||
import type { VerifyToolAuthResponse } from 'librechat-data-provider';
|
||||
import type { UseQueryOptions } from '@tanstack/react-query';
|
||||
import { useVerifyAgentToolAuth } from '~/data-provider';
|
||||
import useLocalStorage from '~/hooks/useLocalStorageAlt';
|
||||
import { ephemeralAgentByConvoId } from '~/store';
|
||||
|
||||
const storageCondition = (value: unknown, rawCurrentValue?: string | null) => {
|
||||
if (rawCurrentValue) {
|
||||
try {
|
||||
const currentValue = rawCurrentValue?.trim() ?? '';
|
||||
if (currentValue === 'true' && value === false) {
|
||||
return true;
|
||||
}
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
}
|
||||
}
|
||||
return value !== undefined && value !== null && value !== '' && value !== false;
|
||||
};
|
||||
|
||||
interface UseToolToggleOptions {
|
||||
conversationId?: string | null;
|
||||
toolKey: string;
|
||||
localStorageKey: LocalStorageKeys;
|
||||
isAuthenticated?: boolean;
|
||||
setIsDialogOpen?: (open: boolean) => void;
|
||||
/** Options for auth verification */
|
||||
authConfig?: {
|
||||
toolId: string;
|
||||
queryOptions?: UseQueryOptions<VerifyToolAuthResponse>;
|
||||
};
|
||||
}
|
||||
|
||||
export function useToolToggle({
|
||||
conversationId,
|
||||
toolKey,
|
||||
localStorageKey,
|
||||
isAuthenticated: externalIsAuthenticated,
|
||||
setIsDialogOpen,
|
||||
authConfig,
|
||||
}: UseToolToggleOptions) {
|
||||
const key = conversationId ?? Constants.NEW_CONVO;
|
||||
const [ephemeralAgent, setEphemeralAgent] = useRecoilState(ephemeralAgentByConvoId(key));
|
||||
|
||||
const authQuery = useVerifyAgentToolAuth(
|
||||
{ toolId: authConfig?.toolId || '' },
|
||||
{
|
||||
enabled: !!authConfig?.toolId,
|
||||
...authConfig?.queryOptions,
|
||||
},
|
||||
);
|
||||
|
||||
const isAuthenticated = useMemo(
|
||||
() =>
|
||||
externalIsAuthenticated ?? (authConfig ? (authQuery?.data?.authenticated ?? false) : false),
|
||||
[externalIsAuthenticated, authConfig, authQuery.data?.authenticated],
|
||||
);
|
||||
|
||||
const isToolEnabled = useMemo(() => {
|
||||
return ephemeralAgent?.[toolKey] ?? false;
|
||||
}, [ephemeralAgent, toolKey]);
|
||||
|
||||
/** Track previous value to prevent infinite loops */
|
||||
const prevIsToolEnabled = useRef(isToolEnabled);
|
||||
|
||||
const [toggleState, setToggleState] = useLocalStorage<boolean>(
|
||||
`${localStorageKey}${key}`,
|
||||
isToolEnabled,
|
||||
undefined,
|
||||
storageCondition,
|
||||
);
|
||||
|
||||
const [isPinned, setIsPinned] = useLocalStorage<boolean>(`${localStorageKey}pinned`, false);
|
||||
|
||||
const handleChange = useCallback(
|
||||
({ e, isChecked }: { e?: React.ChangeEvent<HTMLInputElement>; isChecked: boolean }) => {
|
||||
if (isAuthenticated !== undefined && !isAuthenticated && setIsDialogOpen) {
|
||||
setIsDialogOpen(true);
|
||||
e?.preventDefault?.();
|
||||
return;
|
||||
}
|
||||
setToggleState(isChecked);
|
||||
setEphemeralAgent((prev) => ({
|
||||
...prev,
|
||||
[toolKey]: isChecked,
|
||||
}));
|
||||
},
|
||||
[setToggleState, setIsDialogOpen, isAuthenticated, setEphemeralAgent, toolKey],
|
||||
);
|
||||
|
||||
const debouncedChange = useMemo(
|
||||
() => debounce(handleChange, 50, { leading: true }),
|
||||
[handleChange],
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (prevIsToolEnabled.current !== isToolEnabled) {
|
||||
setToggleState(isToolEnabled);
|
||||
}
|
||||
prevIsToolEnabled.current = isToolEnabled;
|
||||
}, [isToolEnabled, setToggleState]);
|
||||
|
||||
return {
|
||||
toggleState,
|
||||
handleChange,
|
||||
isToolEnabled,
|
||||
setToggleState,
|
||||
ephemeralAgent,
|
||||
debouncedChange,
|
||||
setEphemeralAgent,
|
||||
authData: authQuery?.data,
|
||||
isPinned,
|
||||
setIsPinned,
|
||||
};
|
||||
}
|
||||
@@ -7,10 +7,8 @@ import {
|
||||
Constants,
|
||||
/* @ts-ignore */
|
||||
createPayload,
|
||||
isAgentsEndpoint,
|
||||
LocalStorageKeys,
|
||||
removeNullishValues,
|
||||
isAssistantsEndpoint,
|
||||
} from 'librechat-data-provider';
|
||||
import type { TMessage, TPayload, TSubmission, EventSubmission } from 'librechat-data-provider';
|
||||
import type { EventHandlerParams } from './useEventHandlers';
|
||||
@@ -100,9 +98,7 @@ export default function useSSE(
|
||||
|
||||
const payloadData = createPayload(submission);
|
||||
let { payload } = payloadData;
|
||||
if (isAssistantsEndpoint(payload.endpoint) || isAgentsEndpoint(payload.endpoint)) {
|
||||
payload = removeNullishValues(payload) as TPayload;
|
||||
}
|
||||
payload = removeNullishValues(payload) as TPayload;
|
||||
|
||||
let textIndex = null;
|
||||
|
||||
|
||||
@@ -35,3 +35,4 @@ export { default as useOnClickOutside } from './useOnClickOutside';
|
||||
export { default as useSpeechToText } from './Input/useSpeechToText';
|
||||
export { default as useTextToSpeech } from './Input/useTextToSpeech';
|
||||
export { default as useGenerationsByLatest } from './useGenerationsByLatest';
|
||||
export { useResourcePermissions } from './useResourcePermissions';
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user