Compare commits
78 Commits
chore/tigh
...
add-model-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4fe600990c | ||
|
|
543b617e1c | ||
|
|
c827fdd10e | ||
|
|
ac608ded46 | ||
|
|
0e00f357a6 | ||
|
|
c465d7b732 | ||
|
|
aba0a93d1d | ||
|
|
a49b2b2833 | ||
|
|
e0ebb7097e | ||
|
|
9a79635012 | ||
|
|
ce19abc968 | ||
|
|
49cd3894aa | ||
|
|
da4aa37493 | ||
|
|
5a14ee9c6a | ||
|
|
3394aa5030 | ||
|
|
cee0579e0e | ||
|
|
d6c173c94b | ||
|
|
beff848a3f | ||
|
|
b9bc3123d6 | ||
|
|
639c7ad6ad | ||
|
|
822e2310ce | ||
|
|
2a0a8f6beb | ||
|
|
a6fd32a15a | ||
|
|
80a1a57fde | ||
|
|
3576391482 | ||
|
|
55557f7cc8 | ||
|
|
d7d02766ea | ||
|
|
627f0bffe5 | ||
|
|
8d1d95371f | ||
|
|
8bcdc041b2 | ||
|
|
9b6395d955 | ||
|
|
ad1503abdc | ||
|
|
a6d7ebf22e | ||
|
|
cebf140bce | ||
|
|
cc0cf359a2 | ||
|
|
3547873bc4 | ||
|
|
50b7bd6643 | ||
|
|
81186312ef | ||
|
|
4ec7bcb60f | ||
|
|
c78fd0fc83 | ||
|
|
d711fc7852 | ||
|
|
6af7efd0f4 | ||
|
|
d57e7aec73 | ||
|
|
e4e25aaf2b | ||
|
|
e8ddd279fd | ||
|
|
b742c8c7f9 | ||
|
|
803ade8601 | ||
|
|
dcd96c29c5 | ||
|
|
53c31b85d0 | ||
|
|
d07c2b3475 | ||
|
|
a434d28579 | ||
|
|
d82a63642d | ||
|
|
9585db14ba | ||
|
|
c191af6c9b | ||
|
|
39346d6b8e | ||
|
|
28d63dab71 | ||
|
|
49d1cefe71 | ||
|
|
0262c25989 | ||
|
|
90b037a67f | ||
|
|
fc8fd489d6 | ||
|
|
81b32e400a | ||
|
|
ae732b2ebc | ||
|
|
7e7e75714e | ||
|
|
ff54cbffd9 | ||
|
|
74e029e78f | ||
|
|
75324e1c7e | ||
|
|
949682ef0f | ||
|
|
66bd419baa | ||
|
|
aa42759ffd | ||
|
|
52e59e40be | ||
|
|
a955097faf | ||
|
|
b6413b06bc | ||
|
|
e6cebdf2b6 | ||
|
|
3eb6debe6a | ||
|
|
8780a78165 | ||
|
|
9dbf153489 | ||
|
|
4799593e1a | ||
|
|
a199b87478 |
57
.env.example
57
.env.example
@@ -15,6 +15,20 @@ HOST=localhost
|
||||
PORT=3080
|
||||
|
||||
MONGO_URI=mongodb://127.0.0.1:27017/LibreChat
|
||||
#The maximum number of connections in the connection pool. */
|
||||
MONGO_MAX_POOL_SIZE=
|
||||
#The minimum number of connections in the connection pool. */
|
||||
MONGO_MIN_POOL_SIZE=
|
||||
#The maximum number of connections that may be in the process of being established concurrently by the connection pool. */
|
||||
MONGO_MAX_CONNECTING=
|
||||
#The maximum number of milliseconds that a connection can remain idle in the pool before being removed and closed. */
|
||||
MONGO_MAX_IDLE_TIME_MS=
|
||||
#The maximum time in milliseconds that a thread can wait for a connection to become available. */
|
||||
MONGO_WAIT_QUEUE_TIMEOUT_MS=
|
||||
# Set to false to disable automatic index creation for all models associated with this connection. */
|
||||
MONGO_AUTO_INDEX=
|
||||
# Set to `false` to disable Mongoose automatically calling `createCollection()` on every model created on this connection. */
|
||||
MONGO_AUTO_CREATE=
|
||||
|
||||
DOMAIN_CLIENT=http://localhost:3080
|
||||
DOMAIN_SERVER=http://localhost:3080
|
||||
@@ -465,6 +479,21 @@ OPENID_ON_BEHALF_FLOW_USERINFO_SCOPE="user.read" # example for Scope Needed for
|
||||
# Set to true to use the OpenID Connect end session endpoint for logout
|
||||
OPENID_USE_END_SESSION_ENDPOINT=
|
||||
|
||||
#========================#
|
||||
# SharePoint Integration #
|
||||
#========================#
|
||||
# Requires Entra ID (OpenID) authentication to be configured
|
||||
|
||||
# Enable SharePoint file picker in chat and agent panels
|
||||
# ENABLE_SHAREPOINT_FILEPICKER=true
|
||||
|
||||
# SharePoint tenant base URL (e.g., https://yourtenant.sharepoint.com)
|
||||
# SHAREPOINT_BASE_URL=https://yourtenant.sharepoint.com
|
||||
|
||||
# Microsoft Graph API And SharePoint scopes for file picker
|
||||
# SHAREPOINT_PICKER_SHAREPOINT_SCOPE==https://yourtenant.sharepoint.com/AllSites.Read
|
||||
# SHAREPOINT_PICKER_GRAPH_SCOPE=Files.Read.All
|
||||
#========================#
|
||||
|
||||
# SAML
|
||||
# Note: If OpenID is enabled, SAML authentication will be automatically disabled.
|
||||
@@ -492,6 +521,21 @@ SAML_IMAGE_URL=
|
||||
# SAML_USE_AUTHN_RESPONSE_SIGNED=
|
||||
|
||||
|
||||
#===============================================#
|
||||
# Microsoft Graph API / Entra ID Integration #
|
||||
#===============================================#
|
||||
|
||||
# Enable Entra ID people search integration in permissions/sharing system
|
||||
# When enabled, the people picker will search both local database and Entra ID
|
||||
USE_ENTRA_ID_FOR_PEOPLE_SEARCH=false
|
||||
|
||||
# When enabled, entra id groups owners will be considered as members of the group
|
||||
ENTRA_ID_INCLUDE_OWNERS_AS_MEMBERS=false
|
||||
|
||||
# Microsoft Graph API scopes needed for people/group search
|
||||
# Default scopes provide access to user profiles and group memberships
|
||||
OPENID_GRAPH_SCOPES=User.Read,People.Read,GroupMember.Read.All
|
||||
|
||||
# LDAP
|
||||
LDAP_URL=
|
||||
LDAP_BIND_DN=
|
||||
@@ -698,3 +742,16 @@ OPENWEATHER_API_KEY=
|
||||
# JINA_API_KEY=your_jina_api_key
|
||||
# or
|
||||
# COHERE_API_KEY=your_cohere_api_key
|
||||
|
||||
#======================#
|
||||
# MCP Configuration #
|
||||
#======================#
|
||||
|
||||
# Treat 401/403 responses as OAuth requirement when no oauth metadata found
|
||||
# MCP_OAUTH_ON_AUTH_ERROR=true
|
||||
|
||||
# Timeout for OAuth detection requests in milliseconds
|
||||
# MCP_OAUTH_DETECTION_TIMEOUT=5000
|
||||
|
||||
# Cache connection status checks for this many milliseconds to avoid expensive verification
|
||||
# MCP_CONNECTION_CHECK_TTL=60000
|
||||
|
||||
4
.github/CONTRIBUTING.md
vendored
4
.github/CONTRIBUTING.md
vendored
@@ -147,7 +147,7 @@ Apply the following naming conventions to branches, labels, and other Git-relate
|
||||
## 8. Module Import Conventions
|
||||
|
||||
- `npm` packages first,
|
||||
- from shortest line (top) to longest (bottom)
|
||||
- from longest line (top) to shortest (bottom)
|
||||
|
||||
- Followed by typescript types (pertains to data-provider and client workspaces)
|
||||
- longest line (top) to shortest (bottom)
|
||||
@@ -157,6 +157,8 @@ Apply the following naming conventions to branches, labels, and other Git-relate
|
||||
- longest line (top) to shortest (bottom)
|
||||
- imports with alias `~` treated the same as relative import with respect to line length
|
||||
|
||||
**Note:** ESLint will automatically enforce these import conventions when you run `npm run lint --fix` or through pre-commit hooks.
|
||||
|
||||
---
|
||||
|
||||
Please ensure that you adapt this summary to fit the specific context and nuances of your project.
|
||||
|
||||
12
.github/workflows/data-provider.yml
vendored
12
.github/workflows/data-provider.yml
vendored
@@ -1,4 +1,4 @@
|
||||
name: Node.js Package
|
||||
name: Publish `librechat-data-provider` to NPM
|
||||
|
||||
on:
|
||||
push:
|
||||
@@ -6,6 +6,12 @@ on:
|
||||
- main
|
||||
paths:
|
||||
- 'packages/data-provider/package.json'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
reason:
|
||||
description: 'Reason for manual trigger'
|
||||
required: false
|
||||
default: 'Manual publish requested'
|
||||
|
||||
jobs:
|
||||
build:
|
||||
@@ -14,7 +20,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 16
|
||||
node-version: 20
|
||||
- run: cd packages/data-provider && npm ci
|
||||
- run: cd packages/data-provider && npm run build
|
||||
|
||||
@@ -25,7 +31,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 16
|
||||
node-version: 20
|
||||
registry-url: 'https://registry.npmjs.org'
|
||||
- run: cd packages/data-provider && npm ci
|
||||
- run: cd packages/data-provider && npm run build
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -137,3 +137,4 @@ helm/**/.values.yaml
|
||||
/.openai/
|
||||
/.tabnine/
|
||||
/.codeium
|
||||
*.local.md
|
||||
|
||||
3
.vscode/launch.json
vendored
3
.vscode/launch.json
vendored
@@ -8,7 +8,8 @@
|
||||
"skipFiles": ["<node_internals>/**"],
|
||||
"program": "${workspaceFolder}/api/server/index.js",
|
||||
"env": {
|
||||
"NODE_ENV": "production"
|
||||
"NODE_ENV": "production",
|
||||
"NODE_TLS_REJECT_UNAUTHORIZED": "0"
|
||||
},
|
||||
"console": "integratedTerminal",
|
||||
"envFile": "${workspaceFolder}/.env"
|
||||
|
||||
15
Dockerfile
15
Dockerfile
@@ -19,7 +19,12 @@ WORKDIR /app
|
||||
|
||||
USER node
|
||||
|
||||
COPY --chown=node:node . .
|
||||
COPY --chown=node:node package.json package-lock.json ./
|
||||
COPY --chown=node:node api/package.json ./api/package.json
|
||||
COPY --chown=node:node client/package.json ./client/package.json
|
||||
COPY --chown=node:node packages/data-provider/package.json ./packages/data-provider/package.json
|
||||
COPY --chown=node:node packages/data-schemas/package.json ./packages/data-schemas/package.json
|
||||
COPY --chown=node:node packages/api/package.json ./packages/api/package.json
|
||||
|
||||
RUN \
|
||||
# Allow mounting of these files, which have no default
|
||||
@@ -29,7 +34,11 @@ RUN \
|
||||
npm config set fetch-retry-maxtimeout 600000 ; \
|
||||
npm config set fetch-retries 5 ; \
|
||||
npm config set fetch-retry-mintimeout 15000 ; \
|
||||
npm install --no-audit; \
|
||||
npm ci --no-audit
|
||||
|
||||
COPY --chown=node:node . .
|
||||
|
||||
RUN \
|
||||
# React client build
|
||||
NODE_OPTIONS="--max-old-space-size=2048" npm run frontend; \
|
||||
npm prune --production; \
|
||||
@@ -47,4 +56,4 @@ CMD ["npm", "run", "backend"]
|
||||
# WORKDIR /usr/share/nginx/html
|
||||
# COPY --from=node /app/client/dist /usr/share/nginx/html
|
||||
# COPY client/nginx.conf /etc/nginx/conf.d/default.conf
|
||||
# ENTRYPOINT ["nginx", "-g", "daemon off;"]
|
||||
# ENTRYPOINT ["nginx", "-g", "daemon off;"]
|
||||
|
||||
147
MODEL_SPEC_FOLDERS.md
Normal file
147
MODEL_SPEC_FOLDERS.md
Normal file
@@ -0,0 +1,147 @@
|
||||
# Model Spec Subfolder Support
|
||||
|
||||
This enhancement adds the ability to organize model specs into subfolders/categories for better organization and user experience.
|
||||
|
||||
## Feature Overview
|
||||
|
||||
Model specs can now be grouped into folders by adding an optional `folder` field to each spec. This helps organize related models together, making it easier for users to find and select the appropriate model for their needs.
|
||||
|
||||
## Configuration
|
||||
|
||||
### Basic Usage
|
||||
|
||||
Add a `folder` field to any model spec in your `librechat.yaml`:
|
||||
|
||||
```yaml
|
||||
modelSpecs:
|
||||
list:
|
||||
- name: "gpt4_turbo"
|
||||
label: "GPT-4 Turbo"
|
||||
folder: "OpenAI Models" # This spec will appear under "OpenAI Models" folder
|
||||
preset:
|
||||
endpoint: "openAI"
|
||||
model: "gpt-4-turbo-preview"
|
||||
```
|
||||
|
||||
### Folder Structure
|
||||
|
||||
- **With Folder**: Model specs with the `folder` field will be grouped under that folder name
|
||||
- **Without Folder**: Model specs without the `folder` field appear at the root level
|
||||
- **Multiple Folders**: You can create as many folders as needed to organize your models
|
||||
- **Alphabetical Sorting**: Folders are sorted alphabetically, and specs within folders are sorted by their `order` field or label
|
||||
|
||||
### Example Configuration
|
||||
|
||||
```yaml
|
||||
modelSpecs:
|
||||
list:
|
||||
# OpenAI Models Category
|
||||
- name: "gpt4_turbo"
|
||||
label: "GPT-4 Turbo"
|
||||
folder: "OpenAI Models"
|
||||
preset:
|
||||
endpoint: "openAI"
|
||||
model: "gpt-4-turbo-preview"
|
||||
|
||||
- name: "gpt35_turbo"
|
||||
label: "GPT-3.5 Turbo"
|
||||
folder: "OpenAI Models"
|
||||
preset:
|
||||
endpoint: "openAI"
|
||||
model: "gpt-3.5-turbo"
|
||||
|
||||
# Anthropic Models Category
|
||||
- name: "claude3_opus"
|
||||
label: "Claude 3 Opus"
|
||||
folder: "Anthropic Models"
|
||||
preset:
|
||||
endpoint: "anthropic"
|
||||
model: "claude-3-opus-20240229"
|
||||
|
||||
# Root level model (no folder)
|
||||
- name: "quick_chat"
|
||||
label: "Quick Chat"
|
||||
preset:
|
||||
endpoint: "openAI"
|
||||
model: "gpt-3.5-turbo"
|
||||
```
|
||||
|
||||
## UI Features
|
||||
|
||||
### Folder Display
|
||||
- Folders are displayed with expand/collapse functionality
|
||||
- Folder icons change between open/closed states
|
||||
- Indentation shows the hierarchy clearly
|
||||
|
||||
### Search Integration
|
||||
- When searching for models, the folder path is shown for context
|
||||
- Search works across all models regardless of folder structure
|
||||
|
||||
### User Experience
|
||||
- Folders start expanded by default for easy access
|
||||
- Click on folder header to expand/collapse
|
||||
- Selected model is highlighted with a checkmark
|
||||
- Folder state is preserved during the session
|
||||
|
||||
## Benefits
|
||||
|
||||
1. **Better Organization**: Group related models together (e.g., by provider, capability, or use case)
|
||||
2. **Improved Navigation**: Users can quickly find models in organized categories
|
||||
3. **Scalability**: Handles large numbers of model specs without overwhelming the UI
|
||||
4. **Backward Compatible**: Existing configurations without folders continue to work
|
||||
5. **Flexible Structure**: Mix foldered and non-foldered specs as needed
|
||||
|
||||
## Use Cases
|
||||
|
||||
### By Provider
|
||||
```yaml
|
||||
folder: "OpenAI Models"
|
||||
folder: "Anthropic Models"
|
||||
folder: "Google Models"
|
||||
```
|
||||
|
||||
### By Capability
|
||||
```yaml
|
||||
folder: "Vision Models"
|
||||
folder: "Code Models"
|
||||
folder: "Creative Writing"
|
||||
```
|
||||
|
||||
### By Performance Tier
|
||||
```yaml
|
||||
folder: "Premium Models"
|
||||
folder: "Standard Models"
|
||||
folder: "Budget Models"
|
||||
```
|
||||
|
||||
### By Department/Team
|
||||
```yaml
|
||||
folder: "Engineering Team"
|
||||
folder: "Marketing Team"
|
||||
folder: "Research Team"
|
||||
```
|
||||
|
||||
## Implementation Details
|
||||
|
||||
### Type Changes
|
||||
- Added optional `folder?: string` field to `TModelSpec` type
|
||||
- Updated `tModelSpecSchema` to include the folder field validation
|
||||
|
||||
### Components
|
||||
- Created `ModelSpecFolder` component for rendering folder structure
|
||||
- Updated `ModelSelector` to use folder-aware rendering
|
||||
- Enhanced search results to show folder context
|
||||
|
||||
### Behavior
|
||||
- Folders are collapsible with state management
|
||||
- Models are sorted within folders by order/label
|
||||
- Root-level models appear after all folders
|
||||
|
||||
## Migration
|
||||
|
||||
No migration needed - the feature is fully backward compatible. Existing model specs without the `folder` field will continue to work and appear at the root level.
|
||||
|
||||
## See Also
|
||||
|
||||
- `librechat.example.subfolder.yaml` - Complete example configuration
|
||||
- GitHub Issue #9165 - Original feature request
|
||||
@@ -27,7 +27,6 @@ const {
|
||||
const { getModelMaxTokens, getModelMaxOutputTokens, matchModelName } = require('~/utils');
|
||||
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
|
||||
const { encodeAndFormatDocuments } = require('~/server/services/Files/documents');
|
||||
const { sleep } = require('~/server/utils');
|
||||
const BaseClient = require('./BaseClient');
|
||||
const { logger } = require('~/config');
|
||||
@@ -313,33 +312,6 @@ class AnthropicClient extends BaseClient {
|
||||
return files;
|
||||
}
|
||||
|
||||
async addDocuments(message, attachments) {
|
||||
// Only process documents
|
||||
const documentResult = await encodeAndFormatDocuments(
|
||||
this.options.req,
|
||||
attachments,
|
||||
EModelEndpoint.anthropic,
|
||||
);
|
||||
|
||||
message.documents =
|
||||
documentResult.documents && documentResult.documents.length
|
||||
? documentResult.documents
|
||||
: undefined;
|
||||
|
||||
return documentResult.files;
|
||||
}
|
||||
|
||||
async processAttachments(message, attachments) {
|
||||
// Process both images and documents
|
||||
const [imageFiles, documentFiles] = await Promise.all([
|
||||
this.addImageURLs(message, attachments),
|
||||
this.addDocuments(message, attachments),
|
||||
]);
|
||||
|
||||
// Combine files from both processors
|
||||
return [...imageFiles, ...documentFiles];
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {object} params
|
||||
* @param {number} params.promptTokens
|
||||
@@ -410,7 +382,7 @@ class AnthropicClient extends BaseClient {
|
||||
};
|
||||
}
|
||||
|
||||
const files = await this.processAttachments(latestMessage, attachments);
|
||||
const files = await this.addImageURLs(latestMessage, attachments);
|
||||
|
||||
this.options.attachments = files;
|
||||
}
|
||||
@@ -969,7 +941,7 @@ class AnthropicClient extends BaseClient {
|
||||
const content = `<conversation_context>
|
||||
${convo}
|
||||
</conversation_context>
|
||||
|
||||
|
||||
Please generate a title for this conversation.`;
|
||||
|
||||
const titleMessage = { role: 'user', content };
|
||||
|
||||
@@ -37,6 +37,8 @@ class BaseClient {
|
||||
this.conversationId;
|
||||
/** @type {string} */
|
||||
this.responseMessageId;
|
||||
/** @type {string} */
|
||||
this.parentMessageId;
|
||||
/** @type {TAttachment[]} */
|
||||
this.attachments;
|
||||
/** The key for the usage object's input tokens
|
||||
@@ -185,7 +187,8 @@ class BaseClient {
|
||||
this.user = user;
|
||||
const saveOptions = this.getSaveOptions();
|
||||
this.abortController = opts.abortController ?? new AbortController();
|
||||
const conversationId = overrideConvoId ?? opts.conversationId ?? crypto.randomUUID();
|
||||
const requestConvoId = overrideConvoId ?? opts.conversationId;
|
||||
const conversationId = requestConvoId ?? crypto.randomUUID();
|
||||
const parentMessageId = opts.parentMessageId ?? Constants.NO_PARENT;
|
||||
const userMessageId =
|
||||
overrideUserMessageId ?? opts.overrideParentMessageId ?? crypto.randomUUID();
|
||||
@@ -210,11 +213,12 @@ class BaseClient {
|
||||
...opts,
|
||||
user,
|
||||
head,
|
||||
saveOptions,
|
||||
userMessageId,
|
||||
requestConvoId,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
userMessageId,
|
||||
responseMessageId,
|
||||
saveOptions,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -233,11 +237,12 @@ class BaseClient {
|
||||
const {
|
||||
user,
|
||||
head,
|
||||
saveOptions,
|
||||
userMessageId,
|
||||
requestConvoId,
|
||||
conversationId,
|
||||
parentMessageId,
|
||||
userMessageId,
|
||||
responseMessageId,
|
||||
saveOptions,
|
||||
} = await this.setMessageOptions(opts);
|
||||
|
||||
const userMessage = opts.isEdited
|
||||
@@ -259,7 +264,8 @@ class BaseClient {
|
||||
}
|
||||
|
||||
if (typeof opts?.onStart === 'function') {
|
||||
opts.onStart(userMessage, responseMessageId);
|
||||
const isNewConvo = !requestConvoId && parentMessageId === Constants.NO_PARENT;
|
||||
opts.onStart(userMessage, responseMessageId, isNewConvo);
|
||||
}
|
||||
|
||||
return {
|
||||
@@ -614,15 +620,19 @@ class BaseClient {
|
||||
this.currentMessages.push(userMessage);
|
||||
}
|
||||
|
||||
/**
|
||||
* When the userMessage is pushed to currentMessages, the parentMessage is the userMessageId.
|
||||
* this only matters when buildMessages is utilizing the parentMessageId, and may vary on implementation
|
||||
*/
|
||||
const parentMessageId = isEdited ? head : userMessage.messageId;
|
||||
this.parentMessageId = parentMessageId;
|
||||
let {
|
||||
prompt: payload,
|
||||
tokenCountMap,
|
||||
promptTokens,
|
||||
} = await this.buildMessages(
|
||||
this.currentMessages,
|
||||
// When the userMessage is pushed to currentMessages, the parentMessage is the userMessageId.
|
||||
// this only matters when buildMessages is utilizing the parentMessageId, and may vary on implementation
|
||||
isEdited ? head : userMessage.messageId,
|
||||
parentMessageId,
|
||||
this.getBuildMessagesOptions(opts),
|
||||
opts,
|
||||
);
|
||||
@@ -1233,7 +1243,7 @@ class BaseClient {
|
||||
{},
|
||||
);
|
||||
|
||||
await this.processAttachments(message, files, this.visionMode);
|
||||
await this.addImageURLs(message, files, this.visionMode);
|
||||
|
||||
this.message_file_map[message.messageId] = files;
|
||||
return message;
|
||||
|
||||
@@ -268,7 +268,7 @@ class GoogleClient extends BaseClient {
|
||||
const formattedMessages = [];
|
||||
const attachments = await this.options.attachments;
|
||||
const latestMessage = { ...messages[messages.length - 1] };
|
||||
const files = await this.processAttachments(latestMessage, attachments, VisionModes.generative);
|
||||
const files = await this.addImageURLs(latestMessage, attachments, VisionModes.generative);
|
||||
this.options.attachments = files;
|
||||
messages[messages.length - 1] = latestMessage;
|
||||
|
||||
@@ -312,20 +312,6 @@ class GoogleClient extends BaseClient {
|
||||
return files;
|
||||
}
|
||||
|
||||
// eslint-disable-next-line no-unused-vars
|
||||
async addDocuments(message, attachments) {
|
||||
// GoogleClient doesn't support document processing yet
|
||||
// Return empty results for consistency
|
||||
return [];
|
||||
}
|
||||
|
||||
async processAttachments(message, attachments, mode = '') {
|
||||
// For GoogleClient, only process images
|
||||
const imageFiles = await this.addImageURLs(message, attachments, mode);
|
||||
const documentFiles = await this.addDocuments(message, attachments);
|
||||
return [...imageFiles, ...documentFiles];
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds the augmented prompt for attachments
|
||||
* TODO: Add File API Support
|
||||
@@ -359,7 +345,7 @@ class GoogleClient extends BaseClient {
|
||||
|
||||
const { prompt } = await this.buildMessagesPrompt(messages, parentMessageId);
|
||||
|
||||
const files = await this.processAttachments(latestMessage, attachments);
|
||||
const files = await this.addImageURLs(latestMessage, attachments);
|
||||
|
||||
this.options.attachments = files;
|
||||
|
||||
|
||||
@@ -372,19 +372,6 @@ class OpenAIClient extends BaseClient {
|
||||
return files;
|
||||
}
|
||||
|
||||
async addDocuments(message, attachments) {
|
||||
// OpenAI doesn't support native document processing yet
|
||||
// Return empty results for consistency
|
||||
return [];
|
||||
}
|
||||
|
||||
async processAttachments(message, attachments) {
|
||||
// For OpenAI, only process images
|
||||
const imageFiles = await this.addImageURLs(message, attachments);
|
||||
const documentFiles = await this.addDocuments(message, attachments);
|
||||
return [...imageFiles, ...documentFiles];
|
||||
}
|
||||
|
||||
async buildMessages(messages, parentMessageId, { promptPrefix = null }, opts) {
|
||||
let orderedMessages = this.constructor.getMessagesForConversation({
|
||||
messages,
|
||||
@@ -413,7 +400,7 @@ class OpenAIClient extends BaseClient {
|
||||
};
|
||||
}
|
||||
|
||||
const files = await this.processAttachments(
|
||||
const files = await this.addImageURLs(
|
||||
orderedMessages[orderedMessages.length - 1],
|
||||
attachments,
|
||||
);
|
||||
@@ -666,8 +653,10 @@ class OpenAIClient extends BaseClient {
|
||||
if (headers && typeof headers === 'object' && !Array.isArray(headers)) {
|
||||
configOptions.baseOptions = {
|
||||
headers: resolveHeaders({
|
||||
...headers,
|
||||
...configOptions?.baseOptions?.headers,
|
||||
headers: {
|
||||
...headers,
|
||||
...configOptions?.baseOptions?.headers,
|
||||
},
|
||||
}),
|
||||
};
|
||||
}
|
||||
@@ -762,7 +751,7 @@ class OpenAIClient extends BaseClient {
|
||||
groupMap,
|
||||
});
|
||||
|
||||
this.options.headers = resolveHeaders(headers);
|
||||
this.options.headers = resolveHeaders({ headers });
|
||||
this.options.reverseProxyUrl = baseURL ?? null;
|
||||
this.langchainProxy = extractBaseURL(this.options.reverseProxyUrl);
|
||||
this.apiKey = azureOptions.azureOpenAIApiKey;
|
||||
@@ -1194,7 +1183,7 @@ ${convo}
|
||||
modelGroupMap,
|
||||
groupMap,
|
||||
});
|
||||
opts.defaultHeaders = resolveHeaders(headers);
|
||||
opts.defaultHeaders = resolveHeaders({ headers });
|
||||
this.langchainProxy = extractBaseURL(baseURL);
|
||||
this.apiKey = azureOptions.azureOpenAIApiKey;
|
||||
|
||||
|
||||
@@ -3,61 +3,24 @@ const { EModelEndpoint, ContentTypes } = require('librechat-data-provider');
|
||||
const { HumanMessage, AIMessage, SystemMessage } = require('@langchain/core/messages');
|
||||
|
||||
/**
|
||||
* Formats a message with document attachments for specific endpoints.
|
||||
* Formats a message to OpenAI Vision API payload format.
|
||||
*
|
||||
* @param {Object} params - The parameters for formatting.
|
||||
* @param {Object} params.message - The message object to format.
|
||||
* @param {Array<Object>} [params.documents] - The document attachments for the message.
|
||||
* @param {string} [params.message.role] - The role of the message sender (must be 'user').
|
||||
* @param {string} [params.message.content] - The text content of the message.
|
||||
* @param {EModelEndpoint} [params.endpoint] - Identifier for specific endpoint handling
|
||||
* @returns {(Object)} - The formatted message.
|
||||
*/
|
||||
const formatDocumentMessage = ({ message, documents, endpoint }) => {
|
||||
const contentParts = [];
|
||||
|
||||
// Add documents first (for Anthropic PDFs)
|
||||
if (documents && documents.length > 0) {
|
||||
contentParts.push(...documents);
|
||||
}
|
||||
|
||||
// Add text content
|
||||
contentParts.push({ type: ContentTypes.TEXT, text: message.content });
|
||||
|
||||
if (endpoint === EModelEndpoint.anthropic) {
|
||||
message.content = contentParts;
|
||||
return message;
|
||||
}
|
||||
|
||||
// For other endpoints, might need different handling
|
||||
message.content = contentParts;
|
||||
return message;
|
||||
};
|
||||
|
||||
/**
|
||||
* Formats a message with vision capabilities (image_urls) for specific endpoints.
|
||||
*
|
||||
* @param {Object} params - The parameters for formatting.
|
||||
* @param {Object} params.message - The message object to format.
|
||||
* @param {Array<string>} [params.image_urls] - The image_urls to attach to the message.
|
||||
* @param {EModelEndpoint} [params.endpoint] - Identifier for specific endpoint handling
|
||||
* @returns {(Object)} - The formatted message.
|
||||
*/
|
||||
const formatVisionMessage = ({ message, image_urls, endpoint }) => {
|
||||
const contentParts = [];
|
||||
|
||||
// Add images
|
||||
if (image_urls && image_urls.length > 0) {
|
||||
contentParts.push(...image_urls);
|
||||
}
|
||||
|
||||
// Add text content
|
||||
contentParts.push({ type: ContentTypes.TEXT, text: message.content });
|
||||
|
||||
if (endpoint === EModelEndpoint.anthropic) {
|
||||
message.content = contentParts;
|
||||
message.content = [...image_urls, { type: ContentTypes.TEXT, text: message.content }];
|
||||
return message;
|
||||
}
|
||||
|
||||
message.content = [{ type: ContentTypes.TEXT, text: message.content }, ...image_urls];
|
||||
|
||||
return message;
|
||||
};
|
||||
|
||||
@@ -95,18 +58,7 @@ const formatMessage = ({ message, userName, assistantName, endpoint, langChain =
|
||||
content,
|
||||
};
|
||||
|
||||
const { image_urls, documents } = message;
|
||||
|
||||
// Handle documents
|
||||
if (Array.isArray(documents) && documents.length > 0 && role === 'user') {
|
||||
return formatDocumentMessage({
|
||||
message: formattedMessage,
|
||||
documents: message.documents,
|
||||
endpoint,
|
||||
});
|
||||
}
|
||||
|
||||
// Handle images
|
||||
const { image_urls } = message;
|
||||
if (Array.isArray(image_urls) && image_urls.length > 0 && role === 'user') {
|
||||
return formatVisionMessage({
|
||||
message: formattedMessage,
|
||||
@@ -194,21 +146,7 @@ const formatAgentMessages = (payload) => {
|
||||
message.content = [{ type: ContentTypes.TEXT, [ContentTypes.TEXT]: message.content }];
|
||||
}
|
||||
if (message.role !== 'assistant') {
|
||||
// Check if message has documents and preserve array structure
|
||||
const hasDocuments =
|
||||
Array.isArray(message.content) &&
|
||||
message.content.some((part) => part && part.type === 'document');
|
||||
|
||||
if (hasDocuments && message.role === 'user') {
|
||||
// For user messages with documents, create HumanMessage directly with array content
|
||||
messages.push(new HumanMessage({ content: message.content }));
|
||||
} else if (hasDocuments && message.role === 'system') {
|
||||
// For system messages with documents, create SystemMessage directly with array content
|
||||
messages.push(new SystemMessage({ content: message.content }));
|
||||
} else {
|
||||
// Use regular formatting for messages without documents
|
||||
messages.push(formatMessage({ message, langChain: true }));
|
||||
}
|
||||
messages.push(formatMessage({ message, langChain: true }));
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -301,8 +239,6 @@ const formatAgentMessages = (payload) => {
|
||||
|
||||
module.exports = {
|
||||
formatMessage,
|
||||
formatDocumentMessage,
|
||||
formatVisionMessage,
|
||||
formatFromLangChain,
|
||||
formatAgentMessages,
|
||||
formatLangChainMessages,
|
||||
|
||||
@@ -245,7 +245,7 @@ describe('AnthropicClient', () => {
|
||||
});
|
||||
|
||||
describe('Claude 4 model headers', () => {
|
||||
it('should add "prompt-caching" beta header for claude-sonnet-4 model', () => {
|
||||
it('should add "prompt-caching" and "context-1m" beta headers for claude-sonnet-4 model', () => {
|
||||
const client = new AnthropicClient('test-api-key');
|
||||
const modelOptions = {
|
||||
model: 'claude-sonnet-4-20250514',
|
||||
@@ -255,10 +255,30 @@ describe('AnthropicClient', () => {
|
||||
expect(anthropicClient._options.defaultHeaders).toBeDefined();
|
||||
expect(anthropicClient._options.defaultHeaders).toHaveProperty('anthropic-beta');
|
||||
expect(anthropicClient._options.defaultHeaders['anthropic-beta']).toBe(
|
||||
'prompt-caching-2024-07-31',
|
||||
'prompt-caching-2024-07-31,context-1m-2025-08-07',
|
||||
);
|
||||
});
|
||||
|
||||
it('should add "prompt-caching" and "context-1m" beta headers for claude-sonnet-4 model formats', () => {
|
||||
const client = new AnthropicClient('test-api-key');
|
||||
const modelVariations = [
|
||||
'claude-sonnet-4-20250514',
|
||||
'claude-sonnet-4-latest',
|
||||
'anthropic/claude-sonnet-4-20250514',
|
||||
];
|
||||
|
||||
modelVariations.forEach((model) => {
|
||||
const modelOptions = { model };
|
||||
client.setOptions({ modelOptions, promptCache: true });
|
||||
const anthropicClient = client.getClient(modelOptions);
|
||||
expect(anthropicClient._options.defaultHeaders).toBeDefined();
|
||||
expect(anthropicClient._options.defaultHeaders).toHaveProperty('anthropic-beta');
|
||||
expect(anthropicClient._options.defaultHeaders['anthropic-beta']).toBe(
|
||||
'prompt-caching-2024-07-31,context-1m-2025-08-07',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it('should add "prompt-caching" beta header for claude-opus-4 model', () => {
|
||||
const client = new AnthropicClient('test-api-key');
|
||||
const modelOptions = {
|
||||
@@ -273,20 +293,6 @@ describe('AnthropicClient', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should add "prompt-caching" beta header for claude-4-sonnet model', () => {
|
||||
const client = new AnthropicClient('test-api-key');
|
||||
const modelOptions = {
|
||||
model: 'claude-4-sonnet-20250514',
|
||||
};
|
||||
client.setOptions({ modelOptions, promptCache: true });
|
||||
const anthropicClient = client.getClient(modelOptions);
|
||||
expect(anthropicClient._options.defaultHeaders).toBeDefined();
|
||||
expect(anthropicClient._options.defaultHeaders).toHaveProperty('anthropic-beta');
|
||||
expect(anthropicClient._options.defaultHeaders['anthropic-beta']).toBe(
|
||||
'prompt-caching-2024-07-31',
|
||||
);
|
||||
});
|
||||
|
||||
it('should add "prompt-caching" beta header for claude-4-opus model', () => {
|
||||
const client = new AnthropicClient('test-api-key');
|
||||
const modelOptions = {
|
||||
|
||||
@@ -579,6 +579,8 @@ describe('BaseClient', () => {
|
||||
expect(onStart).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ text: 'Hello, world!' }),
|
||||
expect.any(String),
|
||||
/** `isNewConvo` */
|
||||
true,
|
||||
);
|
||||
});
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ const axios = require('axios');
|
||||
const { tool } = require('@langchain/core/tools');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { Tools, EToolResources } = require('librechat-data-provider');
|
||||
const { filterFilesByAgentAccess } = require('~/server/services/Files/permissions');
|
||||
const { generateShortLivedToken } = require('~/server/services/AuthService');
|
||||
const { getFiles } = require('~/models/File');
|
||||
|
||||
@@ -22,14 +23,24 @@ const primeFiles = async (options) => {
|
||||
const file_ids = tool_resources?.[EToolResources.file_search]?.file_ids ?? [];
|
||||
const agentResourceIds = new Set(file_ids);
|
||||
const resourceFiles = tool_resources?.[EToolResources.file_search]?.files ?? [];
|
||||
const dbFiles = (
|
||||
(await getFiles(
|
||||
{ file_id: { $in: file_ids } },
|
||||
null,
|
||||
{ text: 0 },
|
||||
{ userId: req?.user?.id, agentId },
|
||||
)) ?? []
|
||||
).concat(resourceFiles);
|
||||
|
||||
// Get all files first
|
||||
const allFiles = (await getFiles({ file_id: { $in: file_ids } }, null, { text: 0 })) ?? [];
|
||||
|
||||
// Filter by access if user and agent are provided
|
||||
let dbFiles;
|
||||
if (req?.user?.id && agentId) {
|
||||
dbFiles = await filterFilesByAgentAccess({
|
||||
files: allFiles,
|
||||
userId: req.user.id,
|
||||
role: req.user.role,
|
||||
agentId,
|
||||
});
|
||||
} else {
|
||||
dbFiles = allFiles;
|
||||
}
|
||||
|
||||
dbFiles = dbFiles.concat(resourceFiles);
|
||||
|
||||
let toolContext = `- Note: Semantic search is available through the ${Tools.file_search} tool but no files are currently loaded. Request the user to upload documents to search through.`;
|
||||
|
||||
@@ -114,11 +125,13 @@ const createFileSearchTool = async ({ req, files, entity_id }) => {
|
||||
}
|
||||
|
||||
const formattedResults = validResults
|
||||
.flatMap((result) =>
|
||||
.flatMap((result, fileIndex) =>
|
||||
result.data.map(([docInfo, distance]) => ({
|
||||
filename: docInfo.metadata.source.split('/').pop(),
|
||||
content: docInfo.page_content,
|
||||
distance,
|
||||
file_id: files[fileIndex]?.file_id,
|
||||
page: docInfo.metadata.page || null,
|
||||
})),
|
||||
)
|
||||
// TODO: results should be sorted by relevance, not distance
|
||||
@@ -128,18 +141,37 @@ const createFileSearchTool = async ({ req, files, entity_id }) => {
|
||||
|
||||
const formattedString = formattedResults
|
||||
.map(
|
||||
(result) =>
|
||||
`File: ${result.filename}\nRelevance: ${1.0 - result.distance.toFixed(4)}\nContent: ${
|
||||
(result, index) =>
|
||||
`File: ${result.filename}\nAnchor: \\ue202turn0file${index} (${result.filename})\nRelevance: ${(1.0 - result.distance).toFixed(4)}\nContent: ${
|
||||
result.content
|
||||
}\n`,
|
||||
)
|
||||
.join('\n---\n');
|
||||
|
||||
return formattedString;
|
||||
const sources = formattedResults.map((result) => ({
|
||||
type: 'file',
|
||||
fileId: result.file_id,
|
||||
content: result.content,
|
||||
fileName: result.filename,
|
||||
relevance: 1.0 - result.distance,
|
||||
pages: result.page ? [result.page] : [],
|
||||
pageRelevance: result.page ? { [result.page]: 1.0 - result.distance } : {},
|
||||
}));
|
||||
|
||||
return [formattedString, { [Tools.file_search]: { sources } }];
|
||||
},
|
||||
{
|
||||
name: Tools.file_search,
|
||||
description: `Performs semantic search across attached "${Tools.file_search}" documents using natural language queries. This tool analyzes the content of uploaded files to find relevant information, quotes, and passages that best match your query. Use this to extract specific information or find relevant sections within the available documents.`,
|
||||
responseFormat: 'content_and_artifact',
|
||||
description: `Performs semantic search across attached "${Tools.file_search}" documents using natural language queries. This tool analyzes the content of uploaded files to find relevant information, quotes, and passages that best match your query. Use this to extract specific information or find relevant sections within the available documents.
|
||||
|
||||
**CITE FILE SEARCH RESULTS:**
|
||||
Use anchor markers immediately after statements derived from file content. Reference the filename in your text:
|
||||
- File citation: "The document.pdf states that... \\ue202turn0file0"
|
||||
- Page reference: "According to report.docx... \\ue202turn0file1"
|
||||
- Multi-file: "Multiple sources confirm... \\ue200\\ue202turn0file0\\ue202turn0file1\\ue201"
|
||||
|
||||
**ALWAYS mention the filename in your text before the citation marker. NEVER use markdown links or footnotes.**`,
|
||||
schema: z.object({
|
||||
query: z
|
||||
.string()
|
||||
|
||||
@@ -3,7 +3,7 @@ const { SerpAPI } = require('@langchain/community/tools/serpapi');
|
||||
const { Calculator } = require('@langchain/community/tools/calculator');
|
||||
const { mcpToolPattern, loadWebSearchAuth } = require('@librechat/api');
|
||||
const { EnvVar, createCodeExecutionTool, createSearchTool } = require('@librechat/agents');
|
||||
const { Tools, EToolResources, replaceSpecialVars } = require('librechat-data-provider');
|
||||
const { Tools, Constants, EToolResources, replaceSpecialVars } = require('librechat-data-provider');
|
||||
const {
|
||||
availableTools,
|
||||
manifestToolMap,
|
||||
@@ -24,9 +24,9 @@ const {
|
||||
const { primeFiles: primeCodeFiles } = require('~/server/services/Files/Code/process');
|
||||
const { createFileSearchTool, primeFiles: primeSearchFiles } = require('./fileSearch');
|
||||
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
|
||||
const { createMCPTool, createMCPTools } = require('~/server/services/MCP');
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
const { getCachedTools } = require('~/server/services/Config');
|
||||
const { createMCPTool } = require('~/server/services/MCP');
|
||||
|
||||
/**
|
||||
* Validates the availability and authentication of tools for a user based on environment variables or user-specific plugin authentication values.
|
||||
@@ -123,6 +123,8 @@ const getAuthFields = (toolKey) => {
|
||||
*
|
||||
* @param {object} object
|
||||
* @param {string} object.user
|
||||
* @param {Record<string, Record<string, string>>} [object.userMCPAuthMap]
|
||||
* @param {AbortSignal} [object.signal]
|
||||
* @param {Pick<Agent, 'id' | 'provider' | 'model'>} [object.agent]
|
||||
* @param {string} [object.model]
|
||||
* @param {EModelEndpoint} [object.endpoint]
|
||||
@@ -137,7 +139,9 @@ const loadTools = async ({
|
||||
user,
|
||||
agent,
|
||||
model,
|
||||
signal,
|
||||
endpoint,
|
||||
userMCPAuthMap,
|
||||
tools = [],
|
||||
options = {},
|
||||
functions = true,
|
||||
@@ -231,6 +235,7 @@ const loadTools = async ({
|
||||
/** @type {Record<string, string>} */
|
||||
const toolContextMap = {};
|
||||
const cachedTools = (await getCachedTools({ userId: user, includeGlobal: true })) ?? {};
|
||||
const requestedMCPTools = {};
|
||||
|
||||
for (const tool of tools) {
|
||||
if (tool === Tools.execute_code) {
|
||||
@@ -299,14 +304,35 @@ Current Date & Time: ${replaceSpecialVars({ text: '{{iso_datetime}}' })}
|
||||
};
|
||||
continue;
|
||||
} else if (tool && cachedTools && mcpToolPattern.test(tool)) {
|
||||
requestedTools[tool] = async () =>
|
||||
const [toolName, serverName] = tool.split(Constants.mcp_delimiter);
|
||||
if (toolName === Constants.mcp_all) {
|
||||
const currentMCPGenerator = async (index) =>
|
||||
createMCPTools({
|
||||
req: options.req,
|
||||
res: options.res,
|
||||
index,
|
||||
serverName,
|
||||
userMCPAuthMap,
|
||||
model: agent?.model ?? model,
|
||||
provider: agent?.provider ?? endpoint,
|
||||
signal,
|
||||
});
|
||||
requestedMCPTools[serverName] = [currentMCPGenerator];
|
||||
continue;
|
||||
}
|
||||
const currentMCPGenerator = async (index) =>
|
||||
createMCPTool({
|
||||
index,
|
||||
req: options.req,
|
||||
res: options.res,
|
||||
toolKey: tool,
|
||||
userMCPAuthMap,
|
||||
model: agent?.model ?? model,
|
||||
provider: agent?.provider ?? endpoint,
|
||||
signal,
|
||||
});
|
||||
requestedMCPTools[serverName] = requestedMCPTools[serverName] || [];
|
||||
requestedMCPTools[serverName].push(currentMCPGenerator);
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -346,6 +372,34 @@ Current Date & Time: ${replaceSpecialVars({ text: '{{iso_datetime}}' })}
|
||||
}
|
||||
|
||||
const loadedTools = (await Promise.all(toolPromises)).flatMap((plugin) => plugin || []);
|
||||
const mcpToolPromises = [];
|
||||
/** MCP server tools are initialized sequentially by server */
|
||||
let index = -1;
|
||||
for (const [serverName, generators] of Object.entries(requestedMCPTools)) {
|
||||
index++;
|
||||
for (const generator of generators) {
|
||||
try {
|
||||
if (generator && generators.length === 1) {
|
||||
mcpToolPromises.push(
|
||||
generator(index).catch((error) => {
|
||||
logger.error(`Error loading ${serverName} tools:`, error);
|
||||
return null;
|
||||
}),
|
||||
);
|
||||
continue;
|
||||
}
|
||||
const mcpTool = await generator(index);
|
||||
if (Array.isArray(mcpTool)) {
|
||||
loadedTools.push(...mcpTool);
|
||||
} else if (mcpTool) {
|
||||
loadedTools.push(mcpTool);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`Error loading MCP tool for server ${serverName}:`, error);
|
||||
}
|
||||
}
|
||||
}
|
||||
loadedTools.push(...(await Promise.all(mcpToolPromises)).flatMap((plugin) => plugin || []));
|
||||
return { loadedTools, toolContextMap };
|
||||
};
|
||||
|
||||
|
||||
3
api/cache/cacheConfig.js
vendored
3
api/cache/cacheConfig.js
vendored
@@ -52,7 +52,8 @@ const cacheConfig = {
|
||||
REDIS_CONNECT_TIMEOUT: math(process.env.REDIS_CONNECT_TIMEOUT, 10000),
|
||||
/** Queue commands when disconnected */
|
||||
REDIS_ENABLE_OFFLINE_QUEUE: isEnabled(process.env.REDIS_ENABLE_OFFLINE_QUEUE ?? 'true'),
|
||||
|
||||
/** Enable redis cluster without the need of multiple URIs */
|
||||
USE_REDIS_CLUSTER: isEnabled(process.env.USE_REDIS_CLUSTER ?? 'false'),
|
||||
CI: isEnabled(process.env.CI),
|
||||
DEBUG_MEMORY_CACHE: isEnabled(process.env.DEBUG_MEMORY_CACHE),
|
||||
|
||||
|
||||
33
api/cache/cacheConfig.spec.js
vendored
33
api/cache/cacheConfig.spec.js
vendored
@@ -14,6 +14,7 @@ describe('cacheConfig', () => {
|
||||
delete process.env.REDIS_KEY_PREFIX_VAR;
|
||||
delete process.env.REDIS_KEY_PREFIX;
|
||||
delete process.env.USE_REDIS;
|
||||
delete process.env.USE_REDIS_CLUSTER;
|
||||
delete process.env.REDIS_PING_INTERVAL;
|
||||
delete process.env.FORCED_IN_MEMORY_CACHE_NAMESPACES;
|
||||
|
||||
@@ -101,6 +102,38 @@ describe('cacheConfig', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('USE_REDIS_CLUSTER configuration', () => {
|
||||
test('should default to false when USE_REDIS_CLUSTER is not set', () => {
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
expect(cacheConfig.USE_REDIS_CLUSTER).toBe(false);
|
||||
});
|
||||
|
||||
test('should be false when USE_REDIS_CLUSTER is set to false', () => {
|
||||
process.env.USE_REDIS_CLUSTER = 'false';
|
||||
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
expect(cacheConfig.USE_REDIS_CLUSTER).toBe(false);
|
||||
});
|
||||
|
||||
test('should be true when USE_REDIS_CLUSTER is set to true', () => {
|
||||
process.env.USE_REDIS_CLUSTER = 'true';
|
||||
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
expect(cacheConfig.USE_REDIS_CLUSTER).toBe(true);
|
||||
});
|
||||
|
||||
test('should work with USE_REDIS enabled and REDIS_URI set', () => {
|
||||
process.env.USE_REDIS_CLUSTER = 'true';
|
||||
process.env.USE_REDIS = 'true';
|
||||
process.env.REDIS_URI = 'redis://localhost:6379';
|
||||
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
expect(cacheConfig.USE_REDIS_CLUSTER).toBe(true);
|
||||
expect(cacheConfig.USE_REDIS).toBe(true);
|
||||
expect(cacheConfig.REDIS_URI).toBe('redis://localhost:6379');
|
||||
});
|
||||
});
|
||||
|
||||
describe('REDIS_CA file reading', () => {
|
||||
test('should be null when REDIS_CA is not set', () => {
|
||||
const { cacheConfig } = require('./cacheConfig');
|
||||
|
||||
1
api/cache/getLogStores.js
vendored
1
api/cache/getLogStores.js
vendored
@@ -31,7 +31,6 @@ const namespaces = {
|
||||
[CacheKeys.SAML_SESSION]: sessionCache(CacheKeys.SAML_SESSION),
|
||||
|
||||
[CacheKeys.ROLES]: standardCache(CacheKeys.ROLES),
|
||||
[CacheKeys.MCP_TOOLS]: standardCache(CacheKeys.MCP_TOOLS),
|
||||
[CacheKeys.CONFIG_STORE]: standardCache(CacheKeys.CONFIG_STORE),
|
||||
[CacheKeys.STATIC_CONFIG]: standardCache(CacheKeys.STATIC_CONFIG),
|
||||
[CacheKeys.PENDING_REQ]: standardCache(CacheKeys.PENDING_REQ),
|
||||
|
||||
45
api/cache/redisClients.js
vendored
45
api/cache/redisClients.js
vendored
@@ -38,7 +38,7 @@ if (cacheConfig.USE_REDIS) {
|
||||
const targetError = 'READONLY';
|
||||
if (err.message.includes(targetError)) {
|
||||
logger.warn('ioredis reconnecting due to READONLY error');
|
||||
return true;
|
||||
return 2; // Return retry delay instead of boolean
|
||||
}
|
||||
return false;
|
||||
},
|
||||
@@ -48,26 +48,29 @@ if (cacheConfig.USE_REDIS) {
|
||||
};
|
||||
|
||||
ioredisClient =
|
||||
urls.length === 1
|
||||
urls.length === 1 && !cacheConfig.USE_REDIS_CLUSTER
|
||||
? new IoRedis(cacheConfig.REDIS_URI, redisOptions)
|
||||
: new IoRedis.Cluster(cacheConfig.REDIS_URI, {
|
||||
redisOptions,
|
||||
clusterRetryStrategy: (times) => {
|
||||
if (
|
||||
cacheConfig.REDIS_RETRY_MAX_ATTEMPTS > 0 &&
|
||||
times > cacheConfig.REDIS_RETRY_MAX_ATTEMPTS
|
||||
) {
|
||||
logger.error(
|
||||
`ioredis cluster giving up after ${cacheConfig.REDIS_RETRY_MAX_ATTEMPTS} reconnection attempts`,
|
||||
);
|
||||
return null;
|
||||
}
|
||||
const delay = Math.min(times * 100, cacheConfig.REDIS_RETRY_MAX_DELAY);
|
||||
logger.info(`ioredis cluster reconnecting... attempt ${times}, delay ${delay}ms`);
|
||||
return delay;
|
||||
: new IoRedis.Cluster(
|
||||
urls.map((url) => ({ host: url.hostname, port: parseInt(url.port, 10) || 6379 })),
|
||||
{
|
||||
redisOptions,
|
||||
clusterRetryStrategy: (times) => {
|
||||
if (
|
||||
cacheConfig.REDIS_RETRY_MAX_ATTEMPTS > 0 &&
|
||||
times > cacheConfig.REDIS_RETRY_MAX_ATTEMPTS
|
||||
) {
|
||||
logger.error(
|
||||
`ioredis cluster giving up after ${cacheConfig.REDIS_RETRY_MAX_ATTEMPTS} reconnection attempts`,
|
||||
);
|
||||
return null;
|
||||
}
|
||||
const delay = Math.min(times * 100, cacheConfig.REDIS_RETRY_MAX_DELAY);
|
||||
logger.info(`ioredis cluster reconnecting... attempt ${times}, delay ${delay}ms`);
|
||||
return delay;
|
||||
},
|
||||
enableOfflineQueue: cacheConfig.REDIS_ENABLE_OFFLINE_QUEUE,
|
||||
},
|
||||
enableOfflineQueue: cacheConfig.REDIS_ENABLE_OFFLINE_QUEUE,
|
||||
});
|
||||
);
|
||||
|
||||
ioredisClient.on('error', (err) => {
|
||||
logger.error('ioredis client error:', err);
|
||||
@@ -145,10 +148,10 @@ if (cacheConfig.USE_REDIS) {
|
||||
};
|
||||
|
||||
keyvRedisClient =
|
||||
urls.length === 1
|
||||
urls.length === 1 && !cacheConfig.USE_REDIS_CLUSTER
|
||||
? createClient({ url: cacheConfig.REDIS_URI, ...redisOptions })
|
||||
: createCluster({
|
||||
rootNodes: cacheConfig.REDIS_URI.split(',').map((url) => ({ url })),
|
||||
rootNodes: urls.map((url) => ({ url: url.href })),
|
||||
defaults: redisOptions,
|
||||
});
|
||||
|
||||
|
||||
@@ -1,27 +1,13 @@
|
||||
const { MCPManager, FlowStateManager } = require('@librechat/api');
|
||||
const { EventSource } = require('eventsource');
|
||||
const { Time } = require('librechat-data-provider');
|
||||
const { MCPManager, FlowStateManager } = require('@librechat/api');
|
||||
const logger = require('./winston');
|
||||
|
||||
global.EventSource = EventSource;
|
||||
|
||||
/** @type {MCPManager} */
|
||||
let mcpManager = null;
|
||||
let flowManager = null;
|
||||
|
||||
/**
|
||||
* @param {string} [userId] - Optional user ID, to avoid disconnecting the current user.
|
||||
* @returns {MCPManager}
|
||||
*/
|
||||
function getMCPManager(userId) {
|
||||
if (!mcpManager) {
|
||||
mcpManager = MCPManager.getInstance();
|
||||
} else {
|
||||
mcpManager.checkIdleConnections(userId);
|
||||
}
|
||||
return mcpManager;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {Keyv} flowsCache
|
||||
* @returns {FlowStateManager}
|
||||
@@ -37,6 +23,7 @@ function getFlowStateManager(flowsCache) {
|
||||
|
||||
module.exports = {
|
||||
logger,
|
||||
getMCPManager,
|
||||
createMCPManager: MCPManager.createInstance,
|
||||
getMCPManager: MCPManager.getInstance,
|
||||
getFlowStateManager,
|
||||
};
|
||||
|
||||
@@ -1,11 +1,34 @@
|
||||
require('dotenv').config();
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
|
||||
const mongoose = require('mongoose');
|
||||
const MONGO_URI = process.env.MONGO_URI;
|
||||
|
||||
if (!MONGO_URI) {
|
||||
throw new Error('Please define the MONGO_URI environment variable');
|
||||
}
|
||||
/** The maximum number of connections in the connection pool. */
|
||||
const maxPoolSize = parseInt(process.env.MONGO_MAX_POOL_SIZE) || undefined;
|
||||
/** The minimum number of connections in the connection pool. */
|
||||
const minPoolSize = parseInt(process.env.MONGO_MIN_POOL_SIZE) || undefined;
|
||||
/** The maximum number of connections that may be in the process of being established concurrently by the connection pool. */
|
||||
const maxConnecting = parseInt(process.env.MONGO_MAX_CONNECTING) || undefined;
|
||||
/** The maximum number of milliseconds that a connection can remain idle in the pool before being removed and closed. */
|
||||
const maxIdleTimeMS = parseInt(process.env.MONGO_MAX_IDLE_TIME_MS) || undefined;
|
||||
/** The maximum time in milliseconds that a thread can wait for a connection to become available. */
|
||||
const waitQueueTimeoutMS = parseInt(process.env.MONGO_WAIT_QUEUE_TIMEOUT_MS) || undefined;
|
||||
/** Set to false to disable automatic index creation for all models associated with this connection. */
|
||||
const autoIndex =
|
||||
process.env.MONGO_AUTO_INDEX != undefined
|
||||
? isEnabled(process.env.MONGO_AUTO_INDEX) || false
|
||||
: undefined;
|
||||
|
||||
/** Set to `false` to disable Mongoose automatically calling `createCollection()` on every model created on this connection. */
|
||||
const autoCreate =
|
||||
process.env.MONGO_AUTO_CREATE != undefined
|
||||
? isEnabled(process.env.MONGO_AUTO_CREATE) || false
|
||||
: undefined;
|
||||
/**
|
||||
* Global is used here to maintain a cached connection across hot reloads
|
||||
* in development. This prevents connections growing exponentially
|
||||
@@ -26,13 +49,21 @@ async function connectDb() {
|
||||
if (!cached.promise || disconnected) {
|
||||
const opts = {
|
||||
bufferCommands: false,
|
||||
...(maxPoolSize ? { maxPoolSize } : {}),
|
||||
...(minPoolSize ? { minPoolSize } : {}),
|
||||
...(maxConnecting ? { maxConnecting } : {}),
|
||||
...(maxIdleTimeMS ? { maxIdleTimeMS } : {}),
|
||||
...(waitQueueTimeoutMS ? { waitQueueTimeoutMS } : {}),
|
||||
...(autoIndex != undefined ? { autoIndex } : {}),
|
||||
...(autoCreate != undefined ? { autoCreate } : {}),
|
||||
// useNewUrlParser: true,
|
||||
// useUnifiedTopology: true,
|
||||
// bufferMaxEntries: 0,
|
||||
// useFindAndModify: true,
|
||||
// useCreateIndex: true
|
||||
};
|
||||
|
||||
logger.info('Mongo Connection options');
|
||||
logger.info(JSON.stringify(opts, null, 2));
|
||||
mongoose.set('strictQuery', true);
|
||||
cached.promise = mongoose.connect(MONGO_URI, opts).then((mongoose) => {
|
||||
return mongoose;
|
||||
|
||||
@@ -3,6 +3,7 @@ module.exports = {
|
||||
clearMocks: true,
|
||||
roots: ['<rootDir>'],
|
||||
coverageDirectory: 'coverage',
|
||||
testTimeout: 30000, // 30 seconds timeout for all tests
|
||||
setupFiles: [
|
||||
'./test/jestSetup.js',
|
||||
'./test/__mocks__/logger.js',
|
||||
|
||||
@@ -1,18 +1,17 @@
|
||||
const mongoose = require('mongoose');
|
||||
const crypto = require('node:crypto');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { SystemRoles, Tools, actionDelimiter } = require('librechat-data-provider');
|
||||
const { GLOBAL_PROJECT_NAME, EPHEMERAL_AGENT_ID, mcp_delimiter } =
|
||||
const { ResourceType, SystemRoles, Tools, actionDelimiter } = require('librechat-data-provider');
|
||||
const { GLOBAL_PROJECT_NAME, EPHEMERAL_AGENT_ID, mcp_all, mcp_delimiter } =
|
||||
require('librechat-data-provider').Constants;
|
||||
const { CONFIG_STORE, STARTUP_CONFIG } = require('librechat-data-provider').CacheKeys;
|
||||
const {
|
||||
getProjectByName,
|
||||
addAgentIdsToProject,
|
||||
removeAgentIdsFromProject,
|
||||
removeAgentFromAllProjects,
|
||||
removeAgentIdsFromProject,
|
||||
addAgentIdsToProject,
|
||||
getProjectByName,
|
||||
} = require('./Project');
|
||||
const { removeAllPermissions } = require('~/server/services/PermissionService');
|
||||
const { getCachedTools } = require('~/server/services/Config');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { getActions } = require('./Action');
|
||||
const { Agent } = require('~/db/models');
|
||||
|
||||
@@ -23,7 +22,7 @@ const { Agent } = require('~/db/models');
|
||||
* @throws {Error} If the agent creation fails.
|
||||
*/
|
||||
const createAgent = async (agentData) => {
|
||||
const { author, ...versionData } = agentData;
|
||||
const { author: _author, ...versionData } = agentData;
|
||||
const timestamp = new Date();
|
||||
const initialAgentData = {
|
||||
...agentData,
|
||||
@@ -34,7 +33,9 @@ const createAgent = async (agentData) => {
|
||||
updatedAt: timestamp,
|
||||
},
|
||||
],
|
||||
category: agentData.category || 'general',
|
||||
};
|
||||
|
||||
return (await Agent.create(initialAgentData)).toObject();
|
||||
};
|
||||
|
||||
@@ -77,6 +78,7 @@ const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _
|
||||
tools.push(Tools.web_search);
|
||||
}
|
||||
|
||||
const addedServers = new Set();
|
||||
if (mcpServers.size > 0) {
|
||||
for (const toolName of Object.keys(availableTools)) {
|
||||
if (!toolName.includes(mcp_delimiter)) {
|
||||
@@ -84,9 +86,17 @@ const loadEphemeralAgent = async ({ req, agent_id, endpoint, model_parameters: _
|
||||
}
|
||||
const mcpServer = toolName.split(mcp_delimiter)?.[1];
|
||||
if (mcpServer && mcpServers.has(mcpServer)) {
|
||||
addedServers.add(mcpServer);
|
||||
tools.push(toolName);
|
||||
}
|
||||
}
|
||||
|
||||
for (const mcpServer of mcpServers) {
|
||||
if (addedServers.has(mcpServer)) {
|
||||
continue;
|
||||
}
|
||||
tools.push(`${mcp_all}${mcp_delimiter}${mcpServer}`);
|
||||
}
|
||||
}
|
||||
|
||||
const instructions = req.body.promptPrefix;
|
||||
@@ -131,29 +141,7 @@ const loadAgent = async ({ req, agent_id, endpoint, model_parameters }) => {
|
||||
}
|
||||
|
||||
agent.version = agent.versions ? agent.versions.length : 0;
|
||||
|
||||
if (agent.author.toString() === req.user.id) {
|
||||
return agent;
|
||||
}
|
||||
|
||||
if (!agent.projectIds) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const cache = getLogStores(CONFIG_STORE);
|
||||
/** @type {TStartupConfig} */
|
||||
const cachedStartupConfig = await cache.get(STARTUP_CONFIG);
|
||||
let { instanceProjectId } = cachedStartupConfig ?? {};
|
||||
if (!instanceProjectId) {
|
||||
instanceProjectId = (await getProjectByName(GLOBAL_PROJECT_NAME, '_id'))._id.toString();
|
||||
}
|
||||
|
||||
for (const projectObjectId of agent.projectIds) {
|
||||
const projectId = projectObjectId.toString();
|
||||
if (projectId === instanceProjectId) {
|
||||
return agent;
|
||||
}
|
||||
}
|
||||
return agent;
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -183,7 +171,7 @@ const isDuplicateVersion = (updateData, currentData, versions, actionsHash = nul
|
||||
'actionsHash', // Exclude actionsHash from direct comparison
|
||||
];
|
||||
|
||||
const { $push, $pull, $addToSet, ...directUpdates } = updateData;
|
||||
const { $push: _$push, $pull: _$pull, $addToSet: _$addToSet, ...directUpdates } = updateData;
|
||||
|
||||
if (Object.keys(directUpdates).length === 0 && !actionsHash) {
|
||||
return null;
|
||||
@@ -202,54 +190,116 @@ const isDuplicateVersion = (updateData, currentData, versions, actionsHash = nul
|
||||
|
||||
let isMatch = true;
|
||||
for (const field of importantFields) {
|
||||
if (!wouldBeVersion[field] && !lastVersion[field]) {
|
||||
const wouldBeValue = wouldBeVersion[field];
|
||||
const lastVersionValue = lastVersion[field];
|
||||
|
||||
// Skip if both are undefined/null
|
||||
if (!wouldBeValue && !lastVersionValue) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (Array.isArray(wouldBeVersion[field]) && Array.isArray(lastVersion[field])) {
|
||||
if (wouldBeVersion[field].length !== lastVersion[field].length) {
|
||||
// Handle arrays
|
||||
if (Array.isArray(wouldBeValue) || Array.isArray(lastVersionValue)) {
|
||||
// Normalize: treat undefined/null as empty array for comparison
|
||||
let wouldBeArr;
|
||||
if (Array.isArray(wouldBeValue)) {
|
||||
wouldBeArr = wouldBeValue;
|
||||
} else if (wouldBeValue == null) {
|
||||
wouldBeArr = [];
|
||||
} else {
|
||||
wouldBeArr = [wouldBeValue];
|
||||
}
|
||||
|
||||
let lastVersionArr;
|
||||
if (Array.isArray(lastVersionValue)) {
|
||||
lastVersionArr = lastVersionValue;
|
||||
} else if (lastVersionValue == null) {
|
||||
lastVersionArr = [];
|
||||
} else {
|
||||
lastVersionArr = [lastVersionValue];
|
||||
}
|
||||
|
||||
if (wouldBeArr.length !== lastVersionArr.length) {
|
||||
isMatch = false;
|
||||
break;
|
||||
}
|
||||
|
||||
// Special handling for projectIds (MongoDB ObjectIds)
|
||||
if (field === 'projectIds') {
|
||||
const wouldBeIds = wouldBeVersion[field].map((id) => id.toString()).sort();
|
||||
const versionIds = lastVersion[field].map((id) => id.toString()).sort();
|
||||
const wouldBeIds = wouldBeArr.map((id) => id.toString()).sort();
|
||||
const versionIds = lastVersionArr.map((id) => id.toString()).sort();
|
||||
|
||||
if (!wouldBeIds.every((id, i) => id === versionIds[i])) {
|
||||
isMatch = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Handle arrays of objects like tool_kwargs
|
||||
else if (typeof wouldBeVersion[field][0] === 'object' && wouldBeVersion[field][0] !== null) {
|
||||
const sortedWouldBe = [...wouldBeVersion[field]].map((item) => JSON.stringify(item)).sort();
|
||||
const sortedVersion = [...lastVersion[field]].map((item) => JSON.stringify(item)).sort();
|
||||
// Handle arrays of objects
|
||||
else if (
|
||||
wouldBeArr.length > 0 &&
|
||||
typeof wouldBeArr[0] === 'object' &&
|
||||
wouldBeArr[0] !== null
|
||||
) {
|
||||
const sortedWouldBe = [...wouldBeArr].map((item) => JSON.stringify(item)).sort();
|
||||
const sortedVersion = [...lastVersionArr].map((item) => JSON.stringify(item)).sort();
|
||||
|
||||
if (!sortedWouldBe.every((item, i) => item === sortedVersion[i])) {
|
||||
isMatch = false;
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
const sortedWouldBe = [...wouldBeVersion[field]].sort();
|
||||
const sortedVersion = [...lastVersion[field]].sort();
|
||||
const sortedWouldBe = [...wouldBeArr].sort();
|
||||
const sortedVersion = [...lastVersionArr].sort();
|
||||
|
||||
if (!sortedWouldBe.every((item, i) => item === sortedVersion[i])) {
|
||||
isMatch = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else if (field === 'model_parameters') {
|
||||
const wouldBeParams = wouldBeVersion[field] || {};
|
||||
const lastVersionParams = lastVersion[field] || {};
|
||||
if (JSON.stringify(wouldBeParams) !== JSON.stringify(lastVersionParams)) {
|
||||
}
|
||||
// Handle objects
|
||||
else if (typeof wouldBeValue === 'object' && wouldBeValue !== null) {
|
||||
const lastVersionObj =
|
||||
typeof lastVersionValue === 'object' && lastVersionValue !== null ? lastVersionValue : {};
|
||||
|
||||
// For empty objects, normalize the comparison
|
||||
const wouldBeKeys = Object.keys(wouldBeValue);
|
||||
const lastVersionKeys = Object.keys(lastVersionObj);
|
||||
|
||||
// If both are empty objects, they're equal
|
||||
if (wouldBeKeys.length === 0 && lastVersionKeys.length === 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Otherwise do a deep comparison
|
||||
if (JSON.stringify(wouldBeValue) !== JSON.stringify(lastVersionObj)) {
|
||||
isMatch = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Handle primitive values
|
||||
else {
|
||||
// For primitives, handle the case where one is undefined and the other is a default value
|
||||
if (wouldBeValue !== lastVersionValue) {
|
||||
// Special handling for boolean false vs undefined
|
||||
if (
|
||||
typeof wouldBeValue === 'boolean' &&
|
||||
wouldBeValue === false &&
|
||||
lastVersionValue === undefined
|
||||
) {
|
||||
continue;
|
||||
}
|
||||
// Special handling for empty string vs undefined
|
||||
if (
|
||||
typeof wouldBeValue === 'string' &&
|
||||
wouldBeValue === '' &&
|
||||
lastVersionValue === undefined
|
||||
) {
|
||||
continue;
|
||||
}
|
||||
isMatch = false;
|
||||
break;
|
||||
}
|
||||
} else if (wouldBeVersion[field] !== lastVersion[field]) {
|
||||
isMatch = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -278,7 +328,14 @@ const updateAgent = async (searchParameter, updateData, options = {}) => {
|
||||
|
||||
const currentAgent = await Agent.findOne(searchParameter);
|
||||
if (currentAgent) {
|
||||
const { __v, _id, id, versions, author, ...versionData } = currentAgent.toObject();
|
||||
const {
|
||||
__v,
|
||||
_id,
|
||||
id: __id,
|
||||
versions,
|
||||
author: _author,
|
||||
...versionData
|
||||
} = currentAgent.toObject();
|
||||
const { $push, $pull, $addToSet, ...directUpdates } = updateData;
|
||||
|
||||
let actionsHash = null;
|
||||
@@ -458,12 +515,117 @@ const deleteAgent = async (searchParameter) => {
|
||||
const agent = await Agent.findOneAndDelete(searchParameter);
|
||||
if (agent) {
|
||||
await removeAgentFromAllProjects(agent.id);
|
||||
await removeAllPermissions({
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
});
|
||||
}
|
||||
return agent;
|
||||
};
|
||||
|
||||
/**
|
||||
* Get agents by accessible IDs with optional cursor-based pagination.
|
||||
* @param {Object} params - The parameters for getting accessible agents.
|
||||
* @param {Array} [params.accessibleIds] - Array of agent ObjectIds the user has ACL access to.
|
||||
* @param {Object} [params.otherParams] - Additional query parameters (including author filter).
|
||||
* @param {number} [params.limit] - Number of agents to return (max 100). If not provided, returns all agents.
|
||||
* @param {string} [params.after] - Cursor for pagination - get agents after this cursor. // base64 encoded JSON string with updatedAt and _id.
|
||||
* @returns {Promise<Object>} A promise that resolves to an object containing the agents data and pagination info.
|
||||
*/
|
||||
const getListAgentsByAccess = async ({
|
||||
accessibleIds = [],
|
||||
otherParams = {},
|
||||
limit = null,
|
||||
after = null,
|
||||
}) => {
|
||||
const isPaginated = limit !== null && limit !== undefined;
|
||||
const normalizedLimit = isPaginated ? Math.min(Math.max(1, parseInt(limit) || 20), 100) : null;
|
||||
|
||||
// Build base query combining ACL accessible agents with other filters
|
||||
const baseQuery = { ...otherParams, _id: { $in: accessibleIds } };
|
||||
|
||||
// Add cursor condition
|
||||
if (after) {
|
||||
try {
|
||||
const cursor = JSON.parse(Buffer.from(after, 'base64').toString('utf8'));
|
||||
const { updatedAt, _id } = cursor;
|
||||
|
||||
const cursorCondition = {
|
||||
$or: [
|
||||
{ updatedAt: { $lt: new Date(updatedAt) } },
|
||||
{ updatedAt: new Date(updatedAt), _id: { $gt: new mongoose.Types.ObjectId(_id) } },
|
||||
],
|
||||
};
|
||||
|
||||
// Merge cursor condition with base query
|
||||
if (Object.keys(baseQuery).length > 0) {
|
||||
baseQuery.$and = [{ ...baseQuery }, cursorCondition];
|
||||
// Remove the original conditions from baseQuery to avoid duplication
|
||||
Object.keys(baseQuery).forEach((key) => {
|
||||
if (key !== '$and') delete baseQuery[key];
|
||||
});
|
||||
} else {
|
||||
Object.assign(baseQuery, cursorCondition);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.warn('Invalid cursor:', error.message);
|
||||
}
|
||||
}
|
||||
|
||||
let query = Agent.find(baseQuery, {
|
||||
id: 1,
|
||||
_id: 1,
|
||||
name: 1,
|
||||
avatar: 1,
|
||||
author: 1,
|
||||
projectIds: 1,
|
||||
description: 1,
|
||||
updatedAt: 1,
|
||||
category: 1,
|
||||
support_contact: 1,
|
||||
is_promoted: 1,
|
||||
}).sort({ updatedAt: -1, _id: 1 });
|
||||
|
||||
// Only apply limit if pagination is requested
|
||||
if (isPaginated) {
|
||||
query = query.limit(normalizedLimit + 1);
|
||||
}
|
||||
|
||||
const agents = await query.lean();
|
||||
|
||||
const hasMore = isPaginated ? agents.length > normalizedLimit : false;
|
||||
const data = (isPaginated ? agents.slice(0, normalizedLimit) : agents).map((agent) => {
|
||||
if (agent.author) {
|
||||
agent.author = agent.author.toString();
|
||||
}
|
||||
return agent;
|
||||
});
|
||||
|
||||
// Generate next cursor only if paginated
|
||||
let nextCursor = null;
|
||||
if (isPaginated && hasMore && data.length > 0) {
|
||||
const lastAgent = agents[normalizedLimit - 1];
|
||||
nextCursor = Buffer.from(
|
||||
JSON.stringify({
|
||||
updatedAt: lastAgent.updatedAt.toISOString(),
|
||||
_id: lastAgent._id.toString(),
|
||||
}),
|
||||
).toString('base64');
|
||||
}
|
||||
|
||||
return {
|
||||
object: 'list',
|
||||
data,
|
||||
first_id: data.length > 0 ? data[0].id : null,
|
||||
last_id: data.length > 0 ? data[data.length - 1].id : null,
|
||||
has_more: hasMore,
|
||||
after: nextCursor,
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* Get all agents.
|
||||
* @deprecated Use getListAgentsByAccess for ACL-aware agent listing
|
||||
* @param {Object} searchParameter - The search parameters to find matching agents.
|
||||
* @param {string} searchParameter.author - The user ID of the agent's author.
|
||||
* @returns {Promise<Object>} A promise that resolves to an object containing the agents data and pagination info.
|
||||
@@ -482,13 +644,15 @@ const getListAgents = async (searchParameter) => {
|
||||
const agents = (
|
||||
await Agent.find(query, {
|
||||
id: 1,
|
||||
_id: 0,
|
||||
_id: 1,
|
||||
name: 1,
|
||||
avatar: 1,
|
||||
author: 1,
|
||||
projectIds: 1,
|
||||
description: 1,
|
||||
// @deprecated - isCollaborative replaced by ACL permissions
|
||||
isCollaborative: 1,
|
||||
category: 1,
|
||||
}).lean()
|
||||
).map((agent) => {
|
||||
if (agent.author?.toString() !== author) {
|
||||
@@ -654,6 +818,14 @@ const generateActionMetadataHash = async (actionIds, actions) => {
|
||||
|
||||
return hashHex;
|
||||
};
|
||||
/**
|
||||
* Counts the number of promoted agents.
|
||||
* @returns {Promise<number>} - The count of promoted agents
|
||||
*/
|
||||
const countPromotedAgents = async () => {
|
||||
const count = await Agent.countDocuments({ is_promoted: true });
|
||||
return count;
|
||||
};
|
||||
|
||||
/**
|
||||
* Load a default agent based on the endpoint
|
||||
@@ -671,6 +843,8 @@ module.exports = {
|
||||
revertAgentVersion,
|
||||
updateAgentProjects,
|
||||
addAgentResourceFile,
|
||||
getListAgentsByAccess,
|
||||
removeAgentResourceFiles,
|
||||
generateActionMetadataHash,
|
||||
countPromotedAgents,
|
||||
};
|
||||
|
||||
@@ -14,6 +14,7 @@ const mongoose = require('mongoose');
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const { agentSchema } = require('@librechat/data-schemas');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { AccessRoleIds, ResourceType, PrincipalType } = require('librechat-data-provider');
|
||||
const {
|
||||
getAgent,
|
||||
loadAgent,
|
||||
@@ -21,13 +22,16 @@ const {
|
||||
updateAgent,
|
||||
deleteAgent,
|
||||
getListAgents,
|
||||
getListAgentsByAccess,
|
||||
revertAgentVersion,
|
||||
updateAgentProjects,
|
||||
addAgentResourceFile,
|
||||
removeAgentResourceFiles,
|
||||
generateActionMetadataHash,
|
||||
revertAgentVersion,
|
||||
} = require('./Agent');
|
||||
const permissionService = require('~/server/services/PermissionService');
|
||||
const { getCachedTools } = require('~/server/services/Config');
|
||||
const { AclEntry } = require('~/db/models');
|
||||
|
||||
/**
|
||||
* @type {import('mongoose').Model<import('@librechat/data-schemas').IAgent>}
|
||||
@@ -407,12 +411,26 @@ describe('models/Agent', () => {
|
||||
|
||||
describe('Agent CRUD Operations', () => {
|
||||
let mongoServer;
|
||||
let AccessRole;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
|
||||
await mongoose.connect(mongoUri);
|
||||
|
||||
// Initialize models
|
||||
const dbModels = require('~/db/models');
|
||||
AccessRole = dbModels.AccessRole;
|
||||
|
||||
// Create necessary access roles for agents
|
||||
await AccessRole.create({
|
||||
accessRoleId: AccessRoleIds.AGENT_OWNER,
|
||||
name: 'Owner',
|
||||
description: 'Full control over agents',
|
||||
resourceType: ResourceType.AGENT,
|
||||
permBits: 15, // VIEW | EDIT | DELETE | SHARE
|
||||
});
|
||||
}, 20000);
|
||||
|
||||
afterAll(async () => {
|
||||
@@ -468,6 +486,51 @@ describe('models/Agent', () => {
|
||||
expect(agentAfterDelete).toBeNull();
|
||||
});
|
||||
|
||||
test('should remove ACL entries when deleting an agent', async () => {
|
||||
const agentId = `agent_${uuidv4()}`;
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
|
||||
// Create agent
|
||||
const agent = await createAgent({
|
||||
id: agentId,
|
||||
name: 'Agent With Permissions',
|
||||
provider: 'test',
|
||||
model: 'test-model',
|
||||
author: authorId,
|
||||
});
|
||||
|
||||
// Grant permissions (simulating sharing)
|
||||
await permissionService.grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: authorId,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
accessRoleId: AccessRoleIds.AGENT_OWNER,
|
||||
grantedBy: authorId,
|
||||
});
|
||||
|
||||
// Verify ACL entry exists
|
||||
const aclEntriesBefore = await AclEntry.find({
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
});
|
||||
expect(aclEntriesBefore).toHaveLength(1);
|
||||
|
||||
// Delete the agent
|
||||
await deleteAgent({ id: agentId });
|
||||
|
||||
// Verify agent is deleted
|
||||
const agentAfterDelete = await getAgent({ id: agentId });
|
||||
expect(agentAfterDelete).toBeNull();
|
||||
|
||||
// Verify ACL entries are removed
|
||||
const aclEntriesAfter = await AclEntry.find({
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
});
|
||||
expect(aclEntriesAfter).toHaveLength(0);
|
||||
});
|
||||
|
||||
test('should list agents by author', async () => {
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const otherAuthorId = new mongoose.Types.ObjectId();
|
||||
@@ -1237,6 +1300,335 @@ describe('models/Agent', () => {
|
||||
expect(secondUpdate.versions).toHaveLength(3);
|
||||
});
|
||||
|
||||
test('should detect changes in support_contact fields', async () => {
|
||||
const agentId = `agent_${uuidv4()}`;
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
|
||||
// Create agent with initial support_contact
|
||||
await createAgent({
|
||||
id: agentId,
|
||||
name: 'Agent with Support Contact',
|
||||
provider: 'test',
|
||||
model: 'test-model',
|
||||
author: authorId,
|
||||
support_contact: {
|
||||
name: 'Initial Support',
|
||||
email: 'initial@support.com',
|
||||
},
|
||||
});
|
||||
|
||||
// Update support_contact name only
|
||||
const firstUpdate = await updateAgent(
|
||||
{ id: agentId },
|
||||
{
|
||||
support_contact: {
|
||||
name: 'Updated Support',
|
||||
email: 'initial@support.com',
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
expect(firstUpdate.versions).toHaveLength(2);
|
||||
expect(firstUpdate.support_contact.name).toBe('Updated Support');
|
||||
expect(firstUpdate.support_contact.email).toBe('initial@support.com');
|
||||
|
||||
// Update support_contact email only
|
||||
const secondUpdate = await updateAgent(
|
||||
{ id: agentId },
|
||||
{
|
||||
support_contact: {
|
||||
name: 'Updated Support',
|
||||
email: 'updated@support.com',
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
expect(secondUpdate.versions).toHaveLength(3);
|
||||
expect(secondUpdate.support_contact.email).toBe('updated@support.com');
|
||||
|
||||
// Try to update with same support_contact - should be detected as duplicate but return successfully
|
||||
const duplicateUpdate = await updateAgent(
|
||||
{ id: agentId },
|
||||
{
|
||||
support_contact: {
|
||||
name: 'Updated Support',
|
||||
email: 'updated@support.com',
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
// Should not create a new version
|
||||
expect(duplicateUpdate.versions).toHaveLength(3);
|
||||
expect(duplicateUpdate.version).toBe(3);
|
||||
expect(duplicateUpdate.support_contact.email).toBe('updated@support.com');
|
||||
});
|
||||
|
||||
test('should handle support_contact from empty to populated', async () => {
|
||||
const agentId = `agent_${uuidv4()}`;
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
|
||||
// Create agent without support_contact
|
||||
const agent = await createAgent({
|
||||
id: agentId,
|
||||
name: 'Agent without Support',
|
||||
provider: 'test',
|
||||
model: 'test-model',
|
||||
author: authorId,
|
||||
});
|
||||
|
||||
// Verify support_contact is undefined since it wasn't provided
|
||||
expect(agent.support_contact).toBeUndefined();
|
||||
|
||||
// Update to add support_contact
|
||||
const updated = await updateAgent(
|
||||
{ id: agentId },
|
||||
{
|
||||
support_contact: {
|
||||
name: 'New Support Team',
|
||||
email: 'support@example.com',
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
expect(updated.versions).toHaveLength(2);
|
||||
expect(updated.support_contact.name).toBe('New Support Team');
|
||||
expect(updated.support_contact.email).toBe('support@example.com');
|
||||
});
|
||||
|
||||
test('should handle support_contact edge cases in isDuplicateVersion', async () => {
|
||||
const agentId = `agent_${uuidv4()}`;
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
|
||||
// Create agent with support_contact
|
||||
await createAgent({
|
||||
id: agentId,
|
||||
name: 'Edge Case Agent',
|
||||
provider: 'test',
|
||||
model: 'test-model',
|
||||
author: authorId,
|
||||
support_contact: {
|
||||
name: 'Support',
|
||||
email: 'support@test.com',
|
||||
},
|
||||
});
|
||||
|
||||
// Update to empty support_contact
|
||||
const emptyUpdate = await updateAgent(
|
||||
{ id: agentId },
|
||||
{
|
||||
support_contact: {},
|
||||
},
|
||||
);
|
||||
|
||||
expect(emptyUpdate.versions).toHaveLength(2);
|
||||
expect(emptyUpdate.support_contact).toEqual({});
|
||||
|
||||
// Update back to populated support_contact
|
||||
const repopulated = await updateAgent(
|
||||
{ id: agentId },
|
||||
{
|
||||
support_contact: {
|
||||
name: 'Support',
|
||||
email: 'support@test.com',
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
expect(repopulated.versions).toHaveLength(3);
|
||||
|
||||
// Verify all versions have correct support_contact
|
||||
const finalAgent = await getAgent({ id: agentId });
|
||||
expect(finalAgent.versions[0].support_contact).toEqual({
|
||||
name: 'Support',
|
||||
email: 'support@test.com',
|
||||
});
|
||||
expect(finalAgent.versions[1].support_contact).toEqual({});
|
||||
expect(finalAgent.versions[2].support_contact).toEqual({
|
||||
name: 'Support',
|
||||
email: 'support@test.com',
|
||||
});
|
||||
});
|
||||
|
||||
test('should preserve support_contact in version history', async () => {
|
||||
const agentId = `agent_${uuidv4()}`;
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
|
||||
// Create agent
|
||||
await createAgent({
|
||||
id: agentId,
|
||||
name: 'Version History Test',
|
||||
provider: 'test',
|
||||
model: 'test-model',
|
||||
author: authorId,
|
||||
support_contact: {
|
||||
name: 'Initial Contact',
|
||||
email: 'initial@test.com',
|
||||
},
|
||||
});
|
||||
|
||||
// Multiple updates with different support_contact values
|
||||
await updateAgent(
|
||||
{ id: agentId },
|
||||
{
|
||||
support_contact: {
|
||||
name: 'Second Contact',
|
||||
email: 'second@test.com',
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
await updateAgent(
|
||||
{ id: agentId },
|
||||
{
|
||||
support_contact: {
|
||||
name: 'Third Contact',
|
||||
email: 'third@test.com',
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
const finalAgent = await getAgent({ id: agentId });
|
||||
|
||||
// Verify version history
|
||||
expect(finalAgent.versions).toHaveLength(3);
|
||||
expect(finalAgent.versions[0].support_contact).toEqual({
|
||||
name: 'Initial Contact',
|
||||
email: 'initial@test.com',
|
||||
});
|
||||
expect(finalAgent.versions[1].support_contact).toEqual({
|
||||
name: 'Second Contact',
|
||||
email: 'second@test.com',
|
||||
});
|
||||
expect(finalAgent.versions[2].support_contact).toEqual({
|
||||
name: 'Third Contact',
|
||||
email: 'third@test.com',
|
||||
});
|
||||
|
||||
// Current state should match last version
|
||||
expect(finalAgent.support_contact).toEqual({
|
||||
name: 'Third Contact',
|
||||
email: 'third@test.com',
|
||||
});
|
||||
});
|
||||
|
||||
test('should handle partial support_contact updates', async () => {
|
||||
const agentId = `agent_${uuidv4()}`;
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
|
||||
// Create agent with full support_contact
|
||||
await createAgent({
|
||||
id: agentId,
|
||||
name: 'Partial Update Test',
|
||||
provider: 'test',
|
||||
model: 'test-model',
|
||||
author: authorId,
|
||||
support_contact: {
|
||||
name: 'Original Name',
|
||||
email: 'original@email.com',
|
||||
},
|
||||
});
|
||||
|
||||
// MongoDB's findOneAndUpdate will replace the entire support_contact object
|
||||
// So we need to verify that partial updates still work correctly
|
||||
const updated = await updateAgent(
|
||||
{ id: agentId },
|
||||
{
|
||||
support_contact: {
|
||||
name: 'New Name',
|
||||
email: '', // Empty email
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
expect(updated.versions).toHaveLength(2);
|
||||
expect(updated.support_contact.name).toBe('New Name');
|
||||
expect(updated.support_contact.email).toBe('');
|
||||
|
||||
// Verify isDuplicateVersion works with partial changes - should return successfully without creating new version
|
||||
const duplicateUpdate = await updateAgent(
|
||||
{ id: agentId },
|
||||
{
|
||||
support_contact: {
|
||||
name: 'New Name',
|
||||
email: '',
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
// Should not create a new version since content is the same
|
||||
expect(duplicateUpdate.versions).toHaveLength(2);
|
||||
expect(duplicateUpdate.version).toBe(2);
|
||||
expect(duplicateUpdate.support_contact.name).toBe('New Name');
|
||||
expect(duplicateUpdate.support_contact.email).toBe('');
|
||||
});
|
||||
|
||||
// Edge Cases
|
||||
describe.each([
|
||||
{
|
||||
operation: 'add',
|
||||
name: 'empty file_id',
|
||||
needsAgent: true,
|
||||
params: { tool_resource: 'file_search', file_id: '' },
|
||||
shouldResolve: true,
|
||||
},
|
||||
{
|
||||
operation: 'add',
|
||||
name: 'non-existent agent',
|
||||
needsAgent: false,
|
||||
params: { tool_resource: 'file_search', file_id: 'file123' },
|
||||
shouldResolve: false,
|
||||
error: 'Agent not found for adding resource file',
|
||||
},
|
||||
])('addAgentResourceFile with $name', ({ needsAgent, params, shouldResolve, error }) => {
|
||||
test(`should ${shouldResolve ? 'resolve' : 'reject'}`, async () => {
|
||||
const agent = needsAgent ? await createBasicAgent() : null;
|
||||
const agent_id = needsAgent ? agent.id : `agent_${uuidv4()}`;
|
||||
|
||||
if (shouldResolve) {
|
||||
await expect(addAgentResourceFile({ agent_id, ...params })).resolves.toBeDefined();
|
||||
} else {
|
||||
await expect(addAgentResourceFile({ agent_id, ...params })).rejects.toThrow(error);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe.each([
|
||||
{
|
||||
name: 'empty files array',
|
||||
files: [],
|
||||
needsAgent: true,
|
||||
shouldResolve: true,
|
||||
},
|
||||
{
|
||||
name: 'non-existent tool_resource',
|
||||
files: [{ tool_resource: 'non_existent_tool', file_id: 'file123' }],
|
||||
needsAgent: true,
|
||||
shouldResolve: true,
|
||||
},
|
||||
{
|
||||
name: 'non-existent agent',
|
||||
files: [{ tool_resource: 'file_search', file_id: 'file123' }],
|
||||
needsAgent: false,
|
||||
shouldResolve: false,
|
||||
error: 'Agent not found for removing resource files',
|
||||
},
|
||||
])('removeAgentResourceFiles with $name', ({ files, needsAgent, shouldResolve, error }) => {
|
||||
test(`should ${shouldResolve ? 'resolve' : 'reject'}`, async () => {
|
||||
const agent = needsAgent ? await createBasicAgent() : null;
|
||||
const agent_id = needsAgent ? agent.id : `agent_${uuidv4()}`;
|
||||
|
||||
if (shouldResolve) {
|
||||
const result = await removeAgentResourceFiles({ agent_id, files });
|
||||
expect(result).toBeDefined();
|
||||
if (agent) {
|
||||
expect(result.id).toBe(agent.id);
|
||||
}
|
||||
} else {
|
||||
await expect(removeAgentResourceFiles({ agent_id, files })).rejects.toThrow(error);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
test('should handle extremely large version history', async () => {
|
||||
const agentId = `agent_${uuidv4()}`;
|
||||
@@ -1612,7 +2004,7 @@ describe('models/Agent', () => {
|
||||
expect(result.version).toBe(1);
|
||||
});
|
||||
|
||||
test('should return null when user is not author and agent has no projectIds', async () => {
|
||||
test('should return agent even when user is not author (permissions checked at route level)', async () => {
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const agentId = `agent_${uuidv4()}`;
|
||||
@@ -1633,7 +2025,11 @@ describe('models/Agent', () => {
|
||||
model_parameters: { model: 'gpt-4' },
|
||||
});
|
||||
|
||||
expect(result).toBeFalsy();
|
||||
// With the new permission system, loadAgent returns the agent regardless of permissions
|
||||
// Permission checks are handled at the route level via middleware
|
||||
expect(result).toBeTruthy();
|
||||
expect(result.id).toBe(agentId);
|
||||
expect(result.name).toBe('Test Agent');
|
||||
});
|
||||
|
||||
test('should handle ephemeral agent with no MCP servers', async () => {
|
||||
@@ -1741,7 +2137,7 @@ describe('models/Agent', () => {
|
||||
}
|
||||
});
|
||||
|
||||
test('should handle loadAgent with agent from different project', async () => {
|
||||
test('should return agent from different project (permissions checked at route level)', async () => {
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const agentId = `agent_${uuidv4()}`;
|
||||
@@ -1764,7 +2160,11 @@ describe('models/Agent', () => {
|
||||
model_parameters: { model: 'gpt-4' },
|
||||
});
|
||||
|
||||
expect(result).toBeFalsy();
|
||||
// With the new permission system, loadAgent returns the agent regardless of permissions
|
||||
// Permission checks are handled at the route level via middleware
|
||||
expect(result).toBeTruthy();
|
||||
expect(result.id).toBe(agentId);
|
||||
expect(result.name).toBe('Project Agent');
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -2557,6 +2957,299 @@ describe('models/Agent', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('Support Contact Field', () => {
|
||||
let mongoServer;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
|
||||
await mongoose.connect(mongoUri);
|
||||
}, 20000);
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
await Agent.deleteMany({});
|
||||
});
|
||||
|
||||
it('should not create subdocument with ObjectId for support_contact', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const agentData = {
|
||||
id: 'agent_test_support',
|
||||
name: 'Test Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: userId,
|
||||
support_contact: {
|
||||
name: 'Support Team',
|
||||
email: 'support@example.com',
|
||||
},
|
||||
};
|
||||
|
||||
// Create agent
|
||||
const agent = await createAgent(agentData);
|
||||
|
||||
// Verify support_contact is stored correctly
|
||||
expect(agent.support_contact).toBeDefined();
|
||||
expect(agent.support_contact.name).toBe('Support Team');
|
||||
expect(agent.support_contact.email).toBe('support@example.com');
|
||||
|
||||
// Verify no _id field is created in support_contact
|
||||
expect(agent.support_contact._id).toBeUndefined();
|
||||
|
||||
// Fetch from database to double-check
|
||||
const dbAgent = await Agent.findOne({ id: agentData.id });
|
||||
expect(dbAgent.support_contact).toBeDefined();
|
||||
expect(dbAgent.support_contact.name).toBe('Support Team');
|
||||
expect(dbAgent.support_contact.email).toBe('support@example.com');
|
||||
expect(dbAgent.support_contact._id).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should handle empty support_contact correctly', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const agentData = {
|
||||
id: 'agent_test_empty_support',
|
||||
name: 'Test Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: userId,
|
||||
support_contact: {},
|
||||
};
|
||||
|
||||
const agent = await createAgent(agentData);
|
||||
|
||||
// Verify empty support_contact is stored as empty object
|
||||
expect(agent.support_contact).toEqual({});
|
||||
expect(agent.support_contact._id).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should handle missing support_contact correctly', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const agentData = {
|
||||
id: 'agent_test_no_support',
|
||||
name: 'Test Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: userId,
|
||||
};
|
||||
|
||||
const agent = await createAgent(agentData);
|
||||
|
||||
// Verify support_contact is undefined when not provided
|
||||
expect(agent.support_contact).toBeUndefined();
|
||||
});
|
||||
|
||||
describe('getListAgentsByAccess - Security Tests', () => {
|
||||
let userA, userB;
|
||||
let agentA1, agentA2, agentA3;
|
||||
|
||||
beforeEach(async () => {
|
||||
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
|
||||
await Agent.deleteMany({});
|
||||
await AclEntry.deleteMany({});
|
||||
|
||||
// Create two users
|
||||
userA = new mongoose.Types.ObjectId();
|
||||
userB = new mongoose.Types.ObjectId();
|
||||
|
||||
// Create agents for user A
|
||||
agentA1 = await createAgent({
|
||||
id: `agent_${uuidv4().slice(0, 12)}`,
|
||||
name: 'Agent A1',
|
||||
description: 'User A agent 1',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: userA,
|
||||
});
|
||||
|
||||
agentA2 = await createAgent({
|
||||
id: `agent_${uuidv4().slice(0, 12)}`,
|
||||
name: 'Agent A2',
|
||||
description: 'User A agent 2',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: userA,
|
||||
});
|
||||
|
||||
agentA3 = await createAgent({
|
||||
id: `agent_${uuidv4().slice(0, 12)}`,
|
||||
name: 'Agent A3',
|
||||
description: 'User A agent 3',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: userA,
|
||||
});
|
||||
});
|
||||
|
||||
test('should return empty list when user has no accessible agents (empty accessibleIds)', async () => {
|
||||
// User B has no agents and no shared agents
|
||||
const result = await getListAgentsByAccess({
|
||||
accessibleIds: [],
|
||||
otherParams: {},
|
||||
});
|
||||
|
||||
expect(result.data).toHaveLength(0);
|
||||
expect(result.has_more).toBe(false);
|
||||
expect(result.first_id).toBeNull();
|
||||
expect(result.last_id).toBeNull();
|
||||
});
|
||||
|
||||
test('should not return other users agents when accessibleIds is empty', async () => {
|
||||
// User B trying to list agents with empty accessibleIds should not see User A's agents
|
||||
const result = await getListAgentsByAccess({
|
||||
accessibleIds: [],
|
||||
otherParams: { author: userB },
|
||||
});
|
||||
|
||||
expect(result.data).toHaveLength(0);
|
||||
expect(result.has_more).toBe(false);
|
||||
});
|
||||
|
||||
test('should only return agents in accessibleIds list', async () => {
|
||||
// Give User B access to only one of User A's agents
|
||||
const accessibleIds = [agentA1._id];
|
||||
|
||||
const result = await getListAgentsByAccess({
|
||||
accessibleIds,
|
||||
otherParams: {},
|
||||
});
|
||||
|
||||
expect(result.data).toHaveLength(1);
|
||||
expect(result.data[0].id).toBe(agentA1.id);
|
||||
expect(result.data[0].name).toBe('Agent A1');
|
||||
});
|
||||
|
||||
test('should return multiple accessible agents when provided', async () => {
|
||||
// Give User B access to two of User A's agents
|
||||
const accessibleIds = [agentA1._id, agentA3._id];
|
||||
|
||||
const result = await getListAgentsByAccess({
|
||||
accessibleIds,
|
||||
otherParams: {},
|
||||
});
|
||||
|
||||
expect(result.data).toHaveLength(2);
|
||||
const returnedIds = result.data.map((agent) => agent.id);
|
||||
expect(returnedIds).toContain(agentA1.id);
|
||||
expect(returnedIds).toContain(agentA3.id);
|
||||
expect(returnedIds).not.toContain(agentA2.id);
|
||||
});
|
||||
|
||||
test('should respect other query parameters while enforcing accessibleIds', async () => {
|
||||
// Give access to all agents but filter by name
|
||||
const accessibleIds = [agentA1._id, agentA2._id, agentA3._id];
|
||||
|
||||
const result = await getListAgentsByAccess({
|
||||
accessibleIds,
|
||||
otherParams: { name: 'Agent A2' },
|
||||
});
|
||||
|
||||
expect(result.data).toHaveLength(1);
|
||||
expect(result.data[0].id).toBe(agentA2.id);
|
||||
});
|
||||
|
||||
test('should handle pagination correctly with accessibleIds filter', async () => {
|
||||
// Create more agents
|
||||
const moreAgents = [];
|
||||
for (let i = 4; i <= 10; i++) {
|
||||
const agent = await createAgent({
|
||||
id: `agent_${uuidv4().slice(0, 12)}`,
|
||||
name: `Agent A${i}`,
|
||||
description: `User A agent ${i}`,
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: userA,
|
||||
});
|
||||
moreAgents.push(agent);
|
||||
}
|
||||
|
||||
// Give access to all agents
|
||||
const allAgentIds = [agentA1, agentA2, agentA3, ...moreAgents].map((a) => a._id);
|
||||
|
||||
// First page
|
||||
const page1 = await getListAgentsByAccess({
|
||||
accessibleIds: allAgentIds,
|
||||
otherParams: {},
|
||||
limit: 5,
|
||||
});
|
||||
|
||||
expect(page1.data).toHaveLength(5);
|
||||
expect(page1.has_more).toBe(true);
|
||||
expect(page1.after).toBeTruthy();
|
||||
|
||||
// Second page
|
||||
const page2 = await getListAgentsByAccess({
|
||||
accessibleIds: allAgentIds,
|
||||
otherParams: {},
|
||||
limit: 5,
|
||||
after: page1.after,
|
||||
});
|
||||
|
||||
expect(page2.data).toHaveLength(5);
|
||||
expect(page2.has_more).toBe(false);
|
||||
|
||||
// Verify no overlap between pages
|
||||
const page1Ids = page1.data.map((a) => a.id);
|
||||
const page2Ids = page2.data.map((a) => a.id);
|
||||
const intersection = page1Ids.filter((id) => page2Ids.includes(id));
|
||||
expect(intersection).toHaveLength(0);
|
||||
});
|
||||
|
||||
test('should return empty list when accessibleIds contains non-existent IDs', async () => {
|
||||
// Try with non-existent agent IDs
|
||||
const fakeIds = [new mongoose.Types.ObjectId(), new mongoose.Types.ObjectId()];
|
||||
|
||||
const result = await getListAgentsByAccess({
|
||||
accessibleIds: fakeIds,
|
||||
otherParams: {},
|
||||
});
|
||||
|
||||
expect(result.data).toHaveLength(0);
|
||||
expect(result.has_more).toBe(false);
|
||||
});
|
||||
|
||||
test('should handle undefined accessibleIds as empty array', async () => {
|
||||
// When accessibleIds is undefined, it should be treated as empty array
|
||||
const result = await getListAgentsByAccess({
|
||||
accessibleIds: undefined,
|
||||
otherParams: {},
|
||||
});
|
||||
|
||||
expect(result.data).toHaveLength(0);
|
||||
expect(result.has_more).toBe(false);
|
||||
});
|
||||
|
||||
test('should combine accessibleIds with author filter correctly', async () => {
|
||||
// Create an agent for User B
|
||||
const agentB1 = await createAgent({
|
||||
id: `agent_${uuidv4().slice(0, 12)}`,
|
||||
name: 'Agent B1',
|
||||
description: 'User B agent 1',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: userB,
|
||||
});
|
||||
|
||||
// Give User B access to one of User A's agents
|
||||
const accessibleIds = [agentA1._id, agentB1._id];
|
||||
|
||||
// Filter by author should further restrict the results
|
||||
const result = await getListAgentsByAccess({
|
||||
accessibleIds,
|
||||
otherParams: { author: userB },
|
||||
});
|
||||
|
||||
expect(result.data).toHaveLength(1);
|
||||
expect(result.data[0].id).toBe(agentB1.id);
|
||||
expect(result.data[0].author).toBe(userB.toString());
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
function createBasicAgent(overrides = {}) {
|
||||
const defaults = {
|
||||
id: `agent_${uuidv4()}`,
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { EToolResources, FileContext, Constants } = require('librechat-data-provider');
|
||||
const { getProjectByName } = require('./Project');
|
||||
const { getAgent } = require('./Agent');
|
||||
const { EToolResources, FileContext } = require('librechat-data-provider');
|
||||
const { File } = require('~/db/models');
|
||||
|
||||
/**
|
||||
@@ -14,124 +12,17 @@ const findFileById = async (file_id, options = {}) => {
|
||||
return await File.findOne({ file_id, ...options }).lean();
|
||||
};
|
||||
|
||||
/**
|
||||
* Checks if a user has access to multiple files through a shared agent (batch operation)
|
||||
* @param {string} userId - The user ID to check access for
|
||||
* @param {string[]} fileIds - Array of file IDs to check
|
||||
* @param {string} agentId - The agent ID that might grant access
|
||||
* @returns {Promise<Map<string, boolean>>} Map of fileId to access status
|
||||
*/
|
||||
const hasAccessToFilesViaAgent = async (userId, fileIds, agentId, checkCollaborative = true) => {
|
||||
const accessMap = new Map();
|
||||
|
||||
// Initialize all files as no access
|
||||
fileIds.forEach((fileId) => accessMap.set(fileId, false));
|
||||
|
||||
try {
|
||||
const agent = await getAgent({ id: agentId });
|
||||
|
||||
if (!agent) {
|
||||
return accessMap;
|
||||
}
|
||||
|
||||
// Check if user is the author - if so, grant access to all files
|
||||
if (agent.author.toString() === userId) {
|
||||
fileIds.forEach((fileId) => accessMap.set(fileId, true));
|
||||
return accessMap;
|
||||
}
|
||||
|
||||
// Check if agent is shared with the user via projects
|
||||
if (!agent.projectIds || agent.projectIds.length === 0) {
|
||||
return accessMap;
|
||||
}
|
||||
|
||||
// Check if agent is in global project
|
||||
const globalProject = await getProjectByName(Constants.GLOBAL_PROJECT_NAME, '_id');
|
||||
if (
|
||||
!globalProject ||
|
||||
!agent.projectIds.some((pid) => pid.toString() === globalProject._id.toString())
|
||||
) {
|
||||
return accessMap;
|
||||
}
|
||||
|
||||
// Agent is globally shared - check if it's collaborative
|
||||
if (checkCollaborative && !agent.isCollaborative) {
|
||||
return accessMap;
|
||||
}
|
||||
|
||||
// Check which files are actually attached
|
||||
const attachedFileIds = new Set();
|
||||
if (agent.tool_resources) {
|
||||
for (const [_resourceType, resource] of Object.entries(agent.tool_resources)) {
|
||||
if (resource?.file_ids && Array.isArray(resource.file_ids)) {
|
||||
resource.file_ids.forEach((fileId) => attachedFileIds.add(fileId));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Grant access only to files that are attached to this agent
|
||||
fileIds.forEach((fileId) => {
|
||||
if (attachedFileIds.has(fileId)) {
|
||||
accessMap.set(fileId, true);
|
||||
}
|
||||
});
|
||||
|
||||
return accessMap;
|
||||
} catch (error) {
|
||||
logger.error('[hasAccessToFilesViaAgent] Error checking file access:', error);
|
||||
return accessMap;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Retrieves files matching a given filter, sorted by the most recently updated.
|
||||
* @param {Object} filter - The filter criteria to apply.
|
||||
* @param {Object} [_sortOptions] - Optional sort parameters.
|
||||
* @param {Object|String} [selectFields={ text: 0 }] - Fields to include/exclude in the query results.
|
||||
* Default excludes the 'text' field.
|
||||
* @param {Object} [options] - Additional options
|
||||
* @param {string} [options.userId] - User ID for access control
|
||||
* @param {string} [options.agentId] - Agent ID that might grant access to files
|
||||
* @returns {Promise<Array<MongoFile>>} A promise that resolves to an array of file documents.
|
||||
*/
|
||||
const getFiles = async (filter, _sortOptions, selectFields = { text: 0 }, options = {}) => {
|
||||
const getFiles = async (filter, _sortOptions, selectFields = { text: 0 }) => {
|
||||
const sortOptions = { updatedAt: -1, ..._sortOptions };
|
||||
const files = await File.find(filter).select(selectFields).sort(sortOptions).lean();
|
||||
|
||||
// If userId and agentId are provided, filter files based on access
|
||||
if (options.userId && options.agentId) {
|
||||
// Collect file IDs that need access check
|
||||
const filesToCheck = [];
|
||||
const ownedFiles = [];
|
||||
|
||||
for (const file of files) {
|
||||
if (file.user && file.user.toString() === options.userId) {
|
||||
ownedFiles.push(file);
|
||||
} else {
|
||||
filesToCheck.push(file);
|
||||
}
|
||||
}
|
||||
|
||||
if (filesToCheck.length === 0) {
|
||||
return ownedFiles;
|
||||
}
|
||||
|
||||
// Batch check access for all non-owned files
|
||||
const fileIds = filesToCheck.map((f) => f.file_id);
|
||||
const accessMap = await hasAccessToFilesViaAgent(
|
||||
options.userId,
|
||||
fileIds,
|
||||
options.agentId,
|
||||
false,
|
||||
);
|
||||
|
||||
// Filter files based on access
|
||||
const accessibleFiles = filesToCheck.filter((file) => accessMap.get(file.file_id));
|
||||
|
||||
return [...ownedFiles, ...accessibleFiles];
|
||||
}
|
||||
|
||||
return files;
|
||||
return await File.find(filter).select(selectFields).sort(sortOptions).lean();
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -285,5 +176,4 @@ module.exports = {
|
||||
deleteFiles,
|
||||
deleteFileByFilter,
|
||||
batchUpdateFiles,
|
||||
hasAccessToFilesViaAgent,
|
||||
};
|
||||
|
||||
@@ -1,17 +1,23 @@
|
||||
const mongoose = require('mongoose');
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const { fileSchema } = require('@librechat/data-schemas');
|
||||
const { agentSchema } = require('@librechat/data-schemas');
|
||||
const { projectSchema } = require('@librechat/data-schemas');
|
||||
const { createModels } = require('@librechat/data-schemas');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { GLOBAL_PROJECT_NAME } = require('librechat-data-provider').Constants;
|
||||
const {
|
||||
SystemRoles,
|
||||
ResourceType,
|
||||
AccessRoleIds,
|
||||
PrincipalType,
|
||||
} = require('librechat-data-provider');
|
||||
const { grantPermission } = require('~/server/services/PermissionService');
|
||||
const { getFiles, createFile } = require('./File');
|
||||
const { getProjectByName } = require('./Project');
|
||||
const { seedDefaultRoles } = require('~/models');
|
||||
const { createAgent } = require('./Agent');
|
||||
|
||||
let File;
|
||||
let Agent;
|
||||
let Project;
|
||||
let AclEntry;
|
||||
let User;
|
||||
let modelsToCleanup = [];
|
||||
|
||||
describe('File Access Control', () => {
|
||||
let mongoServer;
|
||||
@@ -19,13 +25,41 @@ describe('File Access Control', () => {
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
File = mongoose.models.File || mongoose.model('File', fileSchema);
|
||||
Agent = mongoose.models.Agent || mongoose.model('Agent', agentSchema);
|
||||
Project = mongoose.models.Project || mongoose.model('Project', projectSchema);
|
||||
await mongoose.connect(mongoUri);
|
||||
|
||||
// Initialize all models
|
||||
const models = createModels(mongoose);
|
||||
|
||||
// Track which models we're adding
|
||||
modelsToCleanup = Object.keys(models);
|
||||
|
||||
// Register models on mongoose.models so methods can access them
|
||||
const dbModels = require('~/db/models');
|
||||
Object.assign(mongoose.models, dbModels);
|
||||
|
||||
File = dbModels.File;
|
||||
Agent = dbModels.Agent;
|
||||
AclEntry = dbModels.AclEntry;
|
||||
User = dbModels.User;
|
||||
|
||||
// Seed default roles
|
||||
await seedDefaultRoles();
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
// Clean up all collections before disconnecting
|
||||
const collections = mongoose.connection.collections;
|
||||
for (const key in collections) {
|
||||
await collections[key].deleteMany({});
|
||||
}
|
||||
|
||||
// Clear only the models we added
|
||||
for (const modelName of modelsToCleanup) {
|
||||
if (mongoose.models[modelName]) {
|
||||
delete mongoose.models[modelName];
|
||||
}
|
||||
}
|
||||
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
@@ -33,16 +67,33 @@ describe('File Access Control', () => {
|
||||
beforeEach(async () => {
|
||||
await File.deleteMany({});
|
||||
await Agent.deleteMany({});
|
||||
await Project.deleteMany({});
|
||||
await AclEntry.deleteMany({});
|
||||
await User.deleteMany({});
|
||||
// Don't delete AccessRole as they are seeded defaults needed for tests
|
||||
});
|
||||
|
||||
describe('hasAccessToFilesViaAgent', () => {
|
||||
it('should efficiently check access for multiple files at once', async () => {
|
||||
const userId = new mongoose.Types.ObjectId().toString();
|
||||
const authorId = new mongoose.Types.ObjectId().toString();
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const agentId = uuidv4();
|
||||
const fileIds = [uuidv4(), uuidv4(), uuidv4(), uuidv4()];
|
||||
|
||||
// Create users
|
||||
await User.create({
|
||||
_id: userId,
|
||||
email: 'user@example.com',
|
||||
emailVerified: true,
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
await User.create({
|
||||
_id: authorId,
|
||||
email: 'author@example.com',
|
||||
emailVerified: true,
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
// Create files
|
||||
for (const fileId of fileIds) {
|
||||
await createFile({
|
||||
@@ -54,13 +105,12 @@ describe('File Access Control', () => {
|
||||
}
|
||||
|
||||
// Create agent with only first two files attached
|
||||
await createAgent({
|
||||
const agent = await createAgent({
|
||||
id: agentId,
|
||||
name: 'Test Agent',
|
||||
author: authorId,
|
||||
model: 'gpt-4',
|
||||
provider: 'openai',
|
||||
isCollaborative: true,
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: [fileIds[0], fileIds[1]],
|
||||
@@ -68,15 +118,24 @@ describe('File Access Control', () => {
|
||||
},
|
||||
});
|
||||
|
||||
// Get or create global project
|
||||
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME, '_id');
|
||||
|
||||
// Share agent globally
|
||||
await Agent.updateOne({ id: agentId }, { $push: { projectIds: globalProject._id } });
|
||||
// Grant EDIT permission to user on the agent
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: userId,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
accessRoleId: AccessRoleIds.AGENT_EDITOR,
|
||||
grantedBy: authorId,
|
||||
});
|
||||
|
||||
// Check access for all files
|
||||
const { hasAccessToFilesViaAgent } = require('./File');
|
||||
const accessMap = await hasAccessToFilesViaAgent(userId, fileIds, agentId);
|
||||
const { hasAccessToFilesViaAgent } = require('~/server/services/Files/permissions');
|
||||
const accessMap = await hasAccessToFilesViaAgent({
|
||||
userId: userId,
|
||||
role: SystemRoles.USER,
|
||||
fileIds,
|
||||
agentId: agent.id, // Use agent.id which is the custom UUID
|
||||
});
|
||||
|
||||
// Should have access only to the first two files
|
||||
expect(accessMap.get(fileIds[0])).toBe(true);
|
||||
@@ -86,10 +145,18 @@ describe('File Access Control', () => {
|
||||
});
|
||||
|
||||
it('should grant access to all files when user is the agent author', async () => {
|
||||
const authorId = new mongoose.Types.ObjectId().toString();
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const agentId = uuidv4();
|
||||
const fileIds = [uuidv4(), uuidv4(), uuidv4()];
|
||||
|
||||
// Create author user
|
||||
await User.create({
|
||||
_id: authorId,
|
||||
email: 'author@example.com',
|
||||
emailVerified: true,
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
// Create agent
|
||||
await createAgent({
|
||||
id: agentId,
|
||||
@@ -105,8 +172,13 @@ describe('File Access Control', () => {
|
||||
});
|
||||
|
||||
// Check access as the author
|
||||
const { hasAccessToFilesViaAgent } = require('./File');
|
||||
const accessMap = await hasAccessToFilesViaAgent(authorId, fileIds, agentId);
|
||||
const { hasAccessToFilesViaAgent } = require('~/server/services/Files/permissions');
|
||||
const accessMap = await hasAccessToFilesViaAgent({
|
||||
userId: authorId,
|
||||
role: SystemRoles.USER,
|
||||
fileIds,
|
||||
agentId,
|
||||
});
|
||||
|
||||
// Author should have access to all files
|
||||
expect(accessMap.get(fileIds[0])).toBe(true);
|
||||
@@ -115,31 +187,58 @@ describe('File Access Control', () => {
|
||||
});
|
||||
|
||||
it('should handle non-existent agent gracefully', async () => {
|
||||
const userId = new mongoose.Types.ObjectId().toString();
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const fileIds = [uuidv4(), uuidv4()];
|
||||
|
||||
const { hasAccessToFilesViaAgent } = require('./File');
|
||||
const accessMap = await hasAccessToFilesViaAgent(userId, fileIds, 'non-existent-agent');
|
||||
// Create user
|
||||
await User.create({
|
||||
_id: userId,
|
||||
email: 'user@example.com',
|
||||
emailVerified: true,
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
const { hasAccessToFilesViaAgent } = require('~/server/services/Files/permissions');
|
||||
const accessMap = await hasAccessToFilesViaAgent({
|
||||
userId: userId,
|
||||
role: SystemRoles.USER,
|
||||
fileIds,
|
||||
agentId: 'non-existent-agent',
|
||||
});
|
||||
|
||||
// Should have no access to any files
|
||||
expect(accessMap.get(fileIds[0])).toBe(false);
|
||||
expect(accessMap.get(fileIds[1])).toBe(false);
|
||||
});
|
||||
|
||||
it('should deny access when agent is not collaborative', async () => {
|
||||
const userId = new mongoose.Types.ObjectId().toString();
|
||||
const authorId = new mongoose.Types.ObjectId().toString();
|
||||
it('should deny access when user only has VIEW permission', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const agentId = uuidv4();
|
||||
const fileIds = [uuidv4(), uuidv4()];
|
||||
|
||||
// Create agent with files but isCollaborative: false
|
||||
await createAgent({
|
||||
// Create users
|
||||
await User.create({
|
||||
_id: userId,
|
||||
email: 'user@example.com',
|
||||
emailVerified: true,
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
await User.create({
|
||||
_id: authorId,
|
||||
email: 'author@example.com',
|
||||
emailVerified: true,
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
// Create agent with files
|
||||
const agent = await createAgent({
|
||||
id: agentId,
|
||||
name: 'Non-Collaborative Agent',
|
||||
name: 'View-Only Agent',
|
||||
author: authorId,
|
||||
model: 'gpt-4',
|
||||
provider: 'openai',
|
||||
isCollaborative: false,
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: fileIds,
|
||||
@@ -147,17 +246,26 @@ describe('File Access Control', () => {
|
||||
},
|
||||
});
|
||||
|
||||
// Get or create global project
|
||||
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME, '_id');
|
||||
|
||||
// Share agent globally
|
||||
await Agent.updateOne({ id: agentId }, { $push: { projectIds: globalProject._id } });
|
||||
// Grant only VIEW permission to user on the agent
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: userId,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
accessRoleId: AccessRoleIds.AGENT_VIEWER,
|
||||
grantedBy: authorId,
|
||||
});
|
||||
|
||||
// Check access for files
|
||||
const { hasAccessToFilesViaAgent } = require('./File');
|
||||
const accessMap = await hasAccessToFilesViaAgent(userId, fileIds, agentId);
|
||||
const { hasAccessToFilesViaAgent } = require('~/server/services/Files/permissions');
|
||||
const accessMap = await hasAccessToFilesViaAgent({
|
||||
userId: userId,
|
||||
role: SystemRoles.USER,
|
||||
fileIds,
|
||||
agentId,
|
||||
});
|
||||
|
||||
// Should have no access to any files when isCollaborative is false
|
||||
// Should have no access to any files when only VIEW permission
|
||||
expect(accessMap.get(fileIds[0])).toBe(false);
|
||||
expect(accessMap.get(fileIds[1])).toBe(false);
|
||||
});
|
||||
@@ -172,18 +280,28 @@ describe('File Access Control', () => {
|
||||
const sharedFileId = `file_${uuidv4()}`;
|
||||
const inaccessibleFileId = `file_${uuidv4()}`;
|
||||
|
||||
// Create/get global project using getProjectByName which will upsert
|
||||
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME);
|
||||
// Create users
|
||||
await User.create({
|
||||
_id: userId,
|
||||
email: 'user@example.com',
|
||||
emailVerified: true,
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
await User.create({
|
||||
_id: authorId,
|
||||
email: 'author@example.com',
|
||||
emailVerified: true,
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
// Create agent with shared file
|
||||
await createAgent({
|
||||
const agent = await createAgent({
|
||||
id: agentId,
|
||||
name: 'Shared Agent',
|
||||
provider: 'test',
|
||||
model: 'test-model',
|
||||
author: authorId,
|
||||
projectIds: [globalProject._id],
|
||||
isCollaborative: true,
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: [sharedFileId],
|
||||
@@ -191,6 +309,16 @@ describe('File Access Control', () => {
|
||||
},
|
||||
});
|
||||
|
||||
// Grant EDIT permission to user on the agent
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: userId,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
accessRoleId: AccessRoleIds.AGENT_EDITOR,
|
||||
grantedBy: authorId,
|
||||
});
|
||||
|
||||
// Create files
|
||||
await createFile({
|
||||
file_id: ownedFileId,
|
||||
@@ -220,14 +348,22 @@ describe('File Access Control', () => {
|
||||
bytes: 300,
|
||||
});
|
||||
|
||||
// Get files with access control
|
||||
const files = await getFiles(
|
||||
// Get all files first
|
||||
const allFiles = await getFiles(
|
||||
{ file_id: { $in: [ownedFileId, sharedFileId, inaccessibleFileId] } },
|
||||
null,
|
||||
{ text: 0 },
|
||||
{ userId: userId.toString(), agentId },
|
||||
);
|
||||
|
||||
// Then filter by access control
|
||||
const { filterFilesByAgentAccess } = require('~/server/services/Files/permissions');
|
||||
const files = await filterFilesByAgentAccess({
|
||||
files: allFiles,
|
||||
userId: userId,
|
||||
role: SystemRoles.USER,
|
||||
agentId,
|
||||
});
|
||||
|
||||
expect(files).toHaveLength(2);
|
||||
expect(files.map((f) => f.file_id)).toContain(ownedFileId);
|
||||
expect(files.map((f) => f.file_id)).toContain(sharedFileId);
|
||||
@@ -261,4 +397,166 @@ describe('File Access Control', () => {
|
||||
expect(files).toHaveLength(2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Role-based file permissions', () => {
|
||||
it('should optimize permission checks when role is provided', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const agentId = uuidv4();
|
||||
const fileIds = [uuidv4(), uuidv4()];
|
||||
|
||||
// Create users
|
||||
await User.create({
|
||||
_id: userId,
|
||||
email: 'user@example.com',
|
||||
emailVerified: true,
|
||||
provider: 'local',
|
||||
role: 'ADMIN', // User has ADMIN role
|
||||
});
|
||||
|
||||
await User.create({
|
||||
_id: authorId,
|
||||
email: 'author@example.com',
|
||||
emailVerified: true,
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
// Create files
|
||||
for (const fileId of fileIds) {
|
||||
await createFile({
|
||||
file_id: fileId,
|
||||
user: authorId,
|
||||
filename: `${fileId}.txt`,
|
||||
filepath: `/uploads/${fileId}.txt`,
|
||||
type: 'text/plain',
|
||||
bytes: 100,
|
||||
});
|
||||
}
|
||||
|
||||
// Create agent with files
|
||||
const agent = await createAgent({
|
||||
id: agentId,
|
||||
name: 'Test Agent',
|
||||
author: authorId,
|
||||
model: 'gpt-4',
|
||||
provider: 'openai',
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: fileIds,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Grant permission to ADMIN role
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.ROLE,
|
||||
principalId: 'ADMIN',
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
accessRoleId: AccessRoleIds.AGENT_EDITOR,
|
||||
grantedBy: authorId,
|
||||
});
|
||||
|
||||
// Check access with role provided (should avoid DB query)
|
||||
const { hasAccessToFilesViaAgent } = require('~/server/services/Files/permissions');
|
||||
const accessMapWithRole = await hasAccessToFilesViaAgent({
|
||||
userId: userId,
|
||||
role: 'ADMIN',
|
||||
fileIds,
|
||||
agentId: agent.id,
|
||||
});
|
||||
|
||||
// User should have access through their ADMIN role
|
||||
expect(accessMapWithRole.get(fileIds[0])).toBe(true);
|
||||
expect(accessMapWithRole.get(fileIds[1])).toBe(true);
|
||||
|
||||
// Check access without role (will query DB to get user's role)
|
||||
const accessMapWithoutRole = await hasAccessToFilesViaAgent({
|
||||
userId: userId,
|
||||
fileIds,
|
||||
agentId: agent.id,
|
||||
});
|
||||
|
||||
// Should have same result
|
||||
expect(accessMapWithoutRole.get(fileIds[0])).toBe(true);
|
||||
expect(accessMapWithoutRole.get(fileIds[1])).toBe(true);
|
||||
});
|
||||
|
||||
it('should deny access when user role changes', async () => {
|
||||
const userId = new mongoose.Types.ObjectId();
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const agentId = uuidv4();
|
||||
const fileId = uuidv4();
|
||||
|
||||
// Create users
|
||||
await User.create({
|
||||
_id: userId,
|
||||
email: 'user@example.com',
|
||||
emailVerified: true,
|
||||
provider: 'local',
|
||||
role: 'EDITOR',
|
||||
});
|
||||
|
||||
await User.create({
|
||||
_id: authorId,
|
||||
email: 'author@example.com',
|
||||
emailVerified: true,
|
||||
provider: 'local',
|
||||
});
|
||||
|
||||
// Create file
|
||||
await createFile({
|
||||
file_id: fileId,
|
||||
user: authorId,
|
||||
filename: 'test.txt',
|
||||
filepath: '/uploads/test.txt',
|
||||
type: 'text/plain',
|
||||
bytes: 100,
|
||||
});
|
||||
|
||||
// Create agent
|
||||
const agent = await createAgent({
|
||||
id: agentId,
|
||||
name: 'Test Agent',
|
||||
author: authorId,
|
||||
model: 'gpt-4',
|
||||
provider: 'openai',
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: [fileId],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Grant permission to EDITOR role only
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.ROLE,
|
||||
principalId: 'EDITOR',
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
accessRoleId: AccessRoleIds.AGENT_EDITOR,
|
||||
grantedBy: authorId,
|
||||
});
|
||||
|
||||
const { hasAccessToFilesViaAgent } = require('~/server/services/Files/permissions');
|
||||
|
||||
// Check with EDITOR role - should have access
|
||||
const accessAsEditor = await hasAccessToFilesViaAgent({
|
||||
userId: userId,
|
||||
role: 'EDITOR',
|
||||
fileIds: [fileId],
|
||||
agentId: agent.id,
|
||||
});
|
||||
expect(accessAsEditor.get(fileId)).toBe(true);
|
||||
|
||||
// Simulate role change to USER - should lose access
|
||||
const accessAsUser = await hasAccessToFilesViaAgent({
|
||||
userId: userId,
|
||||
role: SystemRoles.USER,
|
||||
fileIds: [fileId],
|
||||
agentId: agent.id,
|
||||
});
|
||||
expect(accessAsUser.get(fileId)).toBe(false);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,12 +1,18 @@
|
||||
const { ObjectId } = require('mongodb');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { SystemRoles, SystemCategories, Constants } = require('librechat-data-provider');
|
||||
const {
|
||||
getProjectByName,
|
||||
addGroupIdsToProject,
|
||||
removeGroupIdsFromProject,
|
||||
Constants,
|
||||
SystemRoles,
|
||||
ResourceType,
|
||||
SystemCategories,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
removeGroupFromAllProjects,
|
||||
removeGroupIdsFromProject,
|
||||
addGroupIdsToProject,
|
||||
getProjectByName,
|
||||
} = require('./Project');
|
||||
const { removeAllPermissions } = require('~/server/services/PermissionService');
|
||||
const { PromptGroup, Prompt } = require('~/db/models');
|
||||
const { escapeRegExp } = require('~/server/utils');
|
||||
|
||||
@@ -100,10 +106,6 @@ const getAllPromptGroups = async (req, filter) => {
|
||||
try {
|
||||
const { name, ...query } = filter;
|
||||
|
||||
if (!query.author) {
|
||||
throw new Error('Author is required');
|
||||
}
|
||||
|
||||
let searchShared = true;
|
||||
let searchSharedOnly = false;
|
||||
if (name) {
|
||||
@@ -153,10 +155,6 @@ const getPromptGroups = async (req, filter) => {
|
||||
const validatedPageNumber = Math.max(parseInt(pageNumber, 10), 1);
|
||||
const validatedPageSize = Math.max(parseInt(pageSize, 10), 1);
|
||||
|
||||
if (!query.author) {
|
||||
throw new Error('Author is required');
|
||||
}
|
||||
|
||||
let searchShared = true;
|
||||
let searchSharedOnly = false;
|
||||
if (name) {
|
||||
@@ -221,12 +219,16 @@ const getPromptGroups = async (req, filter) => {
|
||||
* @returns {Promise<TDeletePromptGroupResponse>}
|
||||
*/
|
||||
const deletePromptGroup = async ({ _id, author, role }) => {
|
||||
const query = { _id, author };
|
||||
const groupQuery = { groupId: new ObjectId(_id), author };
|
||||
if (role === SystemRoles.ADMIN) {
|
||||
delete query.author;
|
||||
delete groupQuery.author;
|
||||
// Build query - with ACL, author is optional
|
||||
const query = { _id };
|
||||
const groupQuery = { groupId: new ObjectId(_id) };
|
||||
|
||||
// Legacy: Add author filter if provided (backward compatibility)
|
||||
if (author && role !== SystemRoles.ADMIN) {
|
||||
query.author = author;
|
||||
groupQuery.author = author;
|
||||
}
|
||||
|
||||
const response = await PromptGroup.deleteOne(query);
|
||||
|
||||
if (!response || response.deletedCount === 0) {
|
||||
@@ -235,13 +237,140 @@ const deletePromptGroup = async ({ _id, author, role }) => {
|
||||
|
||||
await Prompt.deleteMany(groupQuery);
|
||||
await removeGroupFromAllProjects(_id);
|
||||
|
||||
try {
|
||||
await removeAllPermissions({ resourceType: ResourceType.PROMPTGROUP, resourceId: _id });
|
||||
} catch (error) {
|
||||
logger.error('Error removing promptGroup permissions:', error);
|
||||
}
|
||||
|
||||
return { message: 'Prompt group deleted successfully' };
|
||||
};
|
||||
|
||||
/**
|
||||
* Get prompt groups by accessible IDs with optional cursor-based pagination.
|
||||
* @param {Object} params - The parameters for getting accessible prompt groups.
|
||||
* @param {Array} [params.accessibleIds] - Array of prompt group ObjectIds the user has ACL access to.
|
||||
* @param {Object} [params.otherParams] - Additional query parameters (including author filter).
|
||||
* @param {number} [params.limit] - Number of prompt groups to return (max 100). If not provided, returns all prompt groups.
|
||||
* @param {string} [params.after] - Cursor for pagination - get prompt groups after this cursor. // base64 encoded JSON string with updatedAt and _id.
|
||||
* @returns {Promise<Object>} A promise that resolves to an object containing the prompt groups data and pagination info.
|
||||
*/
|
||||
async function getListPromptGroupsByAccess({
|
||||
accessibleIds = [],
|
||||
otherParams = {},
|
||||
limit = null,
|
||||
after = null,
|
||||
}) {
|
||||
const isPaginated = limit !== null && limit !== undefined;
|
||||
const normalizedLimit = isPaginated ? Math.min(Math.max(1, parseInt(limit) || 20), 100) : null;
|
||||
|
||||
// Build base query combining ACL accessible prompt groups with other filters
|
||||
const baseQuery = { ...otherParams, _id: { $in: accessibleIds } };
|
||||
|
||||
// Add cursor condition
|
||||
if (after) {
|
||||
try {
|
||||
const cursor = JSON.parse(Buffer.from(after, 'base64').toString('utf8'));
|
||||
const { updatedAt, _id } = cursor;
|
||||
|
||||
const cursorCondition = {
|
||||
$or: [
|
||||
{ updatedAt: { $lt: new Date(updatedAt) } },
|
||||
{ updatedAt: new Date(updatedAt), _id: { $gt: new ObjectId(_id) } },
|
||||
],
|
||||
};
|
||||
|
||||
// Merge cursor condition with base query
|
||||
if (Object.keys(baseQuery).length > 0) {
|
||||
baseQuery.$and = [{ ...baseQuery }, cursorCondition];
|
||||
// Remove the original conditions from baseQuery to avoid duplication
|
||||
Object.keys(baseQuery).forEach((key) => {
|
||||
if (key !== '$and') delete baseQuery[key];
|
||||
});
|
||||
} else {
|
||||
Object.assign(baseQuery, cursorCondition);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.warn('Invalid cursor:', error.message);
|
||||
}
|
||||
}
|
||||
|
||||
// Build aggregation pipeline
|
||||
const pipeline = [{ $match: baseQuery }, { $sort: { updatedAt: -1, _id: 1 } }];
|
||||
|
||||
// Only apply limit if pagination is requested
|
||||
if (isPaginated) {
|
||||
pipeline.push({ $limit: normalizedLimit + 1 });
|
||||
}
|
||||
|
||||
// Add lookup for production prompt
|
||||
pipeline.push(
|
||||
{
|
||||
$lookup: {
|
||||
from: 'prompts',
|
||||
localField: 'productionId',
|
||||
foreignField: '_id',
|
||||
as: 'productionPrompt',
|
||||
},
|
||||
},
|
||||
{ $unwind: { path: '$productionPrompt', preserveNullAndEmptyArrays: true } },
|
||||
{
|
||||
$project: {
|
||||
name: 1,
|
||||
numberOfGenerations: 1,
|
||||
oneliner: 1,
|
||||
category: 1,
|
||||
projectIds: 1,
|
||||
productionId: 1,
|
||||
author: 1,
|
||||
authorName: 1,
|
||||
createdAt: 1,
|
||||
updatedAt: 1,
|
||||
'productionPrompt.prompt': 1,
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
const promptGroups = await PromptGroup.aggregate(pipeline).exec();
|
||||
|
||||
const hasMore = isPaginated ? promptGroups.length > normalizedLimit : false;
|
||||
const data = (isPaginated ? promptGroups.slice(0, normalizedLimit) : promptGroups).map(
|
||||
(group) => {
|
||||
if (group.author) {
|
||||
group.author = group.author.toString();
|
||||
}
|
||||
return group;
|
||||
},
|
||||
);
|
||||
|
||||
// Generate next cursor only if paginated
|
||||
let nextCursor = null;
|
||||
if (isPaginated && hasMore && data.length > 0) {
|
||||
const lastGroup = promptGroups[normalizedLimit - 1];
|
||||
nextCursor = Buffer.from(
|
||||
JSON.stringify({
|
||||
updatedAt: lastGroup.updatedAt.toISOString(),
|
||||
_id: lastGroup._id.toString(),
|
||||
}),
|
||||
).toString('base64');
|
||||
}
|
||||
|
||||
return {
|
||||
object: 'list',
|
||||
data,
|
||||
first_id: data.length > 0 ? data[0]._id.toString() : null,
|
||||
last_id: data.length > 0 ? data[data.length - 1]._id.toString() : null,
|
||||
has_more: hasMore,
|
||||
after: nextCursor,
|
||||
};
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
getPromptGroups,
|
||||
deletePromptGroup,
|
||||
getAllPromptGroups,
|
||||
getListPromptGroupsByAccess,
|
||||
/**
|
||||
* Create a prompt and its respective group
|
||||
* @param {TCreatePromptRecord} saveData
|
||||
@@ -430,6 +559,16 @@ module.exports = {
|
||||
.lean();
|
||||
|
||||
if (remainingPrompts.length === 0) {
|
||||
// Remove all ACL entries for the promptGroup when deleting the last prompt
|
||||
try {
|
||||
await removeAllPermissions({
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: groupId,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('Error removing promptGroup permissions:', error);
|
||||
}
|
||||
|
||||
await PromptGroup.deleteOne({ _id: groupId });
|
||||
await removeGroupFromAllProjects(groupId);
|
||||
|
||||
|
||||
564
api/models/Prompt.spec.js
Normal file
564
api/models/Prompt.spec.js
Normal file
@@ -0,0 +1,564 @@
|
||||
const mongoose = require('mongoose');
|
||||
const { ObjectId } = require('mongodb');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const {
|
||||
SystemRoles,
|
||||
ResourceType,
|
||||
AccessRoleIds,
|
||||
PrincipalType,
|
||||
PermissionBits,
|
||||
} = require('librechat-data-provider');
|
||||
|
||||
// Mock the config/connect module to prevent connection attempts during tests
|
||||
jest.mock('../../config/connect', () => jest.fn().mockResolvedValue(true));
|
||||
|
||||
const dbModels = require('~/db/models');
|
||||
|
||||
// Disable console for tests
|
||||
logger.silent = true;
|
||||
|
||||
let mongoServer;
|
||||
let Prompt, PromptGroup, AclEntry, AccessRole, User, Group, Project;
|
||||
let promptFns, permissionService;
|
||||
let testUsers, testGroups, testRoles;
|
||||
|
||||
beforeAll(async () => {
|
||||
// Set up MongoDB memory server
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
await mongoose.connect(mongoUri);
|
||||
|
||||
// Initialize models
|
||||
Prompt = dbModels.Prompt;
|
||||
PromptGroup = dbModels.PromptGroup;
|
||||
AclEntry = dbModels.AclEntry;
|
||||
AccessRole = dbModels.AccessRole;
|
||||
User = dbModels.User;
|
||||
Group = dbModels.Group;
|
||||
Project = dbModels.Project;
|
||||
|
||||
promptFns = require('~/models/Prompt');
|
||||
permissionService = require('~/server/services/PermissionService');
|
||||
|
||||
// Create test data
|
||||
await setupTestData();
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
async function setupTestData() {
|
||||
// Create access roles for promptGroups
|
||||
testRoles = {
|
||||
viewer: await AccessRole.create({
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_VIEWER,
|
||||
name: 'Viewer',
|
||||
description: 'Can view promptGroups',
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
permBits: PermissionBits.VIEW,
|
||||
}),
|
||||
editor: await AccessRole.create({
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_EDITOR,
|
||||
name: 'Editor',
|
||||
description: 'Can view and edit promptGroups',
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
permBits: PermissionBits.VIEW | PermissionBits.EDIT,
|
||||
}),
|
||||
owner: await AccessRole.create({
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
|
||||
name: 'Owner',
|
||||
description: 'Full control over promptGroups',
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
permBits:
|
||||
PermissionBits.VIEW | PermissionBits.EDIT | PermissionBits.DELETE | PermissionBits.SHARE,
|
||||
}),
|
||||
};
|
||||
|
||||
// Create test users
|
||||
testUsers = {
|
||||
owner: await User.create({
|
||||
name: 'Prompt Owner',
|
||||
email: 'owner@example.com',
|
||||
role: SystemRoles.USER,
|
||||
}),
|
||||
editor: await User.create({
|
||||
name: 'Prompt Editor',
|
||||
email: 'editor@example.com',
|
||||
role: SystemRoles.USER,
|
||||
}),
|
||||
viewer: await User.create({
|
||||
name: 'Prompt Viewer',
|
||||
email: 'viewer@example.com',
|
||||
role: SystemRoles.USER,
|
||||
}),
|
||||
admin: await User.create({
|
||||
name: 'Admin User',
|
||||
email: 'admin@example.com',
|
||||
role: SystemRoles.ADMIN,
|
||||
}),
|
||||
noAccess: await User.create({
|
||||
name: 'No Access User',
|
||||
email: 'noaccess@example.com',
|
||||
role: SystemRoles.USER,
|
||||
}),
|
||||
};
|
||||
|
||||
// Create test groups
|
||||
testGroups = {
|
||||
editors: await Group.create({
|
||||
name: 'Prompt Editors',
|
||||
description: 'Group with editor access',
|
||||
}),
|
||||
viewers: await Group.create({
|
||||
name: 'Prompt Viewers',
|
||||
description: 'Group with viewer access',
|
||||
}),
|
||||
};
|
||||
|
||||
await Project.create({
|
||||
name: 'Global',
|
||||
description: 'Global project',
|
||||
promptGroupIds: [],
|
||||
});
|
||||
}
|
||||
|
||||
describe('Prompt ACL Permissions', () => {
|
||||
describe('Creating Prompts with Permissions', () => {
|
||||
it('should grant owner permissions when creating a prompt', async () => {
|
||||
// First create a group
|
||||
const testGroup = await PromptGroup.create({
|
||||
name: 'Test Group',
|
||||
category: 'testing',
|
||||
author: testUsers.owner._id,
|
||||
authorName: testUsers.owner.name,
|
||||
productionId: new mongoose.Types.ObjectId(),
|
||||
});
|
||||
|
||||
const promptData = {
|
||||
prompt: {
|
||||
prompt: 'Test prompt content',
|
||||
name: 'Test Prompt',
|
||||
type: 'text',
|
||||
groupId: testGroup._id,
|
||||
},
|
||||
author: testUsers.owner._id,
|
||||
};
|
||||
|
||||
await promptFns.savePrompt(promptData);
|
||||
|
||||
// Manually grant permissions as would happen in the route
|
||||
await permissionService.grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUsers.owner._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testGroup._id,
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
|
||||
grantedBy: testUsers.owner._id,
|
||||
});
|
||||
|
||||
// Check ACL entry
|
||||
const aclEntry = await AclEntry.findOne({
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testGroup._id,
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUsers.owner._id,
|
||||
});
|
||||
|
||||
expect(aclEntry).toBeTruthy();
|
||||
expect(aclEntry.permBits).toBe(testRoles.owner.permBits);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Accessing Prompts', () => {
|
||||
let testPromptGroup;
|
||||
|
||||
beforeEach(async () => {
|
||||
// Create a prompt group
|
||||
testPromptGroup = await PromptGroup.create({
|
||||
name: 'Test Group',
|
||||
author: testUsers.owner._id,
|
||||
authorName: testUsers.owner.name,
|
||||
productionId: new ObjectId(),
|
||||
});
|
||||
|
||||
// Create a prompt
|
||||
await Prompt.create({
|
||||
prompt: 'Test prompt for access control',
|
||||
name: 'Access Test Prompt',
|
||||
author: testUsers.owner._id,
|
||||
groupId: testPromptGroup._id,
|
||||
type: 'text',
|
||||
});
|
||||
|
||||
// Grant owner permissions
|
||||
await permissionService.grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUsers.owner._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testPromptGroup._id,
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
|
||||
grantedBy: testUsers.owner._id,
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await Prompt.deleteMany({});
|
||||
await PromptGroup.deleteMany({});
|
||||
await AclEntry.deleteMany({});
|
||||
});
|
||||
|
||||
it('owner should have full access to their prompt', async () => {
|
||||
const hasAccess = await permissionService.checkPermission({
|
||||
userId: testUsers.owner._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testPromptGroup._id,
|
||||
requiredPermission: PermissionBits.VIEW,
|
||||
});
|
||||
|
||||
expect(hasAccess).toBe(true);
|
||||
|
||||
const canEdit = await permissionService.checkPermission({
|
||||
userId: testUsers.owner._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testPromptGroup._id,
|
||||
requiredPermission: PermissionBits.EDIT,
|
||||
});
|
||||
|
||||
expect(canEdit).toBe(true);
|
||||
});
|
||||
|
||||
it('user with viewer role should only have view access', async () => {
|
||||
// Grant viewer permissions
|
||||
await permissionService.grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUsers.viewer._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testPromptGroup._id,
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_VIEWER,
|
||||
grantedBy: testUsers.owner._id,
|
||||
});
|
||||
|
||||
const canView = await permissionService.checkPermission({
|
||||
userId: testUsers.viewer._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testPromptGroup._id,
|
||||
requiredPermission: PermissionBits.VIEW,
|
||||
});
|
||||
|
||||
const canEdit = await permissionService.checkPermission({
|
||||
userId: testUsers.viewer._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testPromptGroup._id,
|
||||
requiredPermission: PermissionBits.EDIT,
|
||||
});
|
||||
|
||||
expect(canView).toBe(true);
|
||||
expect(canEdit).toBe(false);
|
||||
});
|
||||
|
||||
it('user without permissions should have no access', async () => {
|
||||
const hasAccess = await permissionService.checkPermission({
|
||||
userId: testUsers.noAccess._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testPromptGroup._id,
|
||||
requiredPermission: PermissionBits.VIEW,
|
||||
});
|
||||
|
||||
expect(hasAccess).toBe(false);
|
||||
});
|
||||
|
||||
it('admin should have access regardless of permissions', async () => {
|
||||
// Admin users should work through normal permission system
|
||||
// The middleware layer handles admin bypass, not the permission service
|
||||
const hasAccess = await permissionService.checkPermission({
|
||||
userId: testUsers.admin._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testPromptGroup._id,
|
||||
requiredPermission: PermissionBits.VIEW,
|
||||
});
|
||||
|
||||
// Without explicit permissions, even admin won't have access at this layer
|
||||
expect(hasAccess).toBe(false);
|
||||
|
||||
// The actual admin bypass happens in the middleware layer (`canAccessPromptViaGroup`/`canAccessPromptGroupResource`)
|
||||
// which checks req.user.role === SystemRoles.ADMIN
|
||||
});
|
||||
});
|
||||
|
||||
describe('Group-based Access', () => {
|
||||
let testPromptGroup;
|
||||
|
||||
beforeEach(async () => {
|
||||
// Create a prompt group first
|
||||
testPromptGroup = await PromptGroup.create({
|
||||
name: 'Group Access Test Group',
|
||||
author: testUsers.owner._id,
|
||||
authorName: testUsers.owner.name,
|
||||
productionId: new ObjectId(),
|
||||
});
|
||||
|
||||
await Prompt.create({
|
||||
prompt: 'Group access test prompt',
|
||||
name: 'Group Test',
|
||||
author: testUsers.owner._id,
|
||||
groupId: testPromptGroup._id,
|
||||
type: 'text',
|
||||
});
|
||||
|
||||
// Add users to groups
|
||||
await User.findByIdAndUpdate(testUsers.editor._id, {
|
||||
$push: { groups: testGroups.editors._id },
|
||||
});
|
||||
|
||||
await User.findByIdAndUpdate(testUsers.viewer._id, {
|
||||
$push: { groups: testGroups.viewers._id },
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await Prompt.deleteMany({});
|
||||
await AclEntry.deleteMany({});
|
||||
await User.updateMany({}, { $set: { groups: [] } });
|
||||
});
|
||||
|
||||
it('group members should inherit group permissions', async () => {
|
||||
// Create a prompt group
|
||||
const testPromptGroup = await PromptGroup.create({
|
||||
name: 'Group Test Group',
|
||||
author: testUsers.owner._id,
|
||||
authorName: testUsers.owner.name,
|
||||
productionId: new ObjectId(),
|
||||
});
|
||||
|
||||
const { addUserToGroup } = require('~/models');
|
||||
await addUserToGroup(testUsers.editor._id, testGroups.editors._id);
|
||||
|
||||
const prompt = await promptFns.savePrompt({
|
||||
author: testUsers.owner._id,
|
||||
prompt: {
|
||||
prompt: 'Group test prompt',
|
||||
name: 'Group Test',
|
||||
groupId: testPromptGroup._id,
|
||||
type: 'text',
|
||||
},
|
||||
});
|
||||
|
||||
// Check if savePrompt returned an error
|
||||
if (!prompt || !prompt.prompt) {
|
||||
throw new Error(`Failed to save prompt: ${prompt?.message || 'Unknown error'}`);
|
||||
}
|
||||
|
||||
// Grant edit permissions to the group
|
||||
await permissionService.grantPermission({
|
||||
principalType: PrincipalType.GROUP,
|
||||
principalId: testGroups.editors._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testPromptGroup._id,
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_EDITOR,
|
||||
grantedBy: testUsers.owner._id,
|
||||
});
|
||||
|
||||
// Check if group member has access
|
||||
const hasAccess = await permissionService.checkPermission({
|
||||
userId: testUsers.editor._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testPromptGroup._id,
|
||||
requiredPermission: PermissionBits.EDIT,
|
||||
});
|
||||
|
||||
expect(hasAccess).toBe(true);
|
||||
|
||||
// Check that non-member doesn't have access
|
||||
const nonMemberAccess = await permissionService.checkPermission({
|
||||
userId: testUsers.viewer._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testPromptGroup._id,
|
||||
requiredPermission: PermissionBits.EDIT,
|
||||
});
|
||||
|
||||
expect(nonMemberAccess).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Public Access', () => {
|
||||
let publicPromptGroup, privatePromptGroup;
|
||||
|
||||
beforeEach(async () => {
|
||||
// Create separate prompt groups for public and private access
|
||||
publicPromptGroup = await PromptGroup.create({
|
||||
name: 'Public Access Test Group',
|
||||
author: testUsers.owner._id,
|
||||
authorName: testUsers.owner.name,
|
||||
productionId: new ObjectId(),
|
||||
});
|
||||
|
||||
privatePromptGroup = await PromptGroup.create({
|
||||
name: 'Private Access Test Group',
|
||||
author: testUsers.owner._id,
|
||||
authorName: testUsers.owner.name,
|
||||
productionId: new ObjectId(),
|
||||
});
|
||||
|
||||
// Create prompts in their respective groups
|
||||
await Prompt.create({
|
||||
prompt: 'Public prompt',
|
||||
name: 'Public',
|
||||
author: testUsers.owner._id,
|
||||
groupId: publicPromptGroup._id,
|
||||
type: 'text',
|
||||
});
|
||||
|
||||
await Prompt.create({
|
||||
prompt: 'Private prompt',
|
||||
name: 'Private',
|
||||
author: testUsers.owner._id,
|
||||
groupId: privatePromptGroup._id,
|
||||
type: 'text',
|
||||
});
|
||||
|
||||
// Grant public view access to publicPromptGroup
|
||||
await permissionService.grantPermission({
|
||||
principalType: PrincipalType.PUBLIC,
|
||||
principalId: null,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: publicPromptGroup._id,
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_VIEWER,
|
||||
grantedBy: testUsers.owner._id,
|
||||
});
|
||||
|
||||
// Grant only owner access to privatePromptGroup
|
||||
await permissionService.grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUsers.owner._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: privatePromptGroup._id,
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
|
||||
grantedBy: testUsers.owner._id,
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await Prompt.deleteMany({});
|
||||
await PromptGroup.deleteMany({});
|
||||
await AclEntry.deleteMany({});
|
||||
});
|
||||
|
||||
it('public prompt should be accessible to any user', async () => {
|
||||
const hasAccess = await permissionService.checkPermission({
|
||||
userId: testUsers.noAccess._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: publicPromptGroup._id,
|
||||
requiredPermission: PermissionBits.VIEW,
|
||||
includePublic: true,
|
||||
});
|
||||
|
||||
expect(hasAccess).toBe(true);
|
||||
});
|
||||
|
||||
it('private prompt should not be accessible to unauthorized users', async () => {
|
||||
const hasAccess = await permissionService.checkPermission({
|
||||
userId: testUsers.noAccess._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: privatePromptGroup._id,
|
||||
requiredPermission: PermissionBits.VIEW,
|
||||
includePublic: true,
|
||||
});
|
||||
|
||||
expect(hasAccess).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Prompt Deletion', () => {
|
||||
let testPromptGroup;
|
||||
|
||||
it('should remove ACL entries when prompt is deleted', async () => {
|
||||
testPromptGroup = await PromptGroup.create({
|
||||
name: 'Deletion Test Group',
|
||||
author: testUsers.owner._id,
|
||||
authorName: testUsers.owner.name,
|
||||
productionId: new ObjectId(),
|
||||
});
|
||||
|
||||
const prompt = await promptFns.savePrompt({
|
||||
author: testUsers.owner._id,
|
||||
prompt: {
|
||||
prompt: 'To be deleted',
|
||||
name: 'Delete Test',
|
||||
groupId: testPromptGroup._id,
|
||||
type: 'text',
|
||||
},
|
||||
});
|
||||
|
||||
// Check if savePrompt returned an error
|
||||
if (!prompt || !prompt.prompt) {
|
||||
throw new Error(`Failed to save prompt: ${prompt?.message || 'Unknown error'}`);
|
||||
}
|
||||
|
||||
const testPromptId = prompt.prompt._id;
|
||||
const promptGroupId = testPromptGroup._id;
|
||||
|
||||
// Grant permission
|
||||
await permissionService.grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUsers.owner._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testPromptGroup._id,
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
|
||||
grantedBy: testUsers.owner._id,
|
||||
});
|
||||
|
||||
// Verify ACL entry exists
|
||||
const beforeDelete = await AclEntry.find({
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testPromptGroup._id,
|
||||
});
|
||||
expect(beforeDelete).toHaveLength(1);
|
||||
|
||||
// Delete the prompt
|
||||
await promptFns.deletePrompt({
|
||||
promptId: testPromptId,
|
||||
groupId: promptGroupId,
|
||||
author: testUsers.owner._id,
|
||||
role: SystemRoles.USER,
|
||||
});
|
||||
|
||||
// Verify ACL entries are removed
|
||||
const aclEntries = await AclEntry.find({
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testPromptGroup._id,
|
||||
});
|
||||
|
||||
expect(aclEntries).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Backwards Compatibility', () => {
|
||||
it('should handle prompts without ACL entries gracefully', async () => {
|
||||
// Create a prompt group first
|
||||
const promptGroup = await PromptGroup.create({
|
||||
name: 'Legacy Test Group',
|
||||
author: testUsers.owner._id,
|
||||
authorName: testUsers.owner.name,
|
||||
productionId: new ObjectId(),
|
||||
});
|
||||
|
||||
// Create a prompt without ACL entries (legacy prompt)
|
||||
const legacyPrompt = await Prompt.create({
|
||||
prompt: 'Legacy prompt without ACL',
|
||||
name: 'Legacy',
|
||||
author: testUsers.owner._id,
|
||||
groupId: promptGroup._id,
|
||||
type: 'text',
|
||||
});
|
||||
|
||||
// The system should handle this gracefully
|
||||
const prompt = await promptFns.getPrompt({ _id: legacyPrompt._id });
|
||||
expect(prompt).toBeTruthy();
|
||||
expect(prompt._id.toString()).toBe(legacyPrompt._id.toString());
|
||||
});
|
||||
});
|
||||
});
|
||||
280
api/models/PromptGroupMigration.spec.js
Normal file
280
api/models/PromptGroupMigration.spec.js
Normal file
@@ -0,0 +1,280 @@
|
||||
const mongoose = require('mongoose');
|
||||
const { ObjectId } = require('mongodb');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const {
|
||||
Constants,
|
||||
ResourceType,
|
||||
AccessRoleIds,
|
||||
PrincipalType,
|
||||
PrincipalModel,
|
||||
PermissionBits,
|
||||
} = require('librechat-data-provider');
|
||||
|
||||
// Mock the config/connect module to prevent connection attempts during tests
|
||||
jest.mock('../../config/connect', () => jest.fn().mockResolvedValue(true));
|
||||
|
||||
// Disable console for tests
|
||||
logger.silent = true;
|
||||
|
||||
describe('PromptGroup Migration Script', () => {
|
||||
let mongoServer;
|
||||
let Prompt, PromptGroup, AclEntry, AccessRole, User, Project;
|
||||
let migrateToPromptGroupPermissions;
|
||||
let testOwner, testProject;
|
||||
let ownerRole, viewerRole;
|
||||
|
||||
beforeAll(async () => {
|
||||
// Set up MongoDB memory server
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
await mongoose.connect(mongoUri);
|
||||
|
||||
// Initialize models
|
||||
const dbModels = require('~/db/models');
|
||||
Prompt = dbModels.Prompt;
|
||||
PromptGroup = dbModels.PromptGroup;
|
||||
AclEntry = dbModels.AclEntry;
|
||||
AccessRole = dbModels.AccessRole;
|
||||
User = dbModels.User;
|
||||
Project = dbModels.Project;
|
||||
|
||||
// Create test user
|
||||
testOwner = await User.create({
|
||||
name: 'Test Owner',
|
||||
email: 'owner@test.com',
|
||||
role: 'USER',
|
||||
});
|
||||
|
||||
// Create test project with the proper name
|
||||
const projectName = Constants.GLOBAL_PROJECT_NAME || 'instance';
|
||||
testProject = await Project.create({
|
||||
name: projectName,
|
||||
description: 'Global project',
|
||||
promptGroupIds: [],
|
||||
});
|
||||
|
||||
// Create promptGroup access roles
|
||||
ownerRole = await AccessRole.create({
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
|
||||
name: 'Owner',
|
||||
description: 'Full control over promptGroups',
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
permBits:
|
||||
PermissionBits.VIEW | PermissionBits.EDIT | PermissionBits.DELETE | PermissionBits.SHARE,
|
||||
});
|
||||
|
||||
viewerRole = await AccessRole.create({
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_VIEWER,
|
||||
name: 'Viewer',
|
||||
description: 'Can view promptGroups',
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
permBits: PermissionBits.VIEW,
|
||||
});
|
||||
|
||||
await AccessRole.create({
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_EDITOR,
|
||||
name: 'Editor',
|
||||
description: 'Can view and edit promptGroups',
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
permBits: PermissionBits.VIEW | PermissionBits.EDIT,
|
||||
});
|
||||
|
||||
// Import migration function
|
||||
const migration = require('../../config/migrate-prompt-permissions');
|
||||
migrateToPromptGroupPermissions = migration.migrateToPromptGroupPermissions;
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
// Clean up before each test
|
||||
await Prompt.deleteMany({});
|
||||
await PromptGroup.deleteMany({});
|
||||
await AclEntry.deleteMany({});
|
||||
// Reset the project's promptGroupIds array
|
||||
testProject.promptGroupIds = [];
|
||||
await testProject.save();
|
||||
});
|
||||
|
||||
it('should categorize promptGroups correctly in dry run', async () => {
|
||||
// Create global prompt group (in Global project)
|
||||
const globalPromptGroup = await PromptGroup.create({
|
||||
name: 'Global Group',
|
||||
author: testOwner._id,
|
||||
authorName: testOwner.name,
|
||||
productionId: new ObjectId(),
|
||||
});
|
||||
|
||||
// Create private prompt group (not in any project)
|
||||
await PromptGroup.create({
|
||||
name: 'Private Group',
|
||||
author: testOwner._id,
|
||||
authorName: testOwner.name,
|
||||
productionId: new ObjectId(),
|
||||
});
|
||||
|
||||
// Add global group to project's promptGroupIds array
|
||||
testProject.promptGroupIds = [globalPromptGroup._id];
|
||||
await testProject.save();
|
||||
|
||||
const result = await migrateToPromptGroupPermissions({ dryRun: true });
|
||||
|
||||
expect(result.dryRun).toBe(true);
|
||||
expect(result.summary.total).toBe(2);
|
||||
expect(result.summary.globalViewAccess).toBe(1);
|
||||
expect(result.summary.privateGroups).toBe(1);
|
||||
});
|
||||
|
||||
it('should grant appropriate permissions during migration', async () => {
|
||||
// Create prompt groups
|
||||
const globalPromptGroup = await PromptGroup.create({
|
||||
name: 'Global Group',
|
||||
author: testOwner._id,
|
||||
authorName: testOwner.name,
|
||||
productionId: new ObjectId(),
|
||||
});
|
||||
|
||||
const privatePromptGroup = await PromptGroup.create({
|
||||
name: 'Private Group',
|
||||
author: testOwner._id,
|
||||
authorName: testOwner.name,
|
||||
productionId: new ObjectId(),
|
||||
});
|
||||
|
||||
// Add global group to project's promptGroupIds array
|
||||
testProject.promptGroupIds = [globalPromptGroup._id];
|
||||
await testProject.save();
|
||||
|
||||
const result = await migrateToPromptGroupPermissions({ dryRun: false });
|
||||
|
||||
expect(result.migrated).toBe(2);
|
||||
expect(result.errors).toBe(0);
|
||||
expect(result.ownerGrants).toBe(2);
|
||||
expect(result.publicViewGrants).toBe(1);
|
||||
|
||||
// Check global promptGroup permissions
|
||||
const globalOwnerEntry = await AclEntry.findOne({
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: globalPromptGroup._id,
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testOwner._id,
|
||||
});
|
||||
expect(globalOwnerEntry).toBeTruthy();
|
||||
expect(globalOwnerEntry.permBits).toBe(ownerRole.permBits);
|
||||
|
||||
const globalPublicEntry = await AclEntry.findOne({
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: globalPromptGroup._id,
|
||||
principalType: PrincipalType.PUBLIC,
|
||||
});
|
||||
expect(globalPublicEntry).toBeTruthy();
|
||||
expect(globalPublicEntry.permBits).toBe(viewerRole.permBits);
|
||||
|
||||
// Check private promptGroup permissions
|
||||
const privateOwnerEntry = await AclEntry.findOne({
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: privatePromptGroup._id,
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testOwner._id,
|
||||
});
|
||||
expect(privateOwnerEntry).toBeTruthy();
|
||||
expect(privateOwnerEntry.permBits).toBe(ownerRole.permBits);
|
||||
|
||||
const privatePublicEntry = await AclEntry.findOne({
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: privatePromptGroup._id,
|
||||
principalType: PrincipalType.PUBLIC,
|
||||
});
|
||||
expect(privatePublicEntry).toBeNull();
|
||||
});
|
||||
|
||||
it('should skip promptGroups that already have ACL entries', async () => {
|
||||
// Create prompt groups
|
||||
const promptGroup1 = await PromptGroup.create({
|
||||
name: 'Group 1',
|
||||
author: testOwner._id,
|
||||
authorName: testOwner.name,
|
||||
productionId: new ObjectId(),
|
||||
});
|
||||
|
||||
const promptGroup2 = await PromptGroup.create({
|
||||
name: 'Group 2',
|
||||
author: testOwner._id,
|
||||
authorName: testOwner.name,
|
||||
productionId: new ObjectId(),
|
||||
});
|
||||
|
||||
// Grant permission to one promptGroup manually (simulating it already has ACL)
|
||||
await AclEntry.create({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testOwner._id,
|
||||
principalModel: PrincipalModel.USER,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: promptGroup1._id,
|
||||
permBits: ownerRole.permBits,
|
||||
roleId: ownerRole._id,
|
||||
grantedBy: testOwner._id,
|
||||
grantedAt: new Date(),
|
||||
});
|
||||
|
||||
const result = await migrateToPromptGroupPermissions({ dryRun: false });
|
||||
|
||||
// Should only migrate promptGroup2, skip promptGroup1
|
||||
expect(result.migrated).toBe(1);
|
||||
expect(result.errors).toBe(0);
|
||||
|
||||
// Verify promptGroup2 now has permissions
|
||||
const group2Entry = await AclEntry.findOne({
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: promptGroup2._id,
|
||||
});
|
||||
expect(group2Entry).toBeTruthy();
|
||||
});
|
||||
|
||||
it('should handle promptGroups with prompts correctly', async () => {
|
||||
// Create a promptGroup with some prompts
|
||||
const promptGroup = await PromptGroup.create({
|
||||
name: 'Group with Prompts',
|
||||
author: testOwner._id,
|
||||
authorName: testOwner.name,
|
||||
productionId: new ObjectId(),
|
||||
});
|
||||
|
||||
// Create some prompts in this group
|
||||
await Prompt.create({
|
||||
prompt: 'First prompt',
|
||||
author: testOwner._id,
|
||||
groupId: promptGroup._id,
|
||||
type: 'text',
|
||||
});
|
||||
|
||||
await Prompt.create({
|
||||
prompt: 'Second prompt',
|
||||
author: testOwner._id,
|
||||
groupId: promptGroup._id,
|
||||
type: 'text',
|
||||
});
|
||||
|
||||
const result = await migrateToPromptGroupPermissions({ dryRun: false });
|
||||
|
||||
expect(result.migrated).toBe(1);
|
||||
expect(result.errors).toBe(0);
|
||||
|
||||
// Verify the promptGroup has permissions
|
||||
const groupEntry = await AclEntry.findOne({
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: promptGroup._id,
|
||||
});
|
||||
expect(groupEntry).toBeTruthy();
|
||||
|
||||
// Verify no prompt-level permissions were created
|
||||
const promptEntries = await AclEntry.find({
|
||||
resourceType: 'prompt',
|
||||
});
|
||||
expect(promptEntries).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
@@ -2,7 +2,6 @@ const {
|
||||
CacheKeys,
|
||||
SystemRoles,
|
||||
roleDefaults,
|
||||
PermissionTypes,
|
||||
permissionsSchema,
|
||||
removeNullishValues,
|
||||
} = require('librechat-data-provider');
|
||||
@@ -17,7 +16,7 @@ const { Role } = require('~/db/models');
|
||||
*
|
||||
* @param {string} roleName - The name of the role to find or create.
|
||||
* @param {string|string[]} [fieldsToSelect] - The fields to include or exclude in the returned document.
|
||||
* @returns {Promise<Object>} A plain object representing the role document.
|
||||
* @returns {Promise<IRole>} Role document.
|
||||
*/
|
||||
const getRoleByName = async function (roleName, fieldsToSelect = null) {
|
||||
const cache = getLogStores(CacheKeys.ROLES);
|
||||
@@ -73,8 +72,9 @@ const updateRoleByName = async function (roleName, updates) {
|
||||
* Updates access permissions for a specific role and multiple permission types.
|
||||
* @param {string} roleName - The role to update.
|
||||
* @param {Object.<PermissionTypes, Object.<Permissions, boolean>>} permissionsUpdate - Permissions to update and their values.
|
||||
* @param {IRole} [roleData] - Optional role data to use instead of fetching from the database.
|
||||
*/
|
||||
async function updateAccessPermissions(roleName, permissionsUpdate) {
|
||||
async function updateAccessPermissions(roleName, permissionsUpdate, roleData) {
|
||||
// Filter and clean the permission updates based on our schema definition.
|
||||
const updates = {};
|
||||
for (const [permissionType, permissions] of Object.entries(permissionsUpdate)) {
|
||||
@@ -87,7 +87,7 @@ async function updateAccessPermissions(roleName, permissionsUpdate) {
|
||||
}
|
||||
|
||||
try {
|
||||
const role = await getRoleByName(roleName);
|
||||
const role = roleData ?? (await getRoleByName(roleName));
|
||||
if (!role) {
|
||||
return;
|
||||
}
|
||||
@@ -114,7 +114,6 @@ async function updateAccessPermissions(roleName, permissionsUpdate) {
|
||||
}
|
||||
}
|
||||
|
||||
// Process the current updates
|
||||
for (const [permissionType, permissions] of Object.entries(updates)) {
|
||||
const currentTypePermissions = currentPermissions[permissionType] || {};
|
||||
updatedPermissions[permissionType] = { ...currentTypePermissions };
|
||||
|
||||
@@ -22,6 +22,7 @@ const {
|
||||
} = require('./Message');
|
||||
const { getConvoTitle, getConvo, saveConvo, deleteConvos } = require('./Conversation');
|
||||
const { getPreset, getPresets, savePreset, deletePresets } = require('./Preset');
|
||||
const { File } = require('~/db/models');
|
||||
|
||||
module.exports = {
|
||||
...methods,
|
||||
@@ -51,4 +52,6 @@ module.exports = {
|
||||
getPresets,
|
||||
savePreset,
|
||||
deletePresets,
|
||||
|
||||
Files: File,
|
||||
};
|
||||
|
||||
@@ -49,9 +49,10 @@
|
||||
"@langchain/google-vertexai": "^0.2.13",
|
||||
"@langchain/openai": "^0.5.18",
|
||||
"@langchain/textsplitters": "^0.1.0",
|
||||
"@librechat/agents": "^2.4.75",
|
||||
"@librechat/agents": "^2.4.76",
|
||||
"@librechat/api": "*",
|
||||
"@librechat/data-schemas": "*",
|
||||
"@microsoft/microsoft-graph-client": "^3.0.7",
|
||||
"@modelcontextprotocol/sdk": "^1.17.1",
|
||||
"@node-saml/passport-saml": "^5.1.0",
|
||||
"@waylaidwanderer/fetch-event-source": "^3.0.1",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
const { logger } = require('~/config');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
|
||||
// WeakMap to hold temporary data associated with requests
|
||||
/** WeakMap to hold temporary data associated with requests */
|
||||
const requestDataMap = new WeakMap();
|
||||
|
||||
const FinalizationRegistry = global.FinalizationRegistry || null;
|
||||
@@ -23,7 +23,7 @@ const clientRegistry = FinalizationRegistry
|
||||
} else {
|
||||
logger.debug('[FinalizationRegistry] Cleaning up client');
|
||||
}
|
||||
} catch (e) {
|
||||
} catch {
|
||||
// Ignore errors
|
||||
}
|
||||
})
|
||||
@@ -55,6 +55,9 @@ function disposeClient(client) {
|
||||
if (client.responseMessageId) {
|
||||
client.responseMessageId = null;
|
||||
}
|
||||
if (client.parentMessageId) {
|
||||
client.parentMessageId = null;
|
||||
}
|
||||
if (client.message_file_map) {
|
||||
client.message_file_map = null;
|
||||
}
|
||||
@@ -334,7 +337,7 @@ function disposeClient(client) {
|
||||
}
|
||||
}
|
||||
client.options = null;
|
||||
} catch (e) {
|
||||
} catch {
|
||||
// Ignore errors during disposal
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ const {
|
||||
} = require('~/server/services/AuthService');
|
||||
const { findUser, getUserById, deleteAllUserSessions, findSession } = require('~/models');
|
||||
const { getOpenIdConfig } = require('~/strategies');
|
||||
const { getGraphApiToken } = require('~/server/services/GraphTokenService');
|
||||
|
||||
const registrationController = async (req, res) => {
|
||||
try {
|
||||
@@ -83,7 +84,7 @@ const refreshController = async (req, res) => {
|
||||
}
|
||||
try {
|
||||
const payload = jwt.verify(refreshToken, process.env.JWT_REFRESH_SECRET);
|
||||
const user = await getUserById(payload.id, '-password -__v -totpSecret');
|
||||
const user = await getUserById(payload.id, '-password -__v -totpSecret -backupCodes');
|
||||
if (!user) {
|
||||
return res.status(401).redirect('/login');
|
||||
}
|
||||
@@ -118,9 +119,54 @@ const refreshController = async (req, res) => {
|
||||
}
|
||||
};
|
||||
|
||||
const graphTokenController = async (req, res) => {
|
||||
try {
|
||||
// Validate user is authenticated via Entra ID
|
||||
if (!req.user.openidId || req.user.provider !== 'openid') {
|
||||
return res.status(403).json({
|
||||
message: 'Microsoft Graph access requires Entra ID authentication',
|
||||
});
|
||||
}
|
||||
|
||||
// Check if OpenID token reuse is active (required for on-behalf-of flow)
|
||||
if (!isEnabled(process.env.OPENID_REUSE_TOKENS)) {
|
||||
return res.status(403).json({
|
||||
message: 'SharePoint integration requires OpenID token reuse to be enabled',
|
||||
});
|
||||
}
|
||||
|
||||
// Extract access token from Authorization header
|
||||
const authHeader = req.headers.authorization;
|
||||
if (!authHeader || !authHeader.startsWith('Bearer ')) {
|
||||
return res.status(401).json({
|
||||
message: 'Valid authorization token required',
|
||||
});
|
||||
}
|
||||
|
||||
// Get scopes from query parameters
|
||||
const scopes = req.query.scopes;
|
||||
if (!scopes) {
|
||||
return res.status(400).json({
|
||||
message: 'Graph API scopes are required as query parameter',
|
||||
});
|
||||
}
|
||||
|
||||
const accessToken = authHeader.substring(7); // Remove 'Bearer ' prefix
|
||||
const tokenResponse = await getGraphApiToken(req.user, accessToken, scopes);
|
||||
|
||||
res.json(tokenResponse);
|
||||
} catch (error) {
|
||||
logger.error('[graphTokenController] Failed to obtain Graph API token:', error);
|
||||
res.status(500).json({
|
||||
message: 'Failed to obtain Microsoft Graph token',
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
refreshController,
|
||||
registrationController,
|
||||
resetPasswordController,
|
||||
resetPasswordRequestController,
|
||||
graphTokenController,
|
||||
};
|
||||
|
||||
484
api/server/controllers/PermissionsController.js
Normal file
484
api/server/controllers/PermissionsController.js
Normal file
@@ -0,0 +1,484 @@
|
||||
/**
|
||||
* @import { TUpdateResourcePermissionsRequest, TUpdateResourcePermissionsResponse } from 'librechat-data-provider'
|
||||
*/
|
||||
|
||||
const mongoose = require('mongoose');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { ResourceType, PrincipalType } = require('librechat-data-provider');
|
||||
const {
|
||||
bulkUpdateResourcePermissions,
|
||||
ensureGroupPrincipalExists,
|
||||
getEffectivePermissions,
|
||||
ensurePrincipalExists,
|
||||
getAvailableRoles,
|
||||
} = require('~/server/services/PermissionService');
|
||||
const { AclEntry } = require('~/db/models');
|
||||
const {
|
||||
searchPrincipals: searchLocalPrincipals,
|
||||
sortPrincipalsByRelevance,
|
||||
calculateRelevanceScore,
|
||||
} = require('~/models');
|
||||
const {
|
||||
entraIdPrincipalFeatureEnabled,
|
||||
searchEntraIdPrincipals,
|
||||
} = require('~/server/services/GraphApiService');
|
||||
|
||||
/**
|
||||
* Generic controller for resource permission endpoints
|
||||
* Delegates validation and logic to PermissionService
|
||||
*/
|
||||
|
||||
/**
|
||||
* Validates that the resourceType is one of the supported enum values
|
||||
* @param {string} resourceType - The resource type to validate
|
||||
* @throws {Error} If resourceType is not valid
|
||||
*/
|
||||
const validateResourceType = (resourceType) => {
|
||||
const validTypes = Object.values(ResourceType);
|
||||
if (!validTypes.includes(resourceType)) {
|
||||
throw new Error(`Invalid resourceType: ${resourceType}. Valid types: ${validTypes.join(', ')}`);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Bulk update permissions for a resource (grant, update, remove)
|
||||
* @route PUT /api/{resourceType}/{resourceId}/permissions
|
||||
* @param {Object} req - Express request object
|
||||
* @param {Object} req.params - Route parameters
|
||||
* @param {string} req.params.resourceType - Resource type (e.g., 'agent')
|
||||
* @param {string} req.params.resourceId - Resource ID
|
||||
* @param {TUpdateResourcePermissionsRequest} req.body - Request body
|
||||
* @param {Object} res - Express response object
|
||||
* @returns {Promise<TUpdateResourcePermissionsResponse>} Updated permissions response
|
||||
*/
|
||||
const updateResourcePermissions = async (req, res) => {
|
||||
try {
|
||||
const { resourceType, resourceId } = req.params;
|
||||
validateResourceType(resourceType);
|
||||
|
||||
/** @type {TUpdateResourcePermissionsRequest} */
|
||||
const { updated, removed, public: isPublic, publicAccessRoleId } = req.body;
|
||||
const { id: userId } = req.user;
|
||||
|
||||
// Prepare principals for the service call
|
||||
const updatedPrincipals = [];
|
||||
const revokedPrincipals = [];
|
||||
|
||||
// Add updated principals
|
||||
if (updated && Array.isArray(updated)) {
|
||||
updatedPrincipals.push(...updated);
|
||||
}
|
||||
|
||||
// Add public permission if enabled
|
||||
if (isPublic && publicAccessRoleId) {
|
||||
updatedPrincipals.push({
|
||||
type: PrincipalType.PUBLIC,
|
||||
id: null,
|
||||
accessRoleId: publicAccessRoleId,
|
||||
});
|
||||
}
|
||||
|
||||
// Prepare authentication context for enhanced group member fetching
|
||||
const useEntraId = entraIdPrincipalFeatureEnabled(req.user);
|
||||
const authHeader = req.headers.authorization;
|
||||
const accessToken =
|
||||
authHeader && authHeader.startsWith('Bearer ') ? authHeader.substring(7) : null;
|
||||
const authContext =
|
||||
useEntraId && accessToken
|
||||
? {
|
||||
accessToken,
|
||||
sub: req.user.openidId,
|
||||
}
|
||||
: null;
|
||||
|
||||
// Ensure updated principals exist in the database before processing permissions
|
||||
const validatedPrincipals = [];
|
||||
for (const principal of updatedPrincipals) {
|
||||
try {
|
||||
let principalId;
|
||||
|
||||
if (principal.type === PrincipalType.PUBLIC) {
|
||||
principalId = null; // Public principals don't need database records
|
||||
} else if (principal.type === PrincipalType.ROLE) {
|
||||
principalId = principal.id; // Role principals use role name as ID
|
||||
} else if (principal.type === PrincipalType.USER) {
|
||||
principalId = await ensurePrincipalExists(principal);
|
||||
} else if (principal.type === PrincipalType.GROUP) {
|
||||
// Pass authContext to enable member fetching for Entra ID groups when available
|
||||
principalId = await ensureGroupPrincipalExists(principal, authContext);
|
||||
} else {
|
||||
logger.error(`Unsupported principal type: ${principal.type}`);
|
||||
continue; // Skip invalid principal types
|
||||
}
|
||||
|
||||
// Update the principal with the validated ID for ACL operations
|
||||
validatedPrincipals.push({
|
||||
...principal,
|
||||
id: principalId,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('Error ensuring principal exists:', {
|
||||
principal: {
|
||||
type: principal.type,
|
||||
id: principal.id,
|
||||
name: principal.name,
|
||||
source: principal.source,
|
||||
},
|
||||
error: error.message,
|
||||
});
|
||||
// Continue with other principals instead of failing the entire operation
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Add removed principals
|
||||
if (removed && Array.isArray(removed)) {
|
||||
revokedPrincipals.push(...removed);
|
||||
}
|
||||
|
||||
// If public is disabled, add public to revoked list
|
||||
if (!isPublic) {
|
||||
revokedPrincipals.push({
|
||||
type: PrincipalType.PUBLIC,
|
||||
id: null,
|
||||
});
|
||||
}
|
||||
|
||||
const results = await bulkUpdateResourcePermissions({
|
||||
resourceType,
|
||||
resourceId,
|
||||
updatedPrincipals: validatedPrincipals,
|
||||
revokedPrincipals,
|
||||
grantedBy: userId,
|
||||
});
|
||||
|
||||
/** @type {TUpdateResourcePermissionsResponse} */
|
||||
const response = {
|
||||
message: 'Permissions updated successfully',
|
||||
results: {
|
||||
principals: results.granted,
|
||||
public: isPublic || false,
|
||||
publicAccessRoleId: isPublic ? publicAccessRoleId : undefined,
|
||||
},
|
||||
};
|
||||
|
||||
res.status(200).json(response);
|
||||
} catch (error) {
|
||||
logger.error('Error updating resource permissions:', error);
|
||||
res.status(400).json({
|
||||
error: 'Failed to update permissions',
|
||||
details: error.message,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Get principals with their permission roles for a resource (UI-friendly format)
|
||||
* Uses efficient aggregation pipeline to join User/Group data in single query
|
||||
* @route GET /api/permissions/{resourceType}/{resourceId}
|
||||
*/
|
||||
const getResourcePermissions = async (req, res) => {
|
||||
try {
|
||||
const { resourceType, resourceId } = req.params;
|
||||
validateResourceType(resourceType);
|
||||
|
||||
// Use aggregation pipeline for efficient single-query data retrieval
|
||||
const results = await AclEntry.aggregate([
|
||||
// Match ACL entries for this resource
|
||||
{
|
||||
$match: {
|
||||
resourceType,
|
||||
resourceId: mongoose.Types.ObjectId.isValid(resourceId)
|
||||
? mongoose.Types.ObjectId.createFromHexString(resourceId)
|
||||
: resourceId,
|
||||
},
|
||||
},
|
||||
// Lookup AccessRole information
|
||||
{
|
||||
$lookup: {
|
||||
from: 'accessroles',
|
||||
localField: 'roleId',
|
||||
foreignField: '_id',
|
||||
as: 'role',
|
||||
},
|
||||
},
|
||||
// Lookup User information (for user principals)
|
||||
{
|
||||
$lookup: {
|
||||
from: 'users',
|
||||
localField: 'principalId',
|
||||
foreignField: '_id',
|
||||
as: 'userInfo',
|
||||
},
|
||||
},
|
||||
// Lookup Group information (for group principals)
|
||||
{
|
||||
$lookup: {
|
||||
from: 'groups',
|
||||
localField: 'principalId',
|
||||
foreignField: '_id',
|
||||
as: 'groupInfo',
|
||||
},
|
||||
},
|
||||
// Project final structure
|
||||
{
|
||||
$project: {
|
||||
principalType: 1,
|
||||
principalId: 1,
|
||||
accessRoleId: { $arrayElemAt: ['$role.accessRoleId', 0] },
|
||||
userInfo: { $arrayElemAt: ['$userInfo', 0] },
|
||||
groupInfo: { $arrayElemAt: ['$groupInfo', 0] },
|
||||
},
|
||||
},
|
||||
]);
|
||||
|
||||
const principals = [];
|
||||
let publicPermission = null;
|
||||
|
||||
// Process aggregation results
|
||||
for (const result of results) {
|
||||
if (result.principalType === PrincipalType.PUBLIC) {
|
||||
publicPermission = {
|
||||
public: true,
|
||||
publicAccessRoleId: result.accessRoleId,
|
||||
};
|
||||
} else if (result.principalType === PrincipalType.USER && result.userInfo) {
|
||||
principals.push({
|
||||
type: PrincipalType.USER,
|
||||
id: result.userInfo._id.toString(),
|
||||
name: result.userInfo.name || result.userInfo.username,
|
||||
email: result.userInfo.email,
|
||||
avatar: result.userInfo.avatar,
|
||||
source: !result.userInfo._id ? 'entra' : 'local',
|
||||
idOnTheSource: result.userInfo.idOnTheSource || result.userInfo._id.toString(),
|
||||
accessRoleId: result.accessRoleId,
|
||||
});
|
||||
} else if (result.principalType === PrincipalType.GROUP && result.groupInfo) {
|
||||
principals.push({
|
||||
type: PrincipalType.GROUP,
|
||||
id: result.groupInfo._id.toString(),
|
||||
name: result.groupInfo.name,
|
||||
email: result.groupInfo.email,
|
||||
description: result.groupInfo.description,
|
||||
avatar: result.groupInfo.avatar,
|
||||
source: result.groupInfo.source || 'local',
|
||||
idOnTheSource: result.groupInfo.idOnTheSource || result.groupInfo._id.toString(),
|
||||
accessRoleId: result.accessRoleId,
|
||||
});
|
||||
} else if (result.principalType === PrincipalType.ROLE) {
|
||||
principals.push({
|
||||
type: PrincipalType.ROLE,
|
||||
/** Role name as ID */
|
||||
id: result.principalId,
|
||||
/** Display the role name */
|
||||
name: result.principalId,
|
||||
description: `System role: ${result.principalId}`,
|
||||
accessRoleId: result.accessRoleId,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Return response in format expected by frontend
|
||||
const response = {
|
||||
resourceType,
|
||||
resourceId,
|
||||
principals,
|
||||
public: publicPermission?.public || false,
|
||||
...(publicPermission?.publicAccessRoleId && {
|
||||
publicAccessRoleId: publicPermission.publicAccessRoleId,
|
||||
}),
|
||||
};
|
||||
|
||||
res.status(200).json(response);
|
||||
} catch (error) {
|
||||
logger.error('Error getting resource permissions principals:', error);
|
||||
res.status(500).json({
|
||||
error: 'Failed to get permissions principals',
|
||||
details: error.message,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Get available roles for a resource type
|
||||
* @route GET /api/{resourceType}/roles
|
||||
*/
|
||||
const getResourceRoles = async (req, res) => {
|
||||
try {
|
||||
const { resourceType } = req.params;
|
||||
validateResourceType(resourceType);
|
||||
|
||||
const roles = await getAvailableRoles({ resourceType });
|
||||
|
||||
res.status(200).json(
|
||||
roles.map((role) => ({
|
||||
accessRoleId: role.accessRoleId,
|
||||
name: role.name,
|
||||
description: role.description,
|
||||
permBits: role.permBits,
|
||||
})),
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error('Error getting resource roles:', error);
|
||||
res.status(500).json({
|
||||
error: 'Failed to get roles',
|
||||
details: error.message,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Get user's effective permission bitmask for a resource
|
||||
* @route GET /api/{resourceType}/{resourceId}/effective
|
||||
*/
|
||||
const getUserEffectivePermissions = async (req, res) => {
|
||||
try {
|
||||
const { resourceType, resourceId } = req.params;
|
||||
validateResourceType(resourceType);
|
||||
|
||||
const { id: userId } = req.user;
|
||||
|
||||
const permissionBits = await getEffectivePermissions({
|
||||
userId,
|
||||
role: req.user.role,
|
||||
resourceType,
|
||||
resourceId,
|
||||
});
|
||||
|
||||
res.status(200).json({
|
||||
permissionBits,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('Error getting user effective permissions:', error);
|
||||
res.status(500).json({
|
||||
error: 'Failed to get effective permissions',
|
||||
details: error.message,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Search for users and groups to grant permissions
|
||||
* Supports hybrid local database + Entra ID search when configured
|
||||
* @route GET /api/permissions/search-principals
|
||||
*/
|
||||
const searchPrincipals = async (req, res) => {
|
||||
try {
|
||||
const { q: query, limit = 20, types } = req.query;
|
||||
|
||||
if (!query || query.trim().length === 0) {
|
||||
return res.status(400).json({
|
||||
error: 'Query parameter "q" is required and must not be empty',
|
||||
});
|
||||
}
|
||||
|
||||
if (query.trim().length < 2) {
|
||||
return res.status(400).json({
|
||||
error: 'Query must be at least 2 characters long',
|
||||
});
|
||||
}
|
||||
|
||||
const searchLimit = Math.min(Math.max(1, parseInt(limit) || 10), 50);
|
||||
|
||||
let typeFilters = null;
|
||||
if (types) {
|
||||
const typesArray = Array.isArray(types) ? types : types.split(',');
|
||||
const validTypes = typesArray.filter((t) =>
|
||||
[PrincipalType.USER, PrincipalType.GROUP, PrincipalType.ROLE].includes(t),
|
||||
);
|
||||
typeFilters = validTypes.length > 0 ? validTypes : null;
|
||||
}
|
||||
|
||||
const localResults = await searchLocalPrincipals(query.trim(), searchLimit, typeFilters);
|
||||
let allPrincipals = [...localResults];
|
||||
|
||||
const useEntraId = entraIdPrincipalFeatureEnabled(req.user);
|
||||
|
||||
if (useEntraId && localResults.length < searchLimit) {
|
||||
try {
|
||||
let graphType = 'all';
|
||||
if (typeFilters && typeFilters.length === 1) {
|
||||
const graphTypeMap = {
|
||||
[PrincipalType.USER]: 'users',
|
||||
[PrincipalType.GROUP]: 'groups',
|
||||
};
|
||||
const mappedType = graphTypeMap[typeFilters[0]];
|
||||
if (mappedType) {
|
||||
graphType = mappedType;
|
||||
}
|
||||
}
|
||||
|
||||
const authHeader = req.headers.authorization;
|
||||
const accessToken =
|
||||
authHeader && authHeader.startsWith('Bearer ') ? authHeader.substring(7) : null;
|
||||
|
||||
if (accessToken) {
|
||||
const graphResults = await searchEntraIdPrincipals(
|
||||
accessToken,
|
||||
req.user.openidId,
|
||||
query.trim(),
|
||||
graphType,
|
||||
searchLimit - localResults.length,
|
||||
);
|
||||
|
||||
const localEmails = new Set(
|
||||
localResults.map((p) => p.email?.toLowerCase()).filter(Boolean),
|
||||
);
|
||||
const localGroupSourceIds = new Set(
|
||||
localResults.map((p) => p.idOnTheSource).filter(Boolean),
|
||||
);
|
||||
|
||||
for (const principal of graphResults) {
|
||||
const isDuplicateByEmail =
|
||||
principal.email && localEmails.has(principal.email.toLowerCase());
|
||||
const isDuplicateBySourceId =
|
||||
principal.idOnTheSource && localGroupSourceIds.has(principal.idOnTheSource);
|
||||
|
||||
if (!isDuplicateByEmail && !isDuplicateBySourceId) {
|
||||
allPrincipals.push(principal);
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (graphError) {
|
||||
logger.warn('Graph API search failed, falling back to local results:', graphError.message);
|
||||
}
|
||||
}
|
||||
const scoredResults = allPrincipals.map((item) => ({
|
||||
...item,
|
||||
_searchScore: calculateRelevanceScore(item, query.trim()),
|
||||
}));
|
||||
|
||||
const finalResults = sortPrincipalsByRelevance(scoredResults)
|
||||
.slice(0, searchLimit)
|
||||
.map((result) => {
|
||||
const { _searchScore, ...resultWithoutScore } = result;
|
||||
return resultWithoutScore;
|
||||
});
|
||||
|
||||
res.status(200).json({
|
||||
query: query.trim(),
|
||||
limit: searchLimit,
|
||||
types: typeFilters,
|
||||
results: finalResults,
|
||||
count: finalResults.length,
|
||||
sources: {
|
||||
local: finalResults.filter((r) => r.source === 'local').length,
|
||||
entra: finalResults.filter((r) => r.source === 'entra').length,
|
||||
},
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('Error searching principals:', error);
|
||||
res.status(500).json({
|
||||
error: 'Failed to search principals',
|
||||
details: error.message,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
updateResourcePermissions,
|
||||
getResourcePermissions,
|
||||
getResourceRoles,
|
||||
getUserEffectivePermissions,
|
||||
searchPrincipals,
|
||||
};
|
||||
@@ -4,11 +4,18 @@ const {
|
||||
getToolkitKey,
|
||||
checkPluginAuth,
|
||||
filterUniquePlugins,
|
||||
convertMCPToolToPlugin,
|
||||
convertMCPToolsToPlugins,
|
||||
} = require('@librechat/api');
|
||||
const { getCustomConfig, getCachedTools } = require('~/server/services/Config');
|
||||
const {
|
||||
getCachedTools,
|
||||
setCachedTools,
|
||||
mergeUserTools,
|
||||
getCustomConfig,
|
||||
} = require('~/server/services/Config');
|
||||
const { loadAndFormatTools } = require('~/server/services/ToolService');
|
||||
const { availableTools, toolkits } = require('~/app/clients/tools');
|
||||
const { getMCPManager, getFlowStateManager } = require('~/config');
|
||||
const { getMCPManager } = require('~/config');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
const getAvailablePluginsController = async (req, res) => {
|
||||
@@ -22,6 +29,7 @@ const getAvailablePluginsController = async (req, res) => {
|
||||
|
||||
/** @type {{ filteredTools: string[], includedTools: string[] }} */
|
||||
const { filteredTools = [], includedTools = [] } = req.app.locals;
|
||||
/** @type {import('@librechat/api').LCManifestTool[]} */
|
||||
const pluginManifest = availableTools;
|
||||
|
||||
const uniquePlugins = filterUniquePlugins(pluginManifest);
|
||||
@@ -47,45 +55,6 @@ const getAvailablePluginsController = async (req, res) => {
|
||||
}
|
||||
};
|
||||
|
||||
function createServerToolsCallback() {
|
||||
/**
|
||||
* @param {string} serverName
|
||||
* @param {TPlugin[] | null} serverTools
|
||||
*/
|
||||
return async function (serverName, serverTools) {
|
||||
try {
|
||||
const mcpToolsCache = getLogStores(CacheKeys.MCP_TOOLS);
|
||||
if (!serverName || !mcpToolsCache) {
|
||||
return;
|
||||
}
|
||||
await mcpToolsCache.set(serverName, serverTools);
|
||||
logger.debug(`MCP tools for ${serverName} added to cache.`);
|
||||
} catch (error) {
|
||||
logger.error('Error retrieving MCP tools from cache:', error);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
function createGetServerTools() {
|
||||
/**
|
||||
* Retrieves cached server tools
|
||||
* @param {string} serverName
|
||||
* @returns {Promise<TPlugin[] | null>}
|
||||
*/
|
||||
return async function (serverName) {
|
||||
try {
|
||||
const mcpToolsCache = getLogStores(CacheKeys.MCP_TOOLS);
|
||||
if (!mcpToolsCache) {
|
||||
return null;
|
||||
}
|
||||
return await mcpToolsCache.get(serverName);
|
||||
} catch (error) {
|
||||
logger.error('Error retrieving MCP tools from cache:', error);
|
||||
return null;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves and returns a list of available tools, either from a cache or by reading a plugin manifest file.
|
||||
*
|
||||
@@ -101,11 +70,18 @@ function createGetServerTools() {
|
||||
const getAvailableTools = async (req, res) => {
|
||||
try {
|
||||
const userId = req.user?.id;
|
||||
if (!userId) {
|
||||
logger.warn('[getAvailableTools] User ID not found in request');
|
||||
return res.status(401).json({ message: 'Unauthorized' });
|
||||
}
|
||||
const customConfig = await getCustomConfig();
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
const cachedToolsArray = await cache.get(CacheKeys.TOOLS);
|
||||
const cachedUserTools = await getCachedTools({ userId });
|
||||
const userPlugins = convertMCPToolsToPlugins({ functionTools: cachedUserTools, customConfig });
|
||||
const userPlugins =
|
||||
cachedUserTools != null
|
||||
? convertMCPToolsToPlugins({ functionTools: cachedUserTools, customConfig })
|
||||
: undefined;
|
||||
|
||||
if (cachedToolsArray != null && userPlugins != null) {
|
||||
const dedupedTools = filterUniquePlugins([...userPlugins, ...cachedToolsArray]);
|
||||
@@ -113,25 +89,51 @@ const getAvailableTools = async (req, res) => {
|
||||
return;
|
||||
}
|
||||
|
||||
// If not in cache, build from manifest
|
||||
let pluginManifest = availableTools;
|
||||
if (customConfig?.mcpServers != null) {
|
||||
const mcpManager = getMCPManager();
|
||||
const flowsCache = getLogStores(CacheKeys.FLOWS);
|
||||
const flowManager = flowsCache ? getFlowStateManager(flowsCache) : null;
|
||||
const serverToolsCallback = createServerToolsCallback();
|
||||
const getServerTools = createGetServerTools();
|
||||
const mcpTools = await mcpManager.loadManifestTools({
|
||||
flowManager,
|
||||
serverToolsCallback,
|
||||
getServerTools,
|
||||
/** @type {Record<string, FunctionTool> | null} Get tool definitions to filter which tools are actually available */
|
||||
let toolDefinitions = await getCachedTools({ includeGlobal: true });
|
||||
let prelimCachedTools;
|
||||
|
||||
// TODO: this is a temp fix until app config is refactored
|
||||
if (!toolDefinitions) {
|
||||
toolDefinitions = loadAndFormatTools({
|
||||
adminFilter: req.app.locals?.filteredTools,
|
||||
adminIncluded: req.app.locals?.includedTools,
|
||||
directory: req.app.locals?.paths.structuredTools,
|
||||
});
|
||||
pluginManifest = [...mcpTools, ...pluginManifest];
|
||||
prelimCachedTools = toolDefinitions;
|
||||
}
|
||||
|
||||
/** @type {TPlugin[]} */
|
||||
const uniquePlugins = filterUniquePlugins(pluginManifest);
|
||||
/** @type {import('@librechat/api').LCManifestTool[]} */
|
||||
let pluginManifest = availableTools;
|
||||
if (customConfig?.mcpServers != null) {
|
||||
try {
|
||||
const mcpManager = getMCPManager();
|
||||
const mcpTools = await mcpManager.getAllToolFunctions(userId);
|
||||
prelimCachedTools = prelimCachedTools ?? {};
|
||||
for (const [toolKey, toolData] of Object.entries(mcpTools)) {
|
||||
const plugin = convertMCPToolToPlugin({
|
||||
toolKey,
|
||||
toolData,
|
||||
customConfig,
|
||||
});
|
||||
if (plugin) {
|
||||
pluginManifest.push(plugin);
|
||||
}
|
||||
prelimCachedTools[toolKey] = toolData;
|
||||
}
|
||||
await mergeUserTools({ userId, cachedUserTools, userTools: prelimCachedTools });
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
'[getAvailableTools] Error loading MCP Tools, servers may still be initializing:',
|
||||
error,
|
||||
);
|
||||
}
|
||||
} else if (prelimCachedTools != null) {
|
||||
await setCachedTools(prelimCachedTools, { isGlobal: true });
|
||||
}
|
||||
|
||||
/** @type {TPlugin[]} Deduplicate and authenticate plugins */
|
||||
const uniquePlugins = filterUniquePlugins(pluginManifest);
|
||||
const authenticatedPlugins = uniquePlugins.map((plugin) => {
|
||||
if (checkPluginAuth(plugin)) {
|
||||
return { ...plugin, authenticated: true };
|
||||
@@ -140,8 +142,7 @@ const getAvailableTools = async (req, res) => {
|
||||
}
|
||||
});
|
||||
|
||||
const toolDefinitions = (await getCachedTools({ includeGlobal: true })) || {};
|
||||
|
||||
/** Filter plugins based on availability and add MCP-specific auth config */
|
||||
const toolsOutput = [];
|
||||
for (const plugin of authenticatedPlugins) {
|
||||
const isToolDefined = toolDefinitions[plugin.pluginKey] !== undefined;
|
||||
@@ -157,41 +158,36 @@ const getAvailableTools = async (req, res) => {
|
||||
|
||||
const toolToAdd = { ...plugin };
|
||||
|
||||
if (!plugin.pluginKey.includes(Constants.mcp_delimiter)) {
|
||||
toolsOutput.push(toolToAdd);
|
||||
continue;
|
||||
}
|
||||
if (plugin.pluginKey.includes(Constants.mcp_delimiter)) {
|
||||
const parts = plugin.pluginKey.split(Constants.mcp_delimiter);
|
||||
const serverName = parts[parts.length - 1];
|
||||
const serverConfig = customConfig?.mcpServers?.[serverName];
|
||||
|
||||
const parts = plugin.pluginKey.split(Constants.mcp_delimiter);
|
||||
const serverName = parts[parts.length - 1];
|
||||
const serverConfig = customConfig?.mcpServers?.[serverName];
|
||||
|
||||
if (!serverConfig?.customUserVars) {
|
||||
toolsOutput.push(toolToAdd);
|
||||
continue;
|
||||
}
|
||||
|
||||
const customVarKeys = Object.keys(serverConfig.customUserVars);
|
||||
|
||||
if (customVarKeys.length === 0) {
|
||||
toolToAdd.authConfig = [];
|
||||
toolToAdd.authenticated = true;
|
||||
} else {
|
||||
toolToAdd.authConfig = Object.entries(serverConfig.customUserVars).map(([key, value]) => ({
|
||||
authField: key,
|
||||
label: value.title || key,
|
||||
description: value.description || '',
|
||||
}));
|
||||
toolToAdd.authenticated = false;
|
||||
if (serverConfig?.customUserVars) {
|
||||
const customVarKeys = Object.keys(serverConfig.customUserVars);
|
||||
if (customVarKeys.length === 0) {
|
||||
toolToAdd.authConfig = [];
|
||||
toolToAdd.authenticated = true;
|
||||
} else {
|
||||
toolToAdd.authConfig = Object.entries(serverConfig.customUserVars).map(
|
||||
([key, value]) => ({
|
||||
authField: key,
|
||||
label: value.title || key,
|
||||
description: value.description || '',
|
||||
}),
|
||||
);
|
||||
toolToAdd.authenticated = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
toolsOutput.push(toolToAdd);
|
||||
}
|
||||
|
||||
const finalTools = filterUniquePlugins(toolsOutput);
|
||||
await cache.set(CacheKeys.TOOLS, finalTools);
|
||||
|
||||
const dedupedTools = filterUniquePlugins([...userPlugins, ...finalTools]);
|
||||
|
||||
const dedupedTools = filterUniquePlugins([...(userPlugins ?? []), ...finalTools]);
|
||||
res.status(200).json(dedupedTools);
|
||||
} catch (error) {
|
||||
logger.error('[getAvailableTools]', error);
|
||||
|
||||
@@ -13,15 +13,18 @@ jest.mock('@librechat/data-schemas', () => ({
|
||||
jest.mock('~/server/services/Config', () => ({
|
||||
getCustomConfig: jest.fn(),
|
||||
getCachedTools: jest.fn(),
|
||||
setCachedTools: jest.fn(),
|
||||
mergeUserTools: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/ToolService', () => ({
|
||||
getToolkitKey: jest.fn(),
|
||||
loadAndFormatTools: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/config', () => ({
|
||||
getMCPManager: jest.fn(() => ({
|
||||
loadManifestTools: jest.fn().mockResolvedValue([]),
|
||||
loadAllManifestTools: jest.fn().mockResolvedValue([]),
|
||||
})),
|
||||
getFlowStateManager: jest.fn(),
|
||||
}));
|
||||
@@ -35,31 +38,31 @@ jest.mock('~/cache', () => ({
|
||||
getLogStores: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('@librechat/api', () => ({
|
||||
getToolkitKey: jest.fn(),
|
||||
checkPluginAuth: jest.fn(),
|
||||
filterUniquePlugins: jest.fn(),
|
||||
convertMCPToolsToPlugins: jest.fn(),
|
||||
}));
|
||||
|
||||
// Import the actual module with the function we want to test
|
||||
const { getAvailableTools, getAvailablePluginsController } = require('./PluginController');
|
||||
const {
|
||||
filterUniquePlugins,
|
||||
checkPluginAuth,
|
||||
convertMCPToolsToPlugins,
|
||||
getToolkitKey,
|
||||
} = require('@librechat/api');
|
||||
const { loadAndFormatTools } = require('~/server/services/ToolService');
|
||||
|
||||
describe('PluginController', () => {
|
||||
let mockReq, mockRes, mockCache;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
mockReq = { user: { id: 'test-user-id' } };
|
||||
mockReq = {
|
||||
user: { id: 'test-user-id' },
|
||||
app: {
|
||||
locals: {
|
||||
paths: { structuredTools: '/mock/path' },
|
||||
filteredTools: null,
|
||||
includedTools: null,
|
||||
},
|
||||
},
|
||||
};
|
||||
mockRes = { status: jest.fn().mockReturnThis(), json: jest.fn() };
|
||||
mockCache = { get: jest.fn(), set: jest.fn() };
|
||||
getLogStores.mockReturnValue(mockCache);
|
||||
|
||||
// Clear availableTools and toolkits arrays before each test
|
||||
require('~/app/clients/tools').availableTools.length = 0;
|
||||
require('~/app/clients/tools').toolkits.length = 0;
|
||||
});
|
||||
|
||||
describe('getAvailablePluginsController', () => {
|
||||
@@ -68,38 +71,39 @@ describe('PluginController', () => {
|
||||
});
|
||||
|
||||
it('should use filterUniquePlugins to remove duplicate plugins', async () => {
|
||||
// Add plugins with duplicates to availableTools
|
||||
const mockPlugins = [
|
||||
{ name: 'Plugin1', pluginKey: 'key1', description: 'First' },
|
||||
{ name: 'Plugin1', pluginKey: 'key1', description: 'First duplicate' },
|
||||
{ name: 'Plugin2', pluginKey: 'key2', description: 'Second' },
|
||||
];
|
||||
|
||||
require('~/app/clients/tools').availableTools.push(...mockPlugins);
|
||||
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
filterUniquePlugins.mockReturnValue(mockPlugins);
|
||||
checkPluginAuth.mockReturnValue(true);
|
||||
|
||||
await getAvailablePluginsController(mockReq, mockRes);
|
||||
|
||||
expect(filterUniquePlugins).toHaveBeenCalled();
|
||||
expect(mockRes.status).toHaveBeenCalledWith(200);
|
||||
// The response includes authenticated: true for each plugin when checkPluginAuth returns true
|
||||
expect(mockRes.json).toHaveBeenCalledWith([
|
||||
{ name: 'Plugin1', pluginKey: 'key1', description: 'First', authenticated: true },
|
||||
{ name: 'Plugin2', pluginKey: 'key2', description: 'Second', authenticated: true },
|
||||
]);
|
||||
const responseData = mockRes.json.mock.calls[0][0];
|
||||
expect(responseData).toHaveLength(2);
|
||||
expect(responseData[0].pluginKey).toBe('key1');
|
||||
expect(responseData[1].pluginKey).toBe('key2');
|
||||
});
|
||||
|
||||
it('should use checkPluginAuth to verify plugin authentication', async () => {
|
||||
// checkPluginAuth returns false for plugins without authConfig
|
||||
// so authenticated property won't be added
|
||||
const mockPlugin = { name: 'Plugin1', pluginKey: 'key1', description: 'First' };
|
||||
|
||||
require('~/app/clients/tools').availableTools.push(mockPlugin);
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
filterUniquePlugins.mockReturnValue([mockPlugin]);
|
||||
checkPluginAuth.mockReturnValueOnce(true);
|
||||
|
||||
await getAvailablePluginsController(mockReq, mockRes);
|
||||
|
||||
expect(checkPluginAuth).toHaveBeenCalledWith(mockPlugin);
|
||||
const responseData = mockRes.json.mock.calls[0][0];
|
||||
expect(responseData[0].authenticated).toBe(true);
|
||||
// checkPluginAuth returns false, so authenticated property is not added
|
||||
expect(responseData[0].authenticated).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should return cached plugins when available', async () => {
|
||||
@@ -111,8 +115,7 @@ describe('PluginController', () => {
|
||||
|
||||
await getAvailablePluginsController(mockReq, mockRes);
|
||||
|
||||
expect(filterUniquePlugins).not.toHaveBeenCalled();
|
||||
expect(checkPluginAuth).not.toHaveBeenCalled();
|
||||
// When cache is hit, we return immediately without processing
|
||||
expect(mockRes.json).toHaveBeenCalledWith(cachedPlugins);
|
||||
});
|
||||
|
||||
@@ -122,10 +125,9 @@ describe('PluginController', () => {
|
||||
{ name: 'Plugin2', pluginKey: 'key2', description: 'Second' },
|
||||
];
|
||||
|
||||
require('~/app/clients/tools').availableTools.push(...mockPlugins);
|
||||
mockReq.app.locals.includedTools = ['key1'];
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
filterUniquePlugins.mockReturnValue(mockPlugins);
|
||||
checkPluginAuth.mockReturnValue(false);
|
||||
|
||||
await getAvailablePluginsController(mockReq, mockRes);
|
||||
|
||||
@@ -139,70 +141,102 @@ describe('PluginController', () => {
|
||||
it('should use convertMCPToolsToPlugins for user-specific MCP tools', async () => {
|
||||
const mockUserTools = {
|
||||
[`tool1${Constants.mcp_delimiter}server1`]: {
|
||||
function: { name: 'tool1', description: 'Tool 1' },
|
||||
type: 'function',
|
||||
function: {
|
||||
name: `tool1${Constants.mcp_delimiter}server1`,
|
||||
description: 'Tool 1',
|
||||
parameters: { type: 'object', properties: {} },
|
||||
},
|
||||
},
|
||||
};
|
||||
const mockConvertedPlugins = [
|
||||
{
|
||||
name: 'tool1',
|
||||
pluginKey: `tool1${Constants.mcp_delimiter}server1`,
|
||||
description: 'Tool 1',
|
||||
},
|
||||
];
|
||||
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCachedTools.mockResolvedValueOnce(mockUserTools);
|
||||
convertMCPToolsToPlugins.mockReturnValue(mockConvertedPlugins);
|
||||
filterUniquePlugins.mockImplementation((plugins) => plugins);
|
||||
getCustomConfig.mockResolvedValue(null);
|
||||
|
||||
// Mock second call to return tool definitions
|
||||
getCachedTools.mockResolvedValueOnce(mockUserTools);
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
expect(convertMCPToolsToPlugins).toHaveBeenCalledWith({
|
||||
functionTools: mockUserTools,
|
||||
customConfig: null,
|
||||
});
|
||||
const responseData = mockRes.json.mock.calls[0][0];
|
||||
// convertMCPToolsToPlugins should have converted the tool
|
||||
expect(responseData.length).toBeGreaterThan(0);
|
||||
const convertedTool = responseData.find(
|
||||
(tool) => tool.pluginKey === `tool1${Constants.mcp_delimiter}server1`,
|
||||
);
|
||||
expect(convertedTool).toBeDefined();
|
||||
expect(convertedTool.name).toBe('tool1');
|
||||
});
|
||||
|
||||
it('should use filterUniquePlugins to deduplicate combined tools', async () => {
|
||||
const mockUserPlugins = [
|
||||
{ name: 'UserTool', pluginKey: 'user-tool', description: 'User tool' },
|
||||
];
|
||||
const mockManifestPlugins = [
|
||||
const mockUserTools = {
|
||||
'user-tool': {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'user-tool',
|
||||
description: 'User tool',
|
||||
parameters: { type: 'object', properties: {} },
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const mockCachedPlugins = [
|
||||
{ name: 'user-tool', pluginKey: 'user-tool', description: 'Duplicate user tool' },
|
||||
{ name: 'ManifestTool', pluginKey: 'manifest-tool', description: 'Manifest tool' },
|
||||
];
|
||||
|
||||
mockCache.get.mockResolvedValue(mockManifestPlugins);
|
||||
getCachedTools.mockResolvedValueOnce({});
|
||||
convertMCPToolsToPlugins.mockReturnValue(mockUserPlugins);
|
||||
filterUniquePlugins.mockReturnValue([...mockUserPlugins, ...mockManifestPlugins]);
|
||||
mockCache.get.mockResolvedValue(mockCachedPlugins);
|
||||
getCachedTools.mockResolvedValueOnce(mockUserTools);
|
||||
getCustomConfig.mockResolvedValue(null);
|
||||
|
||||
// Mock second call to return tool definitions
|
||||
getCachedTools.mockResolvedValueOnce(mockUserTools);
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
// Should be called to deduplicate the combined array
|
||||
expect(filterUniquePlugins).toHaveBeenLastCalledWith([
|
||||
...mockUserPlugins,
|
||||
...mockManifestPlugins,
|
||||
]);
|
||||
const responseData = mockRes.json.mock.calls[0][0];
|
||||
// Should have deduplicated tools with same pluginKey
|
||||
const userToolCount = responseData.filter((tool) => tool.pluginKey === 'user-tool').length;
|
||||
expect(userToolCount).toBe(1);
|
||||
});
|
||||
|
||||
it('should use checkPluginAuth to verify authentication status', async () => {
|
||||
const mockPlugin = { name: 'Tool1', pluginKey: 'tool1', description: 'Tool 1' };
|
||||
// Add a plugin to availableTools that will be checked
|
||||
const mockPlugin = {
|
||||
name: 'Tool1',
|
||||
pluginKey: 'tool1',
|
||||
description: 'Tool 1',
|
||||
// No authConfig means checkPluginAuth returns false
|
||||
};
|
||||
|
||||
require('~/app/clients/tools').availableTools.push(mockPlugin);
|
||||
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCachedTools.mockResolvedValue({});
|
||||
convertMCPToolsToPlugins.mockReturnValue([]);
|
||||
filterUniquePlugins.mockReturnValue([mockPlugin]);
|
||||
checkPluginAuth.mockReturnValue(true);
|
||||
getCachedTools.mockResolvedValue(null);
|
||||
getCustomConfig.mockResolvedValue(null);
|
||||
|
||||
// Mock getCachedTools second call to return tool definitions
|
||||
getCachedTools.mockResolvedValueOnce({}).mockResolvedValueOnce({ tool1: true });
|
||||
// Mock loadAndFormatTools to return tool definitions including our tool
|
||||
loadAndFormatTools.mockReturnValue({
|
||||
tool1: {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'tool1',
|
||||
description: 'Tool 1',
|
||||
parameters: {},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
expect(checkPluginAuth).toHaveBeenCalledWith(mockPlugin);
|
||||
expect(mockRes.status).toHaveBeenCalledWith(200);
|
||||
const responseData = mockRes.json.mock.calls[0][0];
|
||||
expect(Array.isArray(responseData)).toBe(true);
|
||||
const tool = responseData.find((t) => t.pluginKey === 'tool1');
|
||||
expect(tool).toBeDefined();
|
||||
// checkPluginAuth returns false, so authenticated property is not added
|
||||
expect(tool.authenticated).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should use getToolkitKey for toolkit validation', async () => {
|
||||
@@ -213,22 +247,38 @@ describe('PluginController', () => {
|
||||
toolkit: true,
|
||||
};
|
||||
|
||||
require('~/app/clients/tools').availableTools.push(mockToolkit);
|
||||
|
||||
// Mock toolkits to have a mapping
|
||||
require('~/app/clients/tools').toolkits.push({
|
||||
name: 'Toolkit1',
|
||||
pluginKey: 'toolkit1',
|
||||
tools: ['toolkit1_function'],
|
||||
});
|
||||
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCachedTools.mockResolvedValue({});
|
||||
convertMCPToolsToPlugins.mockReturnValue([]);
|
||||
filterUniquePlugins.mockReturnValue([mockToolkit]);
|
||||
checkPluginAuth.mockReturnValue(false);
|
||||
getToolkitKey.mockReturnValue('toolkit1');
|
||||
getCachedTools.mockResolvedValue(null);
|
||||
getCustomConfig.mockResolvedValue(null);
|
||||
|
||||
// Mock getCachedTools second call to return tool definitions
|
||||
getCachedTools.mockResolvedValueOnce({}).mockResolvedValueOnce({
|
||||
toolkit1_function: true,
|
||||
// Mock loadAndFormatTools to return tool definitions
|
||||
loadAndFormatTools.mockReturnValue({
|
||||
toolkit1_function: {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: 'toolkit1_function',
|
||||
description: 'Toolkit function',
|
||||
parameters: {},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
expect(getToolkitKey).toHaveBeenCalled();
|
||||
expect(mockRes.status).toHaveBeenCalledWith(200);
|
||||
const responseData = mockRes.json.mock.calls[0][0];
|
||||
expect(Array.isArray(responseData)).toBe(true);
|
||||
const toolkit = responseData.find((t) => t.pluginKey === 'toolkit1');
|
||||
expect(toolkit).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -239,32 +289,33 @@ describe('PluginController', () => {
|
||||
|
||||
const functionTools = {
|
||||
[`test-tool${Constants.mcp_delimiter}test-server`]: {
|
||||
function: { name: 'test-tool', description: 'A test tool' },
|
||||
type: 'function',
|
||||
function: {
|
||||
name: `test-tool${Constants.mcp_delimiter}test-server`,
|
||||
description: 'A test tool',
|
||||
parameters: { type: 'object', properties: {} },
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const mockConvertedPlugin = {
|
||||
name: 'test-tool',
|
||||
pluginKey: `test-tool${Constants.mcp_delimiter}test-server`,
|
||||
description: 'A test tool',
|
||||
icon: mcpServers['test-server']?.iconPath,
|
||||
authenticated: true,
|
||||
authConfig: [],
|
||||
// Mock the MCP manager to return tools
|
||||
const mockMCPManager = {
|
||||
getAllToolFunctions: jest.fn().mockResolvedValue(functionTools),
|
||||
};
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMCPManager);
|
||||
|
||||
getCachedTools.mockResolvedValueOnce({});
|
||||
|
||||
// Mock loadAndFormatTools to return empty object since these are MCP tools
|
||||
loadAndFormatTools.mockReturnValue({});
|
||||
|
||||
getCachedTools.mockResolvedValueOnce(functionTools);
|
||||
convertMCPToolsToPlugins.mockReturnValue([mockConvertedPlugin]);
|
||||
filterUniquePlugins.mockImplementation((plugins) => plugins);
|
||||
checkPluginAuth.mockReturnValue(true);
|
||||
getToolkitKey.mockReturnValue(undefined);
|
||||
|
||||
getCachedTools.mockResolvedValueOnce({
|
||||
[`test-tool${Constants.mcp_delimiter}test-server`]: true,
|
||||
});
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
const responseData = mockRes.json.mock.calls[0][0];
|
||||
return responseData.find((tool) => tool.name === 'test-tool');
|
||||
return responseData.find(
|
||||
(tool) => tool.pluginKey === `test-tool${Constants.mcp_delimiter}test-server`,
|
||||
);
|
||||
};
|
||||
|
||||
it('should set plugin.icon when iconPath is defined', async () => {
|
||||
@@ -298,19 +349,21 @@ describe('PluginController', () => {
|
||||
},
|
||||
};
|
||||
|
||||
// We need to test the actual flow where MCP manager tools are included
|
||||
const mcpManagerTools = [
|
||||
{
|
||||
name: 'tool1',
|
||||
pluginKey: `tool1${Constants.mcp_delimiter}test-server`,
|
||||
description: 'Tool 1',
|
||||
authenticated: true,
|
||||
// Mock MCP tools returned by getAllToolFunctions
|
||||
const mcpToolFunctions = {
|
||||
[`tool1${Constants.mcp_delimiter}test-server`]: {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: `tool1${Constants.mcp_delimiter}test-server`,
|
||||
description: 'Tool 1',
|
||||
parameters: {},
|
||||
},
|
||||
},
|
||||
];
|
||||
};
|
||||
|
||||
// Mock the MCP manager to return tools
|
||||
const mockMCPManager = {
|
||||
loadManifestTools: jest.fn().mockResolvedValue(mcpManagerTools),
|
||||
getAllToolFunctions: jest.fn().mockResolvedValue(mcpToolFunctions),
|
||||
};
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMCPManager);
|
||||
|
||||
@@ -320,19 +373,11 @@ describe('PluginController', () => {
|
||||
// First call returns user tools (empty in this case)
|
||||
getCachedTools.mockResolvedValueOnce({});
|
||||
|
||||
// Mock convertMCPToolsToPlugins to return empty array for user tools
|
||||
convertMCPToolsToPlugins.mockReturnValue([]);
|
||||
// Mock loadAndFormatTools to return empty object for MCP tools
|
||||
loadAndFormatTools.mockReturnValue({});
|
||||
|
||||
// Mock filterUniquePlugins to pass through
|
||||
filterUniquePlugins.mockImplementation((plugins) => plugins || []);
|
||||
|
||||
// Mock checkPluginAuth
|
||||
checkPluginAuth.mockReturnValue(true);
|
||||
|
||||
// Second call returns tool definitions
|
||||
getCachedTools.mockResolvedValueOnce({
|
||||
[`tool1${Constants.mcp_delimiter}test-server`]: true,
|
||||
});
|
||||
// Second call returns tool definitions including our MCP tool
|
||||
getCachedTools.mockResolvedValueOnce(mcpToolFunctions);
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
@@ -373,25 +418,23 @@ describe('PluginController', () => {
|
||||
it('should handle null cachedTools and cachedUserTools', async () => {
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCachedTools.mockResolvedValue(null);
|
||||
convertMCPToolsToPlugins.mockReturnValue(undefined);
|
||||
filterUniquePlugins.mockImplementation((plugins) => plugins || []);
|
||||
getCustomConfig.mockResolvedValue(null);
|
||||
|
||||
// Mock loadAndFormatTools to return empty object when getCachedTools returns null
|
||||
loadAndFormatTools.mockReturnValue({});
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
expect(convertMCPToolsToPlugins).toHaveBeenCalledWith({
|
||||
functionTools: null,
|
||||
customConfig: null,
|
||||
});
|
||||
// Should handle null values gracefully
|
||||
expect(mockRes.status).toHaveBeenCalledWith(200);
|
||||
});
|
||||
|
||||
it('should handle when getCachedTools returns undefined', async () => {
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCachedTools.mockResolvedValue(undefined);
|
||||
convertMCPToolsToPlugins.mockReturnValue(undefined);
|
||||
filterUniquePlugins.mockImplementation((plugins) => plugins || []);
|
||||
getCustomConfig.mockResolvedValue(null);
|
||||
checkPluginAuth.mockReturnValue(false);
|
||||
|
||||
// Mock loadAndFormatTools to return empty object when getCachedTools returns undefined
|
||||
loadAndFormatTools.mockReturnValue({});
|
||||
|
||||
// Mock getCachedTools to return undefined for both calls
|
||||
getCachedTools.mockReset();
|
||||
@@ -399,37 +442,40 @@ describe('PluginController', () => {
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
expect(convertMCPToolsToPlugins).toHaveBeenCalledWith({
|
||||
functionTools: undefined,
|
||||
customConfig: null,
|
||||
});
|
||||
// Should handle undefined values gracefully
|
||||
expect(mockRes.status).toHaveBeenCalledWith(200);
|
||||
});
|
||||
|
||||
it('should handle cachedToolsArray and userPlugins both being defined', async () => {
|
||||
const cachedTools = [{ name: 'CachedTool', pluginKey: 'cached-tool', description: 'Cached' }];
|
||||
// Use MCP delimiter for the user tool so convertMCPToolsToPlugins works
|
||||
const userTools = {
|
||||
'user-tool': { function: { name: 'user-tool', description: 'User tool' } },
|
||||
[`user-tool${Constants.mcp_delimiter}server1`]: {
|
||||
type: 'function',
|
||||
function: {
|
||||
name: `user-tool${Constants.mcp_delimiter}server1`,
|
||||
description: 'User tool',
|
||||
parameters: {},
|
||||
},
|
||||
},
|
||||
};
|
||||
const userPlugins = [{ name: 'UserTool', pluginKey: 'user-tool', description: 'User tool' }];
|
||||
|
||||
mockCache.get.mockResolvedValue(cachedTools);
|
||||
getCachedTools.mockResolvedValue(userTools);
|
||||
convertMCPToolsToPlugins.mockReturnValue(userPlugins);
|
||||
filterUniquePlugins.mockReturnValue([...userPlugins, ...cachedTools]);
|
||||
getCustomConfig.mockResolvedValue(null);
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(200);
|
||||
expect(mockRes.json).toHaveBeenCalledWith([...userPlugins, ...cachedTools]);
|
||||
const responseData = mockRes.json.mock.calls[0][0];
|
||||
// Should have both cached and user tools
|
||||
expect(responseData.length).toBeGreaterThanOrEqual(2);
|
||||
});
|
||||
|
||||
it('should handle empty toolDefinitions object', async () => {
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCachedTools.mockResolvedValueOnce({}).mockResolvedValueOnce({});
|
||||
convertMCPToolsToPlugins.mockReturnValue([]);
|
||||
filterUniquePlugins.mockImplementation((plugins) => plugins || []);
|
||||
getCustomConfig.mockResolvedValue(null);
|
||||
checkPluginAuth.mockReturnValue(true);
|
||||
|
||||
await getAvailableTools(mockReq, mockRes);
|
||||
|
||||
@@ -456,18 +502,6 @@ describe('PluginController', () => {
|
||||
getCustomConfig.mockResolvedValue(customConfig);
|
||||
getCachedTools.mockResolvedValueOnce(mockUserTools);
|
||||
|
||||
const mockPlugin = {
|
||||
name: 'tool1',
|
||||
pluginKey: `tool1${Constants.mcp_delimiter}test-server`,
|
||||
description: 'Tool 1',
|
||||
authenticated: true,
|
||||
authConfig: [],
|
||||
};
|
||||
|
||||
convertMCPToolsToPlugins.mockReturnValue([mockPlugin]);
|
||||
filterUniquePlugins.mockImplementation((plugins) => plugins);
|
||||
checkPluginAuth.mockReturnValue(true);
|
||||
|
||||
getCachedTools.mockResolvedValueOnce({
|
||||
[`tool1${Constants.mcp_delimiter}test-server`]: true,
|
||||
});
|
||||
@@ -483,8 +517,6 @@ describe('PluginController', () => {
|
||||
it('should handle req.app.locals with undefined filteredTools and includedTools', async () => {
|
||||
mockReq.app = { locals: {} };
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
filterUniquePlugins.mockReturnValue([]);
|
||||
checkPluginAuth.mockReturnValue(false);
|
||||
|
||||
await getAvailablePluginsController(mockReq, mockRes);
|
||||
|
||||
@@ -500,14 +532,25 @@ describe('PluginController', () => {
|
||||
toolkit: true,
|
||||
};
|
||||
|
||||
// Ensure req.app.locals is properly mocked
|
||||
mockReq.app = {
|
||||
locals: {
|
||||
filteredTools: [],
|
||||
includedTools: [],
|
||||
paths: { structuredTools: '/mock/path' },
|
||||
},
|
||||
};
|
||||
|
||||
// Add the toolkit to availableTools
|
||||
require('~/app/clients/tools').availableTools.push(mockToolkit);
|
||||
|
||||
mockCache.get.mockResolvedValue(null);
|
||||
getCachedTools.mockResolvedValue({});
|
||||
convertMCPToolsToPlugins.mockReturnValue([]);
|
||||
filterUniquePlugins.mockReturnValue([mockToolkit]);
|
||||
checkPluginAuth.mockReturnValue(false);
|
||||
getToolkitKey.mockReturnValue(undefined);
|
||||
getCustomConfig.mockResolvedValue(null);
|
||||
|
||||
// Mock loadAndFormatTools to return an empty object when toolDefinitions is null
|
||||
loadAndFormatTools.mockReturnValue({});
|
||||
|
||||
// Mock getCachedTools second call to return null
|
||||
getCachedTools.mockResolvedValueOnce({}).mockResolvedValueOnce(null);
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ const verify2FA = async (req, res) => {
|
||||
try {
|
||||
const userId = req.user.id;
|
||||
const { token, backupCode } = req.body;
|
||||
const user = await getUserById(userId);
|
||||
const user = await getUserById(userId, '_id totpSecret backupCodes');
|
||||
|
||||
if (!user || !user.totpSecret) {
|
||||
return res.status(400).json({ message: '2FA not initiated' });
|
||||
@@ -79,7 +79,7 @@ const confirm2FA = async (req, res) => {
|
||||
try {
|
||||
const userId = req.user.id;
|
||||
const { token } = req.body;
|
||||
const user = await getUserById(userId);
|
||||
const user = await getUserById(userId, '_id totpSecret');
|
||||
|
||||
if (!user || !user.totpSecret) {
|
||||
return res.status(400).json({ message: '2FA not initiated' });
|
||||
@@ -105,7 +105,7 @@ const disable2FA = async (req, res) => {
|
||||
try {
|
||||
const userId = req.user.id;
|
||||
const { token, backupCode } = req.body;
|
||||
const user = await getUserById(userId);
|
||||
const user = await getUserById(userId, '_id totpSecret backupCodes');
|
||||
|
||||
if (!user || !user.totpSecret) {
|
||||
return res.status(400).json({ message: '2FA is not setup for this user' });
|
||||
|
||||
@@ -24,7 +24,13 @@ const { getMCPManager } = require('~/config');
|
||||
const getUserController = async (req, res) => {
|
||||
/** @type {MongoUser} */
|
||||
const userData = req.user.toObject != null ? req.user.toObject() : { ...req.user };
|
||||
/**
|
||||
* These fields should not exist due to secure field selection, but deletion
|
||||
* is done in case of alternate database incompatibility with Mongo API
|
||||
* */
|
||||
delete userData.password;
|
||||
delete userData.totpSecret;
|
||||
delete userData.backupCodes;
|
||||
if (req.app.locals.fileStrategy === FileSources.s3 && userData.avatar) {
|
||||
const avatarNeedsRefresh = needsRefresh(userData.avatar, 3600);
|
||||
if (!avatarNeedsRefresh) {
|
||||
|
||||
@@ -11,6 +11,7 @@ const {
|
||||
handleToolCalls,
|
||||
ChatModelStreamHandler,
|
||||
} = require('@librechat/agents');
|
||||
const { processFileCitations } = require('~/server/services/Files/Citations');
|
||||
const { processCodeOutput } = require('~/server/services/Files/Code/process');
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
const { saveBase64Image } = require('~/server/services/Files/process');
|
||||
@@ -238,6 +239,31 @@ function createToolEndCallback({ req, res, artifactPromises }) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (output.artifact[Tools.file_search]) {
|
||||
artifactPromises.push(
|
||||
(async () => {
|
||||
const user = req.user;
|
||||
const attachment = await processFileCitations({
|
||||
user,
|
||||
metadata,
|
||||
toolArtifact: output.artifact,
|
||||
toolCallId: output.tool_call_id,
|
||||
});
|
||||
if (!attachment) {
|
||||
return null;
|
||||
}
|
||||
if (!res.headersSent) {
|
||||
return attachment;
|
||||
}
|
||||
res.write(`event: attachment\ndata: ${JSON.stringify(attachment)}\n\n`);
|
||||
return attachment;
|
||||
})().catch((error) => {
|
||||
logger.error('Error processing file citations:', error);
|
||||
return null;
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
if (output.artifact[Tools.web_search]) {
|
||||
artifactPromises.push(
|
||||
(async () => {
|
||||
|
||||
@@ -33,18 +33,13 @@ const {
|
||||
bedrockInputSchema,
|
||||
removeNullishValues,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
findPluginAuthsByKeys,
|
||||
getFormattedMemories,
|
||||
deleteMemory,
|
||||
setMemory,
|
||||
} = require('~/models');
|
||||
const { getMCPAuthMap, checkCapability, hasCustomUserVars } = require('~/server/services/Config');
|
||||
const { addCacheControl, createContextHandlers } = require('~/app/clients/prompts');
|
||||
const { initializeAgent } = require('~/server/services/Endpoints/agents/agent');
|
||||
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
|
||||
const { getFormattedMemories, deleteMemory, setMemory } = require('~/models');
|
||||
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
|
||||
const { getProviderConfig } = require('~/server/services/Endpoints');
|
||||
const { checkCapability } = require('~/server/services/Config');
|
||||
const BaseClient = require('~/app/clients/BaseClient');
|
||||
const { getRoleByName } = require('~/models/Role');
|
||||
const { loadAgent } = require('~/models/Agent');
|
||||
@@ -226,42 +221,6 @@ class AgentClient extends BaseClient {
|
||||
return files;
|
||||
}
|
||||
|
||||
async addDocuments(message, attachments) {
|
||||
const documentResult =
|
||||
await require('~/server/services/Files/documents').encodeAndFormatDocuments(
|
||||
this.options.req,
|
||||
attachments,
|
||||
this.options.agent.provider,
|
||||
);
|
||||
message.documents =
|
||||
documentResult.documents && documentResult.documents.length
|
||||
? documentResult.documents
|
||||
: undefined;
|
||||
return documentResult.files;
|
||||
}
|
||||
|
||||
async processAttachments(message, attachments) {
|
||||
const [imageFiles, documentFiles] = await Promise.all([
|
||||
this.addImageURLs(message, attachments),
|
||||
this.addDocuments(message, attachments),
|
||||
]);
|
||||
|
||||
const allFiles = [...imageFiles, ...documentFiles];
|
||||
const seenFileIds = new Set();
|
||||
const uniqueFiles = [];
|
||||
|
||||
for (const file of allFiles) {
|
||||
if (file.file_id && !seenFileIds.has(file.file_id)) {
|
||||
seenFileIds.add(file.file_id);
|
||||
uniqueFiles.push(file);
|
||||
} else if (!file.file_id) {
|
||||
uniqueFiles.push(file);
|
||||
}
|
||||
}
|
||||
|
||||
return uniqueFiles;
|
||||
}
|
||||
|
||||
async buildMessages(
|
||||
messages,
|
||||
parentMessageId,
|
||||
@@ -295,7 +254,7 @@ class AgentClient extends BaseClient {
|
||||
};
|
||||
}
|
||||
|
||||
const files = await this.processAttachments(
|
||||
const files = await this.addImageURLs(
|
||||
orderedMessages[orderedMessages.length - 1],
|
||||
attachments,
|
||||
);
|
||||
@@ -318,23 +277,6 @@ class AgentClient extends BaseClient {
|
||||
assistantName: this.options?.modelLabel,
|
||||
});
|
||||
|
||||
if (
|
||||
message.documents &&
|
||||
message.documents.length > 0 &&
|
||||
message.role === 'user' &&
|
||||
this.options.agent.provider === EModelEndpoint.anthropic
|
||||
) {
|
||||
const contentParts = [];
|
||||
contentParts.push(...message.documents);
|
||||
if (message.image_urls && message.image_urls.length > 0) {
|
||||
contentParts.push(...message.image_urls);
|
||||
}
|
||||
const textContent =
|
||||
typeof formattedMessage.content === 'string' ? formattedMessage.content : '';
|
||||
contentParts.push({ type: 'text', text: textContent });
|
||||
formattedMessage.content = contentParts;
|
||||
}
|
||||
|
||||
if (message.ocr && i !== orderedMessages.length - 1) {
|
||||
if (typeof formattedMessage.content === 'string') {
|
||||
formattedMessage.content = message.ocr + '\n' + formattedMessage.content;
|
||||
@@ -668,6 +610,7 @@ class AgentClient extends BaseClient {
|
||||
await this.chatCompletion({
|
||||
payload,
|
||||
onProgress: opts.onProgress,
|
||||
userMCPAuthMap: opts.userMCPAuthMap,
|
||||
abortController: opts.abortController,
|
||||
});
|
||||
return this.contentParts;
|
||||
@@ -800,7 +743,13 @@ class AgentClient extends BaseClient {
|
||||
return currentMessageTokens > 0 ? currentMessageTokens : originalEstimate;
|
||||
}
|
||||
|
||||
async chatCompletion({ payload, abortController = null }) {
|
||||
/**
|
||||
* @param {object} params
|
||||
* @param {string | ChatCompletionMessageParam[]} params.payload
|
||||
* @param {Record<string, Record<string, string>>} [params.userMCPAuthMap]
|
||||
* @param {AbortController} [params.abortController]
|
||||
*/
|
||||
async chatCompletion({ payload, userMCPAuthMap, abortController = null }) {
|
||||
/** @type {Partial<GraphRunnableConfig>} */
|
||||
let config;
|
||||
/** @type {ReturnType<createRun>} */
|
||||
@@ -821,6 +770,11 @@ class AgentClient extends BaseClient {
|
||||
last_agent_index: this.agentConfigs?.size ?? 0,
|
||||
user_id: this.user ?? this.options.req.user?.id,
|
||||
hide_sequential_outputs: this.options.agent.hide_sequential_outputs,
|
||||
requestBody: {
|
||||
messageId: this.responseMessageId,
|
||||
conversationId: this.conversationId,
|
||||
parentMessageId: this.parentMessageId,
|
||||
},
|
||||
user: this.options.req.user,
|
||||
},
|
||||
recursionLimit: agentsEConfig?.recursionLimit ?? 25,
|
||||
@@ -830,51 +784,6 @@ class AgentClient extends BaseClient {
|
||||
};
|
||||
|
||||
const toolSet = new Set((this.options.agent.tools ?? []).map((tool) => tool && tool.name));
|
||||
|
||||
if (
|
||||
this.options.agent.provider === EModelEndpoint.anthropic &&
|
||||
payload &&
|
||||
Array.isArray(payload)
|
||||
) {
|
||||
let userMessageWithDocs = null;
|
||||
|
||||
if (this.userMessage?.documents) {
|
||||
userMessageWithDocs = this.userMessage;
|
||||
} else if (this.currentMessages?.length > 0) {
|
||||
const lastMessage = this.currentMessages[this.currentMessages.length - 1];
|
||||
if (lastMessage.documents?.length > 0) {
|
||||
userMessageWithDocs = lastMessage;
|
||||
}
|
||||
} else if (this.messages?.length > 0) {
|
||||
const lastMessage = this.messages[this.messages.length - 1];
|
||||
if (lastMessage.documents?.length > 0) {
|
||||
userMessageWithDocs = lastMessage;
|
||||
}
|
||||
}
|
||||
|
||||
if (userMessageWithDocs) {
|
||||
for (const payloadMessage of payload) {
|
||||
if (
|
||||
payloadMessage.role === 'user' &&
|
||||
userMessageWithDocs.text === payloadMessage.content
|
||||
) {
|
||||
if (typeof payloadMessage.content === 'string') {
|
||||
payloadMessage.content = [
|
||||
...userMessageWithDocs.documents,
|
||||
{ type: 'text', text: payloadMessage.content },
|
||||
];
|
||||
} else if (Array.isArray(payloadMessage.content)) {
|
||||
payloadMessage.content = [
|
||||
...userMessageWithDocs.documents,
|
||||
...payloadMessage.content,
|
||||
];
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let { messages: initialMessages, indexTokenCountMap } = formatAgentMessages(
|
||||
payload,
|
||||
this.indexTokenCountMap,
|
||||
@@ -936,7 +845,7 @@ class AgentClient extends BaseClient {
|
||||
|
||||
if (noSystemMessages === true && systemContent?.length) {
|
||||
const latestMessageContent = _messages.pop().content;
|
||||
if (typeof latestMessage !== 'string') {
|
||||
if (typeof latestMessageContent !== 'string') {
|
||||
latestMessageContent[0].text = [systemContent, latestMessageContent[0].text].join('\n');
|
||||
_messages.push(new HumanMessage({ content: latestMessageContent }));
|
||||
} else {
|
||||
@@ -996,21 +905,9 @@ class AgentClient extends BaseClient {
|
||||
run.Graph.contentData = contentData;
|
||||
}
|
||||
|
||||
try {
|
||||
if (await hasCustomUserVars()) {
|
||||
config.configurable.userMCPAuthMap = await getMCPAuthMap({
|
||||
tools: agent.tools,
|
||||
userId: this.options.req.user.id,
|
||||
findPluginAuthsByKeys,
|
||||
});
|
||||
}
|
||||
} catch (err) {
|
||||
logger.error(
|
||||
`[api/server/controllers/agents/client.js #chatCompletion] Error getting custom user vars for agent ${agent.id}`,
|
||||
err,
|
||||
);
|
||||
if (userMCPAuthMap != null) {
|
||||
config.configurable.userMCPAuthMap = userMCPAuthMap;
|
||||
}
|
||||
|
||||
await run.processStream({ messages }, config, {
|
||||
keepContent: i !== 0,
|
||||
tokenCounter: createTokenCounter(this.getEncoding()),
|
||||
@@ -1132,6 +1029,7 @@ class AgentClient extends BaseClient {
|
||||
if (attachments && attachments.length > 0) {
|
||||
this.artifactPromises.push(...attachments);
|
||||
}
|
||||
|
||||
await this.recordCollectedUsage({ context: 'message' });
|
||||
} catch (err) {
|
||||
logger.error(
|
||||
@@ -1177,7 +1075,6 @@ class AgentClient extends BaseClient {
|
||||
|
||||
/** @type {import('@librechat/agents').ClientOptions} */
|
||||
let clientOptions = {
|
||||
maxTokens: 75,
|
||||
model: agent.model || agent.model_parameters.model,
|
||||
};
|
||||
|
||||
@@ -1244,15 +1141,13 @@ class AgentClient extends BaseClient {
|
||||
clientOptions.configuration = options.configOptions;
|
||||
}
|
||||
|
||||
const shouldRemoveMaxTokens = /\b(o\d|gpt-[5-9])\b/i.test(clientOptions.model);
|
||||
if (shouldRemoveMaxTokens && clientOptions.maxTokens != null) {
|
||||
if (clientOptions.maxTokens != null) {
|
||||
delete clientOptions.maxTokens;
|
||||
} else if (!shouldRemoveMaxTokens && !clientOptions.maxTokens) {
|
||||
clientOptions.maxTokens = 75;
|
||||
}
|
||||
if (shouldRemoveMaxTokens && clientOptions?.modelKwargs?.max_completion_tokens != null) {
|
||||
if (clientOptions?.modelKwargs?.max_completion_tokens != null) {
|
||||
delete clientOptions.modelKwargs.max_completion_tokens;
|
||||
} else if (shouldRemoveMaxTokens && clientOptions?.modelKwargs?.max_output_tokens != null) {
|
||||
}
|
||||
if (clientOptions?.modelKwargs?.max_output_tokens != null) {
|
||||
delete clientOptions.modelKwargs.max_output_tokens;
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,24 @@ const {
|
||||
const { disposeClient, clientRegistry, requestDataMap } = require('~/server/cleanup');
|
||||
const { saveMessage } = require('~/models');
|
||||
|
||||
function createCloseHandler(abortController) {
|
||||
return function (manual) {
|
||||
if (!manual) {
|
||||
logger.debug('[AgentController] Request closed');
|
||||
}
|
||||
if (!abortController) {
|
||||
return;
|
||||
} else if (abortController.signal.aborted) {
|
||||
return;
|
||||
} else if (abortController.requestCompleted) {
|
||||
return;
|
||||
}
|
||||
|
||||
abortController.abort();
|
||||
logger.debug('[AgentController] Request aborted on close');
|
||||
};
|
||||
}
|
||||
|
||||
const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
let {
|
||||
text,
|
||||
@@ -31,7 +49,6 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
let userMessagePromise;
|
||||
let getAbortData;
|
||||
let client = null;
|
||||
// Initialize as an array
|
||||
let cleanupHandlers = [];
|
||||
|
||||
const newConvo = !conversationId;
|
||||
@@ -62,9 +79,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
// Create a function to handle final cleanup
|
||||
const performCleanup = () => {
|
||||
logger.debug('[AgentController] Performing cleanup');
|
||||
// Make sure cleanupHandlers is an array before iterating
|
||||
if (Array.isArray(cleanupHandlers)) {
|
||||
// Execute all cleanup handlers
|
||||
for (const handler of cleanupHandlers) {
|
||||
try {
|
||||
if (typeof handler === 'function') {
|
||||
@@ -105,8 +120,33 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
};
|
||||
|
||||
try {
|
||||
/** @type {{ client: TAgentClient }} */
|
||||
const result = await initializeClient({ req, res, endpointOption });
|
||||
let prelimAbortController = new AbortController();
|
||||
const prelimCloseHandler = createCloseHandler(prelimAbortController);
|
||||
res.on('close', prelimCloseHandler);
|
||||
const removePrelimHandler = (manual) => {
|
||||
try {
|
||||
prelimCloseHandler(manual);
|
||||
res.removeListener('close', prelimCloseHandler);
|
||||
} catch (e) {
|
||||
logger.error('[AgentController] Error removing close listener', e);
|
||||
}
|
||||
};
|
||||
cleanupHandlers.push(removePrelimHandler);
|
||||
/** @type {{ client: TAgentClient; userMCPAuthMap?: Record<string, Record<string, string>> }} */
|
||||
const result = await initializeClient({
|
||||
req,
|
||||
res,
|
||||
endpointOption,
|
||||
signal: prelimAbortController.signal,
|
||||
});
|
||||
if (prelimAbortController.signal?.aborted) {
|
||||
prelimAbortController = null;
|
||||
throw new Error('Request was aborted before initialization could complete');
|
||||
} else {
|
||||
prelimAbortController = null;
|
||||
removePrelimHandler(true);
|
||||
cleanupHandlers.pop();
|
||||
}
|
||||
client = result.client;
|
||||
|
||||
// Register client with finalization registry if available
|
||||
@@ -138,22 +178,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
};
|
||||
|
||||
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
|
||||
|
||||
// Simple handler to avoid capturing scope
|
||||
const closeHandler = () => {
|
||||
logger.debug('[AgentController] Request closed');
|
||||
if (!abortController) {
|
||||
return;
|
||||
} else if (abortController.signal.aborted) {
|
||||
return;
|
||||
} else if (abortController.requestCompleted) {
|
||||
return;
|
||||
}
|
||||
|
||||
abortController.abort();
|
||||
logger.debug('[AgentController] Request aborted on close');
|
||||
};
|
||||
|
||||
const closeHandler = createCloseHandler(abortController);
|
||||
res.on('close', closeHandler);
|
||||
cleanupHandlers.push(() => {
|
||||
try {
|
||||
@@ -175,6 +200,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
|
||||
abortController,
|
||||
overrideParentMessageId,
|
||||
isEdited: !!editedContent,
|
||||
userMCPAuthMap: result.userMCPAuthMap,
|
||||
responseMessageId: editedResponseMessageId,
|
||||
progressOptions: {
|
||||
res,
|
||||
|
||||
@@ -5,30 +5,40 @@ const { logger } = require('@librechat/data-schemas');
|
||||
const { agentCreateSchema, agentUpdateSchema } = require('@librechat/api');
|
||||
const {
|
||||
Tools,
|
||||
Constants,
|
||||
FileSources,
|
||||
SystemRoles,
|
||||
FileSources,
|
||||
ResourceType,
|
||||
AccessRoleIds,
|
||||
PrincipalType,
|
||||
EToolResources,
|
||||
PermissionBits,
|
||||
actionDelimiter,
|
||||
removeNullishValues,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
getAgent,
|
||||
getListAgentsByAccess,
|
||||
countPromotedAgents,
|
||||
revertAgentVersion,
|
||||
createAgent,
|
||||
updateAgent,
|
||||
deleteAgent,
|
||||
getListAgents,
|
||||
getAgent,
|
||||
} = require('~/models/Agent');
|
||||
const {
|
||||
findPubliclyAccessibleResources,
|
||||
findAccessibleResources,
|
||||
hasPublicPermission,
|
||||
grantPermission,
|
||||
} = require('~/server/services/PermissionService');
|
||||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { resizeAvatar } = require('~/server/services/Files/images/avatar');
|
||||
const { getFileStrategy } = require('~/server/utils/getFileStrategy');
|
||||
const { refreshS3Url } = require('~/server/services/Files/S3/crud');
|
||||
const { filterFile } = require('~/server/services/Files/process');
|
||||
const { updateAction, getActions } = require('~/models/Action');
|
||||
const { getCachedTools } = require('~/server/services/Config');
|
||||
const { updateAgentProjects } = require('~/models/Agent');
|
||||
const { getProjectByName } = require('~/models/Project');
|
||||
const { revertAgentVersion } = require('~/models/Agent');
|
||||
const { deleteFileByFilter } = require('~/models/File');
|
||||
const { getCategoriesWithCounts } = require('~/models');
|
||||
|
||||
const systemTools = {
|
||||
[Tools.execute_code]: true,
|
||||
@@ -42,7 +52,7 @@ const systemTools = {
|
||||
* @param {ServerRequest} req - The request object.
|
||||
* @param {AgentCreateParams} req.body - The request body.
|
||||
* @param {ServerResponse} res - The response object.
|
||||
* @returns {Agent} 201 - success response - application/json
|
||||
* @returns {Promise<Agent>} 201 - success response - application/json
|
||||
*/
|
||||
const createAgentHandler = async (req, res) => {
|
||||
try {
|
||||
@@ -67,6 +77,27 @@ const createAgentHandler = async (req, res) => {
|
||||
}
|
||||
|
||||
const agent = await createAgent(agentData);
|
||||
|
||||
// Automatically grant owner permissions to the creator
|
||||
try {
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: userId,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
accessRoleId: AccessRoleIds.AGENT_OWNER,
|
||||
grantedBy: userId,
|
||||
});
|
||||
logger.debug(
|
||||
`[createAgent] Granted owner permissions to user ${userId} for agent ${agent.id}`,
|
||||
);
|
||||
} catch (permissionError) {
|
||||
logger.error(
|
||||
`[createAgent] Failed to grant owner permissions for agent ${agent.id}:`,
|
||||
permissionError,
|
||||
);
|
||||
}
|
||||
|
||||
res.status(201).json(agent);
|
||||
} catch (error) {
|
||||
if (error instanceof z.ZodError) {
|
||||
@@ -89,21 +120,14 @@ const createAgentHandler = async (req, res) => {
|
||||
* @returns {Promise<Agent>} 200 - success response - application/json
|
||||
* @returns {Error} 404 - Agent not found
|
||||
*/
|
||||
const getAgentHandler = async (req, res) => {
|
||||
const getAgentHandler = async (req, res, expandProperties = false) => {
|
||||
try {
|
||||
const id = req.params.id;
|
||||
const author = req.user.id;
|
||||
|
||||
let query = { id, author };
|
||||
|
||||
const globalProject = await getProjectByName(Constants.GLOBAL_PROJECT_NAME, ['agentIds']);
|
||||
if (globalProject && (globalProject.agentIds?.length ?? 0) > 0) {
|
||||
query = {
|
||||
$or: [{ id, $in: globalProject.agentIds }, query],
|
||||
};
|
||||
}
|
||||
|
||||
const agent = await getAgent(query);
|
||||
// Permissions are validated by middleware before calling this function
|
||||
// Simply load the agent by ID
|
||||
const agent = await getAgent({ id });
|
||||
|
||||
if (!agent) {
|
||||
return res.status(404).json({ error: 'Agent not found' });
|
||||
@@ -120,23 +144,45 @@ const getAgentHandler = async (req, res) => {
|
||||
}
|
||||
|
||||
agent.author = agent.author.toString();
|
||||
|
||||
// @deprecated - isCollaborative replaced by ACL permissions
|
||||
agent.isCollaborative = !!agent.isCollaborative;
|
||||
|
||||
// Check if agent is public
|
||||
const isPublic = await hasPublicPermission({
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
requiredPermissions: PermissionBits.VIEW,
|
||||
});
|
||||
agent.isPublic = isPublic;
|
||||
|
||||
if (agent.author !== author) {
|
||||
delete agent.author;
|
||||
}
|
||||
|
||||
if (!agent.isCollaborative && agent.author !== author && req.user.role !== SystemRoles.ADMIN) {
|
||||
if (!expandProperties) {
|
||||
// VIEW permission: Basic agent info only
|
||||
return res.status(200).json({
|
||||
_id: agent._id,
|
||||
id: agent.id,
|
||||
name: agent.name,
|
||||
description: agent.description,
|
||||
avatar: agent.avatar,
|
||||
author: agent.author,
|
||||
provider: agent.provider,
|
||||
model: agent.model,
|
||||
projectIds: agent.projectIds,
|
||||
// @deprecated - isCollaborative replaced by ACL permissions
|
||||
isCollaborative: agent.isCollaborative,
|
||||
isPublic: agent.isPublic,
|
||||
version: agent.version,
|
||||
// Safe metadata
|
||||
createdAt: agent.createdAt,
|
||||
updatedAt: agent.updatedAt,
|
||||
});
|
||||
}
|
||||
|
||||
// EDIT permission: Full agent details including sensitive configuration
|
||||
return res.status(200).json(agent);
|
||||
} catch (error) {
|
||||
logger.error('[/Agents/:id] Error retrieving agent', error);
|
||||
@@ -157,43 +203,20 @@ const updateAgentHandler = async (req, res) => {
|
||||
try {
|
||||
const id = req.params.id;
|
||||
const validatedData = agentUpdateSchema.parse(req.body);
|
||||
const { projectIds, removeProjectIds, ...updateData } = removeNullishValues(validatedData);
|
||||
const isAdmin = req.user.role === SystemRoles.ADMIN;
|
||||
const { _id, ...updateData } = removeNullishValues(validatedData);
|
||||
const existingAgent = await getAgent({ id });
|
||||
|
||||
if (!existingAgent) {
|
||||
return res.status(404).json({ error: 'Agent not found' });
|
||||
}
|
||||
|
||||
const isAuthor = existingAgent.author.toString() === req.user.id;
|
||||
const hasEditPermission = existingAgent.isCollaborative || isAdmin || isAuthor;
|
||||
|
||||
if (!hasEditPermission) {
|
||||
return res.status(403).json({
|
||||
error: 'You do not have permission to modify this non-collaborative agent',
|
||||
});
|
||||
}
|
||||
|
||||
/** @type {boolean} */
|
||||
const isProjectUpdate = (projectIds?.length ?? 0) > 0 || (removeProjectIds?.length ?? 0) > 0;
|
||||
|
||||
let updatedAgent =
|
||||
Object.keys(updateData).length > 0
|
||||
? await updateAgent({ id }, updateData, {
|
||||
updatingUserId: req.user.id,
|
||||
skipVersioning: isProjectUpdate,
|
||||
})
|
||||
: existingAgent;
|
||||
|
||||
if (isProjectUpdate) {
|
||||
updatedAgent = await updateAgentProjects({
|
||||
user: req.user,
|
||||
agentId: id,
|
||||
projectIds,
|
||||
removeProjectIds,
|
||||
});
|
||||
}
|
||||
|
||||
// Add version count to the response
|
||||
updatedAgent.version = updatedAgent.versions ? updatedAgent.versions.length : 0;
|
||||
|
||||
@@ -321,6 +344,26 @@ const duplicateAgentHandler = async (req, res) => {
|
||||
newAgentData.actions = agentActions;
|
||||
const newAgent = await createAgent(newAgentData);
|
||||
|
||||
// Automatically grant owner permissions to the duplicator
|
||||
try {
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: userId,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: newAgent._id,
|
||||
accessRoleId: AccessRoleIds.AGENT_OWNER,
|
||||
grantedBy: userId,
|
||||
});
|
||||
logger.debug(
|
||||
`[duplicateAgent] Granted owner permissions to user ${userId} for duplicated agent ${newAgent.id}`,
|
||||
);
|
||||
} catch (permissionError) {
|
||||
logger.error(
|
||||
`[duplicateAgent] Failed to grant owner permissions for duplicated agent ${newAgent.id}:`,
|
||||
permissionError,
|
||||
);
|
||||
}
|
||||
|
||||
return res.status(201).json({
|
||||
agent: newAgent,
|
||||
actions: newActionsList,
|
||||
@@ -347,7 +390,7 @@ const deleteAgentHandler = async (req, res) => {
|
||||
if (!agent) {
|
||||
return res.status(404).json({ error: 'Agent not found' });
|
||||
}
|
||||
await deleteAgent({ id, author: req.user.id });
|
||||
await deleteAgent({ id });
|
||||
return res.json({ message: 'Agent deleted' });
|
||||
} catch (error) {
|
||||
logger.error('[/Agents/:id] Error deleting Agent', error);
|
||||
@@ -356,7 +399,7 @@ const deleteAgentHandler = async (req, res) => {
|
||||
};
|
||||
|
||||
/**
|
||||
*
|
||||
* Lists agents using ACL-aware permissions (ownership + explicit shares).
|
||||
* @route GET /Agents
|
||||
* @param {object} req - Express Request
|
||||
* @param {object} req.query - Request query
|
||||
@@ -365,9 +408,65 @@ const deleteAgentHandler = async (req, res) => {
|
||||
*/
|
||||
const getListAgentsHandler = async (req, res) => {
|
||||
try {
|
||||
const data = await getListAgents({
|
||||
author: req.user.id,
|
||||
const userId = req.user.id;
|
||||
const { category, search, limit, cursor, promoted } = req.query;
|
||||
let requiredPermission = req.query.requiredPermission;
|
||||
if (typeof requiredPermission === 'string') {
|
||||
requiredPermission = parseInt(requiredPermission, 10);
|
||||
if (isNaN(requiredPermission)) {
|
||||
requiredPermission = PermissionBits.VIEW;
|
||||
}
|
||||
} else if (typeof requiredPermission !== 'number') {
|
||||
requiredPermission = PermissionBits.VIEW;
|
||||
}
|
||||
// Base filter
|
||||
const filter = {};
|
||||
|
||||
// Handle category filter - only apply if category is defined
|
||||
if (category !== undefined && category.trim() !== '') {
|
||||
filter.category = category;
|
||||
}
|
||||
|
||||
// Handle promoted filter - only from query param
|
||||
if (promoted === '1') {
|
||||
filter.is_promoted = true;
|
||||
} else if (promoted === '0') {
|
||||
filter.is_promoted = { $ne: true };
|
||||
}
|
||||
|
||||
// Handle search filter
|
||||
if (search && search.trim() !== '') {
|
||||
filter.$or = [
|
||||
{ name: { $regex: search.trim(), $options: 'i' } },
|
||||
{ description: { $regex: search.trim(), $options: 'i' } },
|
||||
];
|
||||
}
|
||||
// Get agent IDs the user has VIEW access to via ACL
|
||||
const accessibleIds = await findAccessibleResources({
|
||||
userId,
|
||||
role: req.user.role,
|
||||
resourceType: ResourceType.AGENT,
|
||||
requiredPermissions: requiredPermission,
|
||||
});
|
||||
const publiclyAccessibleIds = await findPubliclyAccessibleResources({
|
||||
resourceType: ResourceType.AGENT,
|
||||
requiredPermissions: PermissionBits.VIEW,
|
||||
});
|
||||
// Use the new ACL-aware function
|
||||
const data = await getListAgentsByAccess({
|
||||
accessibleIds,
|
||||
otherParams: filter,
|
||||
limit,
|
||||
after: cursor,
|
||||
});
|
||||
if (data?.data?.length) {
|
||||
data.data = data.data.map((agent) => {
|
||||
if (publiclyAccessibleIds.some((id) => id.equals(agent._id))) {
|
||||
agent.isPublic = true;
|
||||
}
|
||||
return agent;
|
||||
});
|
||||
}
|
||||
return res.json(data);
|
||||
} catch (error) {
|
||||
logger.error('[/Agents] Error listing Agents', error);
|
||||
@@ -401,7 +500,7 @@ const uploadAgentAvatarHandler = async (req, res) => {
|
||||
return res.status(404).json({ error: 'Agent not found' });
|
||||
}
|
||||
|
||||
const isAuthor = existingAgent.author.toString() === req.user.id;
|
||||
const isAuthor = existingAgent.author.toString() === req.user.id.toString();
|
||||
const hasEditPermission = existingAgent.isCollaborative || isAdmin || isAuthor;
|
||||
|
||||
if (!hasEditPermission) {
|
||||
@@ -412,7 +511,7 @@ const uploadAgentAvatarHandler = async (req, res) => {
|
||||
|
||||
const buffer = await fs.readFile(req.file.path);
|
||||
|
||||
const fileStrategy = req.app.locals.fileStrategy;
|
||||
const fileStrategy = getFileStrategy(req.app.locals, { isAvatar: true });
|
||||
|
||||
const resizedBuffer = await resizeAvatar({
|
||||
userId: req.user.id,
|
||||
@@ -509,7 +608,7 @@ const revertAgentVersionHandler = async (req, res) => {
|
||||
return res.status(404).json({ error: 'Agent not found' });
|
||||
}
|
||||
|
||||
const isAuthor = existingAgent.author.toString() === req.user.id;
|
||||
const isAuthor = existingAgent.author.toString() === req.user.id.toString();
|
||||
const hasEditPermission = existingAgent.isCollaborative || isAdmin || isAuthor;
|
||||
|
||||
if (!hasEditPermission) {
|
||||
@@ -534,7 +633,48 @@ const revertAgentVersionHandler = async (req, res) => {
|
||||
res.status(500).json({ error: error.message });
|
||||
}
|
||||
};
|
||||
/**
|
||||
* Get all agent categories with counts
|
||||
*
|
||||
* @param {Object} _req - Express request object (unused)
|
||||
* @param {Object} res - Express response object
|
||||
*/
|
||||
const getAgentCategories = async (_req, res) => {
|
||||
try {
|
||||
const categories = await getCategoriesWithCounts();
|
||||
const promotedCount = await countPromotedAgents();
|
||||
const formattedCategories = categories.map((category) => ({
|
||||
value: category.value,
|
||||
label: category.label,
|
||||
count: category.agentCount,
|
||||
description: category.description,
|
||||
}));
|
||||
|
||||
if (promotedCount > 0) {
|
||||
formattedCategories.unshift({
|
||||
value: 'promoted',
|
||||
label: 'Promoted',
|
||||
count: promotedCount,
|
||||
description: 'Our recommended agents',
|
||||
});
|
||||
}
|
||||
|
||||
formattedCategories.push({
|
||||
value: 'all',
|
||||
label: 'All',
|
||||
description: 'All available agents',
|
||||
});
|
||||
|
||||
res.status(200).json(formattedCategories);
|
||||
} catch (error) {
|
||||
logger.error('[/Agents/Marketplace] Error fetching agent categories:', error);
|
||||
res.status(500).json({
|
||||
error: 'Failed to fetch agent categories',
|
||||
userMessage: 'Unable to load categories. Please refresh the page.',
|
||||
suggestion: 'Try refreshing the page or check your network connection',
|
||||
});
|
||||
}
|
||||
};
|
||||
module.exports = {
|
||||
createAgent: createAgentHandler,
|
||||
getAgent: getAgentHandler,
|
||||
@@ -544,4 +684,5 @@ module.exports = {
|
||||
getListAgents: getListAgentsHandler,
|
||||
uploadAgentAvatar: uploadAgentAvatarHandler,
|
||||
revertAgentVersion: revertAgentVersionHandler,
|
||||
getAgentCategories,
|
||||
};
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
const mongoose = require('mongoose');
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const { nanoid } = require('nanoid');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { agentSchema } = require('@librechat/data-schemas');
|
||||
|
||||
@@ -41,7 +42,27 @@ jest.mock('~/models/File', () => ({
|
||||
deleteFileByFilter: jest.fn(),
|
||||
}));
|
||||
|
||||
const { createAgent: createAgentHandler, updateAgent: updateAgentHandler } = require('./v1');
|
||||
jest.mock('~/server/services/PermissionService', () => ({
|
||||
findAccessibleResources: jest.fn().mockResolvedValue([]),
|
||||
findPubliclyAccessibleResources: jest.fn().mockResolvedValue([]),
|
||||
grantPermission: jest.fn(),
|
||||
hasPublicPermission: jest.fn().mockResolvedValue(false),
|
||||
}));
|
||||
|
||||
jest.mock('~/models', () => ({
|
||||
getCategoriesWithCounts: jest.fn(),
|
||||
}));
|
||||
|
||||
const {
|
||||
createAgent: createAgentHandler,
|
||||
updateAgent: updateAgentHandler,
|
||||
getListAgents: getListAgentsHandler,
|
||||
} = require('./v1');
|
||||
|
||||
const {
|
||||
findAccessibleResources,
|
||||
findPubliclyAccessibleResources,
|
||||
} = require('~/server/services/PermissionService');
|
||||
|
||||
/**
|
||||
* @type {import('mongoose').Model<import('@librechat/data-schemas').IAgent>}
|
||||
@@ -79,6 +100,7 @@ describe('Agent Controllers - Mass Assignment Protection', () => {
|
||||
},
|
||||
body: {},
|
||||
params: {},
|
||||
query: {},
|
||||
app: {
|
||||
locals: {
|
||||
fileStrategy: 'local',
|
||||
@@ -235,6 +257,81 @@ describe('Agent Controllers - Mass Assignment Protection', () => {
|
||||
expect(agentInDb.tool_resources.invalid_resource).toBeUndefined();
|
||||
});
|
||||
|
||||
test('should handle support_contact with empty strings', async () => {
|
||||
const dataWithEmptyContact = {
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
name: 'Agent with Empty Contact',
|
||||
support_contact: {
|
||||
name: '',
|
||||
email: '',
|
||||
},
|
||||
};
|
||||
|
||||
mockReq.body = dataWithEmptyContact;
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||
|
||||
const createdAgent = mockRes.json.mock.calls[0][0];
|
||||
expect(createdAgent.name).toBe('Agent with Empty Contact');
|
||||
expect(createdAgent.support_contact).toBeDefined();
|
||||
expect(createdAgent.support_contact.name).toBe('');
|
||||
expect(createdAgent.support_contact.email).toBe('');
|
||||
});
|
||||
|
||||
test('should handle support_contact with valid email', async () => {
|
||||
const dataWithValidContact = {
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
name: 'Agent with Valid Contact',
|
||||
support_contact: {
|
||||
name: 'Support Team',
|
||||
email: 'support@example.com',
|
||||
},
|
||||
};
|
||||
|
||||
mockReq.body = dataWithValidContact;
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(201);
|
||||
|
||||
const createdAgent = mockRes.json.mock.calls[0][0];
|
||||
expect(createdAgent.support_contact).toBeDefined();
|
||||
expect(createdAgent.support_contact.name).toBe('Support Team');
|
||||
expect(createdAgent.support_contact.email).toBe('support@example.com');
|
||||
});
|
||||
|
||||
test('should reject support_contact with invalid email', async () => {
|
||||
const dataWithInvalidEmail = {
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
name: 'Agent with Invalid Email',
|
||||
support_contact: {
|
||||
name: 'Support',
|
||||
email: 'not-an-email',
|
||||
},
|
||||
};
|
||||
|
||||
mockReq.body = dataWithInvalidEmail;
|
||||
|
||||
await createAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(400);
|
||||
expect(mockRes.json).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
error: 'Invalid request data',
|
||||
details: expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
path: ['support_contact', 'email'],
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
test('should handle avatar validation', async () => {
|
||||
const dataWithAvatar = {
|
||||
provider: 'openai',
|
||||
@@ -372,52 +469,6 @@ describe('Agent Controllers - Mass Assignment Protection', () => {
|
||||
expect(agentInDb.id).toBe(existingAgentId);
|
||||
});
|
||||
|
||||
test('should reject update from non-author when not collaborative', async () => {
|
||||
const differentUserId = new mongoose.Types.ObjectId().toString();
|
||||
mockReq.user.id = differentUserId; // Different user
|
||||
mockReq.params.id = existingAgentId;
|
||||
mockReq.body = {
|
||||
name: 'Unauthorized Update',
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(403);
|
||||
expect(mockRes.json).toHaveBeenCalledWith({
|
||||
error: 'You do not have permission to modify this non-collaborative agent',
|
||||
});
|
||||
|
||||
// Verify agent was not modified in database
|
||||
const agentInDb = await Agent.findOne({ id: existingAgentId });
|
||||
expect(agentInDb.name).toBe('Original Agent');
|
||||
});
|
||||
|
||||
test('should allow update from non-author when collaborative', async () => {
|
||||
// First make the agent collaborative
|
||||
await Agent.updateOne({ id: existingAgentId }, { isCollaborative: true });
|
||||
|
||||
const differentUserId = new mongoose.Types.ObjectId().toString();
|
||||
mockReq.user.id = differentUserId; // Different user
|
||||
mockReq.params.id = existingAgentId;
|
||||
mockReq.body = {
|
||||
name: 'Collaborative Update',
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).not.toHaveBeenCalledWith(403);
|
||||
expect(mockRes.json).toHaveBeenCalled();
|
||||
|
||||
const updatedAgent = mockRes.json.mock.calls[0][0];
|
||||
expect(updatedAgent.name).toBe('Collaborative Update');
|
||||
// Author field should be removed for non-author
|
||||
expect(updatedAgent.author).toBeUndefined();
|
||||
|
||||
// Verify in database
|
||||
const agentInDb = await Agent.findOne({ id: existingAgentId });
|
||||
expect(agentInDb.name).toBe('Collaborative Update');
|
||||
});
|
||||
|
||||
test('should allow admin to update any agent', async () => {
|
||||
const adminUserId = new mongoose.Types.ObjectId().toString();
|
||||
mockReq.user.id = adminUserId;
|
||||
@@ -577,45 +628,6 @@ describe('Agent Controllers - Mass Assignment Protection', () => {
|
||||
expect(agentInDb.__v).not.toBe(99);
|
||||
});
|
||||
|
||||
test('should prevent privilege escalation through isCollaborative', async () => {
|
||||
// Create a non-collaborative agent
|
||||
const authorId = new mongoose.Types.ObjectId();
|
||||
const agent = await Agent.create({
|
||||
id: `agent_${uuidv4()}`,
|
||||
name: 'Private Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: authorId,
|
||||
isCollaborative: false,
|
||||
versions: [
|
||||
{
|
||||
name: 'Private Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
createdAt: new Date(),
|
||||
updatedAt: new Date(),
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
// Try to make it collaborative as a different user
|
||||
const attackerId = new mongoose.Types.ObjectId().toString();
|
||||
mockReq.user.id = attackerId;
|
||||
mockReq.params.id = agent.id;
|
||||
mockReq.body = {
|
||||
isCollaborative: true, // Trying to escalate privileges
|
||||
};
|
||||
|
||||
await updateAgentHandler(mockReq, mockRes);
|
||||
|
||||
// Should be rejected
|
||||
expect(mockRes.status).toHaveBeenCalledWith(403);
|
||||
|
||||
// Verify in database that it's still not collaborative
|
||||
const agentInDb = await Agent.findOne({ id: agent.id });
|
||||
expect(agentInDb.isCollaborative).toBe(false);
|
||||
});
|
||||
|
||||
test('should prevent author hijacking', async () => {
|
||||
const originalAuthorId = new mongoose.Types.ObjectId();
|
||||
const attackerId = new mongoose.Types.ObjectId();
|
||||
@@ -678,4 +690,373 @@ describe('Agent Controllers - Mass Assignment Protection', () => {
|
||||
expect(agentInDb.futureFeature).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('getListAgentsHandler - Security Tests', () => {
|
||||
let userA, userB;
|
||||
let agentA1, agentA2, agentA3, agentB1;
|
||||
|
||||
beforeEach(async () => {
|
||||
await Agent.deleteMany({});
|
||||
jest.clearAllMocks();
|
||||
|
||||
// Create two test users
|
||||
userA = new mongoose.Types.ObjectId();
|
||||
userB = new mongoose.Types.ObjectId();
|
||||
|
||||
// Create agents for User A
|
||||
agentA1 = await Agent.create({
|
||||
id: `agent_${nanoid(12)}`,
|
||||
name: 'Agent A1',
|
||||
description: 'User A agent 1',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: userA,
|
||||
versions: [
|
||||
{
|
||||
name: 'Agent A1',
|
||||
description: 'User A agent 1',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
createdAt: new Date(),
|
||||
updatedAt: new Date(),
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
agentA2 = await Agent.create({
|
||||
id: `agent_${nanoid(12)}`,
|
||||
name: 'Agent A2',
|
||||
description: 'User A agent 2',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: userA,
|
||||
versions: [
|
||||
{
|
||||
name: 'Agent A2',
|
||||
description: 'User A agent 2',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
createdAt: new Date(),
|
||||
updatedAt: new Date(),
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
agentA3 = await Agent.create({
|
||||
id: `agent_${nanoid(12)}`,
|
||||
name: 'Agent A3',
|
||||
description: 'User A agent 3',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: userA,
|
||||
category: 'productivity',
|
||||
versions: [
|
||||
{
|
||||
name: 'Agent A3',
|
||||
description: 'User A agent 3',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
category: 'productivity',
|
||||
createdAt: new Date(),
|
||||
updatedAt: new Date(),
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
// Create an agent for User B
|
||||
agentB1 = await Agent.create({
|
||||
id: `agent_${nanoid(12)}`,
|
||||
name: 'Agent B1',
|
||||
description: 'User B agent 1',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: userB,
|
||||
versions: [
|
||||
{
|
||||
name: 'Agent B1',
|
||||
description: 'User B agent 1',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
createdAt: new Date(),
|
||||
updatedAt: new Date(),
|
||||
},
|
||||
],
|
||||
});
|
||||
});
|
||||
|
||||
test('should return empty list when user has no accessible agents', async () => {
|
||||
// User B has no permissions and no owned agents
|
||||
mockReq.user.id = userB.toString();
|
||||
findAccessibleResources.mockResolvedValue([]);
|
||||
findPubliclyAccessibleResources.mockResolvedValue([]);
|
||||
|
||||
await getListAgentsHandler(mockReq, mockRes);
|
||||
|
||||
expect(findAccessibleResources).toHaveBeenCalledWith({
|
||||
userId: userB.toString(),
|
||||
role: 'USER',
|
||||
resourceType: 'agent',
|
||||
requiredPermissions: 1, // VIEW permission
|
||||
});
|
||||
|
||||
expect(mockRes.json).toHaveBeenCalledWith({
|
||||
object: 'list',
|
||||
data: [],
|
||||
first_id: null,
|
||||
last_id: null,
|
||||
has_more: false,
|
||||
after: null,
|
||||
});
|
||||
});
|
||||
|
||||
test('should not return other users agents when accessibleIds is empty', async () => {
|
||||
// User B trying to see agents with no permissions
|
||||
mockReq.user.id = userB.toString();
|
||||
findAccessibleResources.mockResolvedValue([]);
|
||||
findPubliclyAccessibleResources.mockResolvedValue([]);
|
||||
|
||||
await getListAgentsHandler(mockReq, mockRes);
|
||||
|
||||
const response = mockRes.json.mock.calls[0][0];
|
||||
expect(response.data).toHaveLength(0);
|
||||
|
||||
// Verify User A's agents are not included
|
||||
const agentIds = response.data.map((a) => a.id);
|
||||
expect(agentIds).not.toContain(agentA1.id);
|
||||
expect(agentIds).not.toContain(agentA2.id);
|
||||
expect(agentIds).not.toContain(agentA3.id);
|
||||
});
|
||||
|
||||
test('should only return agents user has access to', async () => {
|
||||
// User B has access to one of User A's agents
|
||||
mockReq.user.id = userB.toString();
|
||||
findAccessibleResources.mockResolvedValue([agentA1._id]);
|
||||
findPubliclyAccessibleResources.mockResolvedValue([]);
|
||||
|
||||
await getListAgentsHandler(mockReq, mockRes);
|
||||
|
||||
const response = mockRes.json.mock.calls[0][0];
|
||||
expect(response.data).toHaveLength(1);
|
||||
expect(response.data[0].id).toBe(agentA1.id);
|
||||
expect(response.data[0].name).toBe('Agent A1');
|
||||
});
|
||||
|
||||
test('should return multiple accessible agents', async () => {
|
||||
// User B has access to multiple agents
|
||||
mockReq.user.id = userB.toString();
|
||||
findAccessibleResources.mockResolvedValue([agentA1._id, agentA3._id, agentB1._id]);
|
||||
findPubliclyAccessibleResources.mockResolvedValue([]);
|
||||
|
||||
await getListAgentsHandler(mockReq, mockRes);
|
||||
|
||||
const response = mockRes.json.mock.calls[0][0];
|
||||
expect(response.data).toHaveLength(3);
|
||||
|
||||
const agentIds = response.data.map((a) => a.id);
|
||||
expect(agentIds).toContain(agentA1.id);
|
||||
expect(agentIds).toContain(agentA3.id);
|
||||
expect(agentIds).toContain(agentB1.id);
|
||||
expect(agentIds).not.toContain(agentA2.id);
|
||||
});
|
||||
|
||||
test('should apply category filter correctly with ACL', async () => {
|
||||
// User has access to all agents but filters by category
|
||||
mockReq.user.id = userB.toString();
|
||||
mockReq.query.category = 'productivity';
|
||||
findAccessibleResources.mockResolvedValue([agentA1._id, agentA2._id, agentA3._id]);
|
||||
findPubliclyAccessibleResources.mockResolvedValue([]);
|
||||
|
||||
await getListAgentsHandler(mockReq, mockRes);
|
||||
|
||||
const response = mockRes.json.mock.calls[0][0];
|
||||
expect(response.data).toHaveLength(1);
|
||||
expect(response.data[0].id).toBe(agentA3.id);
|
||||
expect(response.data[0].category).toBe('productivity');
|
||||
});
|
||||
|
||||
test('should apply search filter correctly with ACL', async () => {
|
||||
// User has access to multiple agents but searches for specific one
|
||||
mockReq.user.id = userB.toString();
|
||||
mockReq.query.search = 'A2';
|
||||
findAccessibleResources.mockResolvedValue([agentA1._id, agentA2._id, agentA3._id]);
|
||||
findPubliclyAccessibleResources.mockResolvedValue([]);
|
||||
|
||||
await getListAgentsHandler(mockReq, mockRes);
|
||||
|
||||
const response = mockRes.json.mock.calls[0][0];
|
||||
expect(response.data).toHaveLength(1);
|
||||
expect(response.data[0].id).toBe(agentA2.id);
|
||||
});
|
||||
|
||||
test('should handle pagination with ACL filtering', async () => {
|
||||
// Create more agents for pagination testing
|
||||
const moreAgents = [];
|
||||
for (let i = 4; i <= 10; i++) {
|
||||
const agent = await Agent.create({
|
||||
id: `agent_${nanoid(12)}`,
|
||||
name: `Agent A${i}`,
|
||||
description: `User A agent ${i}`,
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: userA,
|
||||
versions: [
|
||||
{
|
||||
name: `Agent A${i}`,
|
||||
description: `User A agent ${i}`,
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
createdAt: new Date(),
|
||||
updatedAt: new Date(),
|
||||
},
|
||||
],
|
||||
});
|
||||
moreAgents.push(agent);
|
||||
}
|
||||
|
||||
// User has access to all agents
|
||||
const allAgentIds = [agentA1, agentA2, agentA3, ...moreAgents].map((a) => a._id);
|
||||
mockReq.user.id = userB.toString();
|
||||
mockReq.query.limit = '5';
|
||||
findAccessibleResources.mockResolvedValue(allAgentIds);
|
||||
findPubliclyAccessibleResources.mockResolvedValue([]);
|
||||
|
||||
await getListAgentsHandler(mockReq, mockRes);
|
||||
|
||||
const response = mockRes.json.mock.calls[0][0];
|
||||
expect(response.data).toHaveLength(5);
|
||||
expect(response.has_more).toBe(true);
|
||||
expect(response.after).toBeTruthy();
|
||||
});
|
||||
|
||||
test('should mark publicly accessible agents', async () => {
|
||||
// User has access to agents, some are public
|
||||
mockReq.user.id = userB.toString();
|
||||
findAccessibleResources.mockResolvedValue([agentA1._id, agentA2._id]);
|
||||
findPubliclyAccessibleResources.mockResolvedValue([agentA2._id]);
|
||||
|
||||
await getListAgentsHandler(mockReq, mockRes);
|
||||
|
||||
const response = mockRes.json.mock.calls[0][0];
|
||||
expect(response.data).toHaveLength(2);
|
||||
|
||||
const publicAgent = response.data.find((a) => a.id === agentA2.id);
|
||||
const privateAgent = response.data.find((a) => a.id === agentA1.id);
|
||||
|
||||
expect(publicAgent.isPublic).toBe(true);
|
||||
expect(privateAgent.isPublic).toBeUndefined();
|
||||
});
|
||||
|
||||
test('should handle requiredPermission parameter', async () => {
|
||||
// Test with different permission levels
|
||||
mockReq.user.id = userB.toString();
|
||||
mockReq.query.requiredPermission = '15'; // FULL_ACCESS
|
||||
findAccessibleResources.mockResolvedValue([agentA1._id]);
|
||||
findPubliclyAccessibleResources.mockResolvedValue([]);
|
||||
|
||||
await getListAgentsHandler(mockReq, mockRes);
|
||||
|
||||
expect(findAccessibleResources).toHaveBeenCalledWith({
|
||||
userId: userB.toString(),
|
||||
role: 'USER',
|
||||
resourceType: 'agent',
|
||||
requiredPermissions: 15,
|
||||
});
|
||||
|
||||
const response = mockRes.json.mock.calls[0][0];
|
||||
expect(response.data).toHaveLength(1);
|
||||
});
|
||||
|
||||
test('should handle promoted filter with ACL', async () => {
|
||||
// Create a promoted agent
|
||||
const promotedAgent = await Agent.create({
|
||||
id: `agent_${nanoid(12)}`,
|
||||
name: 'Promoted Agent',
|
||||
description: 'A promoted agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: userA,
|
||||
is_promoted: true,
|
||||
versions: [
|
||||
{
|
||||
name: 'Promoted Agent',
|
||||
description: 'A promoted agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
is_promoted: true,
|
||||
createdAt: new Date(),
|
||||
updatedAt: new Date(),
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
mockReq.user.id = userB.toString();
|
||||
mockReq.query.promoted = '1';
|
||||
findAccessibleResources.mockResolvedValue([agentA1._id, agentA2._id, promotedAgent._id]);
|
||||
findPubliclyAccessibleResources.mockResolvedValue([]);
|
||||
|
||||
await getListAgentsHandler(mockReq, mockRes);
|
||||
|
||||
const response = mockRes.json.mock.calls[0][0];
|
||||
expect(response.data).toHaveLength(1);
|
||||
expect(response.data[0].id).toBe(promotedAgent.id);
|
||||
expect(response.data[0].is_promoted).toBe(true);
|
||||
});
|
||||
|
||||
test('should handle errors gracefully', async () => {
|
||||
mockReq.user.id = userB.toString();
|
||||
findAccessibleResources.mockRejectedValue(new Error('Permission service error'));
|
||||
|
||||
await getListAgentsHandler(mockReq, mockRes);
|
||||
|
||||
expect(mockRes.status).toHaveBeenCalledWith(500);
|
||||
expect(mockRes.json).toHaveBeenCalledWith({
|
||||
error: 'Permission service error',
|
||||
});
|
||||
});
|
||||
|
||||
test('should respect combined filters with ACL', async () => {
|
||||
// Create agents with specific attributes
|
||||
const productivityPromoted = await Agent.create({
|
||||
id: `agent_${nanoid(12)}`,
|
||||
name: 'Productivity Pro',
|
||||
description: 'A promoted productivity agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: userA,
|
||||
category: 'productivity',
|
||||
is_promoted: true,
|
||||
versions: [
|
||||
{
|
||||
name: 'Productivity Pro',
|
||||
description: 'A promoted productivity agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
category: 'productivity',
|
||||
is_promoted: true,
|
||||
createdAt: new Date(),
|
||||
updatedAt: new Date(),
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
mockReq.user.id = userB.toString();
|
||||
mockReq.query.category = 'productivity';
|
||||
mockReq.query.promoted = '1';
|
||||
findAccessibleResources.mockResolvedValue([
|
||||
agentA1._id,
|
||||
agentA2._id,
|
||||
agentA3._id,
|
||||
productivityPromoted._id,
|
||||
]);
|
||||
findPubliclyAccessibleResources.mockResolvedValue([]);
|
||||
|
||||
await getListAgentsHandler(mockReq, mockRes);
|
||||
|
||||
const response = mockRes.json.mock.calls[0][0];
|
||||
expect(response.data).toHaveLength(1);
|
||||
expect(response.data[0].id).toBe(productivityPromoted.id);
|
||||
expect(response.data[0].category).toBe('productivity');
|
||||
expect(response.data[0].is_promoted).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -22,10 +22,11 @@ const verify2FAWithTempToken = async (req, res) => {
|
||||
try {
|
||||
payload = jwt.verify(tempToken, process.env.JWT_SECRET);
|
||||
} catch (err) {
|
||||
logger.error('Failed to verify temporary token:', err);
|
||||
return res.status(401).json({ message: 'Invalid or expired temporary token' });
|
||||
}
|
||||
|
||||
const user = await getUserById(payload.userId);
|
||||
const user = await getUserById(payload.userId, '+totpSecret +backupCodes');
|
||||
if (!user || !user.twoFactorEnabled) {
|
||||
return res.status(400).json({ message: '2FA is not enabled for this user' });
|
||||
}
|
||||
@@ -42,11 +43,11 @@ const verify2FAWithTempToken = async (req, res) => {
|
||||
return res.status(401).json({ message: 'Invalid 2FA code or backup code' });
|
||||
}
|
||||
|
||||
// Prepare user data to return (omit sensitive fields).
|
||||
const userData = user.toObject ? user.toObject() : { ...user };
|
||||
delete userData.password;
|
||||
delete userData.__v;
|
||||
delete userData.password;
|
||||
delete userData.totpSecret;
|
||||
delete userData.backupCodes;
|
||||
userData.id = user._id.toString();
|
||||
|
||||
const authToken = await setAuthTokens(user._id, res);
|
||||
|
||||
@@ -14,6 +14,7 @@ const { isEnabled, ErrorController } = require('@librechat/api');
|
||||
const { connectDb, indexSync } = require('~/db');
|
||||
const validateImageRequest = require('./middleware/validateImageRequest');
|
||||
const { jwtLogin, ldapLogin, passportLogin } = require('~/strategies');
|
||||
const { checkMigrations } = require('./services/start/migration');
|
||||
const initializeMCPs = require('./services/initializeMCPs');
|
||||
const configureSocialLogins = require('./socialLogins');
|
||||
const AppService = require('./services/AppService');
|
||||
@@ -115,6 +116,8 @@ const startServer = async () => {
|
||||
app.use('/api/agents', routes.agents);
|
||||
app.use('/api/banner', routes.banner);
|
||||
app.use('/api/memories', routes.memories);
|
||||
app.use('/api/permissions', routes.accessPermissions);
|
||||
|
||||
app.use('/api/tags', routes.tags);
|
||||
app.use('/api/mcp', routes.mcp);
|
||||
|
||||
@@ -143,7 +146,7 @@ const startServer = async () => {
|
||||
logger.info(`Server listening at http://${host == '0.0.0.0' ? 'localhost' : host}:${port}`);
|
||||
}
|
||||
|
||||
initializeMCPs(app);
|
||||
initializeMCPs(app).then(() => checkMigrations());
|
||||
});
|
||||
};
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { countTokens, isEnabled, sendEvent } = require('@librechat/api');
|
||||
const { isAssistantsEndpoint, ErrorTypes } = require('librechat-data-provider');
|
||||
const { isAssistantsEndpoint, ErrorTypes, Constants } = require('librechat-data-provider');
|
||||
const { truncateText, smartTruncateText } = require('~/app/clients/prompts');
|
||||
const clearPendingReq = require('~/cache/clearPendingReq');
|
||||
const { sendError } = require('~/server/middleware/error');
|
||||
@@ -11,6 +11,10 @@ const { abortRun } = require('./abortRun');
|
||||
|
||||
const abortDataMap = new WeakMap();
|
||||
|
||||
/**
|
||||
* @param {string} abortKey
|
||||
* @returns {boolean}
|
||||
*/
|
||||
function cleanupAbortController(abortKey) {
|
||||
if (!abortControllers.has(abortKey)) {
|
||||
return false;
|
||||
@@ -71,6 +75,20 @@ function cleanupAbortController(abortKey) {
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {string} abortKey
|
||||
* @returns {function(): void}
|
||||
*/
|
||||
function createCleanUpHandler(abortKey) {
|
||||
return function () {
|
||||
try {
|
||||
cleanupAbortController(abortKey);
|
||||
} catch {
|
||||
// Ignore cleanup errors
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
async function abortMessage(req, res) {
|
||||
let { abortKey, endpoint } = req.body;
|
||||
|
||||
@@ -172,11 +190,15 @@ const createAbortController = (req, res, getAbortData, getReqData) => {
|
||||
/**
|
||||
* @param {TMessage} userMessage
|
||||
* @param {string} responseMessageId
|
||||
* @param {boolean} [isNewConvo]
|
||||
*/
|
||||
const onStart = (userMessage, responseMessageId) => {
|
||||
const onStart = (userMessage, responseMessageId, isNewConvo) => {
|
||||
sendEvent(res, { message: userMessage, created: true });
|
||||
|
||||
const abortKey = userMessage?.conversationId ?? req.user.id;
|
||||
const prelimAbortKey = userMessage?.conversationId ?? req.user.id;
|
||||
const abortKey = isNewConvo
|
||||
? `${prelimAbortKey}${Constants.COMMON_DIVIDER}${Constants.NEW_CONVO}`
|
||||
: prelimAbortKey;
|
||||
getReqData({ abortKey });
|
||||
const prevRequest = abortControllers.get(abortKey);
|
||||
const { overrideUserMessageId } = req?.body ?? {};
|
||||
@@ -194,16 +216,7 @@ const createAbortController = (req, res, getAbortData, getReqData) => {
|
||||
};
|
||||
|
||||
abortControllers.set(addedAbortKey, { abortController, ...minimalOptions });
|
||||
|
||||
// Use a simple function for cleanup to avoid capturing context
|
||||
const cleanupHandler = () => {
|
||||
try {
|
||||
cleanupAbortController(addedAbortKey);
|
||||
} catch (e) {
|
||||
// Ignore cleanup errors
|
||||
}
|
||||
};
|
||||
|
||||
const cleanupHandler = createCleanUpHandler(addedAbortKey);
|
||||
res.on('finish', cleanupHandler);
|
||||
return;
|
||||
}
|
||||
@@ -216,16 +229,7 @@ const createAbortController = (req, res, getAbortData, getReqData) => {
|
||||
};
|
||||
|
||||
abortControllers.set(abortKey, { abortController, ...minimalOptions });
|
||||
|
||||
// Use a simple function for cleanup to avoid capturing context
|
||||
const cleanupHandler = () => {
|
||||
try {
|
||||
cleanupAbortController(abortKey);
|
||||
} catch (e) {
|
||||
// Ignore cleanup errors
|
||||
}
|
||||
};
|
||||
|
||||
const cleanupHandler = createCleanUpHandler(abortKey);
|
||||
res.on('finish', cleanupHandler);
|
||||
};
|
||||
|
||||
@@ -364,15 +368,7 @@ const handleAbortError = async (res, req, error, data) => {
|
||||
};
|
||||
}
|
||||
|
||||
// Create a simple callback without capturing parent scope
|
||||
const callback = async () => {
|
||||
try {
|
||||
cleanupAbortController(conversationId);
|
||||
} catch (e) {
|
||||
// Ignore cleanup errors
|
||||
}
|
||||
};
|
||||
|
||||
const callback = createCleanUpHandler(conversationId);
|
||||
await sendError(req, res, options, callback);
|
||||
};
|
||||
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { Constants, isAgentsEndpoint, ResourceType } = require('librechat-data-provider');
|
||||
const { canAccessResource } = require('./canAccessResource');
|
||||
const { getAgent } = require('~/models/Agent');
|
||||
|
||||
/**
|
||||
* Agent ID resolver function for agent_id from request body
|
||||
* Resolves custom agent ID (e.g., "agent_abc123") to MongoDB ObjectId
|
||||
* This is used specifically for chat routes where agent_id comes from request body
|
||||
*
|
||||
* @param {string} agentCustomId - Custom agent ID from request body
|
||||
* @returns {Promise<Object|null>} Agent document with _id field, or null if not found
|
||||
*/
|
||||
const resolveAgentIdFromBody = async (agentCustomId) => {
|
||||
// Handle ephemeral agents - they don't need permission checks
|
||||
if (agentCustomId === Constants.EPHEMERAL_AGENT_ID) {
|
||||
return null; // No permission check needed for ephemeral agents
|
||||
}
|
||||
|
||||
return await getAgent({ id: agentCustomId });
|
||||
};
|
||||
|
||||
/**
|
||||
* Middleware factory that creates middleware to check agent access permissions from request body.
|
||||
* This middleware is specifically designed for chat routes where the agent_id comes from req.body
|
||||
* instead of route parameters.
|
||||
*
|
||||
* @param {Object} options - Configuration options
|
||||
* @param {number} options.requiredPermission - The permission bit required (1=view, 2=edit, 4=delete, 8=share)
|
||||
* @returns {Function} Express middleware function
|
||||
*
|
||||
* @example
|
||||
* // Basic usage for agent chat (requires VIEW permission)
|
||||
* router.post('/chat',
|
||||
* canAccessAgentFromBody({ requiredPermission: PermissionBits.VIEW }),
|
||||
* buildEndpointOption,
|
||||
* chatController
|
||||
* );
|
||||
*/
|
||||
const canAccessAgentFromBody = (options) => {
|
||||
const { requiredPermission } = options;
|
||||
|
||||
// Validate required options
|
||||
if (!requiredPermission || typeof requiredPermission !== 'number') {
|
||||
throw new Error('canAccessAgentFromBody: requiredPermission is required and must be a number');
|
||||
}
|
||||
|
||||
return async (req, res, next) => {
|
||||
try {
|
||||
const { endpoint, agent_id } = req.body;
|
||||
let agentId = agent_id;
|
||||
|
||||
if (!isAgentsEndpoint(endpoint)) {
|
||||
agentId = Constants.EPHEMERAL_AGENT_ID;
|
||||
}
|
||||
|
||||
if (!agentId) {
|
||||
return res.status(400).json({
|
||||
error: 'Bad Request',
|
||||
message: 'agent_id is required in request body',
|
||||
});
|
||||
}
|
||||
|
||||
// Skip permission checks for ephemeral agents
|
||||
if (agentId === Constants.EPHEMERAL_AGENT_ID) {
|
||||
return next();
|
||||
}
|
||||
|
||||
const agentAccessMiddleware = canAccessResource({
|
||||
resourceType: ResourceType.AGENT,
|
||||
requiredPermission,
|
||||
resourceIdParam: 'agent_id', // This will be ignored since we use custom resolver
|
||||
idResolver: () => resolveAgentIdFromBody(agentId),
|
||||
});
|
||||
|
||||
const tempReq = {
|
||||
...req,
|
||||
params: {
|
||||
...req.params,
|
||||
agent_id: agentId,
|
||||
},
|
||||
};
|
||||
|
||||
return agentAccessMiddleware(tempReq, res, next);
|
||||
} catch (error) {
|
||||
logger.error('Failed to validate agent access permissions', error);
|
||||
return res.status(500).json({
|
||||
error: 'Internal Server Error',
|
||||
message: 'Failed to validate agent access permissions',
|
||||
});
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
canAccessAgentFromBody,
|
||||
};
|
||||
@@ -0,0 +1,59 @@
|
||||
const { ResourceType } = require('librechat-data-provider');
|
||||
const { canAccessResource } = require('./canAccessResource');
|
||||
const { getAgent } = require('~/models/Agent');
|
||||
|
||||
/**
|
||||
* Agent ID resolver function
|
||||
* Resolves custom agent ID (e.g., "agent_abc123") to MongoDB ObjectId
|
||||
*
|
||||
* @param {string} agentCustomId - Custom agent ID from route parameter
|
||||
* @returns {Promise<Object|null>} Agent document with _id field, or null if not found
|
||||
*/
|
||||
const resolveAgentId = async (agentCustomId) => {
|
||||
return await getAgent({ id: agentCustomId });
|
||||
};
|
||||
|
||||
/**
|
||||
* Agent-specific middleware factory that creates middleware to check agent access permissions.
|
||||
* This middleware extends the generic canAccessResource to handle agent custom ID resolution.
|
||||
*
|
||||
* @param {Object} options - Configuration options
|
||||
* @param {number} options.requiredPermission - The permission bit required (1=view, 2=edit, 4=delete, 8=share)
|
||||
* @param {string} [options.resourceIdParam='id'] - The name of the route parameter containing the agent custom ID
|
||||
* @returns {Function} Express middleware function
|
||||
*
|
||||
* @example
|
||||
* // Basic usage for viewing agents
|
||||
* router.get('/agents/:id',
|
||||
* canAccessAgentResource({ requiredPermission: 1 }),
|
||||
* getAgent
|
||||
* );
|
||||
*
|
||||
* @example
|
||||
* // Custom resource ID parameter and edit permission
|
||||
* router.patch('/agents/:agent_id',
|
||||
* canAccessAgentResource({
|
||||
* requiredPermission: 2,
|
||||
* resourceIdParam: 'agent_id'
|
||||
* }),
|
||||
* updateAgent
|
||||
* );
|
||||
*/
|
||||
const canAccessAgentResource = (options) => {
|
||||
const { requiredPermission, resourceIdParam = 'id' } = options;
|
||||
|
||||
if (!requiredPermission || typeof requiredPermission !== 'number') {
|
||||
throw new Error('canAccessAgentResource: requiredPermission is required and must be a number');
|
||||
}
|
||||
|
||||
return canAccessResource({
|
||||
resourceType: ResourceType.AGENT,
|
||||
requiredPermission,
|
||||
resourceIdParam,
|
||||
idResolver: resolveAgentId,
|
||||
});
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
canAccessAgentResource,
|
||||
};
|
||||
@@ -0,0 +1,385 @@
|
||||
const mongoose = require('mongoose');
|
||||
const { ResourceType, PrincipalType, PrincipalModel } = require('librechat-data-provider');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { canAccessAgentResource } = require('./canAccessAgentResource');
|
||||
const { User, Role, AclEntry } = require('~/db/models');
|
||||
const { createAgent } = require('~/models/Agent');
|
||||
|
||||
describe('canAccessAgentResource middleware', () => {
|
||||
let mongoServer;
|
||||
let req, res, next;
|
||||
let testUser;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
await mongoose.connect(mongoUri);
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
await mongoose.connection.dropDatabase();
|
||||
await Role.create({
|
||||
name: 'test-role',
|
||||
permissions: {
|
||||
AGENTS: {
|
||||
USE: true,
|
||||
CREATE: true,
|
||||
SHARED_GLOBAL: false,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Create a test user
|
||||
testUser = await User.create({
|
||||
email: 'test@example.com',
|
||||
name: 'Test User',
|
||||
username: 'testuser',
|
||||
role: 'test-role',
|
||||
});
|
||||
|
||||
req = {
|
||||
user: { id: testUser._id, role: testUser.role },
|
||||
params: {},
|
||||
};
|
||||
res = {
|
||||
status: jest.fn().mockReturnThis(),
|
||||
json: jest.fn(),
|
||||
};
|
||||
next = jest.fn();
|
||||
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('middleware factory', () => {
|
||||
test('should throw error if requiredPermission is not provided', () => {
|
||||
expect(() => canAccessAgentResource({})).toThrow(
|
||||
'canAccessAgentResource: requiredPermission is required and must be a number',
|
||||
);
|
||||
});
|
||||
|
||||
test('should throw error if requiredPermission is not a number', () => {
|
||||
expect(() => canAccessAgentResource({ requiredPermission: '1' })).toThrow(
|
||||
'canAccessAgentResource: requiredPermission is required and must be a number',
|
||||
);
|
||||
});
|
||||
|
||||
test('should create middleware with default resourceIdParam', () => {
|
||||
const middleware = canAccessAgentResource({ requiredPermission: 1 });
|
||||
expect(typeof middleware).toBe('function');
|
||||
expect(middleware.length).toBe(3); // Express middleware signature
|
||||
});
|
||||
|
||||
test('should create middleware with custom resourceIdParam', () => {
|
||||
const middleware = canAccessAgentResource({
|
||||
requiredPermission: 2,
|
||||
resourceIdParam: 'agent_id',
|
||||
});
|
||||
expect(typeof middleware).toBe('function');
|
||||
expect(middleware.length).toBe(3);
|
||||
});
|
||||
});
|
||||
|
||||
describe('permission checking with real agents', () => {
|
||||
test('should allow access when user is the agent author', async () => {
|
||||
// Create an agent owned by the test user
|
||||
const agent = await createAgent({
|
||||
id: `agent_${Date.now()}`,
|
||||
name: 'Test Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: testUser._id,
|
||||
});
|
||||
|
||||
// Create ACL entry for the author (owner permissions)
|
||||
await AclEntry.create({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUser._id,
|
||||
principalModel: PrincipalModel.USER,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
permBits: 15, // All permissions (1+2+4+8)
|
||||
grantedBy: testUser._id,
|
||||
});
|
||||
|
||||
req.params.id = agent.id;
|
||||
|
||||
const middleware = canAccessAgentResource({ requiredPermission: 1 }); // VIEW permission
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).toHaveBeenCalled();
|
||||
expect(res.status).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should deny access when user is not the author and has no ACL entry', async () => {
|
||||
// Create an agent owned by a different user
|
||||
const otherUser = await User.create({
|
||||
email: 'other@example.com',
|
||||
name: 'Other User',
|
||||
username: 'otheruser',
|
||||
role: 'test-role',
|
||||
});
|
||||
|
||||
const agent = await createAgent({
|
||||
id: `agent_${Date.now()}`,
|
||||
name: 'Other User Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: otherUser._id,
|
||||
});
|
||||
|
||||
// Create ACL entry for the other user (owner)
|
||||
await AclEntry.create({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: otherUser._id,
|
||||
principalModel: PrincipalModel.USER,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
permBits: 15, // All permissions
|
||||
grantedBy: otherUser._id,
|
||||
});
|
||||
|
||||
req.params.id = agent.id;
|
||||
|
||||
const middleware = canAccessAgentResource({ requiredPermission: 1 }); // VIEW permission
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.json).toHaveBeenCalledWith({
|
||||
error: 'Forbidden',
|
||||
message: 'Insufficient permissions to access this agent',
|
||||
});
|
||||
});
|
||||
|
||||
test('should allow access when user has ACL entry with sufficient permissions', async () => {
|
||||
// Create an agent owned by a different user
|
||||
const otherUser = await User.create({
|
||||
email: 'other2@example.com',
|
||||
name: 'Other User 2',
|
||||
username: 'otheruser2',
|
||||
role: 'test-role',
|
||||
});
|
||||
|
||||
const agent = await createAgent({
|
||||
id: `agent_${Date.now()}`,
|
||||
name: 'Shared Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: otherUser._id,
|
||||
});
|
||||
|
||||
// Create ACL entry granting view permission to test user
|
||||
await AclEntry.create({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUser._id,
|
||||
principalModel: PrincipalModel.USER,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
permBits: 1, // VIEW permission
|
||||
grantedBy: otherUser._id,
|
||||
});
|
||||
|
||||
req.params.id = agent.id;
|
||||
|
||||
const middleware = canAccessAgentResource({ requiredPermission: 1 }); // VIEW permission
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).toHaveBeenCalled();
|
||||
expect(res.status).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should deny access when ACL permissions are insufficient', async () => {
|
||||
// Create an agent owned by a different user
|
||||
const otherUser = await User.create({
|
||||
email: 'other3@example.com',
|
||||
name: 'Other User 3',
|
||||
username: 'otheruser3',
|
||||
role: 'test-role',
|
||||
});
|
||||
|
||||
const agent = await createAgent({
|
||||
id: `agent_${Date.now()}`,
|
||||
name: 'Limited Access Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: otherUser._id,
|
||||
});
|
||||
|
||||
// Create ACL entry granting only view permission
|
||||
await AclEntry.create({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUser._id,
|
||||
principalModel: PrincipalModel.USER,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
permBits: 1, // VIEW permission only
|
||||
grantedBy: otherUser._id,
|
||||
});
|
||||
|
||||
req.params.id = agent.id;
|
||||
|
||||
const middleware = canAccessAgentResource({ requiredPermission: 2 }); // EDIT permission required
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.json).toHaveBeenCalledWith({
|
||||
error: 'Forbidden',
|
||||
message: 'Insufficient permissions to access this agent',
|
||||
});
|
||||
});
|
||||
|
||||
test('should handle non-existent agent', async () => {
|
||||
req.params.id = 'agent_nonexistent';
|
||||
|
||||
const middleware = canAccessAgentResource({ requiredPermission: 1 });
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
expect(res.status).toHaveBeenCalledWith(404);
|
||||
expect(res.json).toHaveBeenCalledWith({
|
||||
error: 'Not Found',
|
||||
message: 'agent not found',
|
||||
});
|
||||
});
|
||||
|
||||
test('should use custom resourceIdParam', async () => {
|
||||
const agent = await createAgent({
|
||||
id: `agent_${Date.now()}`,
|
||||
name: 'Custom Param Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: testUser._id,
|
||||
});
|
||||
|
||||
// Create ACL entry for the author
|
||||
await AclEntry.create({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUser._id,
|
||||
principalModel: PrincipalModel.USER,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
permBits: 15, // All permissions
|
||||
grantedBy: testUser._id,
|
||||
});
|
||||
|
||||
req.params.agent_id = agent.id; // Using custom param name
|
||||
|
||||
const middleware = canAccessAgentResource({
|
||||
requiredPermission: 1,
|
||||
resourceIdParam: 'agent_id',
|
||||
});
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).toHaveBeenCalled();
|
||||
expect(res.status).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('permission levels', () => {
|
||||
let agent;
|
||||
|
||||
beforeEach(async () => {
|
||||
agent = await createAgent({
|
||||
id: `agent_${Date.now()}`,
|
||||
name: 'Permission Test Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: testUser._id,
|
||||
});
|
||||
|
||||
// Create ACL entry with all permissions for the owner
|
||||
await AclEntry.create({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUser._id,
|
||||
principalModel: PrincipalModel.USER,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
permBits: 15, // All permissions (1+2+4+8)
|
||||
grantedBy: testUser._id,
|
||||
});
|
||||
|
||||
req.params.id = agent.id;
|
||||
});
|
||||
|
||||
test('should support view permission (1)', async () => {
|
||||
const middleware = canAccessAgentResource({ requiredPermission: 1 });
|
||||
await middleware(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should support edit permission (2)', async () => {
|
||||
const middleware = canAccessAgentResource({ requiredPermission: 2 });
|
||||
await middleware(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should support delete permission (4)', async () => {
|
||||
const middleware = canAccessAgentResource({ requiredPermission: 4 });
|
||||
await middleware(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should support share permission (8)', async () => {
|
||||
const middleware = canAccessAgentResource({ requiredPermission: 8 });
|
||||
await middleware(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should support combined permissions', async () => {
|
||||
const viewAndEdit = 1 | 2; // 3
|
||||
const middleware = canAccessAgentResource({ requiredPermission: viewAndEdit });
|
||||
await middleware(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('integration with agent operations', () => {
|
||||
test('should work with agent CRUD operations', async () => {
|
||||
const agentId = `agent_${Date.now()}`;
|
||||
|
||||
// Create agent
|
||||
const agent = await createAgent({
|
||||
id: agentId,
|
||||
name: 'Integration Test Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: testUser._id,
|
||||
description: 'Testing integration',
|
||||
});
|
||||
|
||||
// Create ACL entry for the author
|
||||
await AclEntry.create({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUser._id,
|
||||
principalModel: PrincipalModel.USER,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
permBits: 15, // All permissions
|
||||
grantedBy: testUser._id,
|
||||
});
|
||||
|
||||
req.params.id = agentId;
|
||||
|
||||
// Test view access
|
||||
const viewMiddleware = canAccessAgentResource({ requiredPermission: 1 });
|
||||
await viewMiddleware(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
jest.clearAllMocks();
|
||||
|
||||
// Update the agent
|
||||
const { updateAgent } = require('~/models/Agent');
|
||||
await updateAgent({ id: agentId }, { description: 'Updated description' });
|
||||
|
||||
// Test edit access
|
||||
const editMiddleware = canAccessAgentResource({ requiredPermission: 2 });
|
||||
await editMiddleware(req, res, next);
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,61 @@
|
||||
const { ResourceType } = require('librechat-data-provider');
|
||||
const { canAccessResource } = require('./canAccessResource');
|
||||
const { getPromptGroup } = require('~/models/Prompt');
|
||||
|
||||
/**
|
||||
* PromptGroup ID resolver function
|
||||
* Resolves promptGroup ID to MongoDB ObjectId
|
||||
*
|
||||
* @param {string} groupId - PromptGroup ID from route parameter
|
||||
* @returns {Promise<Object|null>} PromptGroup document with _id field, or null if not found
|
||||
*/
|
||||
const resolvePromptGroupId = async (groupId) => {
|
||||
return await getPromptGroup({ _id: groupId });
|
||||
};
|
||||
|
||||
/**
|
||||
* PromptGroup-specific middleware factory that creates middleware to check promptGroup access permissions.
|
||||
* This middleware extends the generic canAccessResource to handle promptGroup ID resolution.
|
||||
*
|
||||
* @param {Object} options - Configuration options
|
||||
* @param {number} options.requiredPermission - The permission bit required (1=view, 2=edit, 4=delete, 8=share)
|
||||
* @param {string} [options.resourceIdParam='groupId'] - The name of the route parameter containing the promptGroup ID
|
||||
* @returns {Function} Express middleware function
|
||||
*
|
||||
* @example
|
||||
* // Basic usage for viewing promptGroups
|
||||
* router.get('/prompts/groups/:groupId',
|
||||
* canAccessPromptGroupResource({ requiredPermission: 1 }),
|
||||
* getPromptGroup
|
||||
* );
|
||||
*
|
||||
* @example
|
||||
* // Custom resource ID parameter and edit permission
|
||||
* router.patch('/prompts/groups/:id',
|
||||
* canAccessPromptGroupResource({
|
||||
* requiredPermission: 2,
|
||||
* resourceIdParam: 'id'
|
||||
* }),
|
||||
* updatePromptGroup
|
||||
* );
|
||||
*/
|
||||
const canAccessPromptGroupResource = (options) => {
|
||||
const { requiredPermission, resourceIdParam = 'groupId' } = options;
|
||||
|
||||
if (!requiredPermission || typeof requiredPermission !== 'number') {
|
||||
throw new Error(
|
||||
'canAccessPromptGroupResource: requiredPermission is required and must be a number',
|
||||
);
|
||||
}
|
||||
|
||||
return canAccessResource({
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
requiredPermission,
|
||||
resourceIdParam,
|
||||
idResolver: resolvePromptGroupId,
|
||||
});
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
canAccessPromptGroupResource,
|
||||
};
|
||||
@@ -0,0 +1,55 @@
|
||||
const { ResourceType } = require('librechat-data-provider');
|
||||
const { canAccessResource } = require('./canAccessResource');
|
||||
const { getPrompt } = require('~/models/Prompt');
|
||||
|
||||
/**
|
||||
* Prompt to PromptGroup ID resolver function
|
||||
* Resolves prompt ID to its parent promptGroup ID
|
||||
*
|
||||
* @param {string} promptId - Prompt ID from route parameter
|
||||
* @returns {Promise<Object|null>} Object with promptGroup's _id field, or null if not found
|
||||
*/
|
||||
const resolvePromptToGroupId = async (promptId) => {
|
||||
const prompt = await getPrompt({ _id: promptId });
|
||||
if (!prompt || !prompt.groupId) {
|
||||
return null;
|
||||
}
|
||||
// Return an object with _id that matches the promptGroup ID
|
||||
return { _id: prompt.groupId };
|
||||
};
|
||||
|
||||
/**
|
||||
* Middleware factory that checks promptGroup permissions when accessing individual prompts.
|
||||
* This allows permission management at the promptGroup level while still supporting
|
||||
* individual prompt access patterns.
|
||||
*
|
||||
* @param {Object} options - Configuration options
|
||||
* @param {number} options.requiredPermission - The permission bit required (1=view, 2=edit, 4=delete, 8=share)
|
||||
* @param {string} [options.resourceIdParam='promptId'] - The name of the route parameter containing the prompt ID
|
||||
* @returns {Function} Express middleware function
|
||||
*
|
||||
* @example
|
||||
* // Check promptGroup permissions when viewing a prompt
|
||||
* router.get('/prompts/:promptId',
|
||||
* canAccessPromptViaGroup({ requiredPermission: 1 }),
|
||||
* getPrompt
|
||||
* );
|
||||
*/
|
||||
const canAccessPromptViaGroup = (options) => {
|
||||
const { requiredPermission, resourceIdParam = 'promptId' } = options;
|
||||
|
||||
if (!requiredPermission || typeof requiredPermission !== 'number') {
|
||||
throw new Error('canAccessPromptViaGroup: requiredPermission is required and must be a number');
|
||||
}
|
||||
|
||||
return canAccessResource({
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
requiredPermission,
|
||||
resourceIdParam,
|
||||
idResolver: resolvePromptToGroupId,
|
||||
});
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
canAccessPromptViaGroup,
|
||||
};
|
||||
158
api/server/middleware/accessResources/canAccessResource.js
Normal file
158
api/server/middleware/accessResources/canAccessResource.js
Normal file
@@ -0,0 +1,158 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { SystemRoles } = require('librechat-data-provider');
|
||||
const { checkPermission } = require('~/server/services/PermissionService');
|
||||
|
||||
/**
|
||||
* Generic base middleware factory that creates middleware to check resource access permissions.
|
||||
* This middleware expects MongoDB ObjectIds as resource identifiers for ACL permission checks.
|
||||
*
|
||||
* @param {Object} options - Configuration options
|
||||
* @param {string} options.resourceType - The type of resource (e.g., 'agent', 'file', 'project')
|
||||
* @param {number} options.requiredPermission - The permission bit required (1=view, 2=edit, 4=delete, 8=share)
|
||||
* @param {string} [options.resourceIdParam='resourceId'] - The name of the route parameter containing the resource ID
|
||||
* @param {Function} [options.idResolver] - Optional function to resolve custom IDs to ObjectIds
|
||||
* @returns {Function} Express middleware function
|
||||
*
|
||||
* @example
|
||||
* // Direct usage with ObjectId (for resources that use MongoDB ObjectId in routes)
|
||||
* router.get('/prompts/:promptId',
|
||||
* canAccessResource({ resourceType: 'prompt', requiredPermission: 1 }),
|
||||
* getPrompt
|
||||
* );
|
||||
*
|
||||
* @example
|
||||
* // Usage with custom ID resolver (for resources that use custom string IDs)
|
||||
* router.get('/agents/:id',
|
||||
* canAccessResource({
|
||||
* resourceType: 'agent',
|
||||
* requiredPermission: 1,
|
||||
* resourceIdParam: 'id',
|
||||
* idResolver: (customId) => resolveAgentId(customId)
|
||||
* }),
|
||||
* getAgent
|
||||
* );
|
||||
*/
|
||||
const canAccessResource = (options) => {
|
||||
const {
|
||||
resourceType,
|
||||
requiredPermission,
|
||||
resourceIdParam = 'resourceId',
|
||||
idResolver = null,
|
||||
} = options;
|
||||
|
||||
if (!resourceType || typeof resourceType !== 'string') {
|
||||
throw new Error('canAccessResource: resourceType is required and must be a string');
|
||||
}
|
||||
|
||||
if (!requiredPermission || typeof requiredPermission !== 'number') {
|
||||
throw new Error('canAccessResource: requiredPermission is required and must be a number');
|
||||
}
|
||||
|
||||
return async (req, res, next) => {
|
||||
try {
|
||||
// Extract resource ID from route parameters
|
||||
const rawResourceId = req.params[resourceIdParam];
|
||||
|
||||
if (!rawResourceId) {
|
||||
logger.warn(`[canAccessResource] Missing ${resourceIdParam} in route parameters`);
|
||||
return res.status(400).json({
|
||||
error: 'Bad Request',
|
||||
message: `${resourceIdParam} is required`,
|
||||
});
|
||||
}
|
||||
|
||||
// Check if user is authenticated
|
||||
if (!req.user || !req.user.id) {
|
||||
logger.warn(
|
||||
`[canAccessResource] Unauthenticated request for ${resourceType} ${rawResourceId}`,
|
||||
);
|
||||
return res.status(401).json({
|
||||
error: 'Unauthorized',
|
||||
message: 'Authentication required',
|
||||
});
|
||||
}
|
||||
// if system admin let through
|
||||
if (req.user.role === SystemRoles.ADMIN) {
|
||||
return next();
|
||||
}
|
||||
const userId = req.user.id;
|
||||
let resourceId = rawResourceId;
|
||||
let resourceInfo = null;
|
||||
|
||||
// Resolve custom ID to ObjectId if resolver is provided
|
||||
if (idResolver) {
|
||||
logger.debug(
|
||||
`[canAccessResource] Resolving ${resourceType} custom ID ${rawResourceId} to ObjectId`,
|
||||
);
|
||||
|
||||
const resolutionResult = await idResolver(rawResourceId);
|
||||
|
||||
if (!resolutionResult) {
|
||||
logger.warn(`[canAccessResource] ${resourceType} not found: ${rawResourceId}`);
|
||||
return res.status(404).json({
|
||||
error: 'Not Found',
|
||||
message: `${resourceType} not found`,
|
||||
});
|
||||
}
|
||||
|
||||
// Handle different resolver return formats
|
||||
if (typeof resolutionResult === 'string' || resolutionResult._id) {
|
||||
resourceId = resolutionResult._id || resolutionResult;
|
||||
resourceInfo = typeof resolutionResult === 'object' ? resolutionResult : null;
|
||||
} else {
|
||||
resourceId = resolutionResult;
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
`[canAccessResource] Resolved ${resourceType} ${rawResourceId} to ObjectId ${resourceId}`,
|
||||
);
|
||||
}
|
||||
|
||||
// Check permissions using PermissionService with ObjectId
|
||||
const hasPermission = await checkPermission({
|
||||
userId,
|
||||
role: req.user.role,
|
||||
resourceType,
|
||||
resourceId,
|
||||
requiredPermission,
|
||||
});
|
||||
|
||||
if (hasPermission) {
|
||||
logger.debug(
|
||||
`[canAccessResource] User ${userId} has permission ${requiredPermission} on ${resourceType} ${rawResourceId} (${resourceId})`,
|
||||
);
|
||||
|
||||
req.resourceAccess = {
|
||||
resourceType,
|
||||
resourceId, // MongoDB ObjectId for ACL operations
|
||||
customResourceId: rawResourceId, // Original ID from route params
|
||||
permission: requiredPermission,
|
||||
userId,
|
||||
...(resourceInfo && { resourceInfo }),
|
||||
};
|
||||
|
||||
return next();
|
||||
}
|
||||
|
||||
logger.warn(
|
||||
`[canAccessResource] User ${userId} denied access to ${resourceType} ${rawResourceId} ` +
|
||||
`(required permission: ${requiredPermission})`,
|
||||
);
|
||||
|
||||
return res.status(403).json({
|
||||
error: 'Forbidden',
|
||||
message: `Insufficient permissions to access this ${resourceType}`,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(`[canAccessResource] Error checking access for ${resourceType}:`, error);
|
||||
return res.status(500).json({
|
||||
error: 'Internal Server Error',
|
||||
message: 'Failed to check resource access permissions',
|
||||
});
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
canAccessResource,
|
||||
};
|
||||
125
api/server/middleware/accessResources/fileAccess.js
Normal file
125
api/server/middleware/accessResources/fileAccess.js
Normal file
@@ -0,0 +1,125 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { PermissionBits, hasPermissions, ResourceType } = require('librechat-data-provider');
|
||||
const { getEffectivePermissions } = require('~/server/services/PermissionService');
|
||||
const { getAgent } = require('~/models/Agent');
|
||||
const { getFiles } = require('~/models/File');
|
||||
|
||||
/**
|
||||
* Checks if user has access to a file through agent permissions
|
||||
* Files inherit permissions from agents - if you can view the agent, you can access its files
|
||||
*/
|
||||
const checkAgentBasedFileAccess = async ({ userId, role, fileId }) => {
|
||||
try {
|
||||
// Find agents that have this file in their tool_resources
|
||||
const agentsWithFile = await getAgent({
|
||||
$or: [
|
||||
{ 'tool_resources.file_search.file_ids': fileId },
|
||||
{ 'tool_resources.execute_code.file_ids': fileId },
|
||||
{ 'tool_resources.ocr.file_ids': fileId },
|
||||
],
|
||||
});
|
||||
|
||||
if (!agentsWithFile || agentsWithFile.length === 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if user has access to any of these agents
|
||||
for (const agent of Array.isArray(agentsWithFile) ? agentsWithFile : [agentsWithFile]) {
|
||||
// Check if user is the agent author
|
||||
if (agent.author && agent.author.toString() === userId) {
|
||||
logger.debug(`[fileAccess] User is author of agent ${agent.id}`);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check ACL permissions for VIEW access on the agent
|
||||
try {
|
||||
const permissions = await getEffectivePermissions({
|
||||
userId,
|
||||
role,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id || agent.id,
|
||||
});
|
||||
|
||||
if (hasPermissions(permissions, PermissionBits.VIEW)) {
|
||||
logger.debug(`[fileAccess] User ${userId} has VIEW permissions on agent ${agent.id}`);
|
||||
return true;
|
||||
}
|
||||
} catch (permissionError) {
|
||||
logger.warn(
|
||||
`[fileAccess] Permission check failed for agent ${agent.id}:`,
|
||||
permissionError.message,
|
||||
);
|
||||
// Continue checking other agents
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
} catch (error) {
|
||||
logger.error('[fileAccess] Error checking agent-based access:', error);
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Middleware to check if user can access a file
|
||||
* Checks: 1) File ownership, 2) Agent-based access (file inherits agent permissions)
|
||||
*/
|
||||
const fileAccess = async (req, res, next) => {
|
||||
try {
|
||||
const fileId = req.params.file_id;
|
||||
const userId = req.user?.id;
|
||||
const userRole = req.user?.role;
|
||||
if (!fileId) {
|
||||
return res.status(400).json({
|
||||
error: 'Bad Request',
|
||||
message: 'file_id is required',
|
||||
});
|
||||
}
|
||||
|
||||
if (!userId) {
|
||||
return res.status(401).json({
|
||||
error: 'Unauthorized',
|
||||
message: 'Authentication required',
|
||||
});
|
||||
}
|
||||
|
||||
// Get the file
|
||||
const [file] = await getFiles({ file_id: fileId });
|
||||
if (!file) {
|
||||
return res.status(404).json({
|
||||
error: 'Not Found',
|
||||
message: 'File not found',
|
||||
});
|
||||
}
|
||||
|
||||
// Check if user owns the file
|
||||
if (file.user && file.user.toString() === userId) {
|
||||
req.fileAccess = { file };
|
||||
return next();
|
||||
}
|
||||
|
||||
// Check agent-based access (file inherits agent permissions)
|
||||
const hasAgentAccess = await checkAgentBasedFileAccess({ userId, role: userRole, fileId });
|
||||
if (hasAgentAccess) {
|
||||
req.fileAccess = { file };
|
||||
return next();
|
||||
}
|
||||
|
||||
// No access
|
||||
logger.warn(`[fileAccess] User ${userId} denied access to file ${fileId}`);
|
||||
return res.status(403).json({
|
||||
error: 'Forbidden',
|
||||
message: 'Insufficient permissions to access this file',
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('[fileAccess] Error checking file access:', error);
|
||||
return res.status(500).json({
|
||||
error: 'Internal Server Error',
|
||||
message: 'Failed to check file access permissions',
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
fileAccess,
|
||||
};
|
||||
13
api/server/middleware/accessResources/index.js
Normal file
13
api/server/middleware/accessResources/index.js
Normal file
@@ -0,0 +1,13 @@
|
||||
const { canAccessResource } = require('./canAccessResource');
|
||||
const { canAccessAgentResource } = require('./canAccessAgentResource');
|
||||
const { canAccessAgentFromBody } = require('./canAccessAgentFromBody');
|
||||
const { canAccessPromptViaGroup } = require('./canAccessPromptViaGroup');
|
||||
const { canAccessPromptGroupResource } = require('./canAccessPromptGroupResource');
|
||||
|
||||
module.exports = {
|
||||
canAccessResource,
|
||||
canAccessAgentResource,
|
||||
canAccessAgentFromBody,
|
||||
canAccessPromptViaGroup,
|
||||
canAccessPromptGroupResource,
|
||||
};
|
||||
82
api/server/middleware/checkPeoplePickerAccess.js
Normal file
82
api/server/middleware/checkPeoplePickerAccess.js
Normal file
@@ -0,0 +1,82 @@
|
||||
const { PrincipalType, PermissionTypes, Permissions } = require('librechat-data-provider');
|
||||
const { getRoleByName } = require('~/models/Role');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* Middleware to check if user has permission to access people picker functionality
|
||||
* Checks specific permission based on the 'type' query parameter:
|
||||
* - type=user: requires VIEW_USERS permission
|
||||
* - type=group: requires VIEW_GROUPS permission
|
||||
* - type=role: requires VIEW_ROLES permission
|
||||
* - no type (mixed search): requires either VIEW_USERS OR VIEW_GROUPS OR VIEW_ROLES
|
||||
*/
|
||||
const checkPeoplePickerAccess = async (req, res, next) => {
|
||||
try {
|
||||
const user = req.user;
|
||||
if (!user || !user.role) {
|
||||
return res.status(401).json({
|
||||
error: 'Unauthorized',
|
||||
message: 'Authentication required',
|
||||
});
|
||||
}
|
||||
|
||||
const role = await getRoleByName(user.role);
|
||||
if (!role || !role.permissions) {
|
||||
return res.status(403).json({
|
||||
error: 'Forbidden',
|
||||
message: 'No permissions configured for user role',
|
||||
});
|
||||
}
|
||||
|
||||
const { type } = req.query;
|
||||
const peoplePickerPerms = role.permissions[PermissionTypes.PEOPLE_PICKER] || {};
|
||||
const canViewUsers = peoplePickerPerms[Permissions.VIEW_USERS] === true;
|
||||
const canViewGroups = peoplePickerPerms[Permissions.VIEW_GROUPS] === true;
|
||||
const canViewRoles = peoplePickerPerms[Permissions.VIEW_ROLES] === true;
|
||||
|
||||
const permissionChecks = {
|
||||
[PrincipalType.USER]: {
|
||||
hasPermission: canViewUsers,
|
||||
message: 'Insufficient permissions to search for users',
|
||||
},
|
||||
[PrincipalType.GROUP]: {
|
||||
hasPermission: canViewGroups,
|
||||
message: 'Insufficient permissions to search for groups',
|
||||
},
|
||||
[PrincipalType.ROLE]: {
|
||||
hasPermission: canViewRoles,
|
||||
message: 'Insufficient permissions to search for roles',
|
||||
},
|
||||
};
|
||||
|
||||
const check = permissionChecks[type];
|
||||
if (check && !check.hasPermission) {
|
||||
return res.status(403).json({
|
||||
error: 'Forbidden',
|
||||
message: check.message,
|
||||
});
|
||||
}
|
||||
|
||||
if (!type && !canViewUsers && !canViewGroups && !canViewRoles) {
|
||||
return res.status(403).json({
|
||||
error: 'Forbidden',
|
||||
message: 'Insufficient permissions to search for users, groups, or roles',
|
||||
});
|
||||
}
|
||||
|
||||
next();
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`[checkPeoplePickerAccess][${req.user?.id}] checkPeoplePickerAccess error for req.query.type = ${req.query.type}`,
|
||||
error,
|
||||
);
|
||||
return res.status(500).json({
|
||||
error: 'Internal Server Error',
|
||||
message: 'Failed to check permissions',
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
checkPeoplePickerAccess,
|
||||
};
|
||||
250
api/server/middleware/checkPeoplePickerAccess.spec.js
Normal file
250
api/server/middleware/checkPeoplePickerAccess.spec.js
Normal file
@@ -0,0 +1,250 @@
|
||||
const { PrincipalType, PermissionTypes, Permissions } = require('librechat-data-provider');
|
||||
const { checkPeoplePickerAccess } = require('./checkPeoplePickerAccess');
|
||||
const { getRoleByName } = require('~/models/Role');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
jest.mock('~/models/Role');
|
||||
jest.mock('~/config', () => ({
|
||||
logger: {
|
||||
error: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
describe('checkPeoplePickerAccess', () => {
|
||||
let req, res, next;
|
||||
|
||||
beforeEach(() => {
|
||||
req = {
|
||||
user: { id: 'user123', role: 'USER' },
|
||||
query: {},
|
||||
};
|
||||
res = {
|
||||
status: jest.fn().mockReturnThis(),
|
||||
json: jest.fn(),
|
||||
};
|
||||
next = jest.fn();
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should return 401 if user is not authenticated', async () => {
|
||||
req.user = null;
|
||||
|
||||
await checkPeoplePickerAccess(req, res, next);
|
||||
|
||||
expect(res.status).toHaveBeenCalledWith(401);
|
||||
expect(res.json).toHaveBeenCalledWith({
|
||||
error: 'Unauthorized',
|
||||
message: 'Authentication required',
|
||||
});
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return 403 if role has no permissions', async () => {
|
||||
getRoleByName.mockResolvedValue(null);
|
||||
|
||||
await checkPeoplePickerAccess(req, res, next);
|
||||
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.json).toHaveBeenCalledWith({
|
||||
error: 'Forbidden',
|
||||
message: 'No permissions configured for user role',
|
||||
});
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should allow access when searching for users with VIEW_USERS permission', async () => {
|
||||
req.query.type = PrincipalType.USER;
|
||||
getRoleByName.mockResolvedValue({
|
||||
permissions: {
|
||||
[PermissionTypes.PEOPLE_PICKER]: {
|
||||
[Permissions.VIEW_USERS]: true,
|
||||
[Permissions.VIEW_GROUPS]: false,
|
||||
[Permissions.VIEW_ROLES]: false,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
await checkPeoplePickerAccess(req, res, next);
|
||||
|
||||
expect(next).toHaveBeenCalled();
|
||||
expect(res.status).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should deny access when searching for users without VIEW_USERS permission', async () => {
|
||||
req.query.type = PrincipalType.USER;
|
||||
getRoleByName.mockResolvedValue({
|
||||
permissions: {
|
||||
[PermissionTypes.PEOPLE_PICKER]: {
|
||||
[Permissions.VIEW_USERS]: false,
|
||||
[Permissions.VIEW_GROUPS]: true,
|
||||
[Permissions.VIEW_ROLES]: true,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
await checkPeoplePickerAccess(req, res, next);
|
||||
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.json).toHaveBeenCalledWith({
|
||||
error: 'Forbidden',
|
||||
message: 'Insufficient permissions to search for users',
|
||||
});
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should allow access when searching for groups with VIEW_GROUPS permission', async () => {
|
||||
req.query.type = PrincipalType.GROUP;
|
||||
getRoleByName.mockResolvedValue({
|
||||
permissions: {
|
||||
[PermissionTypes.PEOPLE_PICKER]: {
|
||||
[Permissions.VIEW_USERS]: false,
|
||||
[Permissions.VIEW_GROUPS]: true,
|
||||
[Permissions.VIEW_ROLES]: false,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
await checkPeoplePickerAccess(req, res, next);
|
||||
|
||||
expect(next).toHaveBeenCalled();
|
||||
expect(res.status).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should deny access when searching for groups without VIEW_GROUPS permission', async () => {
|
||||
req.query.type = PrincipalType.GROUP;
|
||||
getRoleByName.mockResolvedValue({
|
||||
permissions: {
|
||||
[PermissionTypes.PEOPLE_PICKER]: {
|
||||
[Permissions.VIEW_USERS]: true,
|
||||
[Permissions.VIEW_GROUPS]: false,
|
||||
[Permissions.VIEW_ROLES]: true,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
await checkPeoplePickerAccess(req, res, next);
|
||||
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.json).toHaveBeenCalledWith({
|
||||
error: 'Forbidden',
|
||||
message: 'Insufficient permissions to search for groups',
|
||||
});
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should allow access when searching for roles with VIEW_ROLES permission', async () => {
|
||||
req.query.type = PrincipalType.ROLE;
|
||||
getRoleByName.mockResolvedValue({
|
||||
permissions: {
|
||||
[PermissionTypes.PEOPLE_PICKER]: {
|
||||
[Permissions.VIEW_USERS]: false,
|
||||
[Permissions.VIEW_GROUPS]: false,
|
||||
[Permissions.VIEW_ROLES]: true,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
await checkPeoplePickerAccess(req, res, next);
|
||||
|
||||
expect(next).toHaveBeenCalled();
|
||||
expect(res.status).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should deny access when searching for roles without VIEW_ROLES permission', async () => {
|
||||
req.query.type = PrincipalType.ROLE;
|
||||
getRoleByName.mockResolvedValue({
|
||||
permissions: {
|
||||
[PermissionTypes.PEOPLE_PICKER]: {
|
||||
[Permissions.VIEW_USERS]: true,
|
||||
[Permissions.VIEW_GROUPS]: true,
|
||||
[Permissions.VIEW_ROLES]: false,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
await checkPeoplePickerAccess(req, res, next);
|
||||
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.json).toHaveBeenCalledWith({
|
||||
error: 'Forbidden',
|
||||
message: 'Insufficient permissions to search for roles',
|
||||
});
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should allow mixed search when user has at least one permission', async () => {
|
||||
// No type specified = mixed search
|
||||
req.query.type = undefined;
|
||||
getRoleByName.mockResolvedValue({
|
||||
permissions: {
|
||||
[PermissionTypes.PEOPLE_PICKER]: {
|
||||
[Permissions.VIEW_USERS]: false,
|
||||
[Permissions.VIEW_GROUPS]: false,
|
||||
[Permissions.VIEW_ROLES]: true,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
await checkPeoplePickerAccess(req, res, next);
|
||||
|
||||
expect(next).toHaveBeenCalled();
|
||||
expect(res.status).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should deny mixed search when user has no permissions', async () => {
|
||||
// No type specified = mixed search
|
||||
req.query.type = undefined;
|
||||
getRoleByName.mockResolvedValue({
|
||||
permissions: {
|
||||
[PermissionTypes.PEOPLE_PICKER]: {
|
||||
[Permissions.VIEW_USERS]: false,
|
||||
[Permissions.VIEW_GROUPS]: false,
|
||||
[Permissions.VIEW_ROLES]: false,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
await checkPeoplePickerAccess(req, res, next);
|
||||
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.json).toHaveBeenCalledWith({
|
||||
error: 'Forbidden',
|
||||
message: 'Insufficient permissions to search for users, groups, or roles',
|
||||
});
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle errors gracefully', async () => {
|
||||
const error = new Error('Database error');
|
||||
getRoleByName.mockRejectedValue(error);
|
||||
|
||||
await checkPeoplePickerAccess(req, res, next);
|
||||
|
||||
expect(logger.error).toHaveBeenCalledWith(
|
||||
'[checkPeoplePickerAccess][user123] checkPeoplePickerAccess error for req.query.type = undefined',
|
||||
error,
|
||||
);
|
||||
expect(res.status).toHaveBeenCalledWith(500);
|
||||
expect(res.json).toHaveBeenCalledWith({
|
||||
error: 'Internal Server Error',
|
||||
message: 'Failed to check permissions',
|
||||
});
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle missing permissions object gracefully', async () => {
|
||||
req.query.type = PrincipalType.USER;
|
||||
getRoleByName.mockResolvedValue({
|
||||
permissions: {}, // No PEOPLE_PICKER permissions
|
||||
});
|
||||
|
||||
await checkPeoplePickerAccess(req, res, next);
|
||||
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.json).toHaveBeenCalledWith({
|
||||
error: 'Forbidden',
|
||||
message: 'Insufficient permissions to search for users',
|
||||
});
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
@@ -8,7 +8,7 @@ const concurrentLimiter = require('./concurrentLimiter');
|
||||
const validateEndpoint = require('./validateEndpoint');
|
||||
const requireLocalAuth = require('./requireLocalAuth');
|
||||
const canDeleteAccount = require('./canDeleteAccount');
|
||||
const setBalanceConfig = require('./setBalanceConfig');
|
||||
const accessResources = require('./accessResources');
|
||||
const requireLdapAuth = require('./requireLdapAuth');
|
||||
const abortMiddleware = require('./abortMiddleware');
|
||||
const checkInviteUser = require('./checkInviteUser');
|
||||
@@ -29,6 +29,7 @@ module.exports = {
|
||||
...validate,
|
||||
...limiters,
|
||||
...roles,
|
||||
...accessResources,
|
||||
noIndex,
|
||||
checkBan,
|
||||
uaParser,
|
||||
@@ -42,7 +43,6 @@ module.exports = {
|
||||
requireLocalAuth,
|
||||
canDeleteAccount,
|
||||
validateEndpoint,
|
||||
setBalanceConfig,
|
||||
concurrentLimiter,
|
||||
checkDomainAllowed,
|
||||
validateMessageReq,
|
||||
|
||||
370
api/server/middleware/roles/access.spec.js
Normal file
370
api/server/middleware/roles/access.spec.js
Normal file
@@ -0,0 +1,370 @@
|
||||
const mongoose = require('mongoose');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { checkAccess, generateCheckAccess } = require('@librechat/api');
|
||||
const { PermissionTypes, Permissions } = require('librechat-data-provider');
|
||||
const { getRoleByName } = require('~/models/Role');
|
||||
const { Role } = require('~/db/models');
|
||||
|
||||
// Mock the logger from @librechat/data-schemas
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
...jest.requireActual('@librechat/data-schemas'),
|
||||
logger: {
|
||||
warn: jest.fn(),
|
||||
error: jest.fn(),
|
||||
info: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
// Mock the cache to use a simple in-memory implementation
|
||||
const mockCache = new Map();
|
||||
jest.mock('~/cache/getLogStores', () => {
|
||||
return jest.fn(() => ({
|
||||
get: jest.fn(async (key) => mockCache.get(key)),
|
||||
set: jest.fn(async (key, value) => mockCache.set(key, value)),
|
||||
clear: jest.fn(async () => mockCache.clear()),
|
||||
}));
|
||||
});
|
||||
|
||||
describe('Access Middleware', () => {
|
||||
let mongoServer;
|
||||
let req, res, next;
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
await mongoose.connect(mongoUri);
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
await mongoose.connection.dropDatabase();
|
||||
mockCache.clear(); // Clear the cache between tests
|
||||
|
||||
// Create test roles
|
||||
await Role.create({
|
||||
name: 'user',
|
||||
permissions: {
|
||||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
[Permissions.SHARED_GLOBAL]: false,
|
||||
[Permissions.USE]: true,
|
||||
[Permissions.CREATE]: true,
|
||||
},
|
||||
[PermissionTypes.MEMORIES]: {
|
||||
[Permissions.USE]: true,
|
||||
[Permissions.CREATE]: true,
|
||||
[Permissions.UPDATE]: true,
|
||||
[Permissions.READ]: true,
|
||||
[Permissions.OPT_OUT]: true,
|
||||
},
|
||||
[PermissionTypes.AGENTS]: {
|
||||
[Permissions.USE]: true,
|
||||
[Permissions.CREATE]: false,
|
||||
[Permissions.SHARED_GLOBAL]: false,
|
||||
},
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.RUN_CODE]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.WEB_SEARCH]: { [Permissions.USE]: true },
|
||||
},
|
||||
});
|
||||
|
||||
await Role.create({
|
||||
name: 'admin',
|
||||
permissions: {
|
||||
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
[Permissions.SHARED_GLOBAL]: true,
|
||||
[Permissions.USE]: true,
|
||||
[Permissions.CREATE]: true,
|
||||
},
|
||||
[PermissionTypes.MEMORIES]: {
|
||||
[Permissions.USE]: true,
|
||||
[Permissions.CREATE]: true,
|
||||
[Permissions.UPDATE]: true,
|
||||
[Permissions.READ]: true,
|
||||
[Permissions.OPT_OUT]: true,
|
||||
},
|
||||
[PermissionTypes.AGENTS]: {
|
||||
[Permissions.USE]: true,
|
||||
[Permissions.CREATE]: true,
|
||||
[Permissions.SHARED_GLOBAL]: true,
|
||||
},
|
||||
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.RUN_CODE]: { [Permissions.USE]: true },
|
||||
[PermissionTypes.WEB_SEARCH]: { [Permissions.USE]: true },
|
||||
},
|
||||
});
|
||||
|
||||
// Create limited role with no AGENTS permissions
|
||||
await Role.create({
|
||||
name: 'limited',
|
||||
permissions: {
|
||||
// Explicitly set AGENTS permissions to false
|
||||
[PermissionTypes.AGENTS]: {
|
||||
[Permissions.USE]: false,
|
||||
[Permissions.CREATE]: false,
|
||||
[Permissions.SHARED_GLOBAL]: false,
|
||||
},
|
||||
// Has permissions for other types
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
[Permissions.USE]: true,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
req = {
|
||||
user: { id: 'user123', role: 'user' },
|
||||
body: {},
|
||||
originalUrl: '/test',
|
||||
};
|
||||
res = {
|
||||
status: jest.fn().mockReturnThis(),
|
||||
json: jest.fn(),
|
||||
};
|
||||
next = jest.fn();
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('checkAccess', () => {
|
||||
test('should return false if user is not provided', async () => {
|
||||
const result = await checkAccess({
|
||||
user: null,
|
||||
permissionType: PermissionTypes.AGENTS,
|
||||
permissions: [Permissions.USE],
|
||||
getRoleByName,
|
||||
});
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
test('should return true if user has required permission', async () => {
|
||||
const result = await checkAccess({
|
||||
req: {},
|
||||
user: { id: 'user123', role: 'user' },
|
||||
permissionType: PermissionTypes.AGENTS,
|
||||
permissions: [Permissions.USE],
|
||||
getRoleByName,
|
||||
});
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
test('should return false if user lacks required permission', async () => {
|
||||
const result = await checkAccess({
|
||||
req: {},
|
||||
user: { id: 'user123', role: 'user' },
|
||||
permissionType: PermissionTypes.AGENTS,
|
||||
permissions: [Permissions.CREATE],
|
||||
getRoleByName,
|
||||
});
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
test('should return false if user has only some of multiple permissions', async () => {
|
||||
// User has USE but not CREATE, so should fail when checking for both
|
||||
const result = await checkAccess({
|
||||
req: {},
|
||||
user: { id: 'user123', role: 'user' },
|
||||
permissionType: PermissionTypes.AGENTS,
|
||||
permissions: [Permissions.CREATE, Permissions.USE],
|
||||
getRoleByName,
|
||||
});
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
test('should return true if user has all of multiple permissions', async () => {
|
||||
// Admin has both USE and CREATE
|
||||
const result = await checkAccess({
|
||||
req: {},
|
||||
user: { id: 'admin123', role: 'admin' },
|
||||
permissionType: PermissionTypes.AGENTS,
|
||||
permissions: [Permissions.CREATE, Permissions.USE],
|
||||
getRoleByName,
|
||||
});
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
test('should check body properties when permission is not directly granted', async () => {
|
||||
const req = { body: { id: 'agent123' } };
|
||||
const result = await checkAccess({
|
||||
req,
|
||||
user: { id: 'user123', role: 'user' },
|
||||
permissionType: PermissionTypes.AGENTS,
|
||||
permissions: [Permissions.UPDATE],
|
||||
bodyProps: {
|
||||
[Permissions.UPDATE]: ['id'],
|
||||
},
|
||||
checkObject: req.body,
|
||||
getRoleByName,
|
||||
});
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
test('should return false if role is not found', async () => {
|
||||
const result = await checkAccess({
|
||||
req: {},
|
||||
user: { id: 'user123', role: 'nonexistent' },
|
||||
permissionType: PermissionTypes.AGENTS,
|
||||
permissions: [Permissions.USE],
|
||||
getRoleByName,
|
||||
});
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
test('should return false if role has no permissions for the requested type', async () => {
|
||||
const result = await checkAccess({
|
||||
req: {},
|
||||
user: { id: 'user123', role: 'limited' },
|
||||
permissionType: PermissionTypes.AGENTS,
|
||||
permissions: [Permissions.USE],
|
||||
getRoleByName,
|
||||
});
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
test('should handle admin role with all permissions', async () => {
|
||||
const createResult = await checkAccess({
|
||||
req: {},
|
||||
user: { id: 'admin123', role: 'admin' },
|
||||
permissionType: PermissionTypes.AGENTS,
|
||||
permissions: [Permissions.CREATE],
|
||||
getRoleByName,
|
||||
});
|
||||
expect(createResult).toBe(true);
|
||||
|
||||
const shareResult = await checkAccess({
|
||||
req: {},
|
||||
user: { id: 'admin123', role: 'admin' },
|
||||
permissionType: PermissionTypes.AGENTS,
|
||||
permissions: [Permissions.SHARED_GLOBAL],
|
||||
getRoleByName,
|
||||
});
|
||||
expect(shareResult).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('generateCheckAccess', () => {
|
||||
test('should call next() when user has required permission', async () => {
|
||||
const middleware = generateCheckAccess({
|
||||
permissionType: PermissionTypes.AGENTS,
|
||||
permissions: [Permissions.USE],
|
||||
getRoleByName,
|
||||
});
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).toHaveBeenCalled();
|
||||
expect(res.status).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should return 403 when user lacks permission', async () => {
|
||||
const middleware = generateCheckAccess({
|
||||
permissionType: PermissionTypes.AGENTS,
|
||||
permissions: [Permissions.CREATE],
|
||||
getRoleByName,
|
||||
});
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.json).toHaveBeenCalledWith({ message: 'Forbidden: Insufficient permissions' });
|
||||
});
|
||||
|
||||
test('should check body properties when configured', async () => {
|
||||
req.body = { agentId: 'agent123', description: 'test' };
|
||||
|
||||
const bodyProps = {
|
||||
[Permissions.CREATE]: ['agentId'],
|
||||
};
|
||||
|
||||
const middleware = generateCheckAccess({
|
||||
permissionType: PermissionTypes.AGENTS,
|
||||
permissions: [Permissions.CREATE],
|
||||
bodyProps,
|
||||
getRoleByName,
|
||||
});
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).toHaveBeenCalled();
|
||||
expect(res.status).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should handle database errors gracefully', async () => {
|
||||
// Mock getRoleByName to throw an error
|
||||
const mockGetRoleByName = jest
|
||||
.fn()
|
||||
.mockRejectedValue(new Error('Database connection failed'));
|
||||
|
||||
const middleware = generateCheckAccess({
|
||||
permissionType: PermissionTypes.AGENTS,
|
||||
permissions: [Permissions.USE],
|
||||
getRoleByName: mockGetRoleByName,
|
||||
});
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
expect(res.status).toHaveBeenCalledWith(500);
|
||||
expect(res.json).toHaveBeenCalledWith({
|
||||
message: expect.stringContaining('Server error:'),
|
||||
});
|
||||
});
|
||||
|
||||
test('should work with multiple permission types', async () => {
|
||||
req.user.role = 'admin';
|
||||
|
||||
const middleware = generateCheckAccess({
|
||||
permissionType: PermissionTypes.AGENTS,
|
||||
permissions: [Permissions.USE, Permissions.CREATE, Permissions.SHARED_GLOBAL],
|
||||
getRoleByName,
|
||||
});
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test('should handle missing user gracefully', async () => {
|
||||
req.user = null;
|
||||
|
||||
const middleware = generateCheckAccess({
|
||||
permissionType: PermissionTypes.AGENTS,
|
||||
permissions: [Permissions.USE],
|
||||
getRoleByName,
|
||||
});
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.json).toHaveBeenCalledWith({ message: 'Forbidden: Insufficient permissions' });
|
||||
});
|
||||
|
||||
test('should handle role with no AGENTS permissions', async () => {
|
||||
await Role.create({
|
||||
name: 'noaccess',
|
||||
permissions: {
|
||||
// Explicitly set AGENTS with all permissions false
|
||||
[PermissionTypes.AGENTS]: {
|
||||
[Permissions.USE]: false,
|
||||
[Permissions.CREATE]: false,
|
||||
[Permissions.SHARED_GLOBAL]: false,
|
||||
},
|
||||
},
|
||||
});
|
||||
req.user.role = 'noaccess';
|
||||
|
||||
const middleware = generateCheckAccess({
|
||||
permissionType: PermissionTypes.AGENTS,
|
||||
permissions: [Permissions.USE],
|
||||
getRoleByName,
|
||||
});
|
||||
await middleware(req, res, next);
|
||||
|
||||
expect(next).not.toHaveBeenCalled();
|
||||
expect(res.status).toHaveBeenCalledWith(403);
|
||||
expect(res.json).toHaveBeenCalledWith({ message: 'Forbidden: Insufficient permissions' });
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,91 +0,0 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { getBalanceConfig } = require('~/server/services/Config');
|
||||
const { Balance } = require('~/db/models');
|
||||
|
||||
/**
|
||||
* Middleware to synchronize user balance settings with current balance configuration.
|
||||
* @function
|
||||
* @param {Object} req - Express request object containing user information.
|
||||
* @param {Object} res - Express response object.
|
||||
* @param {import('express').NextFunction} next - Next middleware function.
|
||||
*/
|
||||
const setBalanceConfig = async (req, res, next) => {
|
||||
try {
|
||||
const balanceConfig = await getBalanceConfig();
|
||||
if (!balanceConfig?.enabled) {
|
||||
return next();
|
||||
}
|
||||
if (balanceConfig.startBalance == null) {
|
||||
return next();
|
||||
}
|
||||
|
||||
const userId = req.user._id;
|
||||
const userBalanceRecord = await Balance.findOne({ user: userId }).lean();
|
||||
const updateFields = buildUpdateFields(balanceConfig, userBalanceRecord);
|
||||
|
||||
if (Object.keys(updateFields).length === 0) {
|
||||
return next();
|
||||
}
|
||||
|
||||
await Balance.findOneAndUpdate(
|
||||
{ user: userId },
|
||||
{ $set: updateFields },
|
||||
{ upsert: true, new: true },
|
||||
);
|
||||
|
||||
next();
|
||||
} catch (error) {
|
||||
logger.error('Error setting user balance:', error);
|
||||
next(error);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Build an object containing fields that need updating
|
||||
* @param {Object} config - The balance configuration
|
||||
* @param {Object|null} userRecord - The user's current balance record, if any
|
||||
* @returns {Object} Fields that need updating
|
||||
*/
|
||||
function buildUpdateFields(config, userRecord) {
|
||||
const updateFields = {};
|
||||
|
||||
// Ensure user record has the required fields
|
||||
if (!userRecord) {
|
||||
updateFields.user = userRecord?.user;
|
||||
updateFields.tokenCredits = config.startBalance;
|
||||
}
|
||||
|
||||
if (userRecord?.tokenCredits == null && config.startBalance != null) {
|
||||
updateFields.tokenCredits = config.startBalance;
|
||||
}
|
||||
|
||||
const isAutoRefillConfigValid =
|
||||
config.autoRefillEnabled &&
|
||||
config.refillIntervalValue != null &&
|
||||
config.refillIntervalUnit != null &&
|
||||
config.refillAmount != null;
|
||||
|
||||
if (!isAutoRefillConfigValid) {
|
||||
return updateFields;
|
||||
}
|
||||
|
||||
if (userRecord?.autoRefillEnabled !== config.autoRefillEnabled) {
|
||||
updateFields.autoRefillEnabled = config.autoRefillEnabled;
|
||||
}
|
||||
|
||||
if (userRecord?.refillIntervalValue !== config.refillIntervalValue) {
|
||||
updateFields.refillIntervalValue = config.refillIntervalValue;
|
||||
}
|
||||
|
||||
if (userRecord?.refillIntervalUnit !== config.refillIntervalUnit) {
|
||||
updateFields.refillIntervalUnit = config.refillIntervalUnit;
|
||||
}
|
||||
|
||||
if (userRecord?.refillAmount !== config.refillAmount) {
|
||||
updateFields.refillAmount = config.refillAmount;
|
||||
}
|
||||
|
||||
return updateFields;
|
||||
}
|
||||
|
||||
module.exports = setBalanceConfig;
|
||||
@@ -4,12 +4,14 @@ const mongoose = require('mongoose');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
|
||||
jest.mock('@librechat/api', () => ({
|
||||
...jest.requireActual('@librechat/api'),
|
||||
MCPOAuthHandler: {
|
||||
initiateOAuthFlow: jest.fn(),
|
||||
getFlowState: jest.fn(),
|
||||
completeOAuthFlow: jest.fn(),
|
||||
generateFlowId: jest.fn(),
|
||||
},
|
||||
getUserMCPAuthMap: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('@librechat/data-schemas', () => ({
|
||||
@@ -36,6 +38,7 @@ jest.mock('~/models', () => ({
|
||||
updateToken: jest.fn(),
|
||||
createToken: jest.fn(),
|
||||
deleteTokens: jest.fn(),
|
||||
findPluginAuthsByKeys: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Config', () => ({
|
||||
@@ -44,6 +47,10 @@ jest.mock('~/server/services/Config', () => ({
|
||||
loadCustomConfig: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Config/mcpToolsCache', () => ({
|
||||
updateMCPUserTools: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/MCP', () => ({
|
||||
getMCPSetupData: jest.fn(),
|
||||
getServerConnectionStatus: jest.fn(),
|
||||
@@ -66,6 +73,10 @@ jest.mock('~/server/middleware', () => ({
|
||||
requireJwtAuth: (req, res, next) => next(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Tools/mcp', () => ({
|
||||
reinitMCPServer: jest.fn(),
|
||||
}));
|
||||
|
||||
describe('MCP Routes', () => {
|
||||
let app;
|
||||
let mongoServer;
|
||||
@@ -494,12 +505,9 @@ describe('MCP Routes', () => {
|
||||
});
|
||||
|
||||
it('should return 500 when token retrieval throws an unexpected error', async () => {
|
||||
const mockFlowManager = {
|
||||
getFlowState: jest.fn().mockRejectedValue(new Error('Database connection failed')),
|
||||
};
|
||||
|
||||
getLogStores.mockReturnValue({});
|
||||
require('~/config').getFlowStateManager.mockReturnValue(mockFlowManager);
|
||||
getLogStores.mockImplementation(() => {
|
||||
throw new Error('Database connection failed');
|
||||
});
|
||||
|
||||
const response = await request(app).get('/api/mcp/oauth/tokens/test-user-id:error-flow');
|
||||
|
||||
@@ -563,8 +571,8 @@ describe('MCP Routes', () => {
|
||||
});
|
||||
|
||||
describe('POST /oauth/cancel/:serverName', () => {
|
||||
const { getLogStores } = require('~/cache');
|
||||
const { MCPOAuthHandler } = require('@librechat/api');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
it('should cancel OAuth flow successfully', async () => {
|
||||
const mockFlowManager = {
|
||||
@@ -644,15 +652,15 @@ describe('MCP Routes', () => {
|
||||
});
|
||||
|
||||
describe('POST /:serverName/reinitialize', () => {
|
||||
const { loadCustomConfig } = require('~/server/services/Config');
|
||||
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
|
||||
|
||||
it('should return 404 when server is not found in configuration', async () => {
|
||||
loadCustomConfig.mockResolvedValue({
|
||||
mcpServers: {
|
||||
'other-server': {},
|
||||
},
|
||||
});
|
||||
const mockMcpManager = {
|
||||
getRawConfig: jest.fn().mockReturnValue(null),
|
||||
disconnectUserConnection: jest.fn().mockResolvedValue(),
|
||||
};
|
||||
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
require('~/config').getFlowStateManager.mockReturnValue({});
|
||||
require('~/cache').getLogStores.mockReturnValue({});
|
||||
|
||||
const response = await request(app).post('/api/mcp/non-existent-server/reinitialize');
|
||||
|
||||
@@ -663,16 +671,11 @@ describe('MCP Routes', () => {
|
||||
});
|
||||
|
||||
it('should handle OAuth requirement during reinitialize', async () => {
|
||||
loadCustomConfig.mockResolvedValue({
|
||||
mcpServers: {
|
||||
'oauth-server': {
|
||||
customUserVars: {},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const mockMcpManager = {
|
||||
disconnectServer: jest.fn().mockResolvedValue(),
|
||||
getRawConfig: jest.fn().mockReturnValue({
|
||||
customUserVars: {},
|
||||
}),
|
||||
disconnectUserConnection: jest.fn().mockResolvedValue(),
|
||||
mcpConfigs: {},
|
||||
getUserConnection: jest.fn().mockImplementation(async ({ oauthStart }) => {
|
||||
if (oauthStart) {
|
||||
@@ -685,12 +688,19 @@ describe('MCP Routes', () => {
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
require('~/config').getFlowStateManager.mockReturnValue({});
|
||||
require('~/cache').getLogStores.mockReturnValue({});
|
||||
require('~/server/services/Tools/mcp').reinitMCPServer.mockResolvedValue({
|
||||
success: true,
|
||||
message: "MCP server 'oauth-server' ready for OAuth authentication",
|
||||
serverName: 'oauth-server',
|
||||
oauthRequired: true,
|
||||
oauthUrl: 'https://oauth.example.com/auth',
|
||||
});
|
||||
|
||||
const response = await request(app).post('/api/mcp/oauth-server/reinitialize');
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body).toEqual({
|
||||
success: 'https://oauth.example.com/auth',
|
||||
success: true,
|
||||
message: "MCP server 'oauth-server' ready for OAuth authentication",
|
||||
serverName: 'oauth-server',
|
||||
oauthRequired: true,
|
||||
@@ -699,14 +709,9 @@ describe('MCP Routes', () => {
|
||||
});
|
||||
|
||||
it('should return 500 when reinitialize fails with non-OAuth error', async () => {
|
||||
loadCustomConfig.mockResolvedValue({
|
||||
mcpServers: {
|
||||
'error-server': {},
|
||||
},
|
||||
});
|
||||
|
||||
const mockMcpManager = {
|
||||
disconnectServer: jest.fn().mockResolvedValue(),
|
||||
getRawConfig: jest.fn().mockReturnValue({}),
|
||||
disconnectUserConnection: jest.fn().mockResolvedValue(),
|
||||
mcpConfigs: {},
|
||||
getUserConnection: jest.fn().mockRejectedValue(new Error('Connection failed')),
|
||||
};
|
||||
@@ -714,6 +719,7 @@ describe('MCP Routes', () => {
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
require('~/config').getFlowStateManager.mockReturnValue({});
|
||||
require('~/cache').getLogStores.mockReturnValue({});
|
||||
require('~/server/services/Tools/mcp').reinitMCPServer.mockResolvedValue(null);
|
||||
|
||||
const response = await request(app).post('/api/mcp/error-server/reinitialize');
|
||||
|
||||
@@ -724,7 +730,13 @@ describe('MCP Routes', () => {
|
||||
});
|
||||
|
||||
it('should return 500 when unexpected error occurs', async () => {
|
||||
loadCustomConfig.mockRejectedValue(new Error('Config loading failed'));
|
||||
const mockMcpManager = {
|
||||
getRawConfig: jest.fn().mockImplementation(() => {
|
||||
throw new Error('Config loading failed');
|
||||
}),
|
||||
};
|
||||
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
|
||||
const response = await request(app).post('/api/mcp/test-server/reinitialize');
|
||||
|
||||
@@ -747,29 +759,17 @@ describe('MCP Routes', () => {
|
||||
expect(response.body).toEqual({ error: 'User not authenticated' });
|
||||
});
|
||||
|
||||
it('should handle errors when fetching custom user variables', async () => {
|
||||
loadCustomConfig.mockResolvedValue({
|
||||
mcpServers: {
|
||||
'test-server': {
|
||||
customUserVars: {
|
||||
API_KEY: 'test-key-var',
|
||||
SECRET_TOKEN: 'test-secret-var',
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
getUserPluginAuthValue
|
||||
.mockResolvedValueOnce('test-api-key-value')
|
||||
.mockRejectedValueOnce(new Error('Database error'));
|
||||
|
||||
it('should successfully reinitialize server and cache tools', async () => {
|
||||
const mockUserConnection = {
|
||||
fetchTools: jest.fn().mockResolvedValue([]),
|
||||
fetchTools: jest.fn().mockResolvedValue([
|
||||
{ name: 'tool1', description: 'Test tool 1', inputSchema: { type: 'object' } },
|
||||
{ name: 'tool2', description: 'Test tool 2', inputSchema: { type: 'object' } },
|
||||
]),
|
||||
};
|
||||
|
||||
const mockMcpManager = {
|
||||
disconnectServer: jest.fn().mockResolvedValue(),
|
||||
mcpConfigs: {},
|
||||
getRawConfig: jest.fn().mockReturnValue({ endpoint: 'http://test-server.com' }),
|
||||
disconnectUserConnection: jest.fn().mockResolvedValue(),
|
||||
getUserConnection: jest.fn().mockResolvedValue(mockUserConnection),
|
||||
};
|
||||
|
||||
@@ -778,44 +778,86 @@ describe('MCP Routes', () => {
|
||||
require('~/cache').getLogStores.mockReturnValue({});
|
||||
|
||||
const { getCachedTools, setCachedTools } = require('~/server/services/Config');
|
||||
const { updateMCPUserTools } = require('~/server/services/Config/mcpToolsCache');
|
||||
getCachedTools.mockResolvedValue({});
|
||||
setCachedTools.mockResolvedValue();
|
||||
updateMCPUserTools.mockResolvedValue();
|
||||
|
||||
require('~/server/services/Tools/mcp').reinitMCPServer.mockResolvedValue({
|
||||
success: true,
|
||||
message: "MCP server 'test-server' reinitialized successfully",
|
||||
serverName: 'test-server',
|
||||
oauthRequired: false,
|
||||
oauthUrl: null,
|
||||
});
|
||||
|
||||
const response = await request(app).post('/api/mcp/test-server/reinitialize');
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body.success).toBe(true);
|
||||
expect(response.body).toEqual({
|
||||
success: true,
|
||||
message: "MCP server 'test-server' reinitialized successfully",
|
||||
serverName: 'test-server',
|
||||
oauthRequired: false,
|
||||
oauthUrl: null,
|
||||
});
|
||||
expect(mockMcpManager.disconnectUserConnection).toHaveBeenCalledWith(
|
||||
'test-user-id',
|
||||
'test-server',
|
||||
);
|
||||
});
|
||||
|
||||
it('should return failure message when reinitialize completely fails', async () => {
|
||||
loadCustomConfig.mockResolvedValue({
|
||||
mcpServers: {
|
||||
'test-server': {},
|
||||
},
|
||||
});
|
||||
it('should handle server with custom user variables', async () => {
|
||||
const mockUserConnection = {
|
||||
fetchTools: jest.fn().mockResolvedValue([]),
|
||||
};
|
||||
|
||||
const mockMcpManager = {
|
||||
disconnectServer: jest.fn().mockResolvedValue(),
|
||||
mcpConfigs: {},
|
||||
getUserConnection: jest.fn().mockResolvedValue(null),
|
||||
getRawConfig: jest.fn().mockReturnValue({
|
||||
endpoint: 'http://test-server.com',
|
||||
customUserVars: {
|
||||
API_KEY: 'some-env-var',
|
||||
},
|
||||
}),
|
||||
disconnectUserConnection: jest.fn().mockResolvedValue(),
|
||||
getUserConnection: jest.fn().mockResolvedValue(mockUserConnection),
|
||||
};
|
||||
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
require('~/config').getFlowStateManager.mockReturnValue({});
|
||||
require('~/cache').getLogStores.mockReturnValue({});
|
||||
require('@librechat/api').getUserMCPAuthMap.mockResolvedValue({
|
||||
'mcp:test-server': {
|
||||
API_KEY: 'api-key-value',
|
||||
},
|
||||
});
|
||||
require('~/models').findPluginAuthsByKeys.mockResolvedValue([
|
||||
{ key: 'API_KEY', value: 'api-key-value' },
|
||||
]);
|
||||
|
||||
const { getCachedTools, setCachedTools } = require('~/server/services/Config');
|
||||
const { Constants } = require('librechat-data-provider');
|
||||
getCachedTools.mockResolvedValue({
|
||||
[`existing-tool${Constants.mcp_delimiter}test-server`]: { type: 'function' },
|
||||
});
|
||||
const { updateMCPUserTools } = require('~/server/services/Config/mcpToolsCache');
|
||||
getCachedTools.mockResolvedValue({});
|
||||
setCachedTools.mockResolvedValue();
|
||||
updateMCPUserTools.mockResolvedValue();
|
||||
|
||||
require('~/server/services/Tools/mcp').reinitMCPServer.mockResolvedValue({
|
||||
success: true,
|
||||
message: "MCP server 'test-server' reinitialized successfully",
|
||||
serverName: 'test-server',
|
||||
oauthRequired: false,
|
||||
oauthUrl: null,
|
||||
});
|
||||
|
||||
const response = await request(app).post('/api/mcp/test-server/reinitialize');
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body.success).toBe(false);
|
||||
expect(response.body.message).toBe("Failed to reinitialize MCP server 'test-server'");
|
||||
expect(response.body.success).toBe(true);
|
||||
expect(require('@librechat/api').getUserMCPAuthMap).toHaveBeenCalledWith({
|
||||
userId: 'test-user-id',
|
||||
servers: ['test-server'],
|
||||
findPluginAuthsByKeys: require('~/models').findPluginAuthsByKeys,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -984,21 +1026,19 @@ describe('MCP Routes', () => {
|
||||
});
|
||||
|
||||
describe('GET /:serverName/auth-values', () => {
|
||||
const { loadCustomConfig } = require('~/server/services/Config');
|
||||
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
|
||||
|
||||
it('should return auth value flags for server', async () => {
|
||||
loadCustomConfig.mockResolvedValue({
|
||||
mcpServers: {
|
||||
'test-server': {
|
||||
customUserVars: {
|
||||
API_KEY: 'some-env-var',
|
||||
SECRET_TOKEN: 'another-env-var',
|
||||
},
|
||||
const mockMcpManager = {
|
||||
getRawConfig: jest.fn().mockReturnValue({
|
||||
customUserVars: {
|
||||
API_KEY: 'some-env-var',
|
||||
SECRET_TOKEN: 'another-env-var',
|
||||
},
|
||||
},
|
||||
});
|
||||
}),
|
||||
};
|
||||
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
getUserPluginAuthValue.mockResolvedValueOnce('some-api-key-value').mockResolvedValueOnce('');
|
||||
|
||||
const response = await request(app).get('/api/mcp/test-server/auth-values');
|
||||
@@ -1017,11 +1057,11 @@ describe('MCP Routes', () => {
|
||||
});
|
||||
|
||||
it('should return 404 when server is not found in configuration', async () => {
|
||||
loadCustomConfig.mockResolvedValue({
|
||||
mcpServers: {
|
||||
'other-server': {},
|
||||
},
|
||||
});
|
||||
const mockMcpManager = {
|
||||
getRawConfig: jest.fn().mockReturnValue(null),
|
||||
};
|
||||
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
|
||||
const response = await request(app).get('/api/mcp/non-existent-server/auth-values');
|
||||
|
||||
@@ -1032,16 +1072,15 @@ describe('MCP Routes', () => {
|
||||
});
|
||||
|
||||
it('should handle errors when checking auth values', async () => {
|
||||
loadCustomConfig.mockResolvedValue({
|
||||
mcpServers: {
|
||||
'test-server': {
|
||||
customUserVars: {
|
||||
API_KEY: 'some-env-var',
|
||||
},
|
||||
const mockMcpManager = {
|
||||
getRawConfig: jest.fn().mockReturnValue({
|
||||
customUserVars: {
|
||||
API_KEY: 'some-env-var',
|
||||
},
|
||||
},
|
||||
});
|
||||
}),
|
||||
};
|
||||
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
getUserPluginAuthValue.mockRejectedValue(new Error('Database error'));
|
||||
|
||||
const response = await request(app).get('/api/mcp/test-server/auth-values');
|
||||
@@ -1057,7 +1096,13 @@ describe('MCP Routes', () => {
|
||||
});
|
||||
|
||||
it('should return 500 when auth values check throws unexpected error', async () => {
|
||||
loadCustomConfig.mockRejectedValue(new Error('Config loading failed'));
|
||||
const mockMcpManager = {
|
||||
getRawConfig: jest.fn().mockImplementation(() => {
|
||||
throw new Error('Config loading failed');
|
||||
}),
|
||||
};
|
||||
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
|
||||
const response = await request(app).get('/api/mcp/test-server/auth-values');
|
||||
|
||||
@@ -1066,14 +1111,13 @@ describe('MCP Routes', () => {
|
||||
});
|
||||
|
||||
it('should handle customUserVars that is not an object', async () => {
|
||||
const { loadCustomConfig } = require('~/server/services/Config');
|
||||
loadCustomConfig.mockResolvedValue({
|
||||
mcpServers: {
|
||||
'test-server': {
|
||||
customUserVars: 'not-an-object',
|
||||
},
|
||||
},
|
||||
});
|
||||
const mockMcpManager = {
|
||||
getRawConfig: jest.fn().mockReturnValue({
|
||||
customUserVars: 'not-an-object',
|
||||
}),
|
||||
};
|
||||
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
|
||||
const response = await request(app).get('/api/mcp/test-server/auth-values');
|
||||
|
||||
@@ -1097,98 +1141,6 @@ describe('MCP Routes', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('POST /:serverName/reinitialize - Tool Deletion Coverage', () => {
|
||||
it('should handle null cached tools during reinitialize (triggers || {} fallback)', async () => {
|
||||
const { loadCustomConfig, getCachedTools } = require('~/server/services/Config');
|
||||
|
||||
const mockUserConnection = {
|
||||
fetchTools: jest.fn().mockResolvedValue([{ name: 'new-tool', description: 'A new tool' }]),
|
||||
};
|
||||
|
||||
const mockMcpManager = {
|
||||
getUserConnection: jest.fn().mockResolvedValue(mockUserConnection),
|
||||
disconnectServer: jest.fn(),
|
||||
initializeServer: jest.fn(),
|
||||
mcpConfigs: {},
|
||||
};
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
|
||||
loadCustomConfig.mockResolvedValue({
|
||||
mcpServers: {
|
||||
'test-server': { env: { API_KEY: 'test-key' } },
|
||||
},
|
||||
});
|
||||
|
||||
getCachedTools.mockResolvedValue(null);
|
||||
|
||||
const response = await request(app).post('/api/mcp/test-server/reinitialize').expect(200);
|
||||
|
||||
expect(response.body).toEqual({
|
||||
message: "MCP server 'test-server' reinitialized successfully",
|
||||
success: true,
|
||||
oauthRequired: false,
|
||||
oauthUrl: null,
|
||||
serverName: 'test-server',
|
||||
});
|
||||
});
|
||||
|
||||
it('should delete existing cached tools during successful reinitialize', async () => {
|
||||
const {
|
||||
loadCustomConfig,
|
||||
getCachedTools,
|
||||
setCachedTools,
|
||||
} = require('~/server/services/Config');
|
||||
|
||||
const mockUserConnection = {
|
||||
fetchTools: jest.fn().mockResolvedValue([{ name: 'new-tool', description: 'A new tool' }]),
|
||||
};
|
||||
|
||||
const mockMcpManager = {
|
||||
getUserConnection: jest.fn().mockResolvedValue(mockUserConnection),
|
||||
disconnectServer: jest.fn(),
|
||||
initializeServer: jest.fn(),
|
||||
mcpConfigs: {},
|
||||
};
|
||||
require('~/config').getMCPManager.mockReturnValue(mockMcpManager);
|
||||
|
||||
loadCustomConfig.mockResolvedValue({
|
||||
mcpServers: {
|
||||
'test-server': { env: { API_KEY: 'test-key' } },
|
||||
},
|
||||
});
|
||||
|
||||
const existingTools = {
|
||||
'old-tool_mcp_test-server': { type: 'function' },
|
||||
'other-tool_mcp_other-server': { type: 'function' },
|
||||
};
|
||||
getCachedTools.mockResolvedValue(existingTools);
|
||||
|
||||
const response = await request(app).post('/api/mcp/test-server/reinitialize').expect(200);
|
||||
|
||||
expect(response.body).toEqual({
|
||||
message: "MCP server 'test-server' reinitialized successfully",
|
||||
success: true,
|
||||
oauthRequired: false,
|
||||
oauthUrl: null,
|
||||
serverName: 'test-server',
|
||||
});
|
||||
|
||||
expect(setCachedTools).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
'new-tool_mcp_test-server': expect.any(Object),
|
||||
'other-tool_mcp_other-server': { type: 'function' },
|
||||
}),
|
||||
{ userId: 'test-user-id' },
|
||||
);
|
||||
expect(setCachedTools).toHaveBeenCalledWith(
|
||||
expect.not.objectContaining({
|
||||
'old-tool_mcp_test-server': expect.anything(),
|
||||
}),
|
||||
{ userId: 'test-user-id' },
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('GET /:serverName/oauth/callback - Edge Cases', () => {
|
||||
it('should handle OAuth callback without toolFlowId (falsy toolFlowId)', async () => {
|
||||
const { MCPOAuthHandler } = require('@librechat/api');
|
||||
|
||||
85
api/server/routes/accessPermissions.js
Normal file
85
api/server/routes/accessPermissions.js
Normal file
@@ -0,0 +1,85 @@
|
||||
const express = require('express');
|
||||
const { ResourceType, PermissionBits } = require('librechat-data-provider');
|
||||
const {
|
||||
getUserEffectivePermissions,
|
||||
updateResourcePermissions,
|
||||
getResourcePermissions,
|
||||
getResourceRoles,
|
||||
searchPrincipals,
|
||||
} = require('~/server/controllers/PermissionsController');
|
||||
const { requireJwtAuth, checkBan, uaParser, canAccessResource } = require('~/server/middleware');
|
||||
const { checkPeoplePickerAccess } = require('~/server/middleware/checkPeoplePickerAccess');
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
// Apply common middleware
|
||||
router.use(requireJwtAuth);
|
||||
router.use(checkBan);
|
||||
router.use(uaParser);
|
||||
|
||||
/**
|
||||
* Generic routes for resource permissions
|
||||
* Pattern: /api/permissions/{resourceType}/{resourceId}
|
||||
*/
|
||||
|
||||
/**
|
||||
* GET /api/permissions/search-principals
|
||||
* Search for users and groups to grant permissions
|
||||
*/
|
||||
router.get('/search-principals', checkPeoplePickerAccess, searchPrincipals);
|
||||
|
||||
/**
|
||||
* GET /api/permissions/{resourceType}/roles
|
||||
* Get available roles for a resource type
|
||||
*/
|
||||
router.get('/:resourceType/roles', getResourceRoles);
|
||||
|
||||
/**
|
||||
* GET /api/permissions/{resourceType}/{resourceId}
|
||||
* Get all permissions for a specific resource
|
||||
*/
|
||||
router.get('/:resourceType/:resourceId', getResourcePermissions);
|
||||
|
||||
/**
|
||||
* PUT /api/permissions/{resourceType}/{resourceId}
|
||||
* Bulk update permissions for a specific resource
|
||||
*/
|
||||
router.put(
|
||||
'/:resourceType/:resourceId',
|
||||
// Use middleware that dynamically handles resource type and permissions
|
||||
(req, res, next) => {
|
||||
const { resourceType } = req.params;
|
||||
let middleware;
|
||||
|
||||
if (resourceType === ResourceType.AGENT) {
|
||||
middleware = canAccessResource({
|
||||
resourceType: ResourceType.AGENT,
|
||||
requiredPermission: PermissionBits.SHARE,
|
||||
resourceIdParam: 'resourceId',
|
||||
});
|
||||
} else if (resourceType === ResourceType.PROMPTGROUP) {
|
||||
middleware = canAccessResource({
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
requiredPermission: PermissionBits.SHARE,
|
||||
resourceIdParam: 'resourceId',
|
||||
});
|
||||
} else {
|
||||
return res.status(400).json({
|
||||
error: 'Bad Request',
|
||||
message: `Unsupported resource type: ${resourceType}`,
|
||||
});
|
||||
}
|
||||
|
||||
// Execute the middleware
|
||||
middleware(req, res, next);
|
||||
},
|
||||
updateResourcePermissions,
|
||||
);
|
||||
|
||||
/**
|
||||
* GET /api/permissions/{resourceType}/{resourceId}/effective
|
||||
* Get user's effective permissions for a specific resource
|
||||
*/
|
||||
router.get('/:resourceType/:resourceId/effective', getUserEffectivePermissions);
|
||||
|
||||
module.exports = router;
|
||||
@@ -3,16 +3,19 @@ const { nanoid } = require('nanoid');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { generateCheckAccess } = require('@librechat/api');
|
||||
const {
|
||||
SystemRoles,
|
||||
Permissions,
|
||||
ResourceType,
|
||||
PermissionTypes,
|
||||
actionDelimiter,
|
||||
PermissionBits,
|
||||
removeNullishValues,
|
||||
} = require('librechat-data-provider');
|
||||
const { encryptMetadata, domainParser } = require('~/server/services/ActionService');
|
||||
const { findAccessibleResources } = require('~/server/services/PermissionService');
|
||||
const { getAgent, updateAgent, getListAgentsByAccess } = require('~/models/Agent');
|
||||
const { updateAction, getActions, deleteAction } = require('~/models/Action');
|
||||
const { isActionDomainAllowed } = require('~/server/services/domains');
|
||||
const { getAgent, updateAgent } = require('~/models/Agent');
|
||||
const { canAccessAgentResource } = require('~/server/middleware');
|
||||
const { getRoleByName } = require('~/models/Role');
|
||||
|
||||
const router = express.Router();
|
||||
@@ -23,12 +26,6 @@ const checkAgentCreate = generateCheckAccess({
|
||||
getRoleByName,
|
||||
});
|
||||
|
||||
// If the user has ADMIN role
|
||||
// then action edition is possible even if not owner of the assistant
|
||||
const isAdmin = (req) => {
|
||||
return req.user.role === SystemRoles.ADMIN;
|
||||
};
|
||||
|
||||
/**
|
||||
* Retrieves all user's actions
|
||||
* @route GET /actions/
|
||||
@@ -37,10 +34,23 @@ const isAdmin = (req) => {
|
||||
*/
|
||||
router.get('/', async (req, res) => {
|
||||
try {
|
||||
const admin = isAdmin(req);
|
||||
// If admin, get all actions, otherwise only user's actions
|
||||
const searchParams = admin ? {} : { user: req.user.id };
|
||||
res.json(await getActions(searchParams));
|
||||
const userId = req.user.id;
|
||||
const editableAgentObjectIds = await findAccessibleResources({
|
||||
userId,
|
||||
role: req.user.role,
|
||||
resourceType: ResourceType.AGENT,
|
||||
requiredPermissions: PermissionBits.EDIT,
|
||||
});
|
||||
|
||||
const agentsResponse = await getListAgentsByAccess({
|
||||
accessibleIds: editableAgentObjectIds,
|
||||
});
|
||||
|
||||
const editableAgentIds = agentsResponse.data.map((agent) => agent.id);
|
||||
const actions =
|
||||
editableAgentIds.length > 0 ? await getActions({ agent_id: { $in: editableAgentIds } }) : [];
|
||||
|
||||
res.json(actions);
|
||||
} catch (error) {
|
||||
res.status(500).json({ error: error.message });
|
||||
}
|
||||
@@ -55,106 +65,111 @@ router.get('/', async (req, res) => {
|
||||
* @param {ActionMetadata} req.body.metadata - Metadata for the action.
|
||||
* @returns {Object} 200 - success response - application/json
|
||||
*/
|
||||
router.post('/:agent_id', checkAgentCreate, async (req, res) => {
|
||||
try {
|
||||
const { agent_id } = req.params;
|
||||
router.post(
|
||||
'/:agent_id',
|
||||
canAccessAgentResource({
|
||||
requiredPermission: PermissionBits.EDIT,
|
||||
resourceIdParam: 'agent_id',
|
||||
}),
|
||||
checkAgentCreate,
|
||||
async (req, res) => {
|
||||
try {
|
||||
const { agent_id } = req.params;
|
||||
|
||||
/** @type {{ functions: FunctionTool[], action_id: string, metadata: ActionMetadata }} */
|
||||
const { functions, action_id: _action_id, metadata: _metadata } = req.body;
|
||||
if (!functions.length) {
|
||||
return res.status(400).json({ message: 'No functions provided' });
|
||||
}
|
||||
|
||||
let metadata = await encryptMetadata(removeNullishValues(_metadata, true));
|
||||
const isDomainAllowed = await isActionDomainAllowed(metadata.domain);
|
||||
if (!isDomainAllowed) {
|
||||
return res.status(400).json({ message: 'Domain not allowed' });
|
||||
}
|
||||
|
||||
let { domain } = metadata;
|
||||
domain = await domainParser(domain, true);
|
||||
|
||||
if (!domain) {
|
||||
return res.status(400).json({ message: 'No domain provided' });
|
||||
}
|
||||
|
||||
const action_id = _action_id ?? nanoid();
|
||||
const initialPromises = [];
|
||||
const admin = isAdmin(req);
|
||||
|
||||
// If admin, can edit any agent, otherwise only user's agents
|
||||
const agentQuery = admin ? { id: agent_id } : { id: agent_id, author: req.user.id };
|
||||
// TODO: share agents
|
||||
initialPromises.push(getAgent(agentQuery));
|
||||
if (_action_id) {
|
||||
initialPromises.push(getActions({ action_id }, true));
|
||||
}
|
||||
|
||||
/** @type {[Agent, [Action|undefined]]} */
|
||||
const [agent, actions_result] = await Promise.all(initialPromises);
|
||||
if (!agent) {
|
||||
return res.status(404).json({ message: 'Agent not found for adding action' });
|
||||
}
|
||||
|
||||
if (actions_result && actions_result.length) {
|
||||
const action = actions_result[0];
|
||||
metadata = { ...action.metadata, ...metadata };
|
||||
}
|
||||
|
||||
const { actions: _actions = [], author: agent_author } = agent ?? {};
|
||||
const actions = [];
|
||||
for (const action of _actions) {
|
||||
const [_action_domain, current_action_id] = action.split(actionDelimiter);
|
||||
if (current_action_id === action_id) {
|
||||
continue;
|
||||
/** @type {{ functions: FunctionTool[], action_id: string, metadata: ActionMetadata }} */
|
||||
const { functions, action_id: _action_id, metadata: _metadata } = req.body;
|
||||
if (!functions.length) {
|
||||
return res.status(400).json({ message: 'No functions provided' });
|
||||
}
|
||||
|
||||
actions.push(action);
|
||||
}
|
||||
|
||||
actions.push(`${domain}${actionDelimiter}${action_id}`);
|
||||
|
||||
/** @type {string[]}} */
|
||||
const { tools: _tools = [] } = agent;
|
||||
|
||||
const tools = _tools
|
||||
.filter((tool) => !(tool && (tool.includes(domain) || tool.includes(action_id))))
|
||||
.concat(functions.map((tool) => `${tool.function.name}${actionDelimiter}${domain}`));
|
||||
|
||||
// Force version update since actions are changing
|
||||
const updatedAgent = await updateAgent(
|
||||
agentQuery,
|
||||
{ tools, actions },
|
||||
{
|
||||
updatingUserId: req.user.id,
|
||||
forceVersion: true,
|
||||
},
|
||||
);
|
||||
|
||||
// Only update user field for new actions
|
||||
const actionUpdateData = { metadata, agent_id };
|
||||
if (!actions_result || !actions_result.length) {
|
||||
// For new actions, use the agent owner's user ID
|
||||
actionUpdateData.user = agent_author || req.user.id;
|
||||
}
|
||||
|
||||
/** @type {[Action]} */
|
||||
const updatedAction = await updateAction({ action_id }, actionUpdateData);
|
||||
|
||||
const sensitiveFields = ['api_key', 'oauth_client_id', 'oauth_client_secret'];
|
||||
for (let field of sensitiveFields) {
|
||||
if (updatedAction.metadata[field]) {
|
||||
delete updatedAction.metadata[field];
|
||||
let metadata = await encryptMetadata(removeNullishValues(_metadata, true));
|
||||
const isDomainAllowed = await isActionDomainAllowed(metadata.domain);
|
||||
if (!isDomainAllowed) {
|
||||
return res.status(400).json({ message: 'Domain not allowed' });
|
||||
}
|
||||
}
|
||||
|
||||
res.json([updatedAgent, updatedAction]);
|
||||
} catch (error) {
|
||||
const message = 'Trouble updating the Agent Action';
|
||||
logger.error(message, error);
|
||||
res.status(500).json({ message });
|
||||
}
|
||||
});
|
||||
let { domain } = metadata;
|
||||
domain = await domainParser(domain, true);
|
||||
|
||||
if (!domain) {
|
||||
return res.status(400).json({ message: 'No domain provided' });
|
||||
}
|
||||
|
||||
const action_id = _action_id ?? nanoid();
|
||||
const initialPromises = [];
|
||||
|
||||
// Permissions already validated by middleware - load agent directly
|
||||
initialPromises.push(getAgent({ id: agent_id }));
|
||||
if (_action_id) {
|
||||
initialPromises.push(getActions({ action_id }, true));
|
||||
}
|
||||
|
||||
/** @type {[Agent, [Action|undefined]]} */
|
||||
const [agent, actions_result] = await Promise.all(initialPromises);
|
||||
if (!agent) {
|
||||
return res.status(404).json({ message: 'Agent not found for adding action' });
|
||||
}
|
||||
|
||||
if (actions_result && actions_result.length) {
|
||||
const action = actions_result[0];
|
||||
metadata = { ...action.metadata, ...metadata };
|
||||
}
|
||||
|
||||
const { actions: _actions = [], author: agent_author } = agent ?? {};
|
||||
const actions = [];
|
||||
for (const action of _actions) {
|
||||
const [_action_domain, current_action_id] = action.split(actionDelimiter);
|
||||
if (current_action_id === action_id) {
|
||||
continue;
|
||||
}
|
||||
|
||||
actions.push(action);
|
||||
}
|
||||
|
||||
actions.push(`${domain}${actionDelimiter}${action_id}`);
|
||||
|
||||
/** @type {string[]}} */
|
||||
const { tools: _tools = [] } = agent;
|
||||
|
||||
const tools = _tools
|
||||
.filter((tool) => !(tool && (tool.includes(domain) || tool.includes(action_id))))
|
||||
.concat(functions.map((tool) => `${tool.function.name}${actionDelimiter}${domain}`));
|
||||
|
||||
// Force version update since actions are changing
|
||||
const updatedAgent = await updateAgent(
|
||||
{ id: agent_id },
|
||||
{ tools, actions },
|
||||
{
|
||||
updatingUserId: req.user.id,
|
||||
forceVersion: true,
|
||||
},
|
||||
);
|
||||
|
||||
// Only update user field for new actions
|
||||
const actionUpdateData = { metadata, agent_id };
|
||||
if (!actions_result || !actions_result.length) {
|
||||
// For new actions, use the agent owner's user ID
|
||||
actionUpdateData.user = agent_author || req.user.id;
|
||||
}
|
||||
|
||||
/** @type {[Action]} */
|
||||
const updatedAction = await updateAction({ action_id }, actionUpdateData);
|
||||
|
||||
const sensitiveFields = ['api_key', 'oauth_client_id', 'oauth_client_secret'];
|
||||
for (let field of sensitiveFields) {
|
||||
if (updatedAction.metadata[field]) {
|
||||
delete updatedAction.metadata[field];
|
||||
}
|
||||
}
|
||||
|
||||
res.json([updatedAgent, updatedAction]);
|
||||
} catch (error) {
|
||||
const message = 'Trouble updating the Agent Action';
|
||||
logger.error(message, error);
|
||||
res.status(500).json({ message });
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
/**
|
||||
* Deletes an action for a specific agent.
|
||||
@@ -163,52 +178,56 @@ router.post('/:agent_id', checkAgentCreate, async (req, res) => {
|
||||
* @param {string} req.params.action_id - The ID of the action to delete.
|
||||
* @returns {Object} 200 - success response - application/json
|
||||
*/
|
||||
router.delete('/:agent_id/:action_id', checkAgentCreate, async (req, res) => {
|
||||
try {
|
||||
const { agent_id, action_id } = req.params;
|
||||
const admin = isAdmin(req);
|
||||
router.delete(
|
||||
'/:agent_id/:action_id',
|
||||
canAccessAgentResource({
|
||||
requiredPermission: PermissionBits.EDIT,
|
||||
resourceIdParam: 'agent_id',
|
||||
}),
|
||||
checkAgentCreate,
|
||||
async (req, res) => {
|
||||
try {
|
||||
const { agent_id, action_id } = req.params;
|
||||
|
||||
// If admin, can delete any agent, otherwise only user's agents
|
||||
const agentQuery = admin ? { id: agent_id } : { id: agent_id, author: req.user.id };
|
||||
const agent = await getAgent(agentQuery);
|
||||
if (!agent) {
|
||||
return res.status(404).json({ message: 'Agent not found for deleting action' });
|
||||
}
|
||||
|
||||
const { tools = [], actions = [] } = agent;
|
||||
|
||||
let domain = '';
|
||||
const updatedActions = actions.filter((action) => {
|
||||
if (action.includes(action_id)) {
|
||||
[domain] = action.split(actionDelimiter);
|
||||
return false;
|
||||
// Permissions already validated by middleware - load agent directly
|
||||
const agent = await getAgent({ id: agent_id });
|
||||
if (!agent) {
|
||||
return res.status(404).json({ message: 'Agent not found for deleting action' });
|
||||
}
|
||||
return true;
|
||||
});
|
||||
|
||||
domain = await domainParser(domain, true);
|
||||
const { tools = [], actions = [] } = agent;
|
||||
|
||||
if (!domain) {
|
||||
return res.status(400).json({ message: 'No domain provided' });
|
||||
let domain = '';
|
||||
const updatedActions = actions.filter((action) => {
|
||||
if (action.includes(action_id)) {
|
||||
[domain] = action.split(actionDelimiter);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
|
||||
domain = await domainParser(domain, true);
|
||||
|
||||
if (!domain) {
|
||||
return res.status(400).json({ message: 'No domain provided' });
|
||||
}
|
||||
|
||||
const updatedTools = tools.filter((tool) => !(tool && tool.includes(domain)));
|
||||
|
||||
// Force version update since actions are being removed
|
||||
await updateAgent(
|
||||
{ id: agent_id },
|
||||
{ tools: updatedTools, actions: updatedActions },
|
||||
{ updatingUserId: req.user.id, forceVersion: true },
|
||||
);
|
||||
await deleteAction({ action_id });
|
||||
res.status(200).json({ message: 'Action deleted successfully' });
|
||||
} catch (error) {
|
||||
const message = 'Trouble deleting the Agent Action';
|
||||
logger.error(message, error);
|
||||
res.status(500).json({ message });
|
||||
}
|
||||
|
||||
const updatedTools = tools.filter((tool) => !(tool && tool.includes(domain)));
|
||||
|
||||
// Force version update since actions are being removed
|
||||
await updateAgent(
|
||||
agentQuery,
|
||||
{ tools: updatedTools, actions: updatedActions },
|
||||
{ updatingUserId: req.user.id, forceVersion: true },
|
||||
);
|
||||
// If admin, can delete any action, otherwise only user's actions
|
||||
const actionQuery = admin ? { action_id } : { action_id, user: req.user.id };
|
||||
await deleteAction(actionQuery);
|
||||
res.status(200).json({ message: 'Action deleted successfully' });
|
||||
} catch (error) {
|
||||
const message = 'Trouble deleting the Agent Action';
|
||||
logger.error(message, error);
|
||||
res.status(500).json({ message });
|
||||
}
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
module.exports = router;
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
const express = require('express');
|
||||
const { generateCheckAccess, skipAgentCheck } = require('@librechat/api');
|
||||
const { PermissionTypes, Permissions } = require('librechat-data-provider');
|
||||
const { PermissionTypes, Permissions, PermissionBits } = require('librechat-data-provider');
|
||||
const {
|
||||
setHeaders,
|
||||
moderateText,
|
||||
// validateModel,
|
||||
validateConvoAccess,
|
||||
buildEndpointOption,
|
||||
canAccessAgentFromBody,
|
||||
} = require('~/server/middleware');
|
||||
const { initializeClient } = require('~/server/services/Endpoints/agents');
|
||||
const AgentController = require('~/server/controllers/agents/request');
|
||||
@@ -23,8 +24,12 @@ const checkAgentAccess = generateCheckAccess({
|
||||
skipCheck: skipAgentCheck,
|
||||
getRoleByName,
|
||||
});
|
||||
const checkAgentResourceAccess = canAccessAgentFromBody({
|
||||
requiredPermission: PermissionBits.VIEW,
|
||||
});
|
||||
|
||||
router.use(checkAgentAccess);
|
||||
router.use(checkAgentResourceAccess);
|
||||
router.use(validateConvoAccess);
|
||||
router.use(buildEndpointOption);
|
||||
router.use(setHeaders);
|
||||
|
||||
@@ -37,4 +37,6 @@ if (isEnabled(LIMIT_MESSAGE_USER)) {
|
||||
chatRouter.use('/', chat);
|
||||
router.use('/chat', chatRouter);
|
||||
|
||||
// Add marketplace routes
|
||||
|
||||
module.exports = router;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
const express = require('express');
|
||||
const { generateCheckAccess } = require('@librechat/api');
|
||||
const { PermissionTypes, Permissions } = require('librechat-data-provider');
|
||||
const { requireJwtAuth } = require('~/server/middleware');
|
||||
const { PermissionTypes, Permissions, PermissionBits } = require('librechat-data-provider');
|
||||
const { requireJwtAuth, canAccessAgentResource } = require('~/server/middleware');
|
||||
const v1 = require('~/server/controllers/agents/v1');
|
||||
const { getRoleByName } = require('~/models/Role');
|
||||
const actions = require('./actions');
|
||||
@@ -44,6 +44,11 @@ router.use('/actions', actions);
|
||||
*/
|
||||
router.use('/tools', tools);
|
||||
|
||||
/**
|
||||
* Get all agent categories with counts
|
||||
* @route GET /agents/marketplace/categories
|
||||
*/
|
||||
router.get('/categories', v1.getAgentCategories);
|
||||
/**
|
||||
* Creates an agent.
|
||||
* @route POST /agents
|
||||
@@ -53,13 +58,38 @@ router.use('/tools', tools);
|
||||
router.post('/', checkAgentCreate, v1.createAgent);
|
||||
|
||||
/**
|
||||
* Retrieves an agent.
|
||||
* Retrieves basic agent information (VIEW permission required).
|
||||
* Returns safe, non-sensitive agent data for viewing purposes.
|
||||
* @route GET /agents/:id
|
||||
* @param {string} req.params.id - Agent identifier.
|
||||
* @returns {Agent} 200 - Success response - application/json
|
||||
* @returns {Agent} 200 - Basic agent info - application/json
|
||||
*/
|
||||
router.get('/:id', checkAgentAccess, v1.getAgent);
|
||||
router.get(
|
||||
'/:id',
|
||||
checkAgentAccess,
|
||||
canAccessAgentResource({
|
||||
requiredPermission: PermissionBits.VIEW,
|
||||
resourceIdParam: 'id',
|
||||
}),
|
||||
v1.getAgent,
|
||||
);
|
||||
|
||||
/**
|
||||
* Retrieves full agent details including sensitive configuration (EDIT permission required).
|
||||
* Returns complete agent data for editing/configuration purposes.
|
||||
* @route GET /agents/:id/expanded
|
||||
* @param {string} req.params.id - Agent identifier.
|
||||
* @returns {Agent} 200 - Full agent details - application/json
|
||||
*/
|
||||
router.get(
|
||||
'/:id/expanded',
|
||||
checkAgentAccess,
|
||||
canAccessAgentResource({
|
||||
requiredPermission: PermissionBits.EDIT,
|
||||
resourceIdParam: 'id',
|
||||
}),
|
||||
(req, res) => v1.getAgent(req, res, true), // Expanded version
|
||||
);
|
||||
/**
|
||||
* Updates an agent.
|
||||
* @route PATCH /agents/:id
|
||||
@@ -67,7 +97,15 @@ router.get('/:id', checkAgentAccess, v1.getAgent);
|
||||
* @param {AgentUpdateParams} req.body - The agent update parameters.
|
||||
* @returns {Agent} 200 - Success response - application/json
|
||||
*/
|
||||
router.patch('/:id', checkGlobalAgentShare, v1.updateAgent);
|
||||
router.patch(
|
||||
'/:id',
|
||||
checkGlobalAgentShare,
|
||||
canAccessAgentResource({
|
||||
requiredPermission: PermissionBits.EDIT,
|
||||
resourceIdParam: 'id',
|
||||
}),
|
||||
v1.updateAgent,
|
||||
);
|
||||
|
||||
/**
|
||||
* Duplicates an agent.
|
||||
@@ -75,7 +113,15 @@ router.patch('/:id', checkGlobalAgentShare, v1.updateAgent);
|
||||
* @param {string} req.params.id - Agent identifier.
|
||||
* @returns {Agent} 201 - Success response - application/json
|
||||
*/
|
||||
router.post('/:id/duplicate', checkAgentCreate, v1.duplicateAgent);
|
||||
router.post(
|
||||
'/:id/duplicate',
|
||||
checkAgentCreate,
|
||||
canAccessAgentResource({
|
||||
requiredPermission: PermissionBits.VIEW,
|
||||
resourceIdParam: 'id',
|
||||
}),
|
||||
v1.duplicateAgent,
|
||||
);
|
||||
|
||||
/**
|
||||
* Deletes an agent.
|
||||
@@ -83,7 +129,15 @@ router.post('/:id/duplicate', checkAgentCreate, v1.duplicateAgent);
|
||||
* @param {string} req.params.id - Agent identifier.
|
||||
* @returns {Agent} 200 - success response - application/json
|
||||
*/
|
||||
router.delete('/:id', checkAgentCreate, v1.deleteAgent);
|
||||
router.delete(
|
||||
'/:id',
|
||||
checkAgentCreate,
|
||||
canAccessAgentResource({
|
||||
requiredPermission: PermissionBits.DELETE,
|
||||
resourceIdParam: 'id',
|
||||
}),
|
||||
v1.deleteAgent,
|
||||
);
|
||||
|
||||
/**
|
||||
* Reverts an agent to a previous version.
|
||||
@@ -110,6 +164,14 @@ router.get('/', checkAgentAccess, v1.getListAgents);
|
||||
* @param {string} [req.body.metadata] - Optional metadata for the agent's avatar.
|
||||
* @returns {Object} 200 - success response - application/json
|
||||
*/
|
||||
avatar.post('/:agent_id/avatar/', checkAgentAccess, v1.uploadAgentAvatar);
|
||||
avatar.post(
|
||||
'/:agent_id/avatar/',
|
||||
checkAgentAccess,
|
||||
canAccessAgentResource({
|
||||
requiredPermission: PermissionBits.EDIT,
|
||||
resourceIdParam: 'agent_id',
|
||||
}),
|
||||
v1.uploadAgentAvatar,
|
||||
);
|
||||
|
||||
module.exports = { v1: router, avatar };
|
||||
|
||||
@@ -1,72 +1,75 @@
|
||||
const express = require('express');
|
||||
const { createSetBalanceConfig } = require('@librechat/api');
|
||||
const {
|
||||
refreshController,
|
||||
registrationController,
|
||||
resetPasswordController,
|
||||
resetPasswordRequestController,
|
||||
resetPasswordController,
|
||||
registrationController,
|
||||
graphTokenController,
|
||||
refreshController,
|
||||
} = require('~/server/controllers/AuthController');
|
||||
const { loginController } = require('~/server/controllers/auth/LoginController');
|
||||
const { logoutController } = require('~/server/controllers/auth/LogoutController');
|
||||
const { verify2FAWithTempToken } = require('~/server/controllers/auth/TwoFactorAuthController');
|
||||
const {
|
||||
regenerateBackupCodes,
|
||||
disable2FA,
|
||||
confirm2FA,
|
||||
enable2FA,
|
||||
verify2FA,
|
||||
disable2FA,
|
||||
regenerateBackupCodes,
|
||||
confirm2FA,
|
||||
} = require('~/server/controllers/TwoFactorController');
|
||||
const {
|
||||
checkBan,
|
||||
logHeaders,
|
||||
loginLimiter,
|
||||
requireJwtAuth,
|
||||
checkInviteUser,
|
||||
registerLimiter,
|
||||
requireLdapAuth,
|
||||
setBalanceConfig,
|
||||
requireLocalAuth,
|
||||
resetPasswordLimiter,
|
||||
validateRegistration,
|
||||
validatePasswordReset,
|
||||
} = require('~/server/middleware');
|
||||
const { verify2FAWithTempToken } = require('~/server/controllers/auth/TwoFactorAuthController');
|
||||
const { logoutController } = require('~/server/controllers/auth/LogoutController');
|
||||
const { loginController } = require('~/server/controllers/auth/LoginController');
|
||||
const { getBalanceConfig } = require('~/server/services/Config');
|
||||
const middleware = require('~/server/middleware');
|
||||
const { Balance } = require('~/db/models');
|
||||
|
||||
const setBalanceConfig = createSetBalanceConfig({
|
||||
getBalanceConfig,
|
||||
Balance,
|
||||
});
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
const ldapAuth = !!process.env.LDAP_URL && !!process.env.LDAP_USER_SEARCH_BASE;
|
||||
//Local
|
||||
router.post('/logout', requireJwtAuth, logoutController);
|
||||
router.post('/logout', middleware.requireJwtAuth, logoutController);
|
||||
router.post(
|
||||
'/login',
|
||||
logHeaders,
|
||||
loginLimiter,
|
||||
checkBan,
|
||||
ldapAuth ? requireLdapAuth : requireLocalAuth,
|
||||
middleware.logHeaders,
|
||||
middleware.loginLimiter,
|
||||
middleware.checkBan,
|
||||
ldapAuth ? middleware.requireLdapAuth : middleware.requireLocalAuth,
|
||||
setBalanceConfig,
|
||||
loginController,
|
||||
);
|
||||
router.post('/refresh', refreshController);
|
||||
router.post(
|
||||
'/register',
|
||||
registerLimiter,
|
||||
checkBan,
|
||||
checkInviteUser,
|
||||
validateRegistration,
|
||||
middleware.registerLimiter,
|
||||
middleware.checkBan,
|
||||
middleware.checkInviteUser,
|
||||
middleware.validateRegistration,
|
||||
registrationController,
|
||||
);
|
||||
router.post(
|
||||
'/requestPasswordReset',
|
||||
resetPasswordLimiter,
|
||||
checkBan,
|
||||
validatePasswordReset,
|
||||
middleware.resetPasswordLimiter,
|
||||
middleware.checkBan,
|
||||
middleware.validatePasswordReset,
|
||||
resetPasswordRequestController,
|
||||
);
|
||||
router.post('/resetPassword', checkBan, validatePasswordReset, resetPasswordController);
|
||||
router.post(
|
||||
'/resetPassword',
|
||||
middleware.checkBan,
|
||||
middleware.validatePasswordReset,
|
||||
resetPasswordController,
|
||||
);
|
||||
|
||||
router.get('/2fa/enable', requireJwtAuth, enable2FA);
|
||||
router.post('/2fa/verify', requireJwtAuth, verify2FA);
|
||||
router.post('/2fa/verify-temp', checkBan, verify2FAWithTempToken);
|
||||
router.post('/2fa/confirm', requireJwtAuth, confirm2FA);
|
||||
router.post('/2fa/disable', requireJwtAuth, disable2FA);
|
||||
router.post('/2fa/backup/regenerate', requireJwtAuth, regenerateBackupCodes);
|
||||
router.get('/2fa/enable', middleware.requireJwtAuth, enable2FA);
|
||||
router.post('/2fa/verify', middleware.requireJwtAuth, verify2FA);
|
||||
router.post('/2fa/verify-temp', middleware.checkBan, verify2FAWithTempToken);
|
||||
router.post('/2fa/confirm', middleware.requireJwtAuth, confirm2FA);
|
||||
router.post('/2fa/disable', middleware.requireJwtAuth, disable2FA);
|
||||
router.post('/2fa/backup/regenerate', middleware.requireJwtAuth, regenerateBackupCodes);
|
||||
|
||||
router.get('/graph-token', middleware.requireJwtAuth, graphTokenController);
|
||||
|
||||
module.exports = router;
|
||||
|
||||
@@ -21,6 +21,9 @@ const publicSharedLinksEnabled =
|
||||
(process.env.ALLOW_SHARED_LINKS_PUBLIC === undefined ||
|
||||
isEnabled(process.env.ALLOW_SHARED_LINKS_PUBLIC));
|
||||
|
||||
const sharePointFilePickerEnabled = isEnabled(process.env.ENABLE_SHAREPOINT_FILEPICKER);
|
||||
const openidReuseTokens = isEnabled(process.env.OPENID_REUSE_TOKENS);
|
||||
|
||||
router.get('/', async function (req, res) {
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
|
||||
@@ -98,22 +101,30 @@ router.get('/', async function (req, res) {
|
||||
instanceProjectId: instanceProject._id.toString(),
|
||||
bundlerURL: process.env.SANDPACK_BUNDLER_URL,
|
||||
staticBundlerURL: process.env.SANDPACK_STATIC_BUNDLER_URL,
|
||||
sharePointFilePickerEnabled,
|
||||
sharePointBaseUrl: process.env.SHAREPOINT_BASE_URL,
|
||||
sharePointPickerGraphScope: process.env.SHAREPOINT_PICKER_GRAPH_SCOPE,
|
||||
sharePointPickerSharePointScope: process.env.SHAREPOINT_PICKER_SHAREPOINT_SCOPE,
|
||||
openidReuseTokens,
|
||||
};
|
||||
|
||||
payload.mcpServers = {};
|
||||
const config = await getCustomConfig();
|
||||
if (config?.mcpServers != null) {
|
||||
const mcpManager = getMCPManager();
|
||||
const oauthServers = mcpManager.getOAuthServers();
|
||||
|
||||
for (const serverName in config.mcpServers) {
|
||||
const serverConfig = config.mcpServers[serverName];
|
||||
payload.mcpServers[serverName] = {
|
||||
customUserVars: serverConfig?.customUserVars || {},
|
||||
chatMenu: serverConfig?.chatMenu,
|
||||
isOAuth: oauthServers.has(serverName),
|
||||
startup: serverConfig?.startup,
|
||||
};
|
||||
try {
|
||||
const mcpManager = getMCPManager();
|
||||
const oauthServers = mcpManager.getOAuthServers();
|
||||
for (const serverName in config.mcpServers) {
|
||||
const serverConfig = config.mcpServers[serverName];
|
||||
payload.mcpServers[serverName] = {
|
||||
startup: serverConfig?.startup,
|
||||
chatMenu: serverConfig?.chatMenu,
|
||||
isOAuth: oauthServers?.has(serverName),
|
||||
customUserVars: serverConfig?.customUserVars || {},
|
||||
};
|
||||
}
|
||||
} catch (err) {
|
||||
logger.error('Error loading MCP servers', err);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ const express = require('express');
|
||||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { resizeAvatar } = require('~/server/services/Files/images/avatar');
|
||||
const { filterFile } = require('~/server/services/Files/process');
|
||||
const { getFileStrategy } = require('~/server/utils/getFileStrategy');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
const router = express.Router();
|
||||
@@ -18,7 +19,7 @@ router.post('/', async (req, res) => {
|
||||
throw new Error('User ID is undefined');
|
||||
}
|
||||
|
||||
const fileStrategy = req.app.locals.fileStrategy;
|
||||
const fileStrategy = getFileStrategy(req.app.locals, { isAvatar: true });
|
||||
const desiredFormat = req.app.locals.imageOutputType;
|
||||
const resizedBuffer = await resizeAvatar({
|
||||
userId,
|
||||
|
||||
@@ -2,10 +2,13 @@ const express = require('express');
|
||||
const request = require('supertest');
|
||||
const mongoose = require('mongoose');
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const { createMethods } = require('@librechat/data-schemas');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { GLOBAL_PROJECT_NAME } = require('librechat-data-provider').Constants;
|
||||
const { AccessRoleIds, ResourceType, PrincipalType } = require('librechat-data-provider');
|
||||
const { createAgent } = require('~/models/Agent');
|
||||
const { createFile } = require('~/models/File');
|
||||
|
||||
// Mock dependencies
|
||||
// Only mock the external dependencies that we don't want to test
|
||||
jest.mock('~/server/services/Files/process', () => ({
|
||||
processDeleteRequest: jest.fn().mockResolvedValue({}),
|
||||
filterFile: jest.fn(),
|
||||
@@ -25,31 +28,8 @@ jest.mock('~/server/services/Tools/credentials', () => ({
|
||||
loadAuthValues: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/services/Files/S3/crud', () => ({
|
||||
refreshS3FileUrls: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/cache', () => ({
|
||||
getLogStores: jest.fn(() => ({
|
||||
get: jest.fn(),
|
||||
set: jest.fn(),
|
||||
})),
|
||||
}));
|
||||
|
||||
jest.mock('~/config', () => ({
|
||||
logger: {
|
||||
error: jest.fn(),
|
||||
warn: jest.fn(),
|
||||
debug: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
const { createFile } = require('~/models/File');
|
||||
const { createAgent } = require('~/models/Agent');
|
||||
const { getProjectByName } = require('~/models/Project');
|
||||
|
||||
// Import the router after mocks
|
||||
const router = require('./files');
|
||||
// Import the router
|
||||
const router = require('~/server/routes/files/files');
|
||||
|
||||
describe('File Routes - Agent Files Endpoint', () => {
|
||||
let app;
|
||||
@@ -60,13 +40,42 @@ describe('File Routes - Agent Files Endpoint', () => {
|
||||
let fileId1;
|
||||
let fileId2;
|
||||
let fileId3;
|
||||
let File;
|
||||
let User;
|
||||
let Agent;
|
||||
let methods;
|
||||
let AclEntry;
|
||||
// eslint-disable-next-line no-unused-vars
|
||||
let AccessRole;
|
||||
let modelsToCleanup = [];
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
await mongoose.connect(mongoServer.getUri());
|
||||
const mongoUri = mongoServer.getUri();
|
||||
await mongoose.connect(mongoUri);
|
||||
|
||||
// Initialize models
|
||||
require('~/db/models');
|
||||
// Initialize all models using createModels
|
||||
const { createModels } = require('@librechat/data-schemas');
|
||||
const models = createModels(mongoose);
|
||||
|
||||
// Track which models we're adding
|
||||
modelsToCleanup = Object.keys(models);
|
||||
|
||||
// Register models on mongoose.models so methods can access them
|
||||
Object.assign(mongoose.models, models);
|
||||
|
||||
// Create methods with our test mongoose instance
|
||||
methods = createMethods(mongoose);
|
||||
|
||||
// Now we can access models from the db/models
|
||||
File = models.File;
|
||||
Agent = models.Agent;
|
||||
AclEntry = models.AclEntry;
|
||||
User = models.User;
|
||||
AccessRole = models.AccessRole;
|
||||
|
||||
// Seed default roles using our methods
|
||||
await methods.seedDefaultRoles();
|
||||
|
||||
app = express();
|
||||
app.use(express.json());
|
||||
@@ -82,88 +91,121 @@ describe('File Routes - Agent Files Endpoint', () => {
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
jest.clearAllMocks();
|
||||
|
||||
// Clear database
|
||||
// Clean up all collections before disconnecting
|
||||
const collections = mongoose.connection.collections;
|
||||
for (const key in collections) {
|
||||
await collections[key].deleteMany({});
|
||||
}
|
||||
|
||||
authorId = new mongoose.Types.ObjectId().toString();
|
||||
otherUserId = new mongoose.Types.ObjectId().toString();
|
||||
// Clear only the models we added
|
||||
for (const modelName of modelsToCleanup) {
|
||||
if (mongoose.models[modelName]) {
|
||||
delete mongoose.models[modelName];
|
||||
}
|
||||
}
|
||||
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
|
||||
beforeEach(async () => {
|
||||
// Clean up all test data
|
||||
await File.deleteMany({});
|
||||
await Agent.deleteMany({});
|
||||
await User.deleteMany({});
|
||||
await AclEntry.deleteMany({});
|
||||
// Don't delete AccessRole as they are seeded defaults needed for tests
|
||||
|
||||
// Create test users
|
||||
authorId = new mongoose.Types.ObjectId();
|
||||
otherUserId = new mongoose.Types.ObjectId();
|
||||
agentId = uuidv4();
|
||||
fileId1 = uuidv4();
|
||||
fileId2 = uuidv4();
|
||||
fileId3 = uuidv4();
|
||||
|
||||
// Create users in database
|
||||
await User.create({
|
||||
_id: authorId,
|
||||
username: 'author',
|
||||
email: 'author@test.com',
|
||||
});
|
||||
|
||||
await User.create({
|
||||
_id: otherUserId,
|
||||
username: 'other',
|
||||
email: 'other@test.com',
|
||||
});
|
||||
|
||||
// Create files
|
||||
await createFile({
|
||||
user: authorId,
|
||||
file_id: fileId1,
|
||||
filename: 'agent-file1.txt',
|
||||
filepath: `/uploads/${authorId}/${fileId1}`,
|
||||
bytes: 1024,
|
||||
filename: 'file1.txt',
|
||||
filepath: '/uploads/file1.txt',
|
||||
bytes: 100,
|
||||
type: 'text/plain',
|
||||
});
|
||||
|
||||
await createFile({
|
||||
user: authorId,
|
||||
file_id: fileId2,
|
||||
filename: 'agent-file2.txt',
|
||||
filepath: `/uploads/${authorId}/${fileId2}`,
|
||||
bytes: 2048,
|
||||
filename: 'file2.txt',
|
||||
filepath: '/uploads/file2.txt',
|
||||
bytes: 200,
|
||||
type: 'text/plain',
|
||||
});
|
||||
|
||||
await createFile({
|
||||
user: otherUserId,
|
||||
file_id: fileId3,
|
||||
filename: 'user-file.txt',
|
||||
filepath: `/uploads/${otherUserId}/${fileId3}`,
|
||||
bytes: 512,
|
||||
filename: 'file3.txt',
|
||||
filepath: '/uploads/file3.txt',
|
||||
bytes: 300,
|
||||
type: 'text/plain',
|
||||
});
|
||||
|
||||
// Create an agent with files attached
|
||||
await createAgent({
|
||||
id: agentId,
|
||||
name: 'Test Agent',
|
||||
author: authorId,
|
||||
model: 'gpt-4',
|
||||
provider: 'openai',
|
||||
isCollaborative: true,
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: [fileId1, fileId2],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Share the agent globally
|
||||
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME, '_id');
|
||||
if (globalProject) {
|
||||
const { updateAgent } = require('~/models/Agent');
|
||||
await updateAgent({ id: agentId }, { projectIds: [globalProject._id] });
|
||||
}
|
||||
});
|
||||
|
||||
describe('GET /files/agent/:agent_id', () => {
|
||||
it('should return files accessible through the agent for non-author', async () => {
|
||||
it('should return files accessible through the agent for non-author with EDIT permission', async () => {
|
||||
// Create an agent with files attached
|
||||
const agent = await createAgent({
|
||||
id: agentId,
|
||||
name: 'Test Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: authorId,
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: [fileId1, fileId2],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Grant EDIT permission to user on the agent using PermissionService
|
||||
const { grantPermission } = require('~/server/services/PermissionService');
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: otherUserId,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
accessRoleId: AccessRoleIds.AGENT_EDITOR,
|
||||
grantedBy: authorId,
|
||||
});
|
||||
|
||||
// Mock req.user for this request
|
||||
app.use((req, res, next) => {
|
||||
req.user = { id: otherUserId.toString() };
|
||||
next();
|
||||
});
|
||||
|
||||
const response = await request(app).get(`/files/agent/${agentId}`);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body).toHaveLength(2); // Only agent files, not user-owned files
|
||||
|
||||
const fileIds = response.body.map((f) => f.file_id);
|
||||
expect(fileIds).toContain(fileId1);
|
||||
expect(fileIds).toContain(fileId2);
|
||||
expect(fileIds).not.toContain(fileId3); // User's own file not included
|
||||
expect(Array.isArray(response.body)).toBe(true);
|
||||
expect(response.body).toHaveLength(2);
|
||||
expect(response.body.map((f) => f.file_id)).toContain(fileId1);
|
||||
expect(response.body.map((f) => f.file_id)).toContain(fileId2);
|
||||
});
|
||||
|
||||
it('should return 400 when agent_id is not provided', async () => {
|
||||
@@ -176,45 +218,63 @@ describe('File Routes - Agent Files Endpoint', () => {
|
||||
const response = await request(app).get('/files/agent/non-existent-agent');
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body).toEqual([]); // Empty array for non-existent agent
|
||||
expect(Array.isArray(response.body)).toBe(true);
|
||||
expect(response.body).toEqual([]);
|
||||
});
|
||||
|
||||
it('should return empty array when agent is not collaborative', async () => {
|
||||
// Create a non-collaborative agent
|
||||
const nonCollabAgentId = uuidv4();
|
||||
await createAgent({
|
||||
id: nonCollabAgentId,
|
||||
name: 'Non-Collaborative Agent',
|
||||
author: authorId,
|
||||
model: 'gpt-4',
|
||||
it('should return empty array when user only has VIEW permission', async () => {
|
||||
// Create an agent with files attached
|
||||
const agent = await createAgent({
|
||||
id: agentId,
|
||||
name: 'Test Agent',
|
||||
provider: 'openai',
|
||||
isCollaborative: false,
|
||||
model: 'gpt-4',
|
||||
author: authorId,
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: [fileId1],
|
||||
file_ids: [fileId1, fileId2],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Share it globally
|
||||
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME, '_id');
|
||||
if (globalProject) {
|
||||
const { updateAgent } = require('~/models/Agent');
|
||||
await updateAgent({ id: nonCollabAgentId }, { projectIds: [globalProject._id] });
|
||||
}
|
||||
// Grant only VIEW permission to user on the agent
|
||||
const { grantPermission } = require('~/server/services/PermissionService');
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: otherUserId,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
accessRoleId: AccessRoleIds.AGENT_VIEWER,
|
||||
grantedBy: authorId,
|
||||
});
|
||||
|
||||
const response = await request(app).get(`/files/agent/${nonCollabAgentId}`);
|
||||
const response = await request(app).get(`/files/agent/${agentId}`);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body).toEqual([]); // Empty array when not collaborative
|
||||
expect(Array.isArray(response.body)).toBe(true);
|
||||
expect(response.body).toEqual([]);
|
||||
});
|
||||
|
||||
it('should return agent files for agent author', async () => {
|
||||
// Create an agent with files attached
|
||||
await createAgent({
|
||||
id: agentId,
|
||||
name: 'Test Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: authorId,
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: [fileId1, fileId2],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Create a new app instance with author authentication
|
||||
const authorApp = express();
|
||||
authorApp.use(express.json());
|
||||
authorApp.use((req, res, next) => {
|
||||
req.user = { id: authorId };
|
||||
req.user = { id: authorId.toString() };
|
||||
req.app = { locals: {} };
|
||||
next();
|
||||
});
|
||||
@@ -223,46 +283,48 @@ describe('File Routes - Agent Files Endpoint', () => {
|
||||
const response = await request(authorApp).get(`/files/agent/${agentId}`);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body).toHaveLength(2); // Agent files for author
|
||||
|
||||
const fileIds = response.body.map((f) => f.file_id);
|
||||
expect(fileIds).toContain(fileId1);
|
||||
expect(fileIds).toContain(fileId2);
|
||||
expect(fileIds).not.toContain(fileId3); // User's own file not included
|
||||
expect(Array.isArray(response.body)).toBe(true);
|
||||
expect(response.body).toHaveLength(2);
|
||||
});
|
||||
|
||||
it('should return files uploaded by other users to shared agent for author', async () => {
|
||||
// Create a file uploaded by another user
|
||||
const anotherUserId = new mongoose.Types.ObjectId();
|
||||
const otherUserFileId = uuidv4();
|
||||
const anotherUserId = new mongoose.Types.ObjectId().toString();
|
||||
|
||||
await User.create({
|
||||
_id: anotherUserId,
|
||||
username: 'another',
|
||||
email: 'another@test.com',
|
||||
});
|
||||
|
||||
await createFile({
|
||||
user: anotherUserId,
|
||||
file_id: otherUserFileId,
|
||||
filename: 'other-user-file.txt',
|
||||
filepath: `/uploads/${anotherUserId}/${otherUserFileId}`,
|
||||
bytes: 4096,
|
||||
filepath: '/uploads/other-user-file.txt',
|
||||
bytes: 400,
|
||||
type: 'text/plain',
|
||||
});
|
||||
|
||||
// Update agent to include the file uploaded by another user
|
||||
const { updateAgent } = require('~/models/Agent');
|
||||
await updateAgent(
|
||||
{ id: agentId },
|
||||
{
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: [fileId1, fileId2, otherUserFileId],
|
||||
},
|
||||
// Create agent to include the file uploaded by another user
|
||||
await createAgent({
|
||||
id: agentId,
|
||||
name: 'Test Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: authorId,
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: [fileId1, otherUserFileId],
|
||||
},
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
// Create app instance with author authentication
|
||||
// Create a new app instance with author authentication
|
||||
const authorApp = express();
|
||||
authorApp.use(express.json());
|
||||
authorApp.use((req, res, next) => {
|
||||
req.user = { id: authorId };
|
||||
req.user = { id: authorId.toString() };
|
||||
req.app = { locals: {} };
|
||||
next();
|
||||
});
|
||||
@@ -271,12 +333,10 @@ describe('File Routes - Agent Files Endpoint', () => {
|
||||
const response = await request(authorApp).get(`/files/agent/${agentId}`);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body).toHaveLength(3); // Including file from another user
|
||||
|
||||
const fileIds = response.body.map((f) => f.file_id);
|
||||
expect(fileIds).toContain(fileId1);
|
||||
expect(fileIds).toContain(fileId2);
|
||||
expect(fileIds).toContain(otherUserFileId); // File uploaded by another user
|
||||
expect(Array.isArray(response.body)).toBe(true);
|
||||
expect(response.body).toHaveLength(2);
|
||||
expect(response.body.map((f) => f.file_id)).toContain(fileId1);
|
||||
expect(response.body.map((f) => f.file_id)).toContain(otherUserFileId);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -5,9 +5,10 @@ const {
|
||||
Time,
|
||||
isUUID,
|
||||
CacheKeys,
|
||||
Constants,
|
||||
FileSources,
|
||||
ResourceType,
|
||||
EModelEndpoint,
|
||||
PermissionBits,
|
||||
isAgentsEndpoint,
|
||||
checkOpenAIStorage,
|
||||
} = require('librechat-data-provider');
|
||||
@@ -17,12 +18,15 @@ const {
|
||||
processDeleteRequest,
|
||||
processAgentFileUpload,
|
||||
} = require('~/server/services/Files/process');
|
||||
const { getFiles, batchUpdateFiles, hasAccessToFilesViaAgent } = require('~/models/File');
|
||||
const { fileAccess } = require('~/server/middleware/accessResources/fileAccess');
|
||||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { getOpenAIClient } = require('~/server/controllers/assistants/helpers');
|
||||
const { checkPermission } = require('~/server/services/PermissionService');
|
||||
const { loadAuthValues } = require('~/server/services/Tools/credentials');
|
||||
const { refreshS3FileUrls } = require('~/server/services/Files/S3/crud');
|
||||
const { getProjectByName } = require('~/models/Project');
|
||||
const { hasAccessToFilesViaAgent } = require('~/server/services/Files');
|
||||
const { getFiles, batchUpdateFiles } = require('~/models/File');
|
||||
const { cleanFileName } = require('~/server/utils/files');
|
||||
const { getAssistant } = require('~/models/Assistant');
|
||||
const { getAgent } = require('~/models/Agent');
|
||||
const { getLogStores } = require('~/cache');
|
||||
@@ -67,29 +71,25 @@ router.get('/agent/:agent_id', async (req, res) => {
|
||||
return res.status(400).json({ error: 'Agent ID is required' });
|
||||
}
|
||||
|
||||
// Get the agent to check ownership and attached files
|
||||
const agent = await getAgent({ id: agent_id });
|
||||
|
||||
if (!agent) {
|
||||
// No agent found, return empty array
|
||||
return res.status(200).json([]);
|
||||
}
|
||||
|
||||
// Check if user has access to the agent
|
||||
if (agent.author.toString() !== userId) {
|
||||
// Non-authors need the agent to be globally shared and collaborative
|
||||
const globalProject = await getProjectByName(Constants.GLOBAL_PROJECT_NAME, '_id');
|
||||
const hasEditPermission = await checkPermission({
|
||||
userId,
|
||||
role: req.user.role,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
requiredPermission: PermissionBits.EDIT,
|
||||
});
|
||||
|
||||
if (
|
||||
!globalProject ||
|
||||
!agent.projectIds.some((pid) => pid.toString() === globalProject._id.toString()) ||
|
||||
!agent.isCollaborative
|
||||
) {
|
||||
if (!hasEditPermission) {
|
||||
return res.status(200).json([]);
|
||||
}
|
||||
}
|
||||
|
||||
// Collect all file IDs from agent's tool resources
|
||||
const agentFileIds = [];
|
||||
if (agent.tool_resources) {
|
||||
for (const [, resource] of Object.entries(agent.tool_resources)) {
|
||||
@@ -99,12 +99,10 @@ router.get('/agent/:agent_id', async (req, res) => {
|
||||
}
|
||||
}
|
||||
|
||||
// If no files attached to agent, return empty array
|
||||
if (agentFileIds.length === 0) {
|
||||
return res.status(200).json([]);
|
||||
}
|
||||
|
||||
// Get only the files attached to this agent
|
||||
const files = await getFiles({ file_id: { $in: agentFileIds } }, null, { text: 0 });
|
||||
|
||||
res.status(200).json(files);
|
||||
@@ -153,18 +151,15 @@ router.delete('/', async (req, res) => {
|
||||
|
||||
const ownedFiles = [];
|
||||
const nonOwnedFiles = [];
|
||||
const fileMap = new Map();
|
||||
|
||||
for (const file of dbFiles) {
|
||||
fileMap.set(file.file_id, file);
|
||||
if (file.user.toString() === req.user.id) {
|
||||
if (file.user.toString() === req.user.id.toString()) {
|
||||
ownedFiles.push(file);
|
||||
} else {
|
||||
nonOwnedFiles.push(file);
|
||||
}
|
||||
}
|
||||
|
||||
// If all files are owned by the user, no need for further checks
|
||||
if (nonOwnedFiles.length === 0) {
|
||||
await processDeleteRequest({ req, files: ownedFiles });
|
||||
logger.debug(
|
||||
@@ -177,20 +172,18 @@ router.delete('/', async (req, res) => {
|
||||
return;
|
||||
}
|
||||
|
||||
// Check access for non-owned files
|
||||
let authorizedFiles = [...ownedFiles];
|
||||
let unauthorizedFiles = [];
|
||||
|
||||
if (req.body.agent_id && nonOwnedFiles.length > 0) {
|
||||
// Batch check access for all non-owned files
|
||||
const nonOwnedFileIds = nonOwnedFiles.map((f) => f.file_id);
|
||||
const accessMap = await hasAccessToFilesViaAgent(
|
||||
req.user.id,
|
||||
nonOwnedFileIds,
|
||||
req.body.agent_id,
|
||||
);
|
||||
const accessMap = await hasAccessToFilesViaAgent({
|
||||
userId: req.user.id,
|
||||
role: req.user.role,
|
||||
fileIds: nonOwnedFileIds,
|
||||
agentId: req.body.agent_id,
|
||||
});
|
||||
|
||||
// Separate authorized and unauthorized files
|
||||
for (const file of nonOwnedFiles) {
|
||||
if (accessMap.get(file.file_id)) {
|
||||
authorizedFiles.push(file);
|
||||
@@ -199,7 +192,6 @@ router.delete('/', async (req, res) => {
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// No agent context, all non-owned files are unauthorized
|
||||
unauthorizedFiles = nonOwnedFiles;
|
||||
}
|
||||
|
||||
@@ -303,42 +295,30 @@ router.get('/code/download/:session_id/:fileId', async (req, res) => {
|
||||
}
|
||||
});
|
||||
|
||||
router.get('/download/:userId/:file_id', async (req, res) => {
|
||||
router.get('/download/:userId/:file_id', fileAccess, async (req, res) => {
|
||||
try {
|
||||
const { userId, file_id } = req.params;
|
||||
logger.debug(`File download requested by user ${userId}: ${file_id}`);
|
||||
|
||||
if (userId !== req.user.id) {
|
||||
logger.warn(`${errorPrefix} forbidden: ${file_id}`);
|
||||
return res.status(403).send('Forbidden');
|
||||
}
|
||||
|
||||
const [file] = await getFiles({ file_id });
|
||||
const errorPrefix = `File download requested by user ${userId}`;
|
||||
|
||||
if (!file) {
|
||||
logger.warn(`${errorPrefix} not found: ${file_id}`);
|
||||
return res.status(404).send('File not found');
|
||||
}
|
||||
|
||||
if (!file.filepath.includes(userId)) {
|
||||
logger.warn(`${errorPrefix} forbidden: ${file_id}`);
|
||||
return res.status(403).send('Forbidden');
|
||||
}
|
||||
// Access already validated by fileAccess middleware
|
||||
const file = req.fileAccess.file;
|
||||
|
||||
if (checkOpenAIStorage(file.source) && !file.model) {
|
||||
logger.warn(`${errorPrefix} has no associated model: ${file_id}`);
|
||||
logger.warn(`File download requested by user ${userId} has no associated model: ${file_id}`);
|
||||
return res.status(400).send('The model used when creating this file is not available');
|
||||
}
|
||||
|
||||
const { getDownloadStream } = getStrategyFunctions(file.source);
|
||||
if (!getDownloadStream) {
|
||||
logger.warn(`${errorPrefix} has no stream method implemented: ${file.source}`);
|
||||
logger.warn(
|
||||
`File download requested by user ${userId} has no stream method implemented: ${file.source}`,
|
||||
);
|
||||
return res.status(501).send('Not Implemented');
|
||||
}
|
||||
|
||||
const setHeaders = () => {
|
||||
res.setHeader('Content-Disposition', `attachment; filename="${file.filename}"`);
|
||||
const cleanedFilename = cleanFileName(file.filename);
|
||||
res.setHeader('Content-Disposition', `attachment; filename="${cleanedFilename}"`);
|
||||
res.setHeader('Content-Type', 'application/octet-stream');
|
||||
res.setHeader('X-File-Metadata', JSON.stringify(file));
|
||||
};
|
||||
@@ -365,12 +345,17 @@ router.get('/download/:userId/:file_id', async (req, res) => {
|
||||
logger.debug(`File ${file_id} downloaded from OpenAI`);
|
||||
passThrough.body.pipe(res);
|
||||
} else {
|
||||
fileStream = getDownloadStream(file_id);
|
||||
fileStream = await getDownloadStream(req, file.filepath);
|
||||
|
||||
fileStream.on('error', (streamError) => {
|
||||
logger.error('[DOWNLOAD ROUTE] Stream error:', streamError);
|
||||
});
|
||||
|
||||
setHeaders();
|
||||
fileStream.pipe(res);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Error downloading file:', error);
|
||||
logger.error('[DOWNLOAD ROUTE] Error downloading file:', error);
|
||||
res.status(500).send('Error downloading file');
|
||||
}
|
||||
});
|
||||
@@ -405,7 +390,6 @@ router.post('/', async (req, res) => {
|
||||
message = error.message;
|
||||
}
|
||||
|
||||
// TODO: delete remote file if it exists
|
||||
try {
|
||||
await fs.unlink(req.file.path);
|
||||
cleanup = false;
|
||||
|
||||
@@ -2,10 +2,18 @@ const express = require('express');
|
||||
const request = require('supertest');
|
||||
const mongoose = require('mongoose');
|
||||
const { v4: uuidv4 } = require('uuid');
|
||||
const { createMethods } = require('@librechat/data-schemas');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const { GLOBAL_PROJECT_NAME } = require('librechat-data-provider').Constants;
|
||||
const {
|
||||
SystemRoles,
|
||||
ResourceType,
|
||||
AccessRoleIds,
|
||||
PrincipalType,
|
||||
} = require('librechat-data-provider');
|
||||
const { createAgent } = require('~/models/Agent');
|
||||
const { createFile } = require('~/models/File');
|
||||
|
||||
// Mock dependencies
|
||||
// Only mock the external dependencies that we don't want to test
|
||||
jest.mock('~/server/services/Files/process', () => ({
|
||||
processDeleteRequest: jest.fn().mockResolvedValue({}),
|
||||
filterFile: jest.fn(),
|
||||
@@ -44,9 +52,6 @@ jest.mock('~/config', () => ({
|
||||
},
|
||||
}));
|
||||
|
||||
const { createFile } = require('~/models/File');
|
||||
const { createAgent } = require('~/models/Agent');
|
||||
const { getProjectByName } = require('~/models/Project');
|
||||
const { processDeleteRequest } = require('~/server/services/Files/process');
|
||||
|
||||
// Import the router after mocks
|
||||
@@ -57,22 +62,49 @@ describe('File Routes - Delete with Agent Access', () => {
|
||||
let mongoServer;
|
||||
let authorId;
|
||||
let otherUserId;
|
||||
let agentId;
|
||||
let fileId;
|
||||
let File;
|
||||
let Agent;
|
||||
let AclEntry;
|
||||
let User;
|
||||
let methods;
|
||||
let modelsToCleanup = [];
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
await mongoose.connect(mongoServer.getUri());
|
||||
const mongoUri = mongoServer.getUri();
|
||||
await mongoose.connect(mongoUri);
|
||||
|
||||
// Initialize models
|
||||
require('~/db/models');
|
||||
// Initialize all models using createModels
|
||||
const { createModels } = require('@librechat/data-schemas');
|
||||
const models = createModels(mongoose);
|
||||
|
||||
// Track which models we're adding
|
||||
modelsToCleanup = Object.keys(models);
|
||||
|
||||
// Register models on mongoose.models so methods can access them
|
||||
Object.assign(mongoose.models, models);
|
||||
|
||||
// Create methods with our test mongoose instance
|
||||
methods = createMethods(mongoose);
|
||||
|
||||
// Now we can access models from the db/models
|
||||
File = models.File;
|
||||
Agent = models.Agent;
|
||||
AclEntry = models.AclEntry;
|
||||
User = models.User;
|
||||
|
||||
// Seed default roles using our methods
|
||||
await methods.seedDefaultRoles();
|
||||
|
||||
app = express();
|
||||
app.use(express.json());
|
||||
|
||||
// Mock authentication middleware
|
||||
app.use((req, res, next) => {
|
||||
req.user = { id: otherUserId || 'default-user' };
|
||||
req.user = {
|
||||
id: otherUserId || 'default-user',
|
||||
role: SystemRoles.USER,
|
||||
};
|
||||
req.app = { locals: {} };
|
||||
next();
|
||||
});
|
||||
@@ -81,6 +113,19 @@ describe('File Routes - Delete with Agent Access', () => {
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
// Clean up all collections before disconnecting
|
||||
const collections = mongoose.connection.collections;
|
||||
for (const key in collections) {
|
||||
await collections[key].deleteMany({});
|
||||
}
|
||||
|
||||
// Clear only the models we added
|
||||
for (const modelName of modelsToCleanup) {
|
||||
if (mongoose.models[modelName]) {
|
||||
delete mongoose.models[modelName];
|
||||
}
|
||||
}
|
||||
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
});
|
||||
@@ -88,48 +133,40 @@ describe('File Routes - Delete with Agent Access', () => {
|
||||
beforeEach(async () => {
|
||||
jest.clearAllMocks();
|
||||
|
||||
// Clear database
|
||||
const collections = mongoose.connection.collections;
|
||||
for (const key in collections) {
|
||||
await collections[key].deleteMany({});
|
||||
}
|
||||
// Clear database - clean up all test data
|
||||
await File.deleteMany({});
|
||||
await Agent.deleteMany({});
|
||||
await User.deleteMany({});
|
||||
await AclEntry.deleteMany({});
|
||||
// Don't delete AccessRole as they are seeded defaults needed for tests
|
||||
|
||||
authorId = new mongoose.Types.ObjectId().toString();
|
||||
otherUserId = new mongoose.Types.ObjectId().toString();
|
||||
// Create test data
|
||||
authorId = new mongoose.Types.ObjectId();
|
||||
otherUserId = new mongoose.Types.ObjectId();
|
||||
fileId = uuidv4();
|
||||
|
||||
// Create users in database
|
||||
await User.create({
|
||||
_id: authorId,
|
||||
username: 'author',
|
||||
email: 'author@test.com',
|
||||
});
|
||||
|
||||
await User.create({
|
||||
_id: otherUserId,
|
||||
username: 'other',
|
||||
email: 'other@test.com',
|
||||
});
|
||||
|
||||
// Create a file owned by the author
|
||||
await createFile({
|
||||
user: authorId,
|
||||
file_id: fileId,
|
||||
filename: 'test.txt',
|
||||
filepath: `/uploads/${authorId}/${fileId}`,
|
||||
bytes: 1024,
|
||||
filepath: '/uploads/test.txt',
|
||||
bytes: 100,
|
||||
type: 'text/plain',
|
||||
});
|
||||
|
||||
// Create an agent with the file attached
|
||||
const agent = await createAgent({
|
||||
id: uuidv4(),
|
||||
name: 'Test Agent',
|
||||
author: authorId,
|
||||
model: 'gpt-4',
|
||||
provider: 'openai',
|
||||
isCollaborative: true,
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: [fileId],
|
||||
},
|
||||
},
|
||||
});
|
||||
agentId = agent.id;
|
||||
|
||||
// Share the agent globally
|
||||
const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME, '_id');
|
||||
if (globalProject) {
|
||||
const { updateAgent } = require('~/models/Agent');
|
||||
await updateAgent({ id: agentId }, { projectIds: [globalProject._id] });
|
||||
}
|
||||
});
|
||||
|
||||
describe('DELETE /files', () => {
|
||||
@@ -140,8 +177,8 @@ describe('File Routes - Delete with Agent Access', () => {
|
||||
user: otherUserId,
|
||||
file_id: userFileId,
|
||||
filename: 'user-file.txt',
|
||||
filepath: `/uploads/${otherUserId}/${userFileId}`,
|
||||
bytes: 1024,
|
||||
filepath: '/uploads/user-file.txt',
|
||||
bytes: 200,
|
||||
type: 'text/plain',
|
||||
});
|
||||
|
||||
@@ -151,7 +188,7 @@ describe('File Routes - Delete with Agent Access', () => {
|
||||
files: [
|
||||
{
|
||||
file_id: userFileId,
|
||||
filepath: `/uploads/${otherUserId}/${userFileId}`,
|
||||
filepath: '/uploads/user-file.txt',
|
||||
},
|
||||
],
|
||||
});
|
||||
@@ -168,7 +205,7 @@ describe('File Routes - Delete with Agent Access', () => {
|
||||
files: [
|
||||
{
|
||||
file_id: fileId,
|
||||
filepath: `/uploads/${authorId}/${fileId}`,
|
||||
filepath: '/uploads/test.txt',
|
||||
},
|
||||
],
|
||||
});
|
||||
@@ -180,14 +217,39 @@ describe('File Routes - Delete with Agent Access', () => {
|
||||
});
|
||||
|
||||
it('should allow deleting files accessible through shared agent', async () => {
|
||||
// Create an agent with the file attached
|
||||
const agent = await createAgent({
|
||||
id: uuidv4(),
|
||||
name: 'Test Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: authorId,
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: [fileId],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Grant EDIT permission to user on the agent
|
||||
const { grantPermission } = require('~/server/services/PermissionService');
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: otherUserId,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
accessRoleId: AccessRoleIds.AGENT_EDITOR,
|
||||
grantedBy: authorId,
|
||||
});
|
||||
|
||||
const response = await request(app)
|
||||
.delete('/files')
|
||||
.send({
|
||||
agent_id: agentId,
|
||||
agent_id: agent.id,
|
||||
files: [
|
||||
{
|
||||
file_id: fileId,
|
||||
filepath: `/uploads/${authorId}/${fileId}`,
|
||||
filepath: '/uploads/test.txt',
|
||||
},
|
||||
],
|
||||
});
|
||||
@@ -204,19 +266,44 @@ describe('File Routes - Delete with Agent Access', () => {
|
||||
user: authorId,
|
||||
file_id: unattachedFileId,
|
||||
filename: 'unattached.txt',
|
||||
filepath: `/uploads/${authorId}/${unattachedFileId}`,
|
||||
bytes: 1024,
|
||||
filepath: '/uploads/unattached.txt',
|
||||
bytes: 300,
|
||||
type: 'text/plain',
|
||||
});
|
||||
|
||||
// Create an agent without the unattached file
|
||||
const agent = await createAgent({
|
||||
id: uuidv4(),
|
||||
name: 'Test Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: authorId,
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: [fileId], // Only fileId, not unattachedFileId
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Grant EDIT permission to user on the agent
|
||||
const { grantPermission } = require('~/server/services/PermissionService');
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: otherUserId,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
accessRoleId: AccessRoleIds.AGENT_EDITOR,
|
||||
grantedBy: authorId,
|
||||
});
|
||||
|
||||
const response = await request(app)
|
||||
.delete('/files')
|
||||
.send({
|
||||
agent_id: agentId,
|
||||
agent_id: agent.id,
|
||||
files: [
|
||||
{
|
||||
file_id: unattachedFileId,
|
||||
filepath: `/uploads/${authorId}/${unattachedFileId}`,
|
||||
filepath: '/uploads/unattached.txt',
|
||||
},
|
||||
],
|
||||
});
|
||||
@@ -224,6 +311,7 @@ describe('File Routes - Delete with Agent Access', () => {
|
||||
expect(response.status).toBe(403);
|
||||
expect(response.body.message).toBe('You can only delete files you have access to');
|
||||
expect(response.body.unauthorizedFiles).toContain(unattachedFileId);
|
||||
expect(processDeleteRequest).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle mixed authorized and unauthorized files', async () => {
|
||||
@@ -233,8 +321,8 @@ describe('File Routes - Delete with Agent Access', () => {
|
||||
user: otherUserId,
|
||||
file_id: userFileId,
|
||||
filename: 'user-file.txt',
|
||||
filepath: `/uploads/${otherUserId}/${userFileId}`,
|
||||
bytes: 1024,
|
||||
filepath: '/uploads/user-file.txt',
|
||||
bytes: 200,
|
||||
type: 'text/plain',
|
||||
});
|
||||
|
||||
@@ -244,51 +332,87 @@ describe('File Routes - Delete with Agent Access', () => {
|
||||
user: authorId,
|
||||
file_id: unauthorizedFileId,
|
||||
filename: 'unauthorized.txt',
|
||||
filepath: `/uploads/${authorId}/${unauthorizedFileId}`,
|
||||
bytes: 1024,
|
||||
filepath: '/uploads/unauthorized.txt',
|
||||
bytes: 400,
|
||||
type: 'text/plain',
|
||||
});
|
||||
|
||||
// Create an agent with only fileId attached
|
||||
const agent = await createAgent({
|
||||
id: uuidv4(),
|
||||
name: 'Test Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: authorId,
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: [fileId],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Grant EDIT permission to user on the agent
|
||||
const { grantPermission } = require('~/server/services/PermissionService');
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: otherUserId,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
accessRoleId: AccessRoleIds.AGENT_EDITOR,
|
||||
grantedBy: authorId,
|
||||
});
|
||||
|
||||
const response = await request(app)
|
||||
.delete('/files')
|
||||
.send({
|
||||
agent_id: agentId,
|
||||
agent_id: agent.id,
|
||||
files: [
|
||||
{
|
||||
file_id: fileId, // Authorized through agent
|
||||
filepath: `/uploads/${authorId}/${fileId}`,
|
||||
},
|
||||
{
|
||||
file_id: userFileId, // Owned by user
|
||||
filepath: `/uploads/${otherUserId}/${userFileId}`,
|
||||
},
|
||||
{
|
||||
file_id: unauthorizedFileId, // Not authorized
|
||||
filepath: `/uploads/${authorId}/${unauthorizedFileId}`,
|
||||
},
|
||||
{ file_id: userFileId, filepath: '/uploads/user-file.txt' },
|
||||
{ file_id: fileId, filepath: '/uploads/test.txt' },
|
||||
{ file_id: unauthorizedFileId, filepath: '/uploads/unauthorized.txt' },
|
||||
],
|
||||
});
|
||||
|
||||
expect(response.status).toBe(403);
|
||||
expect(response.body.message).toBe('You can only delete files you have access to');
|
||||
expect(response.body.unauthorizedFiles).toContain(unauthorizedFileId);
|
||||
expect(response.body.unauthorizedFiles).not.toContain(fileId);
|
||||
expect(response.body.unauthorizedFiles).not.toContain(userFileId);
|
||||
expect(processDeleteRequest).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should prevent deleting files when agent is not collaborative', async () => {
|
||||
// Update the agent to be non-collaborative
|
||||
const { updateAgent } = require('~/models/Agent');
|
||||
await updateAgent({ id: agentId }, { isCollaborative: false });
|
||||
it('should prevent deleting files when user lacks EDIT permission on agent', async () => {
|
||||
// Create an agent with the file attached
|
||||
const agent = await createAgent({
|
||||
id: uuidv4(),
|
||||
name: 'Test Agent',
|
||||
provider: 'openai',
|
||||
model: 'gpt-4',
|
||||
author: authorId,
|
||||
tool_resources: {
|
||||
file_search: {
|
||||
file_ids: [fileId],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Grant only VIEW permission to user on the agent
|
||||
const { grantPermission } = require('~/server/services/PermissionService');
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: otherUserId,
|
||||
resourceType: ResourceType.AGENT,
|
||||
resourceId: agent._id,
|
||||
accessRoleId: AccessRoleIds.AGENT_VIEWER,
|
||||
grantedBy: authorId,
|
||||
});
|
||||
|
||||
const response = await request(app)
|
||||
.delete('/files')
|
||||
.send({
|
||||
agent_id: agentId,
|
||||
agent_id: agent.id,
|
||||
files: [
|
||||
{
|
||||
file_id: fileId,
|
||||
filepath: `/uploads/${authorId}/${fileId}`,
|
||||
filepath: '/uploads/test.txt',
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
const accessPermissions = require('./accessPermissions');
|
||||
const assistants = require('./assistants');
|
||||
const categories = require('./categories');
|
||||
const tokenizer = require('./tokenizer');
|
||||
@@ -28,6 +29,7 @@ const user = require('./user');
|
||||
const mcp = require('./mcp');
|
||||
|
||||
module.exports = {
|
||||
mcp,
|
||||
edit,
|
||||
auth,
|
||||
keys,
|
||||
@@ -55,5 +57,5 @@ module.exports = {
|
||||
assistants,
|
||||
categories,
|
||||
staticRoute,
|
||||
mcp,
|
||||
accessPermissions,
|
||||
};
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
const { Router } = require('express');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { MCPOAuthHandler } = require('@librechat/api');
|
||||
const { CacheKeys, Constants } = require('librechat-data-provider');
|
||||
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
|
||||
const { setCachedTools, getCachedTools, loadCustomConfig } = require('~/server/services/Config');
|
||||
const { MCPOAuthHandler, getUserMCPAuthMap } = require('@librechat/api');
|
||||
const { getMCPSetupData, getServerConnectionStatus } = require('~/server/services/MCP');
|
||||
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
|
||||
const { updateMCPUserTools } = require('~/server/services/Config/mcpToolsCache');
|
||||
const { getUserPluginAuthValue } = require('~/server/services/PluginService');
|
||||
const { CacheKeys, Constants } = require('librechat-data-provider');
|
||||
const { getMCPManager, getFlowStateManager } = require('~/config');
|
||||
const { reinitMCPServer } = require('~/server/services/Tools/mcp');
|
||||
const { requireJwtAuth } = require('~/server/middleware');
|
||||
const { findPluginAuthsByKeys } = require('~/models');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
const router = Router();
|
||||
@@ -142,33 +144,12 @@ router.get('/:serverName/oauth/callback', async (req, res) => {
|
||||
`[MCP OAuth] Successfully reconnected ${serverName} for user ${flowState.userId}`,
|
||||
);
|
||||
|
||||
const userTools = (await getCachedTools({ userId: flowState.userId })) || {};
|
||||
|
||||
const mcpDelimiter = Constants.mcp_delimiter;
|
||||
for (const key of Object.keys(userTools)) {
|
||||
if (key.endsWith(`${mcpDelimiter}${serverName}`)) {
|
||||
delete userTools[key];
|
||||
}
|
||||
}
|
||||
|
||||
const tools = await userConnection.fetchTools();
|
||||
for (const tool of tools) {
|
||||
const name = `${tool.name}${Constants.mcp_delimiter}${serverName}`;
|
||||
userTools[name] = {
|
||||
type: 'function',
|
||||
['function']: {
|
||||
name,
|
||||
description: tool.description,
|
||||
parameters: tool.inputSchema,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
await setCachedTools(userTools, { userId: flowState.userId });
|
||||
|
||||
logger.debug(
|
||||
`[MCP OAuth] Cached ${tools.length} tools for ${serverName} user ${flowState.userId}`,
|
||||
);
|
||||
await updateMCPUserTools({
|
||||
userId: flowState.userId,
|
||||
serverName,
|
||||
tools,
|
||||
});
|
||||
} else {
|
||||
logger.debug(`[MCP OAuth] System-level OAuth completed for ${serverName}`);
|
||||
}
|
||||
@@ -315,133 +296,47 @@ router.post('/:serverName/reinitialize', requireJwtAuth, async (req, res) => {
|
||||
|
||||
logger.info(`[MCP Reinitialize] Reinitializing server: ${serverName}`);
|
||||
|
||||
const printConfig = false;
|
||||
const config = await loadCustomConfig(printConfig);
|
||||
if (!config || !config.mcpServers || !config.mcpServers[serverName]) {
|
||||
const mcpManager = getMCPManager();
|
||||
const serverConfig = mcpManager.getRawConfig(serverName);
|
||||
if (!serverConfig) {
|
||||
return res.status(404).json({
|
||||
error: `MCP server '${serverName}' not found in configuration`,
|
||||
});
|
||||
}
|
||||
|
||||
const flowsCache = getLogStores(CacheKeys.FLOWS);
|
||||
const flowManager = getFlowStateManager(flowsCache);
|
||||
const mcpManager = getMCPManager();
|
||||
|
||||
await mcpManager.disconnectServer(serverName);
|
||||
logger.info(`[MCP Reinitialize] Disconnected existing server: ${serverName}`);
|
||||
|
||||
const serverConfig = config.mcpServers[serverName];
|
||||
mcpManager.mcpConfigs[serverName] = serverConfig;
|
||||
let customUserVars = {};
|
||||
if (serverConfig.customUserVars && typeof serverConfig.customUserVars === 'object') {
|
||||
for (const varName of Object.keys(serverConfig.customUserVars)) {
|
||||
try {
|
||||
const value = await getUserPluginAuthValue(user.id, varName, false);
|
||||
customUserVars[varName] = value;
|
||||
} catch (err) {
|
||||
logger.error(`[MCP Reinitialize] Error fetching ${varName} for user ${user.id}:`, err);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let userConnection = null;
|
||||
let oauthRequired = false;
|
||||
let oauthUrl = null;
|
||||
|
||||
try {
|
||||
userConnection = await mcpManager.getUserConnection({
|
||||
user,
|
||||
serverName,
|
||||
flowManager,
|
||||
customUserVars,
|
||||
tokenMethods: {
|
||||
findToken,
|
||||
updateToken,
|
||||
createToken,
|
||||
deleteTokens,
|
||||
},
|
||||
returnOnOAuth: true,
|
||||
oauthStart: async (authURL) => {
|
||||
logger.info(`[MCP Reinitialize] OAuth URL received: ${authURL}`);
|
||||
oauthUrl = authURL;
|
||||
oauthRequired = true;
|
||||
},
|
||||
});
|
||||
|
||||
logger.info(`[MCP Reinitialize] Successfully established connection for ${serverName}`);
|
||||
} catch (err) {
|
||||
logger.info(`[MCP Reinitialize] getUserConnection threw error: ${err.message}`);
|
||||
logger.info(
|
||||
`[MCP Reinitialize] OAuth state - oauthRequired: ${oauthRequired}, oauthUrl: ${oauthUrl ? 'present' : 'null'}`,
|
||||
);
|
||||
|
||||
const isOAuthError =
|
||||
err.message?.includes('OAuth') ||
|
||||
err.message?.includes('authentication') ||
|
||||
err.message?.includes('401');
|
||||
|
||||
const isOAuthFlowInitiated = err.message === 'OAuth flow initiated - return early';
|
||||
|
||||
if (isOAuthError || oauthRequired || isOAuthFlowInitiated) {
|
||||
logger.info(
|
||||
`[MCP Reinitialize] OAuth required for ${serverName} (isOAuthError: ${isOAuthError}, oauthRequired: ${oauthRequired}, isOAuthFlowInitiated: ${isOAuthFlowInitiated})`,
|
||||
);
|
||||
oauthRequired = true;
|
||||
} else {
|
||||
logger.error(
|
||||
`[MCP Reinitialize] Error initializing MCP server ${serverName} for user:`,
|
||||
err,
|
||||
);
|
||||
return res.status(500).json({ error: 'Failed to reinitialize MCP server for user' });
|
||||
}
|
||||
}
|
||||
|
||||
if (userConnection && !oauthRequired) {
|
||||
const userTools = (await getCachedTools({ userId: user.id })) || {};
|
||||
|
||||
const mcpDelimiter = Constants.mcp_delimiter;
|
||||
for (const key of Object.keys(userTools)) {
|
||||
if (key.endsWith(`${mcpDelimiter}${serverName}`)) {
|
||||
delete userTools[key];
|
||||
}
|
||||
}
|
||||
|
||||
const tools = await userConnection.fetchTools();
|
||||
for (const tool of tools) {
|
||||
const name = `${tool.name}${Constants.mcp_delimiter}${serverName}`;
|
||||
userTools[name] = {
|
||||
type: 'function',
|
||||
['function']: {
|
||||
name,
|
||||
description: tool.description,
|
||||
parameters: tool.inputSchema,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
await setCachedTools(userTools, { userId: user.id });
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
`[MCP Reinitialize] Sending response for ${serverName} - oauthRequired: ${oauthRequired}, oauthUrl: ${oauthUrl ? 'present' : 'null'}`,
|
||||
await mcpManager.disconnectUserConnection(user.id, serverName);
|
||||
logger.info(
|
||||
`[MCP Reinitialize] Disconnected existing user connection for server: ${serverName}`,
|
||||
);
|
||||
|
||||
const getResponseMessage = () => {
|
||||
if (oauthRequired) {
|
||||
return `MCP server '${serverName}' ready for OAuth authentication`;
|
||||
}
|
||||
if (userConnection) {
|
||||
return `MCP server '${serverName}' reinitialized successfully`;
|
||||
}
|
||||
return `Failed to reinitialize MCP server '${serverName}'`;
|
||||
};
|
||||
/** @type {Record<string, Record<string, string>> | undefined} */
|
||||
let userMCPAuthMap;
|
||||
if (serverConfig.customUserVars && typeof serverConfig.customUserVars === 'object') {
|
||||
userMCPAuthMap = await getUserMCPAuthMap({
|
||||
userId: user.id,
|
||||
servers: [serverName],
|
||||
findPluginAuthsByKeys,
|
||||
});
|
||||
}
|
||||
|
||||
const result = await reinitMCPServer({
|
||||
req,
|
||||
serverName,
|
||||
userMCPAuthMap,
|
||||
});
|
||||
|
||||
if (!result) {
|
||||
return res.status(500).json({ error: 'Failed to reinitialize MCP server for user' });
|
||||
}
|
||||
|
||||
const { success, message, oauthRequired, oauthUrl } = result;
|
||||
|
||||
res.json({
|
||||
success: (userConnection && !oauthRequired) || (oauthRequired && oauthUrl),
|
||||
message: getResponseMessage(),
|
||||
success,
|
||||
message,
|
||||
oauthUrl,
|
||||
serverName,
|
||||
oauthRequired,
|
||||
oauthUrl,
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('[MCP Reinitialize] Unexpected error', error);
|
||||
@@ -551,15 +446,14 @@ router.get('/:serverName/auth-values', requireJwtAuth, async (req, res) => {
|
||||
return res.status(401).json({ error: 'User not authenticated' });
|
||||
}
|
||||
|
||||
const printConfig = false;
|
||||
const config = await loadCustomConfig(printConfig);
|
||||
if (!config || !config.mcpServers || !config.mcpServers[serverName]) {
|
||||
const mcpManager = getMCPManager();
|
||||
const serverConfig = mcpManager.getRawConfig(serverName);
|
||||
if (!serverConfig) {
|
||||
return res.status(404).json({
|
||||
error: `MCP server '${serverName}' not found in configuration`,
|
||||
});
|
||||
}
|
||||
|
||||
const serverConfig = config.mcpServers[serverName];
|
||||
const pluginKey = `${Constants.mcp_prefix}${serverName}`;
|
||||
const authValueFlags = {};
|
||||
|
||||
|
||||
@@ -1,18 +1,20 @@
|
||||
// file deepcode ignore NoRateLimitingForLogin: Rate limiting is handled by the `loginLimiter` middleware
|
||||
const express = require('express');
|
||||
const passport = require('passport');
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
const { randomState } = require('openid-client');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { ErrorTypes } = require('librechat-data-provider');
|
||||
const {
|
||||
checkBan,
|
||||
logHeaders,
|
||||
loginLimiter,
|
||||
setBalanceConfig,
|
||||
checkDomainAllowed,
|
||||
} = require('~/server/middleware');
|
||||
const { isEnabled, createSetBalanceConfig } = require('@librechat/api');
|
||||
const { checkDomainAllowed, loginLimiter, logHeaders, checkBan } = require('~/server/middleware');
|
||||
const { syncUserEntraGroupMemberships } = require('~/server/services/PermissionService');
|
||||
const { setAuthTokens, setOpenIDAuthTokens } = require('~/server/services/AuthService');
|
||||
const { getBalanceConfig } = require('~/server/services/Config');
|
||||
const { Balance } = require('~/db/models');
|
||||
|
||||
const setBalanceConfig = createSetBalanceConfig({
|
||||
getBalanceConfig,
|
||||
Balance,
|
||||
});
|
||||
|
||||
const router = express.Router();
|
||||
|
||||
@@ -36,6 +38,7 @@ const oauthHandler = async (req, res) => {
|
||||
req.user.provider == 'openid' &&
|
||||
isEnabled(process.env.OPENID_REUSE_TOKENS) === true
|
||||
) {
|
||||
await syncUserEntraGroupMemberships(req.user, req.user.tokenset.access_token);
|
||||
setOpenIDAuthTokens(req.user.tokenset, res);
|
||||
} else {
|
||||
await setAuthTokens(req.user._id, res);
|
||||
|
||||
@@ -1,22 +1,45 @@
|
||||
const express = require('express');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { generateCheckAccess } = require('@librechat/api');
|
||||
const { Permissions, SystemRoles, PermissionTypes } = require('librechat-data-provider');
|
||||
const {
|
||||
getPrompt,
|
||||
getPrompts,
|
||||
savePrompt,
|
||||
deletePrompt,
|
||||
getPromptGroup,
|
||||
getPromptGroups,
|
||||
generateCheckAccess,
|
||||
markPublicPromptGroups,
|
||||
buildPromptGroupFilter,
|
||||
formatPromptGroupsResponse,
|
||||
createEmptyPromptGroupsResponse,
|
||||
filterAccessibleIdsBySharedLogic,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
Permissions,
|
||||
SystemRoles,
|
||||
ResourceType,
|
||||
AccessRoleIds,
|
||||
PrincipalType,
|
||||
PermissionBits,
|
||||
PermissionTypes,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
getListPromptGroupsByAccess,
|
||||
makePromptProduction,
|
||||
updatePromptGroup,
|
||||
deletePromptGroup,
|
||||
createPromptGroup,
|
||||
getAllPromptGroups,
|
||||
// updatePromptLabels,
|
||||
makePromptProduction,
|
||||
getPromptGroup,
|
||||
deletePrompt,
|
||||
getPrompts,
|
||||
savePrompt,
|
||||
getPrompt,
|
||||
} = require('~/models/Prompt');
|
||||
const { requireJwtAuth } = require('~/server/middleware');
|
||||
const {
|
||||
canAccessPromptGroupResource,
|
||||
canAccessPromptViaGroup,
|
||||
requireJwtAuth,
|
||||
} = require('~/server/middleware');
|
||||
const {
|
||||
findPubliclyAccessibleResources,
|
||||
getEffectivePermissions,
|
||||
findAccessibleResources,
|
||||
grantPermission,
|
||||
} = require('~/server/services/PermissionService');
|
||||
const { getRoleByName } = require('~/models/Role');
|
||||
|
||||
const router = express.Router();
|
||||
@@ -48,43 +71,78 @@ router.use(checkPromptAccess);
|
||||
* Route to get single prompt group by its ID
|
||||
* GET /groups/:groupId
|
||||
*/
|
||||
router.get('/groups/:groupId', async (req, res) => {
|
||||
let groupId = req.params.groupId;
|
||||
const author = req.user.id;
|
||||
router.get(
|
||||
'/groups/:groupId',
|
||||
canAccessPromptGroupResource({
|
||||
requiredPermission: PermissionBits.VIEW,
|
||||
}),
|
||||
async (req, res) => {
|
||||
const { groupId } = req.params;
|
||||
|
||||
const query = {
|
||||
_id: groupId,
|
||||
$or: [{ projectIds: { $exists: true, $ne: [], $not: { $size: 0 } } }, { author }],
|
||||
};
|
||||
try {
|
||||
const group = await getPromptGroup({ _id: groupId });
|
||||
|
||||
if (req.user.role === SystemRoles.ADMIN) {
|
||||
delete query.$or;
|
||||
}
|
||||
if (!group) {
|
||||
return res.status(404).send({ message: 'Prompt group not found' });
|
||||
}
|
||||
|
||||
try {
|
||||
const group = await getPromptGroup(query);
|
||||
|
||||
if (!group) {
|
||||
return res.status(404).send({ message: 'Prompt group not found' });
|
||||
res.status(200).send(group);
|
||||
} catch (error) {
|
||||
logger.error('Error getting prompt group', error);
|
||||
res.status(500).send({ message: 'Error getting prompt group' });
|
||||
}
|
||||
|
||||
res.status(200).send(group);
|
||||
} catch (error) {
|
||||
logger.error('Error getting prompt group', error);
|
||||
res.status(500).send({ message: 'Error getting prompt group' });
|
||||
}
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
/**
|
||||
* Route to fetch all prompt groups
|
||||
* GET /groups
|
||||
* Route to fetch all prompt groups (ACL-aware)
|
||||
* GET /all
|
||||
*/
|
||||
router.get('/all', async (req, res) => {
|
||||
try {
|
||||
const groups = await getAllPromptGroups(req, {
|
||||
author: req.user._id,
|
||||
const userId = req.user.id;
|
||||
const { name, category, ...otherFilters } = req.query;
|
||||
const { filter, searchShared, searchSharedOnly } = buildPromptGroupFilter({
|
||||
name,
|
||||
category,
|
||||
...otherFilters,
|
||||
});
|
||||
res.status(200).send(groups);
|
||||
|
||||
let accessibleIds = await findAccessibleResources({
|
||||
userId,
|
||||
role: req.user.role,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
requiredPermissions: PermissionBits.VIEW,
|
||||
});
|
||||
|
||||
const publiclyAccessibleIds = await findPubliclyAccessibleResources({
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
requiredPermissions: PermissionBits.VIEW,
|
||||
});
|
||||
|
||||
const filteredAccessibleIds = await filterAccessibleIdsBySharedLogic({
|
||||
accessibleIds,
|
||||
searchShared,
|
||||
searchSharedOnly,
|
||||
publicPromptGroupIds: publiclyAccessibleIds,
|
||||
});
|
||||
|
||||
const result = await getListPromptGroupsByAccess({
|
||||
accessibleIds: filteredAccessibleIds,
|
||||
otherParams: filter,
|
||||
});
|
||||
|
||||
if (!result) {
|
||||
return res.status(200).send([]);
|
||||
}
|
||||
|
||||
const { data: promptGroups = [] } = result;
|
||||
if (!promptGroups.length) {
|
||||
return res.status(200).send([]);
|
||||
}
|
||||
|
||||
const groupsWithPublicFlag = markPublicPromptGroups(promptGroups, publiclyAccessibleIds);
|
||||
res.status(200).send(groupsWithPublicFlag);
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
res.status(500).send({ error: 'Error getting prompt groups' });
|
||||
@@ -92,16 +150,72 @@ router.get('/all', async (req, res) => {
|
||||
});
|
||||
|
||||
/**
|
||||
* Route to fetch paginated prompt groups with filters
|
||||
* Route to fetch paginated prompt groups with filters (ACL-aware)
|
||||
* GET /groups
|
||||
*/
|
||||
router.get('/groups', async (req, res) => {
|
||||
try {
|
||||
const filter = req.query;
|
||||
/* Note: The aggregation requires an ObjectId */
|
||||
filter.author = req.user._id;
|
||||
const groups = await getPromptGroups(req, filter);
|
||||
res.status(200).send(groups);
|
||||
const userId = req.user.id;
|
||||
const { pageSize, pageNumber, limit, cursor, name, category, ...otherFilters } = req.query;
|
||||
|
||||
const { filter, searchShared, searchSharedOnly } = buildPromptGroupFilter({
|
||||
name,
|
||||
category,
|
||||
...otherFilters,
|
||||
});
|
||||
|
||||
let actualLimit = limit;
|
||||
let actualCursor = cursor;
|
||||
|
||||
if (pageSize && !limit) {
|
||||
actualLimit = parseInt(pageSize, 10);
|
||||
}
|
||||
|
||||
let accessibleIds = await findAccessibleResources({
|
||||
userId,
|
||||
role: req.user.role,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
requiredPermissions: PermissionBits.VIEW,
|
||||
});
|
||||
|
||||
const publiclyAccessibleIds = await findPubliclyAccessibleResources({
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
requiredPermissions: PermissionBits.VIEW,
|
||||
});
|
||||
|
||||
const filteredAccessibleIds = await filterAccessibleIdsBySharedLogic({
|
||||
accessibleIds,
|
||||
searchShared,
|
||||
searchSharedOnly,
|
||||
publicPromptGroupIds: publiclyAccessibleIds,
|
||||
});
|
||||
|
||||
const result = await getListPromptGroupsByAccess({
|
||||
accessibleIds: filteredAccessibleIds,
|
||||
otherParams: filter,
|
||||
limit: actualLimit,
|
||||
after: actualCursor,
|
||||
});
|
||||
|
||||
if (!result) {
|
||||
const emptyResponse = createEmptyPromptGroupsResponse({ pageNumber, pageSize, actualLimit });
|
||||
return res.status(200).send(emptyResponse);
|
||||
}
|
||||
|
||||
const { data: promptGroups = [], has_more = false, after = null } = result;
|
||||
|
||||
const groupsWithPublicFlag = markPublicPromptGroups(promptGroups, publiclyAccessibleIds);
|
||||
|
||||
const response = formatPromptGroupsResponse({
|
||||
promptGroups: groupsWithPublicFlag,
|
||||
pageNumber,
|
||||
pageSize,
|
||||
actualLimit,
|
||||
hasMore: has_more,
|
||||
after,
|
||||
});
|
||||
|
||||
res.status(200).send(response);
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
res.status(500).send({ error: 'Error getting prompt groups' });
|
||||
@@ -109,16 +223,17 @@ router.get('/groups', async (req, res) => {
|
||||
});
|
||||
|
||||
/**
|
||||
* Updates or creates a prompt + promptGroup
|
||||
* Creates a new prompt group with initial prompt
|
||||
* @param {object} req
|
||||
* @param {TCreatePrompt} req.body
|
||||
* @param {Express.Response} res
|
||||
*/
|
||||
const createPrompt = async (req, res) => {
|
||||
const createNewPromptGroup = async (req, res) => {
|
||||
try {
|
||||
const { prompt, group } = req.body;
|
||||
if (!prompt) {
|
||||
return res.status(400).send({ error: 'Prompt is required' });
|
||||
|
||||
if (!prompt || !group || !group.name) {
|
||||
return res.status(400).send({ error: 'Prompt and group name are required' });
|
||||
}
|
||||
|
||||
const saveData = {
|
||||
@@ -128,21 +243,80 @@ const createPrompt = async (req, res) => {
|
||||
authorName: req.user.name,
|
||||
};
|
||||
|
||||
/** @type {TCreatePromptResponse} */
|
||||
let result;
|
||||
if (group && group.name) {
|
||||
result = await createPromptGroup(saveData);
|
||||
} else {
|
||||
result = await savePrompt(saveData);
|
||||
const result = await createPromptGroup(saveData);
|
||||
|
||||
if (result.prompt && result.prompt._id && result.prompt.groupId) {
|
||||
try {
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: req.user.id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: result.prompt.groupId,
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
|
||||
grantedBy: req.user.id,
|
||||
});
|
||||
logger.debug(
|
||||
`[createPromptGroup] Granted owner permissions to user ${req.user.id} for promptGroup ${result.prompt.groupId}`,
|
||||
);
|
||||
} catch (permissionError) {
|
||||
logger.error(
|
||||
`[createPromptGroup] Failed to grant owner permissions for promptGroup ${result.prompt.groupId}:`,
|
||||
permissionError,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
res.status(200).send(result);
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
res.status(500).send({ error: 'Error saving prompt' });
|
||||
res.status(500).send({ error: 'Error creating prompt group' });
|
||||
}
|
||||
};
|
||||
|
||||
router.post('/', checkPromptCreate, createPrompt);
|
||||
/**
|
||||
* Adds a new prompt to an existing prompt group
|
||||
* @param {object} req
|
||||
* @param {TCreatePrompt} req.body
|
||||
* @param {Express.Response} res
|
||||
*/
|
||||
const addPromptToGroup = async (req, res) => {
|
||||
try {
|
||||
const { groupId } = req.params;
|
||||
const { prompt } = req.body;
|
||||
|
||||
if (!prompt) {
|
||||
return res.status(400).send({ error: 'Prompt is required' });
|
||||
}
|
||||
|
||||
// Ensure the prompt is associated with the correct group
|
||||
prompt.groupId = groupId;
|
||||
|
||||
const saveData = {
|
||||
prompt,
|
||||
author: req.user.id,
|
||||
authorName: req.user.name,
|
||||
};
|
||||
|
||||
const result = await savePrompt(saveData);
|
||||
res.status(200).send(result);
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
res.status(500).send({ error: 'Error adding prompt to group' });
|
||||
}
|
||||
};
|
||||
|
||||
// Create new prompt group (requires CREATE permission)
|
||||
router.post('/', checkPromptCreate, createNewPromptGroup);
|
||||
|
||||
// Add prompt to existing group (requires EDIT permission on the group)
|
||||
router.post(
|
||||
'/groups/:groupId/prompts',
|
||||
checkPromptAccess,
|
||||
canAccessPromptGroupResource({
|
||||
requiredPermission: PermissionBits.EDIT,
|
||||
}),
|
||||
addPromptToGroup,
|
||||
);
|
||||
|
||||
/**
|
||||
* Updates a prompt group
|
||||
@@ -168,35 +342,74 @@ const patchPromptGroup = async (req, res) => {
|
||||
}
|
||||
};
|
||||
|
||||
router.patch('/groups/:groupId', checkGlobalPromptShare, patchPromptGroup);
|
||||
router.patch(
|
||||
'/groups/:groupId',
|
||||
checkGlobalPromptShare,
|
||||
canAccessPromptGroupResource({
|
||||
requiredPermission: PermissionBits.EDIT,
|
||||
}),
|
||||
patchPromptGroup,
|
||||
);
|
||||
|
||||
router.patch('/:promptId/tags/production', checkPromptCreate, async (req, res) => {
|
||||
try {
|
||||
router.patch(
|
||||
'/:promptId/tags/production',
|
||||
checkPromptCreate,
|
||||
canAccessPromptViaGroup({
|
||||
requiredPermission: PermissionBits.EDIT,
|
||||
resourceIdParam: 'promptId',
|
||||
}),
|
||||
async (req, res) => {
|
||||
try {
|
||||
const { promptId } = req.params;
|
||||
const result = await makePromptProduction(promptId);
|
||||
res.status(200).send(result);
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
res.status(500).send({ error: 'Error updating prompt production' });
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
router.get(
|
||||
'/:promptId',
|
||||
canAccessPromptViaGroup({
|
||||
requiredPermission: PermissionBits.VIEW,
|
||||
resourceIdParam: 'promptId',
|
||||
}),
|
||||
async (req, res) => {
|
||||
const { promptId } = req.params;
|
||||
const result = await makePromptProduction(promptId);
|
||||
res.status(200).send(result);
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
res.status(500).send({ error: 'Error updating prompt production' });
|
||||
}
|
||||
});
|
||||
|
||||
router.get('/:promptId', async (req, res) => {
|
||||
const { promptId } = req.params;
|
||||
const author = req.user.id;
|
||||
const query = { _id: promptId, author };
|
||||
if (req.user.role === SystemRoles.ADMIN) {
|
||||
delete query.author;
|
||||
}
|
||||
const prompt = await getPrompt(query);
|
||||
res.status(200).send(prompt);
|
||||
});
|
||||
const prompt = await getPrompt({ _id: promptId });
|
||||
res.status(200).send(prompt);
|
||||
},
|
||||
);
|
||||
|
||||
router.get('/', async (req, res) => {
|
||||
try {
|
||||
const author = req.user.id;
|
||||
const { groupId } = req.query;
|
||||
const query = { groupId, author };
|
||||
|
||||
// If requesting prompts for a specific group, check permissions
|
||||
if (groupId) {
|
||||
const permissions = await getEffectivePermissions({
|
||||
userId: req.user.id,
|
||||
role: req.user.role,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: groupId,
|
||||
});
|
||||
|
||||
if (!(permissions & PermissionBits.VIEW)) {
|
||||
return res
|
||||
.status(403)
|
||||
.send({ error: 'Insufficient permissions to view prompts in this group' });
|
||||
}
|
||||
|
||||
// If user has access, fetch all prompts in the group (not just their own)
|
||||
const prompts = await getPrompts({ groupId });
|
||||
return res.status(200).send(prompts);
|
||||
}
|
||||
|
||||
// If no groupId, return user's own prompts
|
||||
const query = { author };
|
||||
if (req.user.role === SystemRoles.ADMIN) {
|
||||
delete query.author;
|
||||
}
|
||||
@@ -240,7 +453,8 @@ const deletePromptController = async (req, res) => {
|
||||
const deletePromptGroupController = async (req, res) => {
|
||||
try {
|
||||
const { groupId: _id } = req.params;
|
||||
const message = await deletePromptGroup({ _id, author: req.user.id, role: req.user.role });
|
||||
// Don't pass author - permissions are now checked by middleware
|
||||
const message = await deletePromptGroup({ _id, role: req.user.role });
|
||||
res.send(message);
|
||||
} catch (error) {
|
||||
logger.error('Error deleting prompt group', error);
|
||||
@@ -248,7 +462,22 @@ const deletePromptGroupController = async (req, res) => {
|
||||
}
|
||||
};
|
||||
|
||||
router.delete('/:promptId', checkPromptCreate, deletePromptController);
|
||||
router.delete('/groups/:groupId', checkPromptCreate, deletePromptGroupController);
|
||||
router.delete(
|
||||
'/:promptId',
|
||||
checkPromptCreate,
|
||||
canAccessPromptViaGroup({
|
||||
requiredPermission: PermissionBits.DELETE,
|
||||
resourceIdParam: 'promptId',
|
||||
}),
|
||||
deletePromptController,
|
||||
);
|
||||
router.delete(
|
||||
'/groups/:groupId',
|
||||
checkPromptCreate,
|
||||
canAccessPromptGroupResource({
|
||||
requiredPermission: PermissionBits.DELETE,
|
||||
}),
|
||||
deletePromptGroupController,
|
||||
);
|
||||
|
||||
module.exports = router;
|
||||
|
||||
614
api/server/routes/prompts.test.js
Normal file
614
api/server/routes/prompts.test.js
Normal file
@@ -0,0 +1,614 @@
|
||||
const express = require('express');
|
||||
const request = require('supertest');
|
||||
const mongoose = require('mongoose');
|
||||
const { ObjectId } = require('mongodb');
|
||||
const { MongoMemoryServer } = require('mongodb-memory-server');
|
||||
const {
|
||||
SystemRoles,
|
||||
ResourceType,
|
||||
AccessRoleIds,
|
||||
PrincipalType,
|
||||
PermissionBits,
|
||||
} = require('librechat-data-provider');
|
||||
|
||||
// Mock modules before importing
|
||||
jest.mock('~/server/services/Config', () => ({
|
||||
getCachedTools: jest.fn().mockResolvedValue({}),
|
||||
getCustomConfig: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/models/Role', () => ({
|
||||
getRoleByName: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('~/server/middleware', () => ({
|
||||
requireJwtAuth: (req, res, next) => next(),
|
||||
canAccessPromptViaGroup: jest.requireActual('~/server/middleware').canAccessPromptViaGroup,
|
||||
canAccessPromptGroupResource:
|
||||
jest.requireActual('~/server/middleware').canAccessPromptGroupResource,
|
||||
}));
|
||||
|
||||
let app;
|
||||
let mongoServer;
|
||||
let promptRoutes;
|
||||
let Prompt, PromptGroup, AclEntry, AccessRole, User;
|
||||
let testUsers, testRoles;
|
||||
let grantPermission;
|
||||
|
||||
// Helper function to set user in middleware
|
||||
function setTestUser(app, user) {
|
||||
app.use((req, res, next) => {
|
||||
req.user = {
|
||||
...(user.toObject ? user.toObject() : user),
|
||||
id: user.id || user._id.toString(),
|
||||
_id: user._id,
|
||||
name: user.name,
|
||||
role: user.role,
|
||||
};
|
||||
if (user.role === SystemRoles.ADMIN) {
|
||||
console.log('Setting admin user with role:', req.user.role);
|
||||
}
|
||||
next();
|
||||
});
|
||||
}
|
||||
|
||||
beforeAll(async () => {
|
||||
mongoServer = await MongoMemoryServer.create();
|
||||
const mongoUri = mongoServer.getUri();
|
||||
await mongoose.connect(mongoUri);
|
||||
|
||||
// Initialize models
|
||||
const dbModels = require('~/db/models');
|
||||
Prompt = dbModels.Prompt;
|
||||
PromptGroup = dbModels.PromptGroup;
|
||||
AclEntry = dbModels.AclEntry;
|
||||
AccessRole = dbModels.AccessRole;
|
||||
User = dbModels.User;
|
||||
|
||||
// Import permission service
|
||||
const permissionService = require('~/server/services/PermissionService');
|
||||
grantPermission = permissionService.grantPermission;
|
||||
|
||||
// Create test data
|
||||
await setupTestData();
|
||||
|
||||
// Setup Express app
|
||||
app = express();
|
||||
app.use(express.json());
|
||||
|
||||
// Mock authentication middleware - default to owner
|
||||
setTestUser(app, testUsers.owner);
|
||||
|
||||
// Import routes after mocks are set up
|
||||
promptRoutes = require('./prompts');
|
||||
app.use('/api/prompts', promptRoutes);
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await mongoose.disconnect();
|
||||
await mongoServer.stop();
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
async function setupTestData() {
|
||||
// Create access roles for promptGroups
|
||||
testRoles = {
|
||||
viewer: await AccessRole.create({
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_VIEWER,
|
||||
name: 'Viewer',
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
permBits: PermissionBits.VIEW,
|
||||
}),
|
||||
editor: await AccessRole.create({
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_EDITOR,
|
||||
name: 'Editor',
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
permBits: PermissionBits.VIEW | PermissionBits.EDIT,
|
||||
}),
|
||||
owner: await AccessRole.create({
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
|
||||
name: 'Owner',
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
permBits:
|
||||
PermissionBits.VIEW | PermissionBits.EDIT | PermissionBits.DELETE | PermissionBits.SHARE,
|
||||
}),
|
||||
};
|
||||
|
||||
// Create test users
|
||||
testUsers = {
|
||||
owner: await User.create({
|
||||
id: new ObjectId().toString(),
|
||||
_id: new ObjectId(),
|
||||
name: 'Prompt Owner',
|
||||
email: 'owner@example.com',
|
||||
role: SystemRoles.USER,
|
||||
}),
|
||||
viewer: await User.create({
|
||||
id: new ObjectId().toString(),
|
||||
_id: new ObjectId(),
|
||||
name: 'Prompt Viewer',
|
||||
email: 'viewer@example.com',
|
||||
role: SystemRoles.USER,
|
||||
}),
|
||||
editor: await User.create({
|
||||
id: new ObjectId().toString(),
|
||||
_id: new ObjectId(),
|
||||
name: 'Prompt Editor',
|
||||
email: 'editor@example.com',
|
||||
role: SystemRoles.USER,
|
||||
}),
|
||||
noAccess: await User.create({
|
||||
id: new ObjectId().toString(),
|
||||
_id: new ObjectId(),
|
||||
name: 'No Access',
|
||||
email: 'noaccess@example.com',
|
||||
role: SystemRoles.USER,
|
||||
}),
|
||||
admin: await User.create({
|
||||
id: new ObjectId().toString(),
|
||||
_id: new ObjectId(),
|
||||
name: 'Admin',
|
||||
email: 'admin@example.com',
|
||||
role: SystemRoles.ADMIN,
|
||||
}),
|
||||
};
|
||||
|
||||
// Mock getRoleByName
|
||||
const { getRoleByName } = require('~/models/Role');
|
||||
getRoleByName.mockImplementation((roleName) => {
|
||||
switch (roleName) {
|
||||
case SystemRoles.USER:
|
||||
return { permissions: { PROMPTS: { USE: true, CREATE: true } } };
|
||||
case SystemRoles.ADMIN:
|
||||
return { permissions: { PROMPTS: { USE: true, CREATE: true, SHARED_GLOBAL: true } } };
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
describe('Prompt Routes - ACL Permissions', () => {
|
||||
let consoleErrorSpy;
|
||||
|
||||
beforeEach(() => {
|
||||
consoleErrorSpy = jest.spyOn(console, 'error').mockImplementation();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
consoleErrorSpy.mockRestore();
|
||||
});
|
||||
|
||||
// Simple test to verify route is loaded
|
||||
it('should have routes loaded', async () => {
|
||||
// This should at least not crash
|
||||
const response = await request(app).get('/api/prompts/test-404');
|
||||
console.log('Test 404 response status:', response.status);
|
||||
console.log('Test 404 response body:', response.body);
|
||||
// We expect a 401 or 404, not 500
|
||||
expect(response.status).not.toBe(500);
|
||||
});
|
||||
|
||||
describe('POST /api/prompts - Create Prompt', () => {
|
||||
afterEach(async () => {
|
||||
await Prompt.deleteMany({});
|
||||
await PromptGroup.deleteMany({});
|
||||
await AclEntry.deleteMany({});
|
||||
});
|
||||
|
||||
it('should create a prompt and grant owner permissions', async () => {
|
||||
const promptData = {
|
||||
prompt: {
|
||||
prompt: 'Test prompt content',
|
||||
type: 'text',
|
||||
},
|
||||
group: {
|
||||
name: 'Test Prompt Group',
|
||||
},
|
||||
};
|
||||
|
||||
const response = await request(app).post('/api/prompts').send(promptData);
|
||||
|
||||
if (response.status !== 200) {
|
||||
console.log('POST /api/prompts error status:', response.status);
|
||||
console.log('POST /api/prompts error body:', response.body);
|
||||
console.log('Console errors:', consoleErrorSpy.mock.calls);
|
||||
}
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body.prompt).toBeDefined();
|
||||
expect(response.body.prompt.prompt).toBe(promptData.prompt.prompt);
|
||||
|
||||
// Check ACL entry was created
|
||||
const aclEntry = await AclEntry.findOne({
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: response.body.prompt.groupId,
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUsers.owner._id,
|
||||
});
|
||||
|
||||
expect(aclEntry).toBeTruthy();
|
||||
expect(aclEntry.roleId.toString()).toBe(testRoles.owner._id.toString());
|
||||
});
|
||||
|
||||
it('should create a prompt group with prompt and grant owner permissions', async () => {
|
||||
const promptData = {
|
||||
prompt: {
|
||||
prompt: 'Group prompt content',
|
||||
// Remove 'name' from prompt - it's not in the schema
|
||||
},
|
||||
group: {
|
||||
name: 'Test Group',
|
||||
category: 'testing',
|
||||
},
|
||||
};
|
||||
|
||||
const response = await request(app).post('/api/prompts').send(promptData).expect(200);
|
||||
|
||||
expect(response.body.prompt).toBeDefined();
|
||||
expect(response.body.group).toBeDefined();
|
||||
expect(response.body.group.name).toBe(promptData.group.name);
|
||||
|
||||
// Check ACL entry was created for the promptGroup
|
||||
const aclEntry = await AclEntry.findOne({
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: response.body.group._id,
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUsers.owner._id,
|
||||
});
|
||||
|
||||
expect(aclEntry).toBeTruthy();
|
||||
});
|
||||
});
|
||||
|
||||
describe('GET /api/prompts/:promptId - Get Prompt', () => {
|
||||
let testPrompt;
|
||||
let testGroup;
|
||||
|
||||
beforeEach(async () => {
|
||||
// Create a prompt group first
|
||||
testGroup = await PromptGroup.create({
|
||||
name: 'Test Group',
|
||||
category: 'testing',
|
||||
author: testUsers.owner._id,
|
||||
authorName: testUsers.owner.name,
|
||||
productionId: new ObjectId(),
|
||||
});
|
||||
|
||||
// Create a prompt
|
||||
testPrompt = await Prompt.create({
|
||||
prompt: 'Test prompt for retrieval',
|
||||
name: 'Get Test',
|
||||
author: testUsers.owner._id,
|
||||
type: 'text',
|
||||
groupId: testGroup._id,
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await Prompt.deleteMany({});
|
||||
await PromptGroup.deleteMany({});
|
||||
await AclEntry.deleteMany({});
|
||||
});
|
||||
|
||||
it('should retrieve prompt when user has view permissions', async () => {
|
||||
// Grant view permissions on the promptGroup
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUsers.owner._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testGroup._id,
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_VIEWER,
|
||||
grantedBy: testUsers.owner._id,
|
||||
});
|
||||
|
||||
const response = await request(app).get(`/api/prompts/${testPrompt._id}`);
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body._id).toBe(testPrompt._id.toString());
|
||||
expect(response.body.prompt).toBe(testPrompt.prompt);
|
||||
});
|
||||
|
||||
it('should deny access when user has no permissions', async () => {
|
||||
// Change the user to one without access
|
||||
setTestUser(app, testUsers.noAccess);
|
||||
|
||||
const response = await request(app).get(`/api/prompts/${testPrompt._id}`).expect(403);
|
||||
|
||||
// Verify error response
|
||||
expect(response.body.error).toBe('Forbidden');
|
||||
expect(response.body.message).toBe('Insufficient permissions to access this promptGroup');
|
||||
});
|
||||
|
||||
it('should allow admin access without explicit permissions', async () => {
|
||||
// First, reset the app to remove previous middleware
|
||||
app = express();
|
||||
app.use(express.json());
|
||||
|
||||
// Set admin user BEFORE adding routes
|
||||
app.use((req, res, next) => {
|
||||
req.user = {
|
||||
...testUsers.admin.toObject(),
|
||||
id: testUsers.admin._id.toString(),
|
||||
_id: testUsers.admin._id,
|
||||
name: testUsers.admin.name,
|
||||
role: testUsers.admin.role,
|
||||
};
|
||||
next();
|
||||
});
|
||||
|
||||
// Now add the routes
|
||||
const promptRoutes = require('./prompts');
|
||||
app.use('/api/prompts', promptRoutes);
|
||||
|
||||
console.log('Admin user:', testUsers.admin);
|
||||
console.log('Admin role:', testUsers.admin.role);
|
||||
console.log('SystemRoles.ADMIN:', SystemRoles.ADMIN);
|
||||
|
||||
const response = await request(app).get(`/api/prompts/${testPrompt._id}`).expect(200);
|
||||
|
||||
expect(response.body._id).toBe(testPrompt._id.toString());
|
||||
});
|
||||
});
|
||||
|
||||
describe('DELETE /api/prompts/:promptId - Delete Prompt', () => {
|
||||
let testPrompt;
|
||||
let testGroup;
|
||||
|
||||
beforeEach(async () => {
|
||||
// Create group with prompt
|
||||
testGroup = await PromptGroup.create({
|
||||
name: 'Delete Test Group',
|
||||
category: 'testing',
|
||||
author: testUsers.owner._id,
|
||||
authorName: testUsers.owner.name,
|
||||
productionId: new ObjectId(),
|
||||
});
|
||||
|
||||
testPrompt = await Prompt.create({
|
||||
prompt: 'Test prompt for deletion',
|
||||
name: 'Delete Test',
|
||||
author: testUsers.owner._id,
|
||||
type: 'text',
|
||||
groupId: testGroup._id,
|
||||
});
|
||||
|
||||
// Add prompt to group
|
||||
testGroup.productionId = testPrompt._id;
|
||||
testGroup.promptIds = [testPrompt._id];
|
||||
await testGroup.save();
|
||||
|
||||
// Grant owner permissions on the promptGroup
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUsers.owner._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testGroup._id,
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_OWNER,
|
||||
grantedBy: testUsers.owner._id,
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await Prompt.deleteMany({});
|
||||
await PromptGroup.deleteMany({});
|
||||
await AclEntry.deleteMany({});
|
||||
});
|
||||
|
||||
it('should delete prompt when user has delete permissions', async () => {
|
||||
const response = await request(app)
|
||||
.delete(`/api/prompts/${testPrompt._id}`)
|
||||
.query({ groupId: testGroup._id.toString() })
|
||||
.expect(200);
|
||||
|
||||
expect(response.body.prompt).toBe('Prompt deleted successfully');
|
||||
|
||||
// Verify prompt was deleted
|
||||
const deletedPrompt = await Prompt.findById(testPrompt._id);
|
||||
expect(deletedPrompt).toBeNull();
|
||||
|
||||
// Verify ACL entries were removed
|
||||
const aclEntries = await AclEntry.find({
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testGroup._id,
|
||||
});
|
||||
expect(aclEntries).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('should deny deletion when user lacks delete permissions', async () => {
|
||||
// Create a prompt as a different user (not the one trying to delete)
|
||||
const authorPrompt = await Prompt.create({
|
||||
prompt: 'Test prompt by another user',
|
||||
name: 'Another User Prompt',
|
||||
author: testUsers.editor._id, // Different author
|
||||
type: 'text',
|
||||
groupId: testGroup._id,
|
||||
});
|
||||
|
||||
// Grant only viewer permissions to viewer user on the promptGroup
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUsers.viewer._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testGroup._id,
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_VIEWER,
|
||||
grantedBy: testUsers.editor._id,
|
||||
});
|
||||
|
||||
// Recreate app with viewer user
|
||||
app = express();
|
||||
app.use(express.json());
|
||||
app.use((req, res, next) => {
|
||||
req.user = {
|
||||
...testUsers.viewer.toObject(),
|
||||
id: testUsers.viewer._id.toString(),
|
||||
_id: testUsers.viewer._id,
|
||||
name: testUsers.viewer.name,
|
||||
role: testUsers.viewer.role,
|
||||
};
|
||||
next();
|
||||
});
|
||||
const promptRoutes = require('./prompts');
|
||||
app.use('/api/prompts', promptRoutes);
|
||||
|
||||
await request(app)
|
||||
.delete(`/api/prompts/${authorPrompt._id}`)
|
||||
.query({ groupId: testGroup._id.toString() })
|
||||
.expect(403);
|
||||
|
||||
// Verify prompt still exists
|
||||
const prompt = await Prompt.findById(authorPrompt._id);
|
||||
expect(prompt).toBeTruthy();
|
||||
});
|
||||
});
|
||||
|
||||
describe('PATCH /api/prompts/:promptId/tags/production - Make Production', () => {
|
||||
let testPrompt;
|
||||
let testGroup;
|
||||
|
||||
beforeEach(async () => {
|
||||
// Create group
|
||||
testGroup = await PromptGroup.create({
|
||||
name: 'Production Test Group',
|
||||
category: 'testing',
|
||||
author: testUsers.owner._id,
|
||||
authorName: testUsers.owner.name,
|
||||
productionId: new ObjectId(),
|
||||
});
|
||||
|
||||
testPrompt = await Prompt.create({
|
||||
prompt: 'Test prompt for production',
|
||||
name: 'Production Test',
|
||||
author: testUsers.owner._id,
|
||||
type: 'text',
|
||||
groupId: testGroup._id,
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await Prompt.deleteMany({});
|
||||
await PromptGroup.deleteMany({});
|
||||
await AclEntry.deleteMany({});
|
||||
});
|
||||
|
||||
it('should make prompt production when user has edit permissions', async () => {
|
||||
// Grant edit permissions on the promptGroup
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUsers.owner._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testGroup._id,
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_EDITOR,
|
||||
grantedBy: testUsers.owner._id,
|
||||
});
|
||||
|
||||
// Recreate app to ensure fresh middleware
|
||||
app = express();
|
||||
app.use(express.json());
|
||||
app.use((req, res, next) => {
|
||||
req.user = {
|
||||
...testUsers.owner.toObject(),
|
||||
id: testUsers.owner._id.toString(),
|
||||
_id: testUsers.owner._id,
|
||||
name: testUsers.owner.name,
|
||||
role: testUsers.owner.role,
|
||||
};
|
||||
next();
|
||||
});
|
||||
const promptRoutes = require('./prompts');
|
||||
app.use('/api/prompts', promptRoutes);
|
||||
|
||||
const response = await request(app)
|
||||
.patch(`/api/prompts/${testPrompt._id}/tags/production`)
|
||||
.expect(200);
|
||||
|
||||
expect(response.body.message).toBe('Prompt production made successfully');
|
||||
|
||||
// Verify the group was updated
|
||||
const updatedGroup = await PromptGroup.findById(testGroup._id);
|
||||
expect(updatedGroup.productionId.toString()).toBe(testPrompt._id.toString());
|
||||
});
|
||||
|
||||
it('should deny making production when user lacks edit permissions', async () => {
|
||||
// Grant only view permissions to viewer on the promptGroup
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.USER,
|
||||
principalId: testUsers.viewer._id,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: testGroup._id,
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_VIEWER,
|
||||
grantedBy: testUsers.owner._id,
|
||||
});
|
||||
|
||||
// Recreate app with viewer user
|
||||
app = express();
|
||||
app.use(express.json());
|
||||
app.use((req, res, next) => {
|
||||
req.user = {
|
||||
...testUsers.viewer.toObject(),
|
||||
id: testUsers.viewer._id.toString(),
|
||||
_id: testUsers.viewer._id,
|
||||
name: testUsers.viewer.name,
|
||||
role: testUsers.viewer.role,
|
||||
};
|
||||
next();
|
||||
});
|
||||
const promptRoutes = require('./prompts');
|
||||
app.use('/api/prompts', promptRoutes);
|
||||
|
||||
await request(app).patch(`/api/prompts/${testPrompt._id}/tags/production`).expect(403);
|
||||
|
||||
// Verify prompt hasn't changed
|
||||
const unchangedGroup = await PromptGroup.findById(testGroup._id);
|
||||
expect(unchangedGroup.productionId.toString()).not.toBe(testPrompt._id.toString());
|
||||
});
|
||||
});
|
||||
|
||||
describe('Public Access', () => {
|
||||
let publicPrompt;
|
||||
let publicGroup;
|
||||
|
||||
beforeEach(async () => {
|
||||
// Create a prompt group
|
||||
publicGroup = await PromptGroup.create({
|
||||
name: 'Public Test Group',
|
||||
category: 'testing',
|
||||
author: testUsers.owner._id,
|
||||
authorName: testUsers.owner.name,
|
||||
productionId: new ObjectId(),
|
||||
});
|
||||
|
||||
// Create a public prompt
|
||||
publicPrompt = await Prompt.create({
|
||||
prompt: 'Public prompt content',
|
||||
name: 'Public Test',
|
||||
author: testUsers.owner._id,
|
||||
type: 'text',
|
||||
groupId: publicGroup._id,
|
||||
});
|
||||
|
||||
// Grant public viewer access on the promptGroup
|
||||
await grantPermission({
|
||||
principalType: PrincipalType.PUBLIC,
|
||||
principalId: null,
|
||||
resourceType: ResourceType.PROMPTGROUP,
|
||||
resourceId: publicGroup._id,
|
||||
accessRoleId: AccessRoleIds.PROMPTGROUP_VIEWER,
|
||||
grantedBy: testUsers.owner._id,
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await Prompt.deleteMany({});
|
||||
await PromptGroup.deleteMany({});
|
||||
await AclEntry.deleteMany({});
|
||||
});
|
||||
|
||||
it('should allow any user to view public prompts', async () => {
|
||||
// Change user to someone without explicit permissions
|
||||
setTestUser(app, testUsers.noAccess);
|
||||
|
||||
const response = await request(app).get(`/api/prompts/${publicPrompt._id}`).expect(200);
|
||||
|
||||
expect(response.body._id).toBe(publicPrompt._id.toString());
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,11 +1,13 @@
|
||||
const express = require('express');
|
||||
const {
|
||||
SystemRoles,
|
||||
roleDefaults,
|
||||
PermissionTypes,
|
||||
agentPermissionsSchema,
|
||||
promptPermissionsSchema,
|
||||
memoryPermissionsSchema,
|
||||
agentPermissionsSchema,
|
||||
PermissionTypes,
|
||||
roleDefaults,
|
||||
SystemRoles,
|
||||
marketplacePermissionsSchema,
|
||||
peoplePickerPermissionsSchema,
|
||||
} = require('librechat-data-provider');
|
||||
const { checkAdmin, requireJwtAuth } = require('~/server/middleware');
|
||||
const { updateRoleByName, getRoleByName } = require('~/models/Role');
|
||||
@@ -13,6 +15,81 @@ const { updateRoleByName, getRoleByName } = require('~/models/Role');
|
||||
const router = express.Router();
|
||||
router.use(requireJwtAuth);
|
||||
|
||||
/**
|
||||
* Permission configuration mapping
|
||||
* Maps route paths to their corresponding schemas and permission types
|
||||
*/
|
||||
const permissionConfigs = {
|
||||
prompts: {
|
||||
schema: promptPermissionsSchema,
|
||||
permissionType: PermissionTypes.PROMPTS,
|
||||
errorMessage: 'Invalid prompt permissions.',
|
||||
},
|
||||
agents: {
|
||||
schema: agentPermissionsSchema,
|
||||
permissionType: PermissionTypes.AGENTS,
|
||||
errorMessage: 'Invalid agent permissions.',
|
||||
},
|
||||
memories: {
|
||||
schema: memoryPermissionsSchema,
|
||||
permissionType: PermissionTypes.MEMORIES,
|
||||
errorMessage: 'Invalid memory permissions.',
|
||||
},
|
||||
'people-picker': {
|
||||
schema: peoplePickerPermissionsSchema,
|
||||
permissionType: PermissionTypes.PEOPLE_PICKER,
|
||||
errorMessage: 'Invalid people picker permissions.',
|
||||
},
|
||||
marketplace: {
|
||||
schema: marketplacePermissionsSchema,
|
||||
permissionType: PermissionTypes.MARKETPLACE,
|
||||
errorMessage: 'Invalid marketplace permissions.',
|
||||
},
|
||||
};
|
||||
|
||||
/**
|
||||
* Generic handler for updating permissions
|
||||
* @param {string} permissionKey - The key from permissionConfigs
|
||||
* @returns {Function} Express route handler
|
||||
*/
|
||||
const createPermissionUpdateHandler = (permissionKey) => {
|
||||
const config = permissionConfigs[permissionKey];
|
||||
|
||||
return async (req, res) => {
|
||||
const { roleName: _r } = req.params;
|
||||
// TODO: TEMP, use a better parsing for roleName
|
||||
const roleName = _r.toUpperCase();
|
||||
const updates = req.body;
|
||||
|
||||
try {
|
||||
const parsedUpdates = config.schema.partial().parse(updates);
|
||||
|
||||
const role = await getRoleByName(roleName);
|
||||
if (!role) {
|
||||
return res.status(404).send({ message: 'Role not found' });
|
||||
}
|
||||
|
||||
const currentPermissions =
|
||||
role.permissions?.[config.permissionType] || role[config.permissionType] || {};
|
||||
|
||||
const mergedUpdates = {
|
||||
permissions: {
|
||||
...role.permissions,
|
||||
[config.permissionType]: {
|
||||
...currentPermissions,
|
||||
...parsedUpdates,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const updatedRole = await updateRoleByName(roleName, mergedUpdates);
|
||||
res.status(200).send(updatedRole);
|
||||
} catch (error) {
|
||||
return res.status(400).send({ message: config.errorMessage, error: error.errors });
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* GET /api/roles/:roleName
|
||||
* Get a specific role by name
|
||||
@@ -45,117 +122,30 @@ router.get('/:roleName', async (req, res) => {
|
||||
* PUT /api/roles/:roleName/prompts
|
||||
* Update prompt permissions for a specific role
|
||||
*/
|
||||
router.put('/:roleName/prompts', checkAdmin, async (req, res) => {
|
||||
const { roleName: _r } = req.params;
|
||||
// TODO: TEMP, use a better parsing for roleName
|
||||
const roleName = _r.toUpperCase();
|
||||
/** @type {TRole['permissions']['PROMPTS']} */
|
||||
const updates = req.body;
|
||||
|
||||
try {
|
||||
const parsedUpdates = promptPermissionsSchema.partial().parse(updates);
|
||||
|
||||
const role = await getRoleByName(roleName);
|
||||
if (!role) {
|
||||
return res.status(404).send({ message: 'Role not found' });
|
||||
}
|
||||
|
||||
const currentPermissions =
|
||||
role.permissions?.[PermissionTypes.PROMPTS] || role[PermissionTypes.PROMPTS] || {};
|
||||
|
||||
const mergedUpdates = {
|
||||
permissions: {
|
||||
...role.permissions,
|
||||
[PermissionTypes.PROMPTS]: {
|
||||
...currentPermissions,
|
||||
...parsedUpdates,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const updatedRole = await updateRoleByName(roleName, mergedUpdates);
|
||||
res.status(200).send(updatedRole);
|
||||
} catch (error) {
|
||||
return res.status(400).send({ message: 'Invalid prompt permissions.', error: error.errors });
|
||||
}
|
||||
});
|
||||
router.put('/:roleName/prompts', checkAdmin, createPermissionUpdateHandler('prompts'));
|
||||
|
||||
/**
|
||||
* PUT /api/roles/:roleName/agents
|
||||
* Update agent permissions for a specific role
|
||||
*/
|
||||
router.put('/:roleName/agents', checkAdmin, async (req, res) => {
|
||||
const { roleName: _r } = req.params;
|
||||
// TODO: TEMP, use a better parsing for roleName
|
||||
const roleName = _r.toUpperCase();
|
||||
/** @type {TRole['permissions']['AGENTS']} */
|
||||
const updates = req.body;
|
||||
|
||||
try {
|
||||
const parsedUpdates = agentPermissionsSchema.partial().parse(updates);
|
||||
|
||||
const role = await getRoleByName(roleName);
|
||||
if (!role) {
|
||||
return res.status(404).send({ message: 'Role not found' });
|
||||
}
|
||||
|
||||
const currentPermissions =
|
||||
role.permissions?.[PermissionTypes.AGENTS] || role[PermissionTypes.AGENTS] || {};
|
||||
|
||||
const mergedUpdates = {
|
||||
permissions: {
|
||||
...role.permissions,
|
||||
[PermissionTypes.AGENTS]: {
|
||||
...currentPermissions,
|
||||
...parsedUpdates,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const updatedRole = await updateRoleByName(roleName, mergedUpdates);
|
||||
res.status(200).send(updatedRole);
|
||||
} catch (error) {
|
||||
return res.status(400).send({ message: 'Invalid agent permissions.', error: error.errors });
|
||||
}
|
||||
});
|
||||
router.put('/:roleName/agents', checkAdmin, createPermissionUpdateHandler('agents'));
|
||||
|
||||
/**
|
||||
* PUT /api/roles/:roleName/memories
|
||||
* Update memory permissions for a specific role
|
||||
*/
|
||||
router.put('/:roleName/memories', checkAdmin, async (req, res) => {
|
||||
const { roleName: _r } = req.params;
|
||||
// TODO: TEMP, use a better parsing for roleName
|
||||
const roleName = _r.toUpperCase();
|
||||
/** @type {TRole['permissions']['MEMORIES']} */
|
||||
const updates = req.body;
|
||||
router.put('/:roleName/memories', checkAdmin, createPermissionUpdateHandler('memories'));
|
||||
|
||||
try {
|
||||
const parsedUpdates = memoryPermissionsSchema.partial().parse(updates);
|
||||
/**
|
||||
* PUT /api/roles/:roleName/people-picker
|
||||
* Update people picker permissions for a specific role
|
||||
*/
|
||||
router.put('/:roleName/people-picker', checkAdmin, createPermissionUpdateHandler('people-picker'));
|
||||
|
||||
const role = await getRoleByName(roleName);
|
||||
if (!role) {
|
||||
return res.status(404).send({ message: 'Role not found' });
|
||||
}
|
||||
|
||||
const currentPermissions =
|
||||
role.permissions?.[PermissionTypes.MEMORIES] || role[PermissionTypes.MEMORIES] || {};
|
||||
|
||||
const mergedUpdates = {
|
||||
permissions: {
|
||||
...role.permissions,
|
||||
[PermissionTypes.MEMORIES]: {
|
||||
...currentPermissions,
|
||||
...parsedUpdates,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const updatedRole = await updateRoleByName(roleName, mergedUpdates);
|
||||
res.status(200).send(updatedRole);
|
||||
} catch (error) {
|
||||
return res.status(400).send({ message: 'Invalid memory permissions.', error: error.errors });
|
||||
}
|
||||
});
|
||||
/**
|
||||
* PUT /api/roles/:roleName/marketplace
|
||||
* Update marketplace permissions for a specific role
|
||||
*/
|
||||
router.put('/:roleName/marketplace', checkAdmin, createPermissionUpdateHandler('marketplace'));
|
||||
|
||||
module.exports = router;
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
jest.mock('~/models', () => ({
|
||||
initializeRoles: jest.fn(),
|
||||
seedDefaultRoles: jest.fn(),
|
||||
ensureDefaultCategories: jest.fn(),
|
||||
}));
|
||||
jest.mock('~/models/Role', () => ({
|
||||
updateAccessPermissions: jest.fn(),
|
||||
getRoleByName: jest.fn(),
|
||||
getRoleByName: jest.fn().mockResolvedValue(null),
|
||||
updateRoleByName: jest.fn(),
|
||||
}));
|
||||
|
||||
@@ -87,4 +89,76 @@ describe('AppService interface configuration', () => {
|
||||
expect(app.locals.interfaceConfig.bookmarks).toBe(false);
|
||||
expect(loadDefaultInterface).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should correctly configure peoplePicker permissions including roles', async () => {
|
||||
mockLoadCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
peoplePicker: {
|
||||
users: true,
|
||||
groups: true,
|
||||
roles: true,
|
||||
},
|
||||
},
|
||||
});
|
||||
loadDefaultInterface.mockResolvedValue({
|
||||
peoplePicker: {
|
||||
users: true,
|
||||
groups: true,
|
||||
roles: true,
|
||||
},
|
||||
});
|
||||
|
||||
await AppService(app);
|
||||
|
||||
expect(app.locals.interfaceConfig.peoplePicker).toBeDefined();
|
||||
expect(app.locals.interfaceConfig.peoplePicker).toMatchObject({
|
||||
users: true,
|
||||
groups: true,
|
||||
roles: true,
|
||||
});
|
||||
expect(loadDefaultInterface).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle mixed peoplePicker permissions', async () => {
|
||||
mockLoadCustomConfig.mockResolvedValue({
|
||||
interface: {
|
||||
peoplePicker: {
|
||||
users: true,
|
||||
groups: false,
|
||||
roles: true,
|
||||
},
|
||||
},
|
||||
});
|
||||
loadDefaultInterface.mockResolvedValue({
|
||||
peoplePicker: {
|
||||
users: true,
|
||||
groups: false,
|
||||
roles: true,
|
||||
},
|
||||
});
|
||||
|
||||
await AppService(app);
|
||||
|
||||
expect(app.locals.interfaceConfig.peoplePicker.users).toBe(true);
|
||||
expect(app.locals.interfaceConfig.peoplePicker.groups).toBe(false);
|
||||
expect(app.locals.interfaceConfig.peoplePicker.roles).toBe(true);
|
||||
});
|
||||
|
||||
it('should set default peoplePicker permissions when not provided', async () => {
|
||||
mockLoadCustomConfig.mockResolvedValue({});
|
||||
loadDefaultInterface.mockResolvedValue({
|
||||
peoplePicker: {
|
||||
users: true,
|
||||
groups: true,
|
||||
roles: true,
|
||||
},
|
||||
});
|
||||
|
||||
await AppService(app);
|
||||
|
||||
expect(app.locals.interfaceConfig.peoplePicker).toBeDefined();
|
||||
expect(app.locals.interfaceConfig.peoplePicker.users).toBe(true);
|
||||
expect(app.locals.interfaceConfig.peoplePicker.groups).toBe(true);
|
||||
expect(app.locals.interfaceConfig.peoplePicker.roles).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,4 +1,9 @@
|
||||
const { loadMemoryConfig, agentsConfigSetup, loadWebSearchConfig } = require('@librechat/api');
|
||||
const {
|
||||
isEnabled,
|
||||
loadMemoryConfig,
|
||||
agentsConfigSetup,
|
||||
loadWebSearchConfig,
|
||||
} = require('@librechat/api');
|
||||
const {
|
||||
FileSources,
|
||||
loadOCRConfig,
|
||||
@@ -6,12 +11,13 @@ const {
|
||||
getConfigDefaults,
|
||||
} = require('librechat-data-provider');
|
||||
const {
|
||||
checkWebSearchConfig,
|
||||
checkAzureVariables,
|
||||
checkVariables,
|
||||
checkHealth,
|
||||
checkConfig,
|
||||
checkVariables,
|
||||
checkAzureVariables,
|
||||
checkWebSearchConfig,
|
||||
} = require('./start/checks');
|
||||
const { ensureDefaultCategories, seedDefaultRoles, initializeRoles } = require('~/models');
|
||||
const { azureAssistantsDefaults, assistantsConfigSetup } = require('./start/assistants');
|
||||
const { initializeAzureBlobService } = require('./Files/Azure/initialize');
|
||||
const { initializeFirebase } = require('./Files/Firebase/initialize');
|
||||
@@ -23,8 +29,6 @@ const { azureConfigSetup } = require('./start/azureOpenAI');
|
||||
const { processModelSpecs } = require('./start/modelSpecs');
|
||||
const { initializeS3 } = require('./Files/S3/initialize');
|
||||
const { loadAndFormatTools } = require('./ToolService');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { initializeRoles } = require('~/models');
|
||||
const { setCachedTools } = require('./Config');
|
||||
const paths = require('~/config/paths');
|
||||
|
||||
@@ -35,6 +39,8 @@ const paths = require('~/config/paths');
|
||||
*/
|
||||
const AppService = async (app) => {
|
||||
await initializeRoles();
|
||||
await seedDefaultRoles();
|
||||
await ensureDefaultCategories();
|
||||
/** @type {TCustomConfig} */
|
||||
const config = (await loadCustomConfig()) ?? {};
|
||||
const configDefaults = getConfigDefaults();
|
||||
@@ -84,6 +90,7 @@ const AppService = async (app) => {
|
||||
const turnstileConfig = loadTurnstileConfig(config, configDefaults);
|
||||
|
||||
const defaultLocals = {
|
||||
config,
|
||||
ocr,
|
||||
paths,
|
||||
memory,
|
||||
|
||||
@@ -28,9 +28,12 @@ jest.mock('./Files/Firebase/initialize', () => ({
|
||||
}));
|
||||
jest.mock('~/models', () => ({
|
||||
initializeRoles: jest.fn(),
|
||||
seedDefaultRoles: jest.fn(),
|
||||
ensureDefaultCategories: jest.fn(),
|
||||
}));
|
||||
jest.mock('~/models/Role', () => ({
|
||||
updateAccessPermissions: jest.fn(),
|
||||
getRoleByName: jest.fn().mockResolvedValue(null),
|
||||
}));
|
||||
jest.mock('./Config', () => ({
|
||||
setCachedTools: jest.fn(),
|
||||
@@ -131,6 +134,9 @@ describe('AppService', () => {
|
||||
expect(process.env.CDN_PROVIDER).toEqual('testStrategy');
|
||||
|
||||
expect(app.locals).toEqual({
|
||||
config: expect.objectContaining({
|
||||
fileStrategy: 'testStrategy',
|
||||
}),
|
||||
socialLogins: ['testLogin'],
|
||||
fileStrategy: 'testStrategy',
|
||||
interfaceConfig: expect.objectContaining({
|
||||
@@ -165,6 +171,9 @@ describe('AppService', () => {
|
||||
agents: {
|
||||
disableBuilder: false,
|
||||
capabilities: expect.arrayContaining([...defaultAgentCapabilities]),
|
||||
maxCitations: 30,
|
||||
maxCitationsPerFile: 7,
|
||||
minRelevanceScore: 0.45,
|
||||
},
|
||||
});
|
||||
});
|
||||
@@ -770,6 +779,7 @@ describe('AppService updating app.locals and issuing warnings', () => {
|
||||
|
||||
expect(app.locals).toBeDefined();
|
||||
expect(app.locals.paths).toBeDefined();
|
||||
expect(app.locals.config).toEqual({});
|
||||
expect(app.locals.fileStrategy).toEqual(FileSources.local);
|
||||
expect(app.locals.socialLogins).toEqual(defaultSocialLogins);
|
||||
expect(app.locals.balance).toEqual(
|
||||
@@ -802,6 +812,7 @@ describe('AppService updating app.locals and issuing warnings', () => {
|
||||
|
||||
expect(app.locals).toBeDefined();
|
||||
expect(app.locals.paths).toBeDefined();
|
||||
expect(app.locals.config).toEqual(customConfig);
|
||||
expect(app.locals.fileStrategy).toEqual(customConfig.fileStrategy);
|
||||
expect(app.locals.socialLogins).toEqual(customConfig.registration.socialLogins);
|
||||
expect(app.locals.balance).toEqual(customConfig.balance);
|
||||
@@ -959,4 +970,29 @@ describe('AppService updating app.locals and issuing warnings', () => {
|
||||
expect(app.locals.ocr.strategy).toEqual('mistral_ocr');
|
||||
expect(app.locals.ocr.mistralModel).toEqual('mistral-medium');
|
||||
});
|
||||
|
||||
it('should correctly configure peoplePicker permissions when specified', async () => {
|
||||
const mockConfig = {
|
||||
interface: {
|
||||
peoplePicker: {
|
||||
users: true,
|
||||
groups: true,
|
||||
roles: true,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
require('./Config/loadCustomConfig').mockImplementationOnce(() => Promise.resolve(mockConfig));
|
||||
|
||||
const app = { locals: {} };
|
||||
await AppService(app);
|
||||
|
||||
// Check that interface config includes the permissions
|
||||
expect(app.locals.interfaceConfig.peoplePicker).toBeDefined();
|
||||
expect(app.locals.interfaceConfig.peoplePicker).toMatchObject({
|
||||
users: true,
|
||||
groups: true,
|
||||
roles: true,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -26,7 +26,7 @@ const ToolCacheKeys = {
|
||||
* @param {string[]} [options.roleIds] - Role IDs for role-based tools
|
||||
* @param {string[]} [options.groupIds] - Group IDs for group-based tools
|
||||
* @param {boolean} [options.includeGlobal=true] - Whether to include global tools
|
||||
* @returns {Promise<Object|null>} The available tools object or null if not cached
|
||||
* @returns {Promise<LCAvailableTools|null>} The available tools object or null if not cached
|
||||
*/
|
||||
async function getCachedTools(options = {}) {
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
@@ -41,13 +41,13 @@ async function getCachedTools(options = {}) {
|
||||
// Future implementation will merge tools from multiple sources
|
||||
// based on user permissions, roles, and groups
|
||||
if (userId) {
|
||||
// Check if we have pre-computed effective tools for this user
|
||||
/** @type {LCAvailableTools | null} Check if we have pre-computed effective tools for this user */
|
||||
const effectiveTools = await cache.get(ToolCacheKeys.EFFECTIVE(userId));
|
||||
if (effectiveTools) {
|
||||
return effectiveTools;
|
||||
}
|
||||
|
||||
// Otherwise, compute from individual sources
|
||||
/** @type {LCAvailableTools | null} Otherwise, compute from individual sources */
|
||||
const toolSources = [];
|
||||
|
||||
if (includeGlobal) {
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { isEnabled, getUserMCPAuthMap } = require('@librechat/api');
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
const { CacheKeys, EModelEndpoint } = require('librechat-data-provider');
|
||||
const { normalizeEndpointName } = require('~/server/utils');
|
||||
const loadCustomConfig = require('./loadCustomConfig');
|
||||
@@ -53,31 +52,6 @@ const getCustomEndpointConfig = async (endpoint) => {
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {Object} params
|
||||
* @param {string} params.userId
|
||||
* @param {GenericTool[]} [params.tools]
|
||||
* @param {import('@librechat/data-schemas').PluginAuthMethods['findPluginAuthsByKeys']} params.findPluginAuthsByKeys
|
||||
* @returns {Promise<Record<string, Record<string, string>> | undefined>}
|
||||
*/
|
||||
async function getMCPAuthMap({ userId, tools, findPluginAuthsByKeys }) {
|
||||
try {
|
||||
if (!tools || tools.length === 0) {
|
||||
return;
|
||||
}
|
||||
return await getUserMCPAuthMap({
|
||||
tools,
|
||||
userId,
|
||||
findPluginAuthsByKeys,
|
||||
});
|
||||
} catch (err) {
|
||||
logger.error(
|
||||
`[api/server/controllers/agents/client.js #chatCompletion] Error getting custom user vars for agent`,
|
||||
err,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @returns {Promise<boolean>}
|
||||
*/
|
||||
@@ -88,7 +62,6 @@ async function hasCustomUserVars() {
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
getMCPAuthMap,
|
||||
getCustomConfig,
|
||||
getBalanceConfig,
|
||||
hasCustomUserVars,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
const { config } = require('./EndpointService');
|
||||
const getCachedTools = require('./getCachedTools');
|
||||
const getCustomConfig = require('./getCustomConfig');
|
||||
const mcpToolsCache = require('./mcpToolsCache');
|
||||
const loadCustomConfig = require('./loadCustomConfig');
|
||||
const loadConfigModels = require('./loadConfigModels');
|
||||
const loadDefaultModels = require('./loadDefaultModels');
|
||||
@@ -17,5 +18,6 @@ module.exports = {
|
||||
loadAsyncEndpoints,
|
||||
...getCachedTools,
|
||||
...getCustomConfig,
|
||||
...mcpToolsCache,
|
||||
...getEndpointsConfig,
|
||||
};
|
||||
|
||||
@@ -76,10 +76,11 @@ async function loadConfigModels(req) {
|
||||
fetchPromisesMap[uniqueKey] =
|
||||
fetchPromisesMap[uniqueKey] ||
|
||||
fetchModels({
|
||||
user: req.user.id,
|
||||
baseURL: BASE_URL,
|
||||
apiKey: API_KEY,
|
||||
name,
|
||||
apiKey: API_KEY,
|
||||
baseURL: BASE_URL,
|
||||
user: req.user.id,
|
||||
direct: endpoint.directEndpoint,
|
||||
userIdQuery: models.userIdQuery,
|
||||
});
|
||||
uniqueKeyToEndpointsMap[uniqueKey] = uniqueKeyToEndpointsMap[uniqueKey] || [];
|
||||
|
||||
143
api/server/services/Config/mcpToolsCache.js
Normal file
143
api/server/services/Config/mcpToolsCache.js
Normal file
@@ -0,0 +1,143 @@
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys, Constants } = require('librechat-data-provider');
|
||||
const { getCachedTools, setCachedTools } = require('./getCachedTools');
|
||||
const { getLogStores } = require('~/cache');
|
||||
|
||||
/**
|
||||
* Updates MCP tools in the cache for a specific server and user
|
||||
* @param {Object} params - Parameters for updating MCP tools
|
||||
* @param {string} params.userId - User ID
|
||||
* @param {string} params.serverName - MCP server name
|
||||
* @param {Array} params.tools - Array of tool objects from MCP server
|
||||
* @returns {Promise<LCAvailableTools>}
|
||||
*/
|
||||
async function updateMCPUserTools({ userId, serverName, tools }) {
|
||||
try {
|
||||
const userTools = await getCachedTools({ userId });
|
||||
|
||||
const mcpDelimiter = Constants.mcp_delimiter;
|
||||
for (const key of Object.keys(userTools)) {
|
||||
if (key.endsWith(`${mcpDelimiter}${serverName}`)) {
|
||||
delete userTools[key];
|
||||
}
|
||||
}
|
||||
|
||||
for (const tool of tools) {
|
||||
const name = `${tool.name}${Constants.mcp_delimiter}${serverName}`;
|
||||
userTools[name] = {
|
||||
type: 'function',
|
||||
['function']: {
|
||||
name,
|
||||
description: tool.description,
|
||||
parameters: tool.inputSchema,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
await setCachedTools(userTools, { userId });
|
||||
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
await cache.delete(CacheKeys.TOOLS);
|
||||
logger.debug(`[MCP Cache] Updated ${tools.length} tools for ${serverName} user ${userId}`);
|
||||
return userTools;
|
||||
} catch (error) {
|
||||
logger.error(`[MCP Cache] Failed to update tools for ${serverName}:`, error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Merges app-level tools with global tools
|
||||
* @param {import('@librechat/api').LCAvailableTools} appTools
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
async function mergeAppTools(appTools) {
|
||||
try {
|
||||
const count = Object.keys(appTools).length;
|
||||
if (!count) {
|
||||
return;
|
||||
}
|
||||
const cachedTools = await getCachedTools({ includeGlobal: true });
|
||||
const mergedTools = { ...cachedTools, ...appTools };
|
||||
await setCachedTools(mergedTools, { isGlobal: true });
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
await cache.delete(CacheKeys.TOOLS);
|
||||
logger.debug(`Merged ${count} app-level tools`);
|
||||
} catch (error) {
|
||||
logger.error('Failed to merge app-level tools:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Merges user-level tools with global tools
|
||||
* @param {object} params
|
||||
* @param {string} params.userId
|
||||
* @param {Record<string, FunctionTool>} params.cachedUserTools
|
||||
* @param {import('@librechat/api').LCAvailableTools} params.userTools
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
async function mergeUserTools({ userId, cachedUserTools, userTools }) {
|
||||
try {
|
||||
if (!userId) {
|
||||
return;
|
||||
}
|
||||
const count = Object.keys(userTools).length;
|
||||
if (!count) {
|
||||
return;
|
||||
}
|
||||
const cachedTools = cachedUserTools ?? (await getCachedTools({ userId }));
|
||||
const mergedTools = { ...cachedTools, ...userTools };
|
||||
await setCachedTools(mergedTools, { userId });
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
await cache.delete(CacheKeys.TOOLS);
|
||||
logger.debug(`Merged ${count} user-level tools`);
|
||||
} catch (error) {
|
||||
logger.error('Failed to merge user-level tools:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Clears all MCP tools for a specific server
|
||||
* @param {Object} params - Parameters for clearing MCP tools
|
||||
* @param {string} [params.userId] - User ID (if clearing user-specific tools)
|
||||
* @param {string} params.serverName - MCP server name
|
||||
* @returns {Promise<void>}
|
||||
*/
|
||||
async function clearMCPServerTools({ userId, serverName }) {
|
||||
try {
|
||||
const tools = await getCachedTools({ userId, includeGlobal: !userId });
|
||||
|
||||
// Remove all tools for this server
|
||||
const mcpDelimiter = Constants.mcp_delimiter;
|
||||
let removedCount = 0;
|
||||
for (const key of Object.keys(tools)) {
|
||||
if (key.endsWith(`${mcpDelimiter}${serverName}`)) {
|
||||
delete tools[key];
|
||||
removedCount++;
|
||||
}
|
||||
}
|
||||
|
||||
if (removedCount > 0) {
|
||||
await setCachedTools(tools, userId ? { userId } : { isGlobal: true });
|
||||
|
||||
const cache = getLogStores(CacheKeys.CONFIG_STORE);
|
||||
await cache.delete(CacheKeys.TOOLS);
|
||||
|
||||
logger.debug(
|
||||
`[MCP Cache] Removed ${removedCount} tools for ${serverName}${userId ? ` user ${userId}` : ' (global)'}`,
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`[MCP Cache] Failed to clear tools for ${serverName}:`, error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
mergeAppTools,
|
||||
mergeUserTools,
|
||||
updateMCPUserTools,
|
||||
clearMCPServerTools,
|
||||
};
|
||||
@@ -30,7 +30,13 @@ const { getModelMaxTokens } = require('~/utils');
|
||||
* @param {TEndpointOption} [params.endpointOption]
|
||||
* @param {Set<string>} [params.allowedProviders]
|
||||
* @param {boolean} [params.isInitialAgent]
|
||||
* @returns {Promise<Agent & { tools: StructuredTool[], attachments: Array<MongoFile>, toolContextMap: Record<string, unknown>, maxContextTokens: number }>}
|
||||
* @returns {Promise<Agent & {
|
||||
* tools: StructuredTool[],
|
||||
* attachments: Array<MongoFile>,
|
||||
* toolContextMap: Record<string, unknown>,
|
||||
* maxContextTokens: number,
|
||||
* userMCPAuthMap?: Record<string, Record<string, string>>
|
||||
* }>}
|
||||
*/
|
||||
const initializeAgent = async ({
|
||||
req,
|
||||
@@ -91,16 +97,19 @@ const initializeAgent = async ({
|
||||
});
|
||||
|
||||
const provider = agent.provider;
|
||||
const { tools: structuredTools, toolContextMap } =
|
||||
(await loadTools?.({
|
||||
req,
|
||||
res,
|
||||
provider,
|
||||
agentId: agent.id,
|
||||
tools: agent.tools,
|
||||
model: agent.model,
|
||||
tool_resources,
|
||||
})) ?? {};
|
||||
const {
|
||||
tools: structuredTools,
|
||||
toolContextMap,
|
||||
userMCPAuthMap,
|
||||
} = (await loadTools?.({
|
||||
req,
|
||||
res,
|
||||
provider,
|
||||
agentId: agent.id,
|
||||
tools: agent.tools,
|
||||
model: agent.model,
|
||||
tool_resources,
|
||||
})) ?? {};
|
||||
|
||||
agent.endpoint = provider;
|
||||
const { getOptions, overrideProvider } = await getProviderConfig(provider);
|
||||
@@ -189,6 +198,7 @@ const initializeAgent = async ({
|
||||
tools,
|
||||
attachments,
|
||||
resendFiles,
|
||||
userMCPAuthMap,
|
||||
toolContextMap,
|
||||
useLegacyContent: !!options.useLegacyContent,
|
||||
maxContextTokens: Math.round((agentMaxContextTokens - maxTokens) * 0.9),
|
||||
|
||||
@@ -19,7 +19,10 @@ const AgentClient = require('~/server/controllers/agents/client');
|
||||
const { getAgent } = require('~/models/Agent');
|
||||
const { logViolation } = require('~/cache');
|
||||
|
||||
function createToolLoader() {
|
||||
/**
|
||||
* @param {AbortSignal} signal
|
||||
*/
|
||||
function createToolLoader(signal) {
|
||||
/**
|
||||
* @param {object} params
|
||||
* @param {ServerRequest} params.req
|
||||
@@ -29,7 +32,11 @@ function createToolLoader() {
|
||||
* @param {string} params.provider
|
||||
* @param {string} params.model
|
||||
* @param {AgentToolResources} params.tool_resources
|
||||
* @returns {Promise<{ tools: StructuredTool[], toolContextMap: Record<string, unknown> } | undefined>}
|
||||
* @returns {Promise<{
|
||||
* tools: StructuredTool[],
|
||||
* toolContextMap: Record<string, unknown>,
|
||||
* userMCPAuthMap?: Record<string, Record<string, string>>
|
||||
* } | undefined>}
|
||||
*/
|
||||
return async function loadTools({ req, res, agentId, tools, provider, model, tool_resources }) {
|
||||
const agent = { id: agentId, tools, provider, model };
|
||||
@@ -38,6 +45,7 @@ function createToolLoader() {
|
||||
req,
|
||||
res,
|
||||
agent,
|
||||
signal,
|
||||
tool_resources,
|
||||
});
|
||||
} catch (error) {
|
||||
@@ -46,7 +54,7 @@ function createToolLoader() {
|
||||
};
|
||||
}
|
||||
|
||||
const initializeClient = async ({ req, res, endpointOption }) => {
|
||||
const initializeClient = async ({ req, res, signal, endpointOption }) => {
|
||||
if (!endpointOption) {
|
||||
throw new Error('Endpoint option not provided');
|
||||
}
|
||||
@@ -92,7 +100,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
||||
/** @type {Set<string>} */
|
||||
const allowedProviders = new Set(req?.app?.locals?.[EModelEndpoint.agents]?.allowedProviders);
|
||||
|
||||
const loadTools = createToolLoader();
|
||||
const loadTools = createToolLoader(signal);
|
||||
/** @type {Array<MongoFile>} */
|
||||
const requestFiles = req.body.files ?? [];
|
||||
/** @type {string} */
|
||||
@@ -111,6 +119,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
||||
});
|
||||
|
||||
const agent_ids = primaryConfig.agent_ids;
|
||||
let userMCPAuthMap = primaryConfig.userMCPAuthMap;
|
||||
if (agent_ids?.length) {
|
||||
for (const agentId of agent_ids) {
|
||||
const agent = await getAgent({ id: agentId });
|
||||
@@ -140,6 +149,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
||||
endpointOption,
|
||||
allowedProviders,
|
||||
});
|
||||
Object.assign(userMCPAuthMap, config.userMCPAuthMap ?? {});
|
||||
agentConfigs.set(agentId, config);
|
||||
}
|
||||
}
|
||||
@@ -188,7 +198,7 @@ const initializeClient = async ({ req, res, endpointOption }) => {
|
||||
: EModelEndpoint.agents,
|
||||
});
|
||||
|
||||
return { client };
|
||||
return { client, userMCPAuthMap };
|
||||
};
|
||||
|
||||
module.exports = { initializeClient };
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
const { isEnabled } = require('@librechat/api');
|
||||
const { logger } = require('@librechat/data-schemas');
|
||||
const { CacheKeys } = require('librechat-data-provider');
|
||||
const getLogStores = require('~/cache/getLogStores');
|
||||
const { isEnabled } = require('~/server/utils');
|
||||
const { saveConvo } = require('~/models');
|
||||
const { logger } = require('~/config');
|
||||
|
||||
/**
|
||||
* Add title to conversation in a way that avoids memory retention
|
||||
|
||||
@@ -45,6 +45,10 @@ function getClaudeHeaders(model, supportsCacheControl) {
|
||||
'anthropic-beta':
|
||||
'token-efficient-tools-2025-02-19,output-128k-2025-02-19,prompt-caching-2024-07-31',
|
||||
};
|
||||
} else if (/claude-sonnet-4/.test(model)) {
|
||||
return {
|
||||
'anthropic-beta': 'prompt-caching-2024-07-31,context-1m-2025-08-07',
|
||||
};
|
||||
} else if (
|
||||
/claude-(?:sonnet|opus|haiku)-[4-9]/.test(model) ||
|
||||
/claude-[4-9]-(?:sonnet|opus|haiku)?/.test(model) ||
|
||||
|
||||
@@ -39,8 +39,9 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio
|
||||
if (optionsOnly) {
|
||||
clientOptions = Object.assign(
|
||||
{
|
||||
reverseProxyUrl: ANTHROPIC_REVERSE_PROXY ?? null,
|
||||
proxy: PROXY ?? null,
|
||||
userId: req.user.id,
|
||||
reverseProxyUrl: ANTHROPIC_REVERSE_PROXY ?? null,
|
||||
modelOptions: endpointOption?.model_parameters ?? {},
|
||||
},
|
||||
clientOptions,
|
||||
|
||||
@@ -15,6 +15,7 @@ const { checkPromptCacheSupport, getClaudeHeaders, configureReasoning } = requir
|
||||
* @param {number} [options.modelOptions.topK] - Controls the number of top tokens to consider.
|
||||
* @param {string[]} [options.modelOptions.stop] - Sequences where the API will stop generating further tokens.
|
||||
* @param {boolean} [options.modelOptions.stream] - Whether to stream the response.
|
||||
* @param {string} options.userId - The user ID for tracking and personalization.
|
||||
* @param {string} [options.proxy] - Proxy server URL.
|
||||
* @param {string} [options.reverseProxyUrl] - URL for a reverse proxy, if used.
|
||||
*
|
||||
@@ -47,6 +48,11 @@ function getLLMConfig(apiKey, options = {}) {
|
||||
maxTokens:
|
||||
mergedOptions.maxOutputTokens || anthropicSettings.maxOutputTokens.reset(mergedOptions.model),
|
||||
clientOptions: {},
|
||||
invocationKwargs: {
|
||||
metadata: {
|
||||
user_id: options.userId,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
requestOptions = configureReasoning(requestOptions, systemOptions);
|
||||
|
||||
@@ -1,50 +1,19 @@
|
||||
const { anthropicSettings, removeNullishValues } = require('librechat-data-provider');
|
||||
const { getLLMConfig } = require('~/server/services/Endpoints/anthropic/llm');
|
||||
const { checkPromptCacheSupport, getClaudeHeaders, configureReasoning } = require('./helpers');
|
||||
|
||||
jest.mock('https-proxy-agent', () => ({
|
||||
HttpsProxyAgent: jest.fn().mockImplementation((proxy) => ({ proxy })),
|
||||
}));
|
||||
|
||||
jest.mock('./helpers', () => ({
|
||||
checkPromptCacheSupport: jest.fn(),
|
||||
getClaudeHeaders: jest.fn(),
|
||||
configureReasoning: jest.fn((requestOptions) => requestOptions),
|
||||
}));
|
||||
|
||||
jest.mock('librechat-data-provider', () => ({
|
||||
anthropicSettings: {
|
||||
model: { default: 'claude-3-opus-20240229' },
|
||||
maxOutputTokens: { default: 4096, reset: jest.fn(() => 4096) },
|
||||
thinking: { default: false },
|
||||
promptCache: { default: false },
|
||||
thinkingBudget: { default: null },
|
||||
},
|
||||
removeNullishValues: jest.fn((obj) => {
|
||||
const result = {};
|
||||
for (const key in obj) {
|
||||
if (obj[key] !== null && obj[key] !== undefined) {
|
||||
result[key] = obj[key];
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}),
|
||||
}));
|
||||
|
||||
describe('getLLMConfig', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
checkPromptCacheSupport.mockReturnValue(false);
|
||||
getClaudeHeaders.mockReturnValue(undefined);
|
||||
configureReasoning.mockImplementation((requestOptions) => requestOptions);
|
||||
anthropicSettings.maxOutputTokens.reset.mockReturnValue(4096);
|
||||
});
|
||||
|
||||
it('should create a basic configuration with default values', () => {
|
||||
const result = getLLMConfig('test-api-key', { modelOptions: {} });
|
||||
|
||||
expect(result.llmConfig).toHaveProperty('apiKey', 'test-api-key');
|
||||
expect(result.llmConfig).toHaveProperty('model', anthropicSettings.model.default);
|
||||
expect(result.llmConfig).toHaveProperty('model', 'claude-3-5-sonnet-latest');
|
||||
expect(result.llmConfig).toHaveProperty('stream', true);
|
||||
expect(result.llmConfig).toHaveProperty('maxTokens');
|
||||
});
|
||||
@@ -99,40 +68,73 @@ describe('getLLMConfig', () => {
|
||||
expect(result.llmConfig).toHaveProperty('topP', 0.9);
|
||||
});
|
||||
|
||||
it('should NOT include topK and topP for Claude-3-7 models (hyphen notation)', () => {
|
||||
configureReasoning.mockImplementation((requestOptions) => {
|
||||
requestOptions.thinking = { type: 'enabled' };
|
||||
return requestOptions;
|
||||
});
|
||||
|
||||
it('should NOT include topK and topP for Claude-3-7 models with thinking enabled (hyphen notation)', () => {
|
||||
const result = getLLMConfig('test-api-key', {
|
||||
modelOptions: {
|
||||
model: 'claude-3-7-sonnet',
|
||||
topK: 10,
|
||||
topP: 0.9,
|
||||
thinking: true,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig).not.toHaveProperty('topK');
|
||||
expect(result.llmConfig).not.toHaveProperty('topP');
|
||||
expect(result.llmConfig).toHaveProperty('thinking');
|
||||
expect(result.llmConfig.thinking).toHaveProperty('type', 'enabled');
|
||||
// When thinking is enabled, it uses the default thinkingBudget of 2000
|
||||
expect(result.llmConfig.thinking).toHaveProperty('budget_tokens', 2000);
|
||||
});
|
||||
|
||||
it('should NOT include topK and topP for Claude-3.7 models (decimal notation)', () => {
|
||||
configureReasoning.mockImplementation((requestOptions) => {
|
||||
requestOptions.thinking = { type: 'enabled' };
|
||||
return requestOptions;
|
||||
});
|
||||
it('should add "prompt-caching" and "context-1m" beta headers for claude-sonnet-4 model', () => {
|
||||
const modelOptions = {
|
||||
model: 'claude-sonnet-4-20250514',
|
||||
promptCache: true,
|
||||
};
|
||||
const result = getLLMConfig('test-key', { modelOptions });
|
||||
const clientOptions = result.llmConfig.clientOptions;
|
||||
expect(clientOptions.defaultHeaders).toBeDefined();
|
||||
expect(clientOptions.defaultHeaders).toHaveProperty('anthropic-beta');
|
||||
expect(clientOptions.defaultHeaders['anthropic-beta']).toBe(
|
||||
'prompt-caching-2024-07-31,context-1m-2025-08-07',
|
||||
);
|
||||
});
|
||||
|
||||
it('should add "prompt-caching" and "context-1m" beta headers for claude-sonnet-4 model formats', () => {
|
||||
const modelVariations = [
|
||||
'claude-sonnet-4-20250514',
|
||||
'claude-sonnet-4-latest',
|
||||
'anthropic/claude-sonnet-4-20250514',
|
||||
];
|
||||
|
||||
modelVariations.forEach((model) => {
|
||||
const modelOptions = { model, promptCache: true };
|
||||
const result = getLLMConfig('test-key', { modelOptions });
|
||||
const clientOptions = result.llmConfig.clientOptions;
|
||||
expect(clientOptions.defaultHeaders).toBeDefined();
|
||||
expect(clientOptions.defaultHeaders).toHaveProperty('anthropic-beta');
|
||||
expect(clientOptions.defaultHeaders['anthropic-beta']).toBe(
|
||||
'prompt-caching-2024-07-31,context-1m-2025-08-07',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it('should NOT include topK and topP for Claude-3.7 models with thinking enabled (decimal notation)', () => {
|
||||
const result = getLLMConfig('test-api-key', {
|
||||
modelOptions: {
|
||||
model: 'claude-3.7-sonnet',
|
||||
topK: 10,
|
||||
topP: 0.9,
|
||||
thinking: true,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result.llmConfig).not.toHaveProperty('topK');
|
||||
expect(result.llmConfig).not.toHaveProperty('topP');
|
||||
expect(result.llmConfig).toHaveProperty('thinking');
|
||||
expect(result.llmConfig.thinking).toHaveProperty('type', 'enabled');
|
||||
// When thinking is enabled, it uses the default thinkingBudget of 2000
|
||||
expect(result.llmConfig.thinking).toHaveProperty('budget_tokens', 2000);
|
||||
});
|
||||
|
||||
it('should handle custom maxOutputTokens', () => {
|
||||
@@ -233,7 +235,6 @@ describe('getLLMConfig', () => {
|
||||
});
|
||||
|
||||
it('should handle maxOutputTokens when explicitly set to falsy value', () => {
|
||||
anthropicSettings.maxOutputTokens.reset.mockReturnValue(8192);
|
||||
const result = getLLMConfig('test-api-key', {
|
||||
modelOptions: {
|
||||
model: 'claude-3-opus',
|
||||
@@ -241,8 +242,8 @@ describe('getLLMConfig', () => {
|
||||
},
|
||||
});
|
||||
|
||||
expect(anthropicSettings.maxOutputTokens.reset).toHaveBeenCalledWith('claude-3-opus');
|
||||
expect(result.llmConfig).toHaveProperty('maxTokens', 8192);
|
||||
// The actual anthropicSettings.maxOutputTokens.reset('claude-3-opus') returns 4096
|
||||
expect(result.llmConfig).toHaveProperty('maxTokens', 4096);
|
||||
});
|
||||
|
||||
it('should handle both proxy and reverseProxyUrl', () => {
|
||||
@@ -263,9 +264,6 @@ describe('getLLMConfig', () => {
|
||||
});
|
||||
|
||||
it('should handle prompt cache with supported model', () => {
|
||||
checkPromptCacheSupport.mockReturnValue(true);
|
||||
getClaudeHeaders.mockReturnValue({ 'anthropic-beta': 'prompt-caching-2024-07-31' });
|
||||
|
||||
const result = getLLMConfig('test-api-key', {
|
||||
modelOptions: {
|
||||
model: 'claude-3-5-sonnet',
|
||||
@@ -273,43 +271,38 @@ describe('getLLMConfig', () => {
|
||||
},
|
||||
});
|
||||
|
||||
expect(checkPromptCacheSupport).toHaveBeenCalledWith('claude-3-5-sonnet');
|
||||
expect(getClaudeHeaders).toHaveBeenCalledWith('claude-3-5-sonnet', true);
|
||||
// claude-3-5-sonnet supports prompt caching and should get the appropriate headers
|
||||
expect(result.llmConfig.clientOptions.defaultHeaders).toEqual({
|
||||
'anthropic-beta': 'prompt-caching-2024-07-31',
|
||||
'anthropic-beta': 'max-tokens-3-5-sonnet-2024-07-15,prompt-caching-2024-07-31',
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle thinking and thinkingBudget options', () => {
|
||||
configureReasoning.mockImplementation((requestOptions, systemOptions) => {
|
||||
if (systemOptions.thinking) {
|
||||
requestOptions.thinking = { type: 'enabled' };
|
||||
}
|
||||
if (systemOptions.thinkingBudget) {
|
||||
requestOptions.thinking = {
|
||||
...requestOptions.thinking,
|
||||
budget_tokens: systemOptions.thinkingBudget,
|
||||
};
|
||||
}
|
||||
return requestOptions;
|
||||
});
|
||||
|
||||
getLLMConfig('test-api-key', {
|
||||
const result = getLLMConfig('test-api-key', {
|
||||
modelOptions: {
|
||||
model: 'claude-3-7-sonnet',
|
||||
thinking: true,
|
||||
thinkingBudget: 5000,
|
||||
thinkingBudget: 10000, // This exceeds the default max_tokens of 8192
|
||||
},
|
||||
});
|
||||
|
||||
expect(configureReasoning).toHaveBeenCalledWith(
|
||||
expect.any(Object),
|
||||
expect.objectContaining({
|
||||
// The function should add thinking configuration for claude-3-7 models
|
||||
expect(result.llmConfig).toHaveProperty('thinking');
|
||||
expect(result.llmConfig.thinking).toHaveProperty('type', 'enabled');
|
||||
// With claude-3-7-sonnet, the max_tokens default is 8192
|
||||
// Budget tokens gets adjusted to 90% of max_tokens (8192 * 0.9 = 7372) when it exceeds max_tokens
|
||||
expect(result.llmConfig.thinking).toHaveProperty('budget_tokens', 7372);
|
||||
|
||||
// Test with budget_tokens within max_tokens limit
|
||||
const result2 = getLLMConfig('test-api-key', {
|
||||
modelOptions: {
|
||||
model: 'claude-3-7-sonnet',
|
||||
thinking: true,
|
||||
promptCache: false,
|
||||
thinkingBudget: 5000,
|
||||
}),
|
||||
);
|
||||
thinkingBudget: 2000,
|
||||
},
|
||||
});
|
||||
|
||||
expect(result2.llmConfig.thinking).toHaveProperty('budget_tokens', 2000);
|
||||
});
|
||||
|
||||
it('should remove system options from modelOptions', () => {
|
||||
@@ -330,16 +323,6 @@ describe('getLLMConfig', () => {
|
||||
});
|
||||
|
||||
it('should handle all nullish values removal', () => {
|
||||
removeNullishValues.mockImplementation((obj) => {
|
||||
const cleaned = {};
|
||||
Object.entries(obj).forEach(([key, value]) => {
|
||||
if (value !== null && value !== undefined) {
|
||||
cleaned[key] = value;
|
||||
}
|
||||
});
|
||||
return cleaned;
|
||||
});
|
||||
|
||||
const result = getLLMConfig('test-api-key', {
|
||||
modelOptions: {
|
||||
temperature: null,
|
||||
|
||||
@@ -109,14 +109,14 @@ const initializeClient = async ({ req, res, version, endpointOption, initAppClie
|
||||
|
||||
apiKey = azureOptions.azureOpenAIApiKey;
|
||||
opts.defaultQuery = { 'api-version': azureOptions.azureOpenAIApiVersion };
|
||||
opts.defaultHeaders = resolveHeaders(
|
||||
{
|
||||
opts.defaultHeaders = resolveHeaders({
|
||||
headers: {
|
||||
...headers,
|
||||
'api-key': apiKey,
|
||||
'OpenAI-Beta': `assistants=${version}`,
|
||||
},
|
||||
req.user,
|
||||
);
|
||||
user: req.user,
|
||||
});
|
||||
opts.model = azureOptions.azureOpenAIApiDeploymentName;
|
||||
|
||||
if (initAppClient) {
|
||||
|
||||
@@ -28,7 +28,11 @@ const initializeClient = async ({ req, res, endpointOption, optionsOnly, overrid
|
||||
const CUSTOM_API_KEY = extractEnvVariable(endpointConfig.apiKey);
|
||||
const CUSTOM_BASE_URL = extractEnvVariable(endpointConfig.baseURL);
|
||||
|
||||
let resolvedHeaders = resolveHeaders(endpointConfig.headers, req.user);
|
||||
let resolvedHeaders = resolveHeaders({
|
||||
headers: endpointConfig.headers,
|
||||
user: req.user,
|
||||
body: req.body,
|
||||
});
|
||||
|
||||
if (CUSTOM_API_KEY.match(envVarRegex)) {
|
||||
throw new Error(`Missing API Key for ${endpoint}.`);
|
||||
|
||||
@@ -64,13 +64,14 @@ describe('custom/initializeClient', () => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
it('calls resolveHeaders with headers and user', async () => {
|
||||
it('calls resolveHeaders with headers, user, and body for body placeholder support', async () => {
|
||||
const { resolveHeaders } = require('@librechat/api');
|
||||
await initializeClient({ req: mockRequest, res: mockResponse, optionsOnly: true });
|
||||
expect(resolveHeaders).toHaveBeenCalledWith(
|
||||
{ 'x-user': '{{LIBRECHAT_USER_ID}}', 'x-email': '{{LIBRECHAT_USER_EMAIL}}' },
|
||||
{ id: 'user-123', email: 'test@example.com' },
|
||||
);
|
||||
expect(resolveHeaders).toHaveBeenCalledWith({
|
||||
headers: { 'x-user': '{{LIBRECHAT_USER_ID}}', 'x-email': '{{LIBRECHAT_USER_EMAIL}}' },
|
||||
user: { id: 'user-123', email: 'test@example.com' },
|
||||
body: { endpoint: 'test-endpoint' }, // body - supports {{LIBRECHAT_BODY_*}} placeholders
|
||||
});
|
||||
});
|
||||
|
||||
it('throws if endpoint config is missing', async () => {
|
||||
|
||||
@@ -81,10 +81,10 @@ const initializeClient = async ({
|
||||
serverless = _serverless;
|
||||
|
||||
clientOptions.reverseProxyUrl = baseURL ?? clientOptions.reverseProxyUrl;
|
||||
clientOptions.headers = resolveHeaders(
|
||||
{ ...headers, ...(clientOptions.headers ?? {}) },
|
||||
req.user,
|
||||
);
|
||||
clientOptions.headers = resolveHeaders({
|
||||
headers: { ...headers, ...(clientOptions.headers ?? {}) },
|
||||
user: req.user,
|
||||
});
|
||||
|
||||
clientOptions.titleConvo = azureConfig.titleConvo;
|
||||
clientOptions.titleModel = azureConfig.titleModel;
|
||||
|
||||
149
api/server/services/Files/Citations/index.js
Normal file
149
api/server/services/Files/Citations/index.js
Normal file
@@ -0,0 +1,149 @@
|
||||
const { nanoid } = require('nanoid');
|
||||
const { checkAccess } = require('@librechat/api');
|
||||
const { Tools, PermissionTypes, Permissions } = require('librechat-data-provider');
|
||||
const { getCustomConfig } = require('~/server/services/Config/getCustomConfig');
|
||||
const { getRoleByName } = require('~/models/Role');
|
||||
const { logger } = require('~/config');
|
||||
const { Files } = require('~/models');
|
||||
|
||||
/**
|
||||
* Process file search results from tool calls
|
||||
* @param {Object} options
|
||||
* @param {IUser} options.user - The user object
|
||||
* @param {GraphRunnableConfig['configurable']} options.metadata - The metadata
|
||||
* @param {any} options.toolArtifact - The tool artifact containing structured data
|
||||
* @param {string} options.toolCallId - The tool call ID
|
||||
* @returns {Promise<Object|null>} The file search attachment or null
|
||||
*/
|
||||
async function processFileCitations({ user, toolArtifact, toolCallId, metadata }) {
|
||||
try {
|
||||
if (!toolArtifact?.[Tools.file_search]?.sources) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (user) {
|
||||
try {
|
||||
const hasFileCitationsAccess = await checkAccess({
|
||||
user,
|
||||
permissionType: PermissionTypes.FILE_CITATIONS,
|
||||
permissions: [Permissions.USE],
|
||||
getRoleByName,
|
||||
});
|
||||
|
||||
if (!hasFileCitationsAccess) {
|
||||
logger.debug(
|
||||
`[processFileCitations] User ${user.id} does not have FILE_CITATIONS permission`,
|
||||
);
|
||||
return null;
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`[processFileCitations] Permission check failed for FILE_CITATIONS: ${error.message}`,
|
||||
);
|
||||
logger.debug(`[processFileCitations] Proceeding with citations due to permission error`);
|
||||
}
|
||||
}
|
||||
|
||||
const customConfig = await getCustomConfig();
|
||||
const maxCitations = customConfig?.endpoints?.agents?.maxCitations ?? 30;
|
||||
const maxCitationsPerFile = customConfig?.endpoints?.agents?.maxCitationsPerFile ?? 5;
|
||||
const minRelevanceScore = customConfig?.endpoints?.agents?.minRelevanceScore ?? 0.45;
|
||||
|
||||
const sources = toolArtifact[Tools.file_search].sources || [];
|
||||
const filteredSources = sources.filter((source) => source.relevance >= minRelevanceScore);
|
||||
if (filteredSources.length === 0) {
|
||||
logger.debug(
|
||||
`[processFileCitations] No sources above relevance threshold of ${minRelevanceScore}`,
|
||||
);
|
||||
return null;
|
||||
}
|
||||
|
||||
const selectedSources = applyCitationLimits(filteredSources, maxCitations, maxCitationsPerFile);
|
||||
const enhancedSources = await enhanceSourcesWithMetadata(selectedSources, customConfig);
|
||||
|
||||
if (enhancedSources.length > 0) {
|
||||
const fileSearchAttachment = {
|
||||
type: Tools.file_search,
|
||||
[Tools.file_search]: { sources: enhancedSources },
|
||||
toolCallId: toolCallId,
|
||||
messageId: metadata.run_id,
|
||||
conversationId: metadata.thread_id,
|
||||
name: `${Tools.file_search}_file_search_results_${nanoid()}`,
|
||||
};
|
||||
|
||||
return fileSearchAttachment;
|
||||
}
|
||||
|
||||
return null;
|
||||
} catch (error) {
|
||||
logger.error('[processFileCitations] Error processing file citations:', error);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Apply citation limits to sources
|
||||
* @param {Array} sources - All sources
|
||||
* @param {number} maxCitations - Maximum total citations
|
||||
* @param {number} maxCitationsPerFile - Maximum citations per file
|
||||
* @returns {Array} Selected sources
|
||||
*/
|
||||
function applyCitationLimits(sources, maxCitations, maxCitationsPerFile) {
|
||||
const byFile = {};
|
||||
sources.forEach((source) => {
|
||||
if (!byFile[source.fileId]) {
|
||||
byFile[source.fileId] = [];
|
||||
}
|
||||
byFile[source.fileId].push(source);
|
||||
});
|
||||
|
||||
const representatives = [];
|
||||
for (const fileId in byFile) {
|
||||
const fileSources = byFile[fileId].sort((a, b) => b.relevance - a.relevance);
|
||||
const selectedFromFile = fileSources.slice(0, maxCitationsPerFile);
|
||||
representatives.push(...selectedFromFile);
|
||||
}
|
||||
|
||||
return representatives.sort((a, b) => b.relevance - a.relevance).slice(0, maxCitations);
|
||||
}
|
||||
|
||||
/**
|
||||
* Enhance sources with file metadata from database
|
||||
* @param {Array} sources - Selected sources
|
||||
* @param {Object} customConfig - Custom configuration
|
||||
* @returns {Promise<Array>} Enhanced sources
|
||||
*/
|
||||
async function enhanceSourcesWithMetadata(sources, customConfig) {
|
||||
const fileIds = [...new Set(sources.map((source) => source.fileId))];
|
||||
|
||||
let fileMetadataMap = {};
|
||||
try {
|
||||
const files = await Files.find({ file_id: { $in: fileIds } });
|
||||
fileMetadataMap = files.reduce((map, file) => {
|
||||
map[file.file_id] = file;
|
||||
return map;
|
||||
}, {});
|
||||
} catch (error) {
|
||||
logger.error('[enhanceSourcesWithMetadata] Error looking up file metadata:', error);
|
||||
}
|
||||
|
||||
return sources.map((source) => {
|
||||
const fileRecord = fileMetadataMap[source.fileId] || {};
|
||||
const configuredStorageType = fileRecord.source || customConfig?.fileStrategy || 'local';
|
||||
|
||||
return {
|
||||
...source,
|
||||
fileName: fileRecord.filename || source.fileName || 'Unknown File',
|
||||
metadata: {
|
||||
...source.metadata,
|
||||
storageType: configuredStorageType,
|
||||
},
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
applyCitationLimits,
|
||||
processFileCitations,
|
||||
enhanceSourcesWithMetadata,
|
||||
};
|
||||
@@ -11,6 +11,7 @@ const {
|
||||
imageExtRegex,
|
||||
EToolResources,
|
||||
} = require('librechat-data-provider');
|
||||
const { filterFilesByAgentAccess } = require('~/server/services/Files/permissions');
|
||||
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
|
||||
const { convertImage } = require('~/server/services/Files/images/convert');
|
||||
const { createFile, getFiles, updateFile } = require('~/models/File');
|
||||
@@ -164,14 +165,24 @@ const primeFiles = async (options, apiKey) => {
|
||||
const file_ids = tool_resources?.[EToolResources.execute_code]?.file_ids ?? [];
|
||||
const agentResourceIds = new Set(file_ids);
|
||||
const resourceFiles = tool_resources?.[EToolResources.execute_code]?.files ?? [];
|
||||
const dbFiles = (
|
||||
(await getFiles(
|
||||
{ file_id: { $in: file_ids } },
|
||||
null,
|
||||
{ text: 0 },
|
||||
{ userId: req?.user?.id, agentId },
|
||||
)) ?? []
|
||||
).concat(resourceFiles);
|
||||
|
||||
// Get all files first
|
||||
const allFiles = (await getFiles({ file_id: { $in: file_ids } }, null, { text: 0 })) ?? [];
|
||||
|
||||
// Filter by access if user and agent are provided
|
||||
let dbFiles;
|
||||
if (req?.user?.id && agentId) {
|
||||
dbFiles = await filterFilesByAgentAccess({
|
||||
files: allFiles,
|
||||
userId: req.user.id,
|
||||
role: req.user.role,
|
||||
agentId,
|
||||
});
|
||||
} else {
|
||||
dbFiles = allFiles;
|
||||
}
|
||||
|
||||
dbFiles = dbFiles.concat(resourceFiles);
|
||||
|
||||
const files = [];
|
||||
const sessions = new Map();
|
||||
@@ -225,7 +236,17 @@ const primeFiles = async (options, apiKey) => {
|
||||
entity_id: queryParams.entity_id,
|
||||
apiKey,
|
||||
});
|
||||
await updateFile({ file_id: file.file_id, metadata: { fileIdentifier } });
|
||||
|
||||
// Preserve existing metadata when adding fileIdentifier
|
||||
const updatedMetadata = {
|
||||
...file.metadata, // Preserve existing metadata (like S3 storage info)
|
||||
fileIdentifier, // Add fileIdentifier
|
||||
};
|
||||
|
||||
await updateFile({
|
||||
file_id: file.file_id,
|
||||
metadata: updatedMetadata,
|
||||
});
|
||||
sessions.set(session_id, true);
|
||||
pushFile();
|
||||
} catch (error) {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user