Compare commits

...

42 Commits

Author SHA1 Message Date
Danny Avila
d12122f3db ci: update button accessibility labels in AgentDetail tests 2025-06-27 00:25:07 -04:00
Danny Avila
fddf1196c9 ci: update AclEntry model tests 2025-06-27 00:24:49 -04:00
Danny Avila
90dd0198db chore: remove unused PermissionTypes import from Role model 2025-06-26 23:41:15 -04:00
Danny Avila
7464214f0f ci: update access middleware tests 2025-06-26 23:40:51 -04:00
Danny Avila
451e426e4c chore: update localization for agent detail component and clean up imports 2025-06-26 22:44:15 -04:00
Danny Avila
1b9f155fb8 refactor: remove redundant properties from IAgent interface in agent schema 2025-06-26 22:41:23 -04:00
Danny Avila
e6b9d6b84b refactor: enhance isDuplicateVersion function in Agent model for improved comparison logic
- Introduced handling for undefined/null values in array and object comparisons.
- Normalized array comparisons to treat undefined/null as empty arrays.
- Added deep comparison for objects and improved handling of primitive values.
- Enhanced projectIds comparison to ensure consistent MongoDB ObjectId handling.
2025-06-26 22:41:13 -04:00
Atef Bellaaj
4f7fbdf1c5 eslint fix 2025-06-26 19:32:47 -04:00
Atef Bellaaj
c6e1c65fe7 ESLint fixes 2025-06-26 19:32:30 -04:00
Atef Bellaaj
33c4ef03c3 Fix Unit tests WIP 2025-06-26 19:32:06 -04:00
Atef Bellaaj
37c423eb00 Fix linting WIP 2025-06-26 19:31:49 -04:00
Atef Bellaaj
4c40469951 fix: update agent permission check logic in AgentPanel to simplify condition 2025-06-26 19:28:10 -04:00
Atef Bellaaj
770c810650 feat: add validation and error messages for agent name in AgentConfig and AgentPanel 2025-06-26 19:28:10 -04:00
Atef Bellaaj
39e39ca7f5 fix: remove minlength validation for support contact name in agent schema 2025-06-26 19:28:09 -04:00
Atef Bellaaj
94c1f5f518 selected view only agents injected in the drop down 2025-06-26 19:28:09 -04:00
Atef Bellaaj
6a28d01b20 refactor: consolidate agent marketplace endpoints into main agents API and improve data management consistency
- Remove dedicated marketplace controller and routes, merging functionality into main agents v1 API
  - Add countPromotedAgents function to Agent model for promoted agents count
  - Enhance getListAgents handler with marketplace filtering (category, search, promoted status)
  - Move getAgentCategories from marketplace to v1 controller with same functionality
  - Update agent mutations to invalidate marketplace queries and handle multiple permission levels
  - Improve cache management by updating all agent query variants (VIEW/EDIT permissions)
  - Consolidate agent data access patterns for better maintainability and consistency
  - Remove duplicate marketplace route definitions and middleware
2025-06-26 19:28:09 -04:00
Danny Avila
3f6d7ab7c7 - Add useMarketplaceAgentsInfiniteQuery and useGetAgentCategoriesQuery to client/src/data-provider/Agents/
- Replace manual pagination in AgentGrid with infinite query pattern
  - Update imports to use local data provider instead of librechat-data-provider
  - Add proper permission handling with PERMISSION_BITS.VIEW/EDIT constants
  - Improve agent access control by adding requiredPermission validation in backend
  - Remove manual cursor/state management in favor of infinite query built-ins
  - Maintain existing search and category filtering functionality
2025-06-26 19:28:08 -04:00
Atef Bellaaj
6ebcfdf3e2 feat: add icon property to ProcessedAgentCategory interface 2025-06-26 19:28:08 -04:00
Danny Avila
2eef94d58d refactor: unify agent marketplace to single endpoint with cursor pagination
- Replace multiple marketplace routes with unified /marketplace endpoint
  - Add query string controls: category, search, limit, cursor, promoted, requiredPermission
  - Implement cursor-based pagination replacing page-based system
  - Integrate ACL permissions for proper access control
  - Fix ObjectId constructor error in Agent model
  - Update React components to use unified useGetMarketplaceAgentsQuery hook
  - Enhance type safety and remove deprecated useDynamicAgentQuery
  - Update tests for new marketplace architecture
  -Known issues:
  see more button after category switching + Unit tests
2025-06-26 19:28:07 -04:00
Atef Bellaaj
be7476d530 - Move AgentCategory from api/models to @packages/data-schemas structure
- Add schema, types, methods, and model following codebase conventions
  - Implement auto-seeding of default categories during AppService startup
  - Update marketplace controller to use new data-schemas methods
  - Remove old model file and standalone seed script
2025-06-26 19:28:07 -04:00
Atef Bellaaj
bb149bccc6 refactored and moved agent category methods and schema to data-schema package 2025-06-26 19:28:07 -04:00
“Praneeth
4d753d44e2 bugfix: Enhance Agent and AgentCategory schemas with new fields for category, support contact, and promotion status 2025-06-26 19:28:06 -04:00
Danny Avila
c6d4629fd1 🔐 feat: Granular Role-based Permissions + Entra ID Group Discovery (#7804) 2025-06-26 19:28:04 -04:00
Danny Avila
d471209ced WIP: pre-granular-permissions commit
feat: Add category and support contact fields to Agent schema and UI components

Revert "feat: Add category and support contact fields to Agent schema and UI components"

This reverts commit c43a52b4c9.

Fix: Update import for renderHook in useAgentCategories.spec.tsx

fix: Update icon rendering in AgentCategoryDisplay tests to use empty spans

refactor: Improve category synchronization logic and clean up AgentConfig component

refactor: Remove unused UI flow translations from translation.json

feat: agent marketplace features
2025-06-26 19:22:39 -04:00
Danny Avila
dd67e463e4 📦 chore: bump pbkdf2 to v3.1.3 (#8091) 2025-06-26 19:19:04 -04:00
github-actions[bot]
d60ad61325 🌍 i18n: Update translation.json with latest translations (#8058)
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2025-06-26 19:12:46 -04:00
Danny Avila
452151e408 🐛 fix: RAG API failing with OPENID_REUSE_TOKENS Enabled (#8090)
* feat: Implement Short-Lived JWT Token Generation for RAG API

* fix: Update import paths

* fix: Correct environment variable names for OpenID on behalf flow

* fix: Remove unnecessary spaces in OpenID on behalf flow userinfo scope

---------

Co-authored-by: Atef Bellaaj <slalom.bellaaj@external.daimlertruck.com>
2025-06-26 19:10:21 -04:00
Danny Avila
33b4a97b42 🔒 fix: Agents Config/Permission Checks after Streamline Change (#8089)
* refactor: access control logic to TypeScript

* chore: Change EndpointURLs to a constant object for improved type safety

* 🐛 fix: Enhance agent access control by adding skipAgentCheck functionality

* 🐛 fix: Add endpointFileConfig prop to AttachFileMenu and update file handling logic

* 🐛 fix: Update tool handling logic to support optional groupedTools and improve null checks, add dedicated tool dialog for Assistants

* chore: Export Accordion component from UI index for improved modularity

* feat: Add ActivePanelContext for managing active panel state across components

* chore: Replace string IDs with EModelEndpoint constants for assistants and agents in useSideNavLinks

* fix: Integrate access checks for agent creation and deletion routes in actions.js
2025-06-26 18:53:05 -04:00
Sebastien Bruel
9cdc62b655 📂 fix: Prevent Null Reference Errors in File Process (#8084) 2025-06-26 18:51:35 -04:00
Danny Avila
799f0e5810 🐛 fix: Move MemoryEntry and PluginAuth model retrieval inside methods for Runtime Usage 2025-06-25 20:58:34 -04:00
Danny Avila
cbda3cb529 🕐 feat: Configurable Retention Period for Temporary Chats (#8056)
* feat: Add configurable retention period for temporary chats

* Addressing eslint errors

* Fix: failing test due to missing registration

* Update: variable name and use hours instead of days for chat retention

* Addressing comments

* chore: fix import order in Conversation.js

* chore: import order in Message.js

* chore: fix import order in config.ts

* chore: move common methods to packages/api to reduce potential for circular dependencies

* refactor: update temp chat retention config type to Partial<TCustomConfig>

* refactor: remove unused config variable from AppService and update loadCustomConfig tests with logger mock

* refactor: handle model undefined edge case by moving Session model initialization inside methods

---------

Co-authored-by: Rakshit Tiwari <rak1729e@gmail.com>
2025-06-25 17:16:26 -04:00
Karol Potocki
3ab1bd65e5 🐛 fix: Support Bedrock Provider for MCP Image Content Rendering (#8047) 2025-06-25 15:38:24 -04:00
Marlon
c551ba21f5 📜 chore: Update .env.example (#8043)
Update recent Gemini model names and remove deprecated Gemini models from env.example
2025-06-25 15:31:24 -04:00
Danny Avila
c87422a1e0 🧠 feat: Thinking Budget, Include Thoughts, and Dynamic Thinking for Gemini 2.5 (#8055)
* feat: support thinking budget parameter for Gemini 2.5 series (#6949, #7542)

https://ai.google.dev/gemini-api/docs/thinking#set-budget

* refactor: update thinking budget minimum value to -1 for dynamic thinking

- see: https://ai.google.dev/gemini-api/docs/thinking#set-budget

* chore: bump @librechat/agents to v2.4.43

* refactor: rename LLMConfigOptions to OpenAIConfigOptions for clarity and consistency

- Updated type definitions and references in initialize.ts, llm.ts, and openai.ts to reflect the new naming convention.
- Ensured that the OpenAI configuration options are consistently used across the relevant files.

* refactor: port Google LLM methods to TypeScript Package

* chore: update @librechat/agents version to 2.4.43 in package-lock.json and package.json

* refactor: update thinking budget description for clarity and adjust placeholder in parameter settings

* refactor: enhance googleSettings default value for thinking budget to support dynamic adjustment

* chore: update @librechat/agents to v2.4.44 for Vertex Dynamic Thinking workaround

* refactor: rename google config function, update `createRun` types, use `reasoning` as `reasoningKey` for Google

* refactor: simplify placeholder handling in DynamicInput component

* refactor: enhance thinking budget description for clarity and allow automatic decision by setting to "-1"

* refactor: update text styling in OptionHover component for improved readability

* chore: update @librechat/agents dependency to v2.4.46 in package.json and package-lock.json

* chore: update @librechat/api version to 1.2.5 in package.json and package-lock.json

* refactor: enhance `clientOptions` handling by filtering `omitTitleOptions`, add `json` field for Google models

---------

Co-authored-by: ciffelia <15273128+ciffelia@users.noreply.github.com>
2025-06-25 15:14:33 -04:00
Dustin Healy
b169306096 🧪 ci: Add Tests for Custom Endpoint Header Resolution (#8045)
* Enhanced existing tests for the `resolveHeaders` function to cover all user field placeholders and messy scenarios.
* Added basic integration tests for custom endpoints initialization file
2025-06-24 21:11:06 -04:00
Rakshit Tiwari
42977ac0d0 🖼️ feat: Add Optional Client-Side Image Resizing to Prevent Upload Errors (#7909)
* feat: Add optional client-side image resizing to prevent upload errors

* Addressing comments from author

* Addressing eslint errors

* Fixing the naming to clientresize from clientsideresize
2025-06-24 10:43:29 -04:00
Dustin Healy
d9a0fe03ed 🔧 fix: User Placeholders in Headers for Custom Endpoints (#8030)
* hotfix(custom-endpoints): fix user placeholder resolution in headers

* fix: import

---------

Co-authored-by: Danny Avila <danny@librechat.ai>
2025-06-24 08:21:14 -04:00
Danny Avila
d39b99971f 🧠 fix: Agent Title Config & Resource Handling (#8028)
* 🔧 fix: enhance client options handling in AgentClient and set default recursion limit

- Updated the recursion limit to default to 25 if not specified in agentsEConfig.
- Enhanced client options in AgentClient to include model parameters such as apiKey and anthropicApiUrl from agentModelParams.
- Updated requestOptions in the anthropic endpoint to use reverseProxyUrl as anthropicApiUrl.

* Enhance LLM configuration tests with edge case handling

* chore add return type annotation for getCustomEndpointConfig function

* fix: update modelOptions handling to use optional chaining and default to empty object in multiple endpoint initializations

* chore: update @librechat/agents to version 2.4.42

* refactor: streamline agent endpoint configuration and enhance client options handling for title generations

- Introduced a new `getProviderConfig` function to centralize provider configuration logic.
- Updated `AgentClient` to utilize the new provider configuration, improving clarity and maintainability.
- Removed redundant code related to endpoint initialization and model parameter handling.
- Enhanced error logging for missing endpoint configurations.

* fix: add abort handling for image generation and editing in OpenAIImageTools

* ci: enhance getLLMConfig tests to verify fetchOptions and dispatcher properties

* fix: use optional chaining for endpointOption properties in getOptions

* fix: increase title generation timeout from 25s to 45s, pass `endpointOption` to `getOptions`

* fix: update file filtering logic in getToolFilesByIds to ensure text field is properly checked

* fix: add error handling for empty OCR results in uploadMistralOCR and uploadAzureMistralOCR

* fix: enhance error handling in file upload to include 'No OCR result' message

* chore: update error messages in uploadMistralOCR and uploadAzureMistralOCR

* fix: enhance filtering logic in getToolFilesByIds to include context checks for OCR resources to only include files directly attached to agent

---------

Co-authored-by: Matt Burnett <matt.burnett@shopify.com>
2025-06-23 19:44:24 -04:00
Marco Beretta
1b7e044bf5 🤩 style: DialogImage, Update Stylesheet, and Improve Accessibility (#8014)
* 🔧 fix: Adjust typography and border styles for improved readability in markdown components

* 🔧 fix: Enhance code block styling in markdown for better visibility and consistency

* 🔧 fix: Adjust margins and line heights for improved readability in markdown elements

* 🔧 fix: Adjust spacing for horizontal rules in markdown for improved consistency

* 🔧 fix: Refactor DialogImage component for improved quality styling and layout consistency

* 🔧 fix: Enhance zoom and pan functionality in DialogImage component with improved controls and user experience

* 🔧 fix: Improve zoom and pan functionality in DialogImage component with enhanced controls and reset zoom feature
2025-06-23 14:30:15 -04:00
Danny Avila
5c947be455 fix: Minor Menu Issues (#8026)
* fix: Enable portal support in ExportAndShareMenu component

* fix: MCPSubMenu with focus loop and improved button click handling

* chore: remove "tools" header in toolsdropdown
2025-06-23 14:29:21 -04:00
Dustin Healy
2b2f7fe289 feat: Configurable MCP Dropdown Placeholder (#7988)
* new env  variable for mcp label

* 🔄 refactor: Update MCPSelect placeholderText to draw from interface section of librechat.yaml rather than .env

* 🧹 chore: extract mcpServers schema for better maintainability

* 🔄 refactor: Update MCPSelect and useMCPSelect to utilize TPlugin type for better type consistency

* 🔄 refactor: Pass placeholder from startupConfig to MCPSubMenu for improved localization

* 🔄 refactor: Integrate startupConfig into BadgeRowContext and related components for enhanced configuration management

---------

Co-authored-by: mwbrandao <mariana.brandao@nos.pt>
Co-authored-by: Danny Avila <danny@librechat.ai>
2025-06-23 13:21:01 -04:00
Danny Avila
a058963a9f 👤 feat: User Placeholder Variables for Custom Endpoint Headers (#7993)
* 🔧 refactor: move `processMCPEnv` from `librechat-data-provider` and move to `@librechat/api`

* 🔧 refactor: Update resolveHeaders import paths

* 🔧 refactor: Enhance resolveHeaders to support user and custom variables

- Updated resolveHeaders function to accept user and custom user variables for placeholder replacement.
- Modified header resolution in multiple client and controller files to utilize the enhanced resolveHeaders functionality.
- Added comprehensive tests for resolveHeaders to ensure correct processing of user and custom variables.

* 🔧 fix: Update user ID placeholder processing in env.ts

* 🔧 fix: Remove arguments passing this.user rather than req.user

- Updated multiple client and controller files to call resolveHeaders without the user parameter

* 🔧 refactor: Enhance processUserPlaceholders to be more readable / less nested

* 🔧 refactor: Update processUserPlaceholders to pass all tests in mpc.spec.ts and env.spec.ts

* chore: remove legacy ChatGPTClient

* chore: remove LLM initialization code

* chore: initial deprecation removal of `gptPlugins`

* chore: remove cohere-ai dependency from package.json and package-lock.json

* chore: update brace-expansion to version 2.0.2 and add license information

* chore: remove PluginsClient test file

* chore: remove legacy

* ci: remove deprecated sendMessage/getCompletion/chatCompletion tests

---------

Co-authored-by: Dustin Healy <54083382+dustinhealy@users.noreply.github.com>
2025-06-23 12:39:27 -04:00
284 changed files with 21071 additions and 4392 deletions

View File

@@ -58,7 +58,7 @@ DEBUG_CONSOLE=false
# Endpoints #
#===================================================#
# ENDPOINTS=openAI,assistants,azureOpenAI,google,gptPlugins,anthropic
# ENDPOINTS=openAI,assistants,azureOpenAI,google,anthropic
PROXY=
@@ -142,10 +142,10 @@ GOOGLE_KEY=user_provided
# GOOGLE_AUTH_HEADER=true
# Gemini API (AI Studio)
# GOOGLE_MODELS=gemini-2.5-pro-preview-05-06,gemini-2.5-flash-preview-04-17,gemini-2.0-flash-001,gemini-2.0-flash-exp,gemini-2.0-flash-lite-001,gemini-1.5-pro-002,gemini-1.5-flash-002
# GOOGLE_MODELS=gemini-2.5-pro,gemini-2.5-flash,gemini-2.5-flash-lite-preview-06-17,gemini-2.0-flash,gemini-2.0-flash-lite
# Vertex AI
# GOOGLE_MODELS=gemini-2.5-pro-preview-05-06,gemini-2.5-flash-preview-04-17,gemini-2.0-flash-001,gemini-2.0-flash-exp,gemini-2.0-flash-lite-001,gemini-1.5-pro-002,gemini-1.5-flash-002
# GOOGLE_MODELS=gemini-2.5-pro,gemini-2.5-flash,gemini-2.5-flash-lite-preview-06-17,gemini-2.0-flash-001,gemini-2.0-flash-lite-001
# GOOGLE_TITLE_MODEL=gemini-2.0-flash-lite-001
@@ -453,8 +453,8 @@ OPENID_REUSE_TOKENS=
OPENID_JWKS_URL_CACHE_ENABLED=
OPENID_JWKS_URL_CACHE_TIME= # 600000 ms eq to 10 minutes leave empty to disable caching
#Set to true to trigger token exchange flow to acquire access token for the userinfo endpoint.
OPENID_ON_BEHALF_FLOW_FOR_USERINFRO_REQUIRED=
OPENID_ON_BEHALF_FLOW_USERINFRO_SCOPE = "user.read" # example for Scope Needed for Microsoft Graph API
OPENID_ON_BEHALF_FLOW_FOR_USERINFO_REQUIRED=
OPENID_ON_BEHALF_FLOW_USERINFO_SCOPE="user.read" # example for Scope Needed for Microsoft Graph API
# Set to true to use the OpenID Connect end session endpoint for logout
OPENID_USE_END_SESSION_ENDPOINT=
@@ -485,6 +485,21 @@ SAML_IMAGE_URL=
# SAML_USE_AUTHN_RESPONSE_SIGNED=
#===============================================#
# Microsoft Graph API / Entra ID Integration #
#===============================================#
# Enable Entra ID people search integration in permissions/sharing system
# When enabled, the people picker will search both local database and Entra ID
USE_ENTRA_ID_FOR_PEOPLE_SEARCH=false
# When enabled, entra id groups owners will be considered as members of the group
ENTRA_ID_INCLUDE_OWNERS_AS_MEMBERS=false
# Microsoft Graph API scopes needed for people/group search
# Default scopes provide access to user profiles and group memberships
OPENID_GRAPH_SCOPES=User.Read,People.Read,GroupMember.Read.All
# LDAP
LDAP_URL=
LDAP_BIND_DN=
@@ -657,4 +672,4 @@ OPENWEATHER_API_KEY=
# Reranker (Required)
# JINA_API_KEY=your_jina_api_key
# or
# COHERE_API_KEY=your_cohere_api_key
# COHERE_API_KEY=your_cohere_api_key

3
.vscode/launch.json vendored
View File

@@ -8,7 +8,8 @@
"skipFiles": ["<node_internals>/**"],
"program": "${workspaceFolder}/api/server/index.js",
"env": {
"NODE_ENV": "production"
"NODE_ENV": "production",
"NODE_TLS_REJECT_UNAUTHORIZED": "0"
},
"console": "integratedTerminal",
"envFile": "${workspaceFolder}/.env"

View File

@@ -1,804 +0,0 @@
const { Keyv } = require('keyv');
const crypto = require('crypto');
const { CohereClient } = require('cohere-ai');
const { fetchEventSource } = require('@waylaidwanderer/fetch-event-source');
const { constructAzureURL, genAzureChatCompletion } = require('@librechat/api');
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
const {
ImageDetail,
EModelEndpoint,
resolveHeaders,
CohereConstants,
mapModelToAzureConfig,
} = require('librechat-data-provider');
const { createContextHandlers } = require('./prompts');
const { createCoherePayload } = require('./llm');
const { extractBaseURL } = require('~/utils');
const BaseClient = require('./BaseClient');
const { logger } = require('~/config');
const CHATGPT_MODEL = 'gpt-3.5-turbo';
const tokenizersCache = {};
class ChatGPTClient extends BaseClient {
constructor(apiKey, options = {}, cacheOptions = {}) {
super(apiKey, options, cacheOptions);
cacheOptions.namespace = cacheOptions.namespace || 'chatgpt';
this.conversationsCache = new Keyv(cacheOptions);
this.setOptions(options);
}
setOptions(options) {
if (this.options && !this.options.replaceOptions) {
// nested options aren't spread properly, so we need to do this manually
this.options.modelOptions = {
...this.options.modelOptions,
...options.modelOptions,
};
delete options.modelOptions;
// now we can merge options
this.options = {
...this.options,
...options,
};
} else {
this.options = options;
}
if (this.options.openaiApiKey) {
this.apiKey = this.options.openaiApiKey;
}
const modelOptions = this.options.modelOptions || {};
this.modelOptions = {
...modelOptions,
// set some good defaults (check for undefined in some cases because they may be 0)
model: modelOptions.model || CHATGPT_MODEL,
temperature: typeof modelOptions.temperature === 'undefined' ? 0.8 : modelOptions.temperature,
top_p: typeof modelOptions.top_p === 'undefined' ? 1 : modelOptions.top_p,
presence_penalty:
typeof modelOptions.presence_penalty === 'undefined' ? 1 : modelOptions.presence_penalty,
stop: modelOptions.stop,
};
this.isChatGptModel = this.modelOptions.model.includes('gpt-');
const { isChatGptModel } = this;
this.isUnofficialChatGptModel =
this.modelOptions.model.startsWith('text-chat') ||
this.modelOptions.model.startsWith('text-davinci-002-render');
const { isUnofficialChatGptModel } = this;
// Davinci models have a max context length of 4097 tokens.
this.maxContextTokens = this.options.maxContextTokens || (isChatGptModel ? 4095 : 4097);
// I decided to reserve 1024 tokens for the response.
// The max prompt tokens is determined by the max context tokens minus the max response tokens.
// Earlier messages will be dropped until the prompt is within the limit.
this.maxResponseTokens = this.modelOptions.max_tokens || 1024;
this.maxPromptTokens =
this.options.maxPromptTokens || this.maxContextTokens - this.maxResponseTokens;
if (this.maxPromptTokens + this.maxResponseTokens > this.maxContextTokens) {
throw new Error(
`maxPromptTokens + max_tokens (${this.maxPromptTokens} + ${this.maxResponseTokens} = ${
this.maxPromptTokens + this.maxResponseTokens
}) must be less than or equal to maxContextTokens (${this.maxContextTokens})`,
);
}
this.userLabel = this.options.userLabel || 'User';
this.chatGptLabel = this.options.chatGptLabel || 'ChatGPT';
if (isChatGptModel) {
// Use these faux tokens to help the AI understand the context since we are building the chat log ourselves.
// Trying to use "<|im_start|>" causes the AI to still generate "<" or "<|" at the end sometimes for some reason,
// without tripping the stop sequences, so I'm using "||>" instead.
this.startToken = '||>';
this.endToken = '';
this.gptEncoder = this.constructor.getTokenizer('cl100k_base');
} else if (isUnofficialChatGptModel) {
this.startToken = '<|im_start|>';
this.endToken = '<|im_end|>';
this.gptEncoder = this.constructor.getTokenizer('text-davinci-003', true, {
'<|im_start|>': 100264,
'<|im_end|>': 100265,
});
} else {
// Previously I was trying to use "<|endoftext|>" but there seems to be some bug with OpenAI's token counting
// system that causes only the first "<|endoftext|>" to be counted as 1 token, and the rest are not treated
// as a single token. So we're using this instead.
this.startToken = '||>';
this.endToken = '';
try {
this.gptEncoder = this.constructor.getTokenizer(this.modelOptions.model, true);
} catch {
this.gptEncoder = this.constructor.getTokenizer('text-davinci-003', true);
}
}
if (!this.modelOptions.stop) {
const stopTokens = [this.startToken];
if (this.endToken && this.endToken !== this.startToken) {
stopTokens.push(this.endToken);
}
stopTokens.push(`\n${this.userLabel}:`);
stopTokens.push('<|diff_marker|>');
// I chose not to do one for `chatGptLabel` because I've never seen it happen
this.modelOptions.stop = stopTokens;
}
if (this.options.reverseProxyUrl) {
this.completionsUrl = this.options.reverseProxyUrl;
} else if (isChatGptModel) {
this.completionsUrl = 'https://api.openai.com/v1/chat/completions';
} else {
this.completionsUrl = 'https://api.openai.com/v1/completions';
}
return this;
}
static getTokenizer(encoding, isModelName = false, extendSpecialTokens = {}) {
if (tokenizersCache[encoding]) {
return tokenizersCache[encoding];
}
let tokenizer;
if (isModelName) {
tokenizer = encodingForModel(encoding, extendSpecialTokens);
} else {
tokenizer = getEncoding(encoding, extendSpecialTokens);
}
tokenizersCache[encoding] = tokenizer;
return tokenizer;
}
/** @type {getCompletion} */
async getCompletion(input, onProgress, onTokenProgress, abortController = null) {
if (!abortController) {
abortController = new AbortController();
}
let modelOptions = { ...this.modelOptions };
if (typeof onProgress === 'function') {
modelOptions.stream = true;
}
if (this.isChatGptModel) {
modelOptions.messages = input;
} else {
modelOptions.prompt = input;
}
if (this.useOpenRouter && modelOptions.prompt) {
delete modelOptions.stop;
}
const { debug } = this.options;
let baseURL = this.completionsUrl;
if (debug) {
console.debug();
console.debug(baseURL);
console.debug(modelOptions);
console.debug();
}
const opts = {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
};
if (this.isVisionModel) {
modelOptions.max_tokens = 4000;
}
/** @type {TAzureConfig | undefined} */
const azureConfig = this.options?.req?.app?.locals?.[EModelEndpoint.azureOpenAI];
const isAzure = this.azure || this.options.azure;
if (
(isAzure && this.isVisionModel && azureConfig) ||
(azureConfig && this.isVisionModel && this.options.endpoint === EModelEndpoint.azureOpenAI)
) {
const { modelGroupMap, groupMap } = azureConfig;
const {
azureOptions,
baseURL,
headers = {},
serverless,
} = mapModelToAzureConfig({
modelName: modelOptions.model,
modelGroupMap,
groupMap,
});
opts.headers = resolveHeaders(headers);
this.langchainProxy = extractBaseURL(baseURL);
this.apiKey = azureOptions.azureOpenAIApiKey;
const groupName = modelGroupMap[modelOptions.model].group;
this.options.addParams = azureConfig.groupMap[groupName].addParams;
this.options.dropParams = azureConfig.groupMap[groupName].dropParams;
// Note: `forcePrompt` not re-assigned as only chat models are vision models
this.azure = !serverless && azureOptions;
this.azureEndpoint =
!serverless && genAzureChatCompletion(this.azure, modelOptions.model, this);
if (serverless === true) {
this.options.defaultQuery = azureOptions.azureOpenAIApiVersion
? { 'api-version': azureOptions.azureOpenAIApiVersion }
: undefined;
this.options.headers['api-key'] = this.apiKey;
}
}
if (this.options.defaultQuery) {
opts.defaultQuery = this.options.defaultQuery;
}
if (this.options.headers) {
opts.headers = { ...opts.headers, ...this.options.headers };
}
if (isAzure) {
// Azure does not accept `model` in the body, so we need to remove it.
delete modelOptions.model;
baseURL = this.langchainProxy
? constructAzureURL({
baseURL: this.langchainProxy,
azureOptions: this.azure,
})
: this.azureEndpoint.split(/(?<!\/)\/(chat|completion)\//)[0];
if (this.options.forcePrompt) {
baseURL += '/completions';
} else {
baseURL += '/chat/completions';
}
opts.defaultQuery = { 'api-version': this.azure.azureOpenAIApiVersion };
opts.headers = { ...opts.headers, 'api-key': this.apiKey };
} else if (this.apiKey) {
opts.headers.Authorization = `Bearer ${this.apiKey}`;
}
if (process.env.OPENAI_ORGANIZATION) {
opts.headers['OpenAI-Organization'] = process.env.OPENAI_ORGANIZATION;
}
if (this.useOpenRouter) {
opts.headers['HTTP-Referer'] = 'https://librechat.ai';
opts.headers['X-Title'] = 'LibreChat';
}
/* hacky fixes for Mistral AI API:
- Re-orders system message to the top of the messages payload, as not allowed anywhere else
- If there is only one message and it's a system message, change the role to user
*/
if (baseURL.includes('https://api.mistral.ai/v1') && modelOptions.messages) {
const { messages } = modelOptions;
const systemMessageIndex = messages.findIndex((msg) => msg.role === 'system');
if (systemMessageIndex > 0) {
const [systemMessage] = messages.splice(systemMessageIndex, 1);
messages.unshift(systemMessage);
}
modelOptions.messages = messages;
if (messages.length === 1 && messages[0].role === 'system') {
modelOptions.messages[0].role = 'user';
}
}
if (this.options.addParams && typeof this.options.addParams === 'object') {
modelOptions = {
...modelOptions,
...this.options.addParams,
};
logger.debug('[ChatGPTClient] chatCompletion: added params', {
addParams: this.options.addParams,
modelOptions,
});
}
if (this.options.dropParams && Array.isArray(this.options.dropParams)) {
this.options.dropParams.forEach((param) => {
delete modelOptions[param];
});
logger.debug('[ChatGPTClient] chatCompletion: dropped params', {
dropParams: this.options.dropParams,
modelOptions,
});
}
if (baseURL.startsWith(CohereConstants.API_URL)) {
const payload = createCoherePayload({ modelOptions });
return await this.cohereChatCompletion({ payload, onTokenProgress });
}
if (baseURL.includes('v1') && !baseURL.includes('/completions') && !this.isChatCompletion) {
baseURL = baseURL.split('v1')[0] + 'v1/completions';
} else if (
baseURL.includes('v1') &&
!baseURL.includes('/chat/completions') &&
this.isChatCompletion
) {
baseURL = baseURL.split('v1')[0] + 'v1/chat/completions';
}
const BASE_URL = new URL(baseURL);
if (opts.defaultQuery) {
Object.entries(opts.defaultQuery).forEach(([key, value]) => {
BASE_URL.searchParams.append(key, value);
});
delete opts.defaultQuery;
}
const completionsURL = BASE_URL.toString();
opts.body = JSON.stringify(modelOptions);
if (modelOptions.stream) {
return new Promise(async (resolve, reject) => {
try {
let done = false;
await fetchEventSource(completionsURL, {
...opts,
signal: abortController.signal,
async onopen(response) {
if (response.status === 200) {
return;
}
if (debug) {
console.debug(response);
}
let error;
try {
const body = await response.text();
error = new Error(`Failed to send message. HTTP ${response.status} - ${body}`);
error.status = response.status;
error.json = JSON.parse(body);
} catch {
error = error || new Error(`Failed to send message. HTTP ${response.status}`);
}
throw error;
},
onclose() {
if (debug) {
console.debug('Server closed the connection unexpectedly, returning...');
}
// workaround for private API not sending [DONE] event
if (!done) {
onProgress('[DONE]');
resolve();
}
},
onerror(err) {
if (debug) {
console.debug(err);
}
// rethrow to stop the operation
throw err;
},
onmessage(message) {
if (debug) {
console.debug(message);
}
if (!message.data || message.event === 'ping') {
return;
}
if (message.data === '[DONE]') {
onProgress('[DONE]');
resolve();
done = true;
return;
}
onProgress(JSON.parse(message.data));
},
});
} catch (err) {
reject(err);
}
});
}
const response = await fetch(completionsURL, {
...opts,
signal: abortController.signal,
});
if (response.status !== 200) {
const body = await response.text();
const error = new Error(`Failed to send message. HTTP ${response.status} - ${body}`);
error.status = response.status;
try {
error.json = JSON.parse(body);
} catch {
error.body = body;
}
throw error;
}
return response.json();
}
/** @type {cohereChatCompletion} */
async cohereChatCompletion({ payload, onTokenProgress }) {
const cohere = new CohereClient({
token: this.apiKey,
environment: this.completionsUrl,
});
if (!payload.stream) {
const chatResponse = await cohere.chat(payload);
return chatResponse.text;
}
const chatStream = await cohere.chatStream(payload);
let reply = '';
for await (const message of chatStream) {
if (!message) {
continue;
}
if (message.eventType === 'text-generation' && message.text) {
onTokenProgress(message.text);
reply += message.text;
}
/*
Cohere API Chinese Unicode character replacement hotfix.
Should be un-commented when the following issue is resolved:
https://github.com/cohere-ai/cohere-typescript/issues/151
else if (message.eventType === 'stream-end' && message.response) {
reply = message.response.text;
}
*/
}
return reply;
}
async generateTitle(userMessage, botMessage) {
const instructionsPayload = {
role: 'system',
content: `Write an extremely concise subtitle for this conversation with no more than a few words. All words should be capitalized. Exclude punctuation.
||>Message:
${userMessage.message}
||>Response:
${botMessage.message}
||>Title:`,
};
const titleGenClientOptions = JSON.parse(JSON.stringify(this.options));
titleGenClientOptions.modelOptions = {
model: 'gpt-3.5-turbo',
temperature: 0,
presence_penalty: 0,
frequency_penalty: 0,
};
const titleGenClient = new ChatGPTClient(this.apiKey, titleGenClientOptions);
const result = await titleGenClient.getCompletion([instructionsPayload], null);
// remove any non-alphanumeric characters, replace multiple spaces with 1, and then trim
return result.choices[0].message.content
.replace(/[^a-zA-Z0-9' ]/g, '')
.replace(/\s+/g, ' ')
.trim();
}
async sendMessage(message, opts = {}) {
if (opts.clientOptions && typeof opts.clientOptions === 'object') {
this.setOptions(opts.clientOptions);
}
const conversationId = opts.conversationId || crypto.randomUUID();
const parentMessageId = opts.parentMessageId || crypto.randomUUID();
let conversation =
typeof opts.conversation === 'object'
? opts.conversation
: await this.conversationsCache.get(conversationId);
let isNewConversation = false;
if (!conversation) {
conversation = {
messages: [],
createdAt: Date.now(),
};
isNewConversation = true;
}
const shouldGenerateTitle = opts.shouldGenerateTitle && isNewConversation;
const userMessage = {
id: crypto.randomUUID(),
parentMessageId,
role: 'User',
message,
};
conversation.messages.push(userMessage);
// Doing it this way instead of having each message be a separate element in the array seems to be more reliable,
// especially when it comes to keeping the AI in character. It also seems to improve coherency and context retention.
const { prompt: payload, context } = await this.buildPrompt(
conversation.messages,
userMessage.id,
{
isChatGptModel: this.isChatGptModel,
promptPrefix: opts.promptPrefix,
},
);
if (this.options.keepNecessaryMessagesOnly) {
conversation.messages = context;
}
let reply = '';
let result = null;
if (typeof opts.onProgress === 'function') {
await this.getCompletion(
payload,
(progressMessage) => {
if (progressMessage === '[DONE]') {
return;
}
const token = this.isChatGptModel
? progressMessage.choices[0].delta.content
: progressMessage.choices[0].text;
// first event's delta content is always undefined
if (!token) {
return;
}
if (this.options.debug) {
console.debug(token);
}
if (token === this.endToken) {
return;
}
opts.onProgress(token);
reply += token;
},
opts.abortController || new AbortController(),
);
} else {
result = await this.getCompletion(
payload,
null,
opts.abortController || new AbortController(),
);
if (this.options.debug) {
console.debug(JSON.stringify(result));
}
if (this.isChatGptModel) {
reply = result.choices[0].message.content;
} else {
reply = result.choices[0].text.replace(this.endToken, '');
}
}
// avoids some rendering issues when using the CLI app
if (this.options.debug) {
console.debug();
}
reply = reply.trim();
const replyMessage = {
id: crypto.randomUUID(),
parentMessageId: userMessage.id,
role: 'ChatGPT',
message: reply,
};
conversation.messages.push(replyMessage);
const returnData = {
response: replyMessage.message,
conversationId,
parentMessageId: replyMessage.parentMessageId,
messageId: replyMessage.id,
details: result || {},
};
if (shouldGenerateTitle) {
conversation.title = await this.generateTitle(userMessage, replyMessage);
returnData.title = conversation.title;
}
await this.conversationsCache.set(conversationId, conversation);
if (this.options.returnConversation) {
returnData.conversation = conversation;
}
return returnData;
}
async buildPrompt(messages, { isChatGptModel = false, promptPrefix = null }) {
promptPrefix = (promptPrefix || this.options.promptPrefix || '').trim();
// Handle attachments and create augmentedPrompt
if (this.options.attachments) {
const attachments = await this.options.attachments;
const lastMessage = messages[messages.length - 1];
if (this.message_file_map) {
this.message_file_map[lastMessage.messageId] = attachments;
} else {
this.message_file_map = {
[lastMessage.messageId]: attachments,
};
}
const files = await this.addImageURLs(lastMessage, attachments);
this.options.attachments = files;
this.contextHandlers = createContextHandlers(this.options.req, lastMessage.text);
}
if (this.message_file_map) {
this.contextHandlers = createContextHandlers(
this.options.req,
messages[messages.length - 1].text,
);
}
// Calculate image token cost and process embedded files
messages.forEach((message, i) => {
if (this.message_file_map && this.message_file_map[message.messageId]) {
const attachments = this.message_file_map[message.messageId];
for (const file of attachments) {
if (file.embedded) {
this.contextHandlers?.processFile(file);
continue;
}
messages[i].tokenCount =
(messages[i].tokenCount || 0) +
this.calculateImageTokenCost({
width: file.width,
height: file.height,
detail: this.options.imageDetail ?? ImageDetail.auto,
});
}
}
});
if (this.contextHandlers) {
this.augmentedPrompt = await this.contextHandlers.createContext();
promptPrefix = this.augmentedPrompt + promptPrefix;
}
if (promptPrefix) {
// If the prompt prefix doesn't end with the end token, add it.
if (!promptPrefix.endsWith(`${this.endToken}`)) {
promptPrefix = `${promptPrefix.trim()}${this.endToken}\n\n`;
}
promptPrefix = `${this.startToken}Instructions:\n${promptPrefix}`;
}
const promptSuffix = `${this.startToken}${this.chatGptLabel}:\n`; // Prompt ChatGPT to respond.
const instructionsPayload = {
role: 'system',
content: promptPrefix,
};
const messagePayload = {
role: 'system',
content: promptSuffix,
};
let currentTokenCount;
if (isChatGptModel) {
currentTokenCount =
this.getTokenCountForMessage(instructionsPayload) +
this.getTokenCountForMessage(messagePayload);
} else {
currentTokenCount = this.getTokenCount(`${promptPrefix}${promptSuffix}`);
}
let promptBody = '';
const maxTokenCount = this.maxPromptTokens;
const context = [];
// Iterate backwards through the messages, adding them to the prompt until we reach the max token count.
// Do this within a recursive async function so that it doesn't block the event loop for too long.
const buildPromptBody = async () => {
if (currentTokenCount < maxTokenCount && messages.length > 0) {
const message = messages.pop();
const roleLabel =
message?.isCreatedByUser || message?.role?.toLowerCase() === 'user'
? this.userLabel
: this.chatGptLabel;
const messageString = `${this.startToken}${roleLabel}:\n${
message?.text ?? message?.message
}${this.endToken}\n`;
let newPromptBody;
if (promptBody || isChatGptModel) {
newPromptBody = `${messageString}${promptBody}`;
} else {
// Always insert prompt prefix before the last user message, if not gpt-3.5-turbo.
// This makes the AI obey the prompt instructions better, which is important for custom instructions.
// After a bunch of testing, it doesn't seem to cause the AI any confusion, even if you ask it things
// like "what's the last thing I wrote?".
newPromptBody = `${promptPrefix}${messageString}${promptBody}`;
}
context.unshift(message);
const tokenCountForMessage = this.getTokenCount(messageString);
const newTokenCount = currentTokenCount + tokenCountForMessage;
if (newTokenCount > maxTokenCount) {
if (promptBody) {
// This message would put us over the token limit, so don't add it.
return false;
}
// This is the first message, so we can't add it. Just throw an error.
throw new Error(
`Prompt is too long. Max token count is ${maxTokenCount}, but prompt is ${newTokenCount} tokens long.`,
);
}
promptBody = newPromptBody;
currentTokenCount = newTokenCount;
// wait for next tick to avoid blocking the event loop
await new Promise((resolve) => setImmediate(resolve));
return buildPromptBody();
}
return true;
};
await buildPromptBody();
const prompt = `${promptBody}${promptSuffix}`;
if (isChatGptModel) {
messagePayload.content = prompt;
// Add 3 tokens for Assistant Label priming after all messages have been counted.
currentTokenCount += 3;
}
// Use up to `this.maxContextTokens` tokens (prompt + response), but try to leave `this.maxTokens` tokens for the response.
this.modelOptions.max_tokens = Math.min(
this.maxContextTokens - currentTokenCount,
this.maxResponseTokens,
);
if (isChatGptModel) {
return { prompt: [instructionsPayload, messagePayload], context };
}
return { prompt, context, promptTokens: currentTokenCount };
}
getTokenCount(text) {
return this.gptEncoder.encode(text, 'all').length;
}
/**
* Algorithm adapted from "6. Counting tokens for chat API calls" of
* https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
*
* An additional 3 tokens need to be added for assistant label priming after all messages have been counted.
*
* @param {Object} message
*/
getTokenCountForMessage(message) {
// Note: gpt-3.5-turbo and gpt-4 may update over time. Use default for these as well as for unknown models
let tokensPerMessage = 3;
let tokensPerName = 1;
if (this.modelOptions.model === 'gpt-3.5-turbo-0301') {
tokensPerMessage = 4;
tokensPerName = -1;
}
let numTokens = tokensPerMessage;
for (let [key, value] of Object.entries(message)) {
numTokens += this.getTokenCount(value);
if (key === 'name') {
numTokens += tokensPerName;
}
}
return numTokens;
}
}
module.exports = ChatGPTClient;

View File

@@ -1,7 +1,7 @@
const { google } = require('googleapis');
const { Tokenizer } = require('@librechat/api');
const { concat } = require('@langchain/core/utils/stream');
const { ChatVertexAI } = require('@langchain/google-vertexai');
const { Tokenizer, getSafetySettings } = require('@librechat/api');
const { ChatGoogleGenerativeAI } = require('@langchain/google-genai');
const { GoogleGenerativeAI: GenAI } = require('@google/generative-ai');
const { HumanMessage, SystemMessage } = require('@langchain/core/messages');
@@ -12,13 +12,13 @@ const {
endpointSettings,
parseTextParts,
EModelEndpoint,
googleSettings,
ContentTypes,
VisionModes,
ErrorTypes,
Constants,
AuthKeys,
} = require('librechat-data-provider');
const { getSafetySettings } = require('~/server/services/Endpoints/google/llm');
const { encodeAndFormat } = require('~/server/services/Files/images');
const { spendTokens } = require('~/models/spendTokens');
const { getModelMaxTokens } = require('~/utils');
@@ -166,6 +166,16 @@ class GoogleClient extends BaseClient {
);
}
// Add thinking configuration
this.modelOptions.thinkingConfig = {
thinkingBudget:
(this.modelOptions.thinking ?? googleSettings.thinking.default)
? this.modelOptions.thinkingBudget
: 0,
};
delete this.modelOptions.thinking;
delete this.modelOptions.thinkingBudget;
this.sender =
this.options.sender ??
getResponseSender({

View File

@@ -5,6 +5,7 @@ const {
isEnabled,
Tokenizer,
createFetch,
resolveHeaders,
constructAzureURL,
genAzureChatCompletion,
createStreamEventHandlers,
@@ -15,7 +16,6 @@ const {
ContentTypes,
parseTextParts,
EModelEndpoint,
resolveHeaders,
KnownEndpoints,
openAISettings,
ImageDetailCost,
@@ -37,7 +37,6 @@ const { addSpaceIfNeeded, sleep } = require('~/server/utils');
const { spendTokens } = require('~/models/spendTokens');
const { handleOpenAIErrors } = require('./tools/util');
const { createLLM, RunManager } = require('./llm');
const ChatGPTClient = require('./ChatGPTClient');
const { summaryBuffer } = require('./memory');
const { runTitleChain } = require('./chains');
const { tokenSplit } = require('./document');
@@ -47,12 +46,6 @@ const { logger } = require('~/config');
class OpenAIClient extends BaseClient {
constructor(apiKey, options = {}) {
super(apiKey, options);
this.ChatGPTClient = new ChatGPTClient();
this.buildPrompt = this.ChatGPTClient.buildPrompt.bind(this);
/** @type {getCompletion} */
this.getCompletion = this.ChatGPTClient.getCompletion.bind(this);
/** @type {cohereChatCompletion} */
this.cohereChatCompletion = this.ChatGPTClient.cohereChatCompletion.bind(this);
this.contextStrategy = options.contextStrategy
? options.contextStrategy.toLowerCase()
: 'discard';
@@ -379,23 +372,12 @@ class OpenAIClient extends BaseClient {
return files;
}
async buildMessages(
messages,
parentMessageId,
{ isChatCompletion = false, promptPrefix = null },
opts,
) {
async buildMessages(messages, parentMessageId, { promptPrefix = null }, opts) {
let orderedMessages = this.constructor.getMessagesForConversation({
messages,
parentMessageId,
summary: this.shouldSummarize,
});
if (!isChatCompletion) {
return await this.buildPrompt(orderedMessages, {
isChatGptModel: isChatCompletion,
promptPrefix,
});
}
let payload;
let instructions;

View File

@@ -1,542 +0,0 @@
const OpenAIClient = require('./OpenAIClient');
const { CallbackManager } = require('@langchain/core/callbacks/manager');
const { BufferMemory, ChatMessageHistory } = require('langchain/memory');
const { addImages, buildErrorInput, buildPromptPrefix } = require('./output_parsers');
const { initializeCustomAgent, initializeFunctionsAgent } = require('./agents');
const { processFileURL } = require('~/server/services/Files/process');
const { EModelEndpoint } = require('librechat-data-provider');
const { checkBalance } = require('~/models/balanceMethods');
const { formatLangChainMessages } = require('./prompts');
const { extractBaseURL } = require('~/utils');
const { loadTools } = require('./tools/util');
const { logger } = require('~/config');
class PluginsClient extends OpenAIClient {
constructor(apiKey, options = {}) {
super(apiKey, options);
this.sender = options.sender ?? 'Assistant';
this.tools = [];
this.actions = [];
this.setOptions(options);
this.openAIApiKey = this.apiKey;
this.executor = null;
}
setOptions(options) {
this.agentOptions = { ...options.agentOptions };
this.functionsAgent = this.agentOptions?.agent === 'functions';
this.agentIsGpt3 = this.agentOptions?.model?.includes('gpt-3');
super.setOptions(options);
this.isGpt3 = this.modelOptions?.model?.includes('gpt-3');
if (this.options.reverseProxyUrl) {
this.langchainProxy = extractBaseURL(this.options.reverseProxyUrl);
}
}
getSaveOptions() {
return {
artifacts: this.options.artifacts,
chatGptLabel: this.options.chatGptLabel,
modelLabel: this.options.modelLabel,
promptPrefix: this.options.promptPrefix,
tools: this.options.tools,
...this.modelOptions,
agentOptions: this.agentOptions,
iconURL: this.options.iconURL,
greeting: this.options.greeting,
spec: this.options.spec,
};
}
saveLatestAction(action) {
this.actions.push(action);
}
getFunctionModelName(input) {
if (/-(?!0314)\d{4}/.test(input)) {
return input;
} else if (input.includes('gpt-3.5-turbo')) {
return 'gpt-3.5-turbo';
} else if (input.includes('gpt-4')) {
return 'gpt-4';
} else {
return 'gpt-3.5-turbo';
}
}
getBuildMessagesOptions(opts) {
return {
isChatCompletion: true,
promptPrefix: opts.promptPrefix,
abortController: opts.abortController,
};
}
async initialize({ user, message, onAgentAction, onChainEnd, signal }) {
const modelOptions = {
modelName: this.agentOptions.model,
temperature: this.agentOptions.temperature,
};
const model = this.initializeLLM({
...modelOptions,
context: 'plugins',
initialMessageCount: this.currentMessages.length + 1,
});
logger.debug(
`[PluginsClient] Agent Model: ${model.modelName} | Temp: ${model.temperature} | Functions: ${this.functionsAgent}`,
);
// Map Messages to Langchain format
const pastMessages = formatLangChainMessages(this.currentMessages.slice(0, -1), {
userName: this.options?.name,
});
logger.debug('[PluginsClient] pastMessages: ' + pastMessages.length);
// TODO: use readOnly memory, TokenBufferMemory? (both unavailable in LangChainJS)
const memory = new BufferMemory({
llm: model,
chatHistory: new ChatMessageHistory(pastMessages),
});
const { loadedTools } = await loadTools({
user,
model,
tools: this.options.tools,
functions: this.functionsAgent,
options: {
memory,
signal: this.abortController.signal,
openAIApiKey: this.openAIApiKey,
conversationId: this.conversationId,
fileStrategy: this.options.req.app.locals.fileStrategy,
processFileURL,
message,
},
useSpecs: true,
});
if (loadedTools.length === 0) {
return;
}
this.tools = loadedTools;
logger.debug('[PluginsClient] Requested Tools', this.options.tools);
logger.debug(
'[PluginsClient] Loaded Tools',
this.tools.map((tool) => tool.name),
);
const handleAction = (action, runId, callback = null) => {
this.saveLatestAction(action);
logger.debug('[PluginsClient] Latest Agent Action ', this.actions[this.actions.length - 1]);
if (typeof callback === 'function') {
callback(action, runId);
}
};
// initialize agent
const initializer = this.functionsAgent ? initializeFunctionsAgent : initializeCustomAgent;
let customInstructions = (this.options.promptPrefix ?? '').trim();
if (typeof this.options.artifactsPrompt === 'string' && this.options.artifactsPrompt) {
customInstructions = `${customInstructions ?? ''}\n${this.options.artifactsPrompt}`.trim();
}
this.executor = await initializer({
model,
signal,
pastMessages,
tools: this.tools,
customInstructions,
verbose: this.options.debug,
returnIntermediateSteps: true,
customName: this.options.chatGptLabel,
currentDateString: this.currentDateString,
callbackManager: CallbackManager.fromHandlers({
async handleAgentAction(action, runId) {
handleAction(action, runId, onAgentAction);
},
async handleChainEnd(action) {
if (typeof onChainEnd === 'function') {
onChainEnd(action);
}
},
}),
});
logger.debug('[PluginsClient] Loaded agent.');
}
async executorCall(message, { signal, stream, onToolStart, onToolEnd }) {
let errorMessage = '';
const maxAttempts = 1;
for (let attempts = 1; attempts <= maxAttempts; attempts++) {
const errorInput = buildErrorInput({
message,
errorMessage,
actions: this.actions,
functionsAgent: this.functionsAgent,
});
const input = attempts > 1 ? errorInput : message;
logger.debug(`[PluginsClient] Attempt ${attempts} of ${maxAttempts}`);
if (errorMessage.length > 0) {
logger.debug('[PluginsClient] Caught error, input: ' + JSON.stringify(input));
}
try {
this.result = await this.executor.call({ input, signal }, [
{
async handleToolStart(...args) {
await onToolStart(...args);
},
async handleToolEnd(...args) {
await onToolEnd(...args);
},
async handleLLMEnd(output) {
const { generations } = output;
const { text } = generations[0][0];
if (text && typeof stream === 'function') {
await stream(text);
}
},
},
]);
break; // Exit the loop if the function call is successful
} catch (err) {
logger.error('[PluginsClient] executorCall error:', err);
if (attempts === maxAttempts) {
const { run } = this.runManager.getRunByConversationId(this.conversationId);
const defaultOutput = `Encountered an error while attempting to respond: ${err.message}`;
this.result.output = run && run.error ? run.error : defaultOutput;
this.result.errorMessage = run && run.error ? run.error : err.message;
this.result.intermediateSteps = this.actions;
break;
}
}
}
}
/**
*
* @param {TMessage} responseMessage
* @param {Partial<TMessage>} saveOptions
* @param {string} user
* @returns
*/
async handleResponseMessage(responseMessage, saveOptions, user) {
const { output, errorMessage, ...result } = this.result;
logger.debug('[PluginsClient][handleResponseMessage] Output:', {
output,
errorMessage,
...result,
});
const { error } = responseMessage;
if (!error) {
responseMessage.tokenCount = this.getTokenCountForResponse(responseMessage);
responseMessage.completionTokens = this.getTokenCount(responseMessage.text);
}
// Record usage only when completion is skipped as it is already recorded in the agent phase.
if (!this.agentOptions.skipCompletion && !error) {
await this.recordTokenUsage(responseMessage);
}
const databasePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user);
delete responseMessage.tokenCount;
return { ...responseMessage, ...result, databasePromise };
}
async sendMessage(message, opts = {}) {
/** @type {Promise<TMessage>} */
let userMessagePromise;
/** @type {{ filteredTools: string[], includedTools: string[] }} */
const { filteredTools = [], includedTools = [] } = this.options.req.app.locals;
if (includedTools.length > 0) {
const tools = this.options.tools.filter((plugin) => includedTools.includes(plugin));
this.options.tools = tools;
} else {
const tools = this.options.tools.filter((plugin) => !filteredTools.includes(plugin));
this.options.tools = tools;
}
// If a message is edited, no tools can be used.
const completionMode = this.options.tools.length === 0 || opts.isEdited;
if (completionMode) {
this.setOptions(opts);
return super.sendMessage(message, opts);
}
logger.debug('[PluginsClient] sendMessage', { userMessageText: message, opts });
const {
user,
conversationId,
responseMessageId,
saveOptions,
userMessage,
onAgentAction,
onChainEnd,
onToolStart,
onToolEnd,
} = await this.handleStartMethods(message, opts);
if (opts.progressCallback) {
opts.onProgress = opts.progressCallback.call(null, {
...(opts.progressOptions ?? {}),
parentMessageId: userMessage.messageId,
messageId: responseMessageId,
});
}
this.currentMessages.push(userMessage);
let {
prompt: payload,
tokenCountMap,
promptTokens,
} = await this.buildMessages(
this.currentMessages,
userMessage.messageId,
this.getBuildMessagesOptions({
promptPrefix: null,
abortController: this.abortController,
}),
);
if (tokenCountMap) {
logger.debug('[PluginsClient] tokenCountMap', { tokenCountMap });
if (tokenCountMap[userMessage.messageId]) {
userMessage.tokenCount = tokenCountMap[userMessage.messageId];
logger.debug('[PluginsClient] userMessage.tokenCount', userMessage.tokenCount);
}
this.handleTokenCountMap(tokenCountMap);
}
this.result = {};
if (payload) {
this.currentMessages = payload;
}
if (!this.skipSaveUserMessage) {
userMessagePromise = this.saveMessageToDatabase(userMessage, saveOptions, user);
if (typeof opts?.getReqData === 'function') {
opts.getReqData({
userMessagePromise,
});
}
}
const balance = this.options.req?.app?.locals?.balance;
if (balance?.enabled) {
await checkBalance({
req: this.options.req,
res: this.options.res,
txData: {
user: this.user,
tokenType: 'prompt',
amount: promptTokens,
debug: this.options.debug,
model: this.modelOptions.model,
endpoint: EModelEndpoint.openAI,
},
});
}
const responseMessage = {
endpoint: EModelEndpoint.gptPlugins,
iconURL: this.options.iconURL,
messageId: responseMessageId,
conversationId,
parentMessageId: userMessage.messageId,
isCreatedByUser: false,
model: this.modelOptions.model,
sender: this.sender,
promptTokens,
};
await this.initialize({
user,
message,
onAgentAction,
onChainEnd,
signal: this.abortController.signal,
onProgress: opts.onProgress,
});
// const stream = async (text) => {
// await this.generateTextStream.call(this, text, opts.onProgress, { delay: 1 });
// };
await this.executorCall(message, {
signal: this.abortController.signal,
// stream,
onToolStart,
onToolEnd,
});
// If message was aborted mid-generation
if (this.result?.errorMessage?.length > 0 && this.result?.errorMessage?.includes('cancel')) {
responseMessage.text = 'Cancelled.';
return await this.handleResponseMessage(responseMessage, saveOptions, user);
}
// If error occurred during generation (likely token_balance)
if (this.result?.errorMessage?.length > 0) {
responseMessage.error = true;
responseMessage.text = this.result.output;
return await this.handleResponseMessage(responseMessage, saveOptions, user);
}
if (this.agentOptions.skipCompletion && this.result.output && this.functionsAgent) {
const partialText = opts.getPartialText();
const trimmedPartial = opts.getPartialText().replaceAll(':::plugin:::\n', '');
responseMessage.text =
trimmedPartial.length === 0 ? `${partialText}${this.result.output}` : partialText;
addImages(this.result.intermediateSteps, responseMessage);
await this.generateTextStream(this.result.output, opts.onProgress, { delay: 5 });
return await this.handleResponseMessage(responseMessage, saveOptions, user);
}
if (this.agentOptions.skipCompletion && this.result.output) {
responseMessage.text = this.result.output;
addImages(this.result.intermediateSteps, responseMessage);
await this.generateTextStream(this.result.output, opts.onProgress, { delay: 5 });
return await this.handleResponseMessage(responseMessage, saveOptions, user);
}
logger.debug('[PluginsClient] Completion phase: this.result', this.result);
const promptPrefix = buildPromptPrefix({
result: this.result,
message,
functionsAgent: this.functionsAgent,
});
logger.debug('[PluginsClient]', { promptPrefix });
payload = await this.buildCompletionPrompt({
messages: this.currentMessages,
promptPrefix,
});
logger.debug('[PluginsClient] buildCompletionPrompt Payload', payload);
responseMessage.text = await this.sendCompletion(payload, opts);
return await this.handleResponseMessage(responseMessage, saveOptions, user);
}
async buildCompletionPrompt({ messages, promptPrefix: _promptPrefix }) {
logger.debug('[PluginsClient] buildCompletionPrompt messages', messages);
const orderedMessages = messages;
let promptPrefix = _promptPrefix.trim();
// If the prompt prefix doesn't end with the end token, add it.
if (!promptPrefix.endsWith(`${this.endToken}`)) {
promptPrefix = `${promptPrefix.trim()}${this.endToken}\n\n`;
}
promptPrefix = `${this.startToken}Instructions:\n${promptPrefix}`;
const promptSuffix = `${this.startToken}${this.chatGptLabel ?? 'Assistant'}:\n`;
const instructionsPayload = {
role: 'system',
content: promptPrefix,
};
const messagePayload = {
role: 'system',
content: promptSuffix,
};
if (this.isGpt3) {
instructionsPayload.role = 'user';
messagePayload.role = 'user';
instructionsPayload.content += `\n${promptSuffix}`;
}
// testing if this works with browser endpoint
if (!this.isGpt3 && this.options.reverseProxyUrl) {
instructionsPayload.role = 'user';
}
let currentTokenCount =
this.getTokenCountForMessage(instructionsPayload) +
this.getTokenCountForMessage(messagePayload);
let promptBody = '';
const maxTokenCount = this.maxPromptTokens;
// Iterate backwards through the messages, adding them to the prompt until we reach the max token count.
// Do this within a recursive async function so that it doesn't block the event loop for too long.
const buildPromptBody = async () => {
if (currentTokenCount < maxTokenCount && orderedMessages.length > 0) {
const message = orderedMessages.pop();
const isCreatedByUser = message.isCreatedByUser || message.role?.toLowerCase() === 'user';
const roleLabel = isCreatedByUser ? this.userLabel : this.chatGptLabel;
let messageString = `${this.startToken}${roleLabel}:\n${
message.text ?? message.content ?? ''
}${this.endToken}\n`;
let newPromptBody = `${messageString}${promptBody}`;
const tokenCountForMessage = this.getTokenCount(messageString);
const newTokenCount = currentTokenCount + tokenCountForMessage;
if (newTokenCount > maxTokenCount) {
if (promptBody) {
// This message would put us over the token limit, so don't add it.
return false;
}
// This is the first message, so we can't add it. Just throw an error.
throw new Error(
`Prompt is too long. Max token count is ${maxTokenCount}, but prompt is ${newTokenCount} tokens long.`,
);
}
promptBody = newPromptBody;
currentTokenCount = newTokenCount;
// wait for next tick to avoid blocking the event loop
await new Promise((resolve) => setTimeout(resolve, 0));
return buildPromptBody();
}
return true;
};
await buildPromptBody();
const prompt = promptBody;
messagePayload.content = prompt;
// Add 2 tokens for metadata after all messages have been counted.
currentTokenCount += 2;
if (this.isGpt3 && messagePayload.content.length > 0) {
const context = 'Chat History:\n';
messagePayload.content = `${context}${prompt}`;
currentTokenCount += this.getTokenCount(context);
}
// Use up to `this.maxContextTokens` tokens (prompt + response), but try to leave `this.maxTokens` tokens for the response.
this.modelOptions.max_tokens = Math.min(
this.maxContextTokens - currentTokenCount,
this.maxResponseTokens,
);
if (this.isGpt3) {
messagePayload.content += promptSuffix;
return [instructionsPayload, messagePayload];
}
const result = [messagePayload, instructionsPayload];
if (this.functionsAgent && !this.isGpt3) {
result[1].content = `${result[1].content}\n${this.startToken}${this.chatGptLabel}:\nSure thing! Here is the output you requested:\n`;
}
return result.filter((message) => message.content.length > 0);
}
}
module.exports = PluginsClient;

View File

@@ -1,15 +1,11 @@
const ChatGPTClient = require('./ChatGPTClient');
const OpenAIClient = require('./OpenAIClient');
const PluginsClient = require('./PluginsClient');
const GoogleClient = require('./GoogleClient');
const TextStream = require('./TextStream');
const AnthropicClient = require('./AnthropicClient');
const toolUtils = require('./tools/util');
module.exports = {
ChatGPTClient,
OpenAIClient,
PluginsClient,
GoogleClient,
TextStream,
AnthropicClient,

View File

@@ -1,6 +1,7 @@
const axios = require('axios');
const { isEnabled } = require('~/server/utils');
const { logger } = require('~/config');
const { isEnabled } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { generateShortLivedToken } = require('~/server/services/AuthService');
const footer = `Use the context as your learned knowledge to better answer the user.
@@ -18,7 +19,7 @@ function createContextHandlers(req, userMessageContent) {
const queryPromises = [];
const processedFiles = [];
const processedIds = new Set();
const jwtToken = req.headers.authorization.split(' ')[1];
const jwtToken = generateShortLivedToken(req.user.id);
const useFullContext = isEnabled(process.env.RAG_USE_FULL_CONTEXT);
const query = async (file) => {

View File

@@ -531,44 +531,6 @@ describe('OpenAIClient', () => {
});
});
describe('sendMessage/getCompletion/chatCompletion', () => {
afterEach(() => {
delete process.env.AZURE_OPENAI_DEFAULT_MODEL;
delete process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME;
});
it('should call getCompletion and fetchEventSource when using a text/instruct model', async () => {
const model = 'text-davinci-003';
const onProgress = jest.fn().mockImplementation(() => ({}));
const testClient = new OpenAIClient('test-api-key', {
...defaultOptions,
modelOptions: { model },
});
const getCompletion = jest.spyOn(testClient, 'getCompletion');
await testClient.sendMessage('Hi mom!', { onProgress });
expect(getCompletion).toHaveBeenCalled();
expect(getCompletion.mock.calls.length).toBe(1);
expect(getCompletion.mock.calls[0][0]).toBe('||>User:\nHi mom!\n||>Assistant:\n');
expect(fetchEventSource).toHaveBeenCalled();
expect(fetchEventSource.mock.calls.length).toBe(1);
// Check if the first argument (url) is correct
const firstCallArgs = fetchEventSource.mock.calls[0];
const expectedURL = 'https://api.openai.com/v1/completions';
expect(firstCallArgs[0]).toBe(expectedURL);
const requestBody = JSON.parse(firstCallArgs[1].body);
expect(requestBody).toHaveProperty('model');
expect(requestBody.model).toBe(model);
});
});
describe('checkVisionRequest functionality', () => {
let client;
const attachments = [{ type: 'image/png' }];

View File

@@ -1,314 +0,0 @@
const crypto = require('crypto');
const { Constants } = require('librechat-data-provider');
const { HumanMessage, AIMessage } = require('@langchain/core/messages');
const PluginsClient = require('../PluginsClient');
jest.mock('~/db/connect');
jest.mock('~/models/Conversation', () => {
return function () {
return {
save: jest.fn(),
deleteConvos: jest.fn(),
};
};
});
const defaultAzureOptions = {
azureOpenAIApiInstanceName: 'your-instance-name',
azureOpenAIApiDeploymentName: 'your-deployment-name',
azureOpenAIApiVersion: '2020-07-01-preview',
};
describe('PluginsClient', () => {
let TestAgent;
let options = {
tools: [],
modelOptions: {
model: 'gpt-3.5-turbo',
temperature: 0,
max_tokens: 2,
},
agentOptions: {
model: 'gpt-3.5-turbo',
},
};
let parentMessageId;
let conversationId;
const fakeMessages = [];
const userMessage = 'Hello, ChatGPT!';
const apiKey = 'fake-api-key';
beforeEach(() => {
TestAgent = new PluginsClient(apiKey, options);
TestAgent.loadHistory = jest
.fn()
.mockImplementation((conversationId, parentMessageId = null) => {
if (!conversationId) {
TestAgent.currentMessages = [];
return Promise.resolve([]);
}
const orderedMessages = TestAgent.constructor.getMessagesForConversation({
messages: fakeMessages,
parentMessageId,
});
const chatMessages = orderedMessages.map((msg) =>
msg?.isCreatedByUser || msg?.role?.toLowerCase() === 'user'
? new HumanMessage(msg.text)
: new AIMessage(msg.text),
);
TestAgent.currentMessages = orderedMessages;
return Promise.resolve(chatMessages);
});
TestAgent.sendMessage = jest.fn().mockImplementation(async (message, opts = {}) => {
if (opts && typeof opts === 'object') {
TestAgent.setOptions(opts);
}
const conversationId = opts.conversationId || crypto.randomUUID();
const parentMessageId = opts.parentMessageId || Constants.NO_PARENT;
const userMessageId = opts.overrideParentMessageId || crypto.randomUUID();
this.pastMessages = await TestAgent.loadHistory(
conversationId,
TestAgent.options?.parentMessageId,
);
const userMessage = {
text: message,
sender: 'ChatGPT',
isCreatedByUser: true,
messageId: userMessageId,
parentMessageId,
conversationId,
};
const response = {
sender: 'ChatGPT',
text: 'Hello, User!',
isCreatedByUser: false,
messageId: crypto.randomUUID(),
parentMessageId: userMessage.messageId,
conversationId,
};
fakeMessages.push(userMessage);
fakeMessages.push(response);
return response;
});
});
test('initializes PluginsClient without crashing', () => {
expect(TestAgent).toBeInstanceOf(PluginsClient);
});
test('check setOptions function', () => {
expect(TestAgent.agentIsGpt3).toBe(true);
});
describe('sendMessage', () => {
test('sendMessage should return a response message', async () => {
const expectedResult = expect.objectContaining({
sender: 'ChatGPT',
text: expect.any(String),
isCreatedByUser: false,
messageId: expect.any(String),
parentMessageId: expect.any(String),
conversationId: expect.any(String),
});
const response = await TestAgent.sendMessage(userMessage);
parentMessageId = response.messageId;
conversationId = response.conversationId;
expect(response).toEqual(expectedResult);
});
test('sendMessage should work with provided conversationId and parentMessageId', async () => {
const userMessage = 'Second message in the conversation';
const opts = {
conversationId,
parentMessageId,
};
const expectedResult = expect.objectContaining({
sender: 'ChatGPT',
text: expect.any(String),
isCreatedByUser: false,
messageId: expect.any(String),
parentMessageId: expect.any(String),
conversationId: opts.conversationId,
});
const response = await TestAgent.sendMessage(userMessage, opts);
parentMessageId = response.messageId;
expect(response.conversationId).toEqual(conversationId);
expect(response).toEqual(expectedResult);
});
test('should return chat history', async () => {
const chatMessages = await TestAgent.loadHistory(conversationId, parentMessageId);
expect(TestAgent.currentMessages).toHaveLength(4);
expect(chatMessages[0].text).toEqual(userMessage);
});
});
describe('getFunctionModelName', () => {
let client;
beforeEach(() => {
client = new PluginsClient('dummy_api_key');
});
test('should return the input when it includes a dash followed by four digits', () => {
expect(client.getFunctionModelName('-1234')).toBe('-1234');
expect(client.getFunctionModelName('gpt-4-5678-preview')).toBe('gpt-4-5678-preview');
});
test('should return the input for all function-capable models (`0613` models and above)', () => {
expect(client.getFunctionModelName('gpt-4-0613')).toBe('gpt-4-0613');
expect(client.getFunctionModelName('gpt-4-32k-0613')).toBe('gpt-4-32k-0613');
expect(client.getFunctionModelName('gpt-3.5-turbo-0613')).toBe('gpt-3.5-turbo-0613');
expect(client.getFunctionModelName('gpt-3.5-turbo-16k-0613')).toBe('gpt-3.5-turbo-16k-0613');
expect(client.getFunctionModelName('gpt-3.5-turbo-1106')).toBe('gpt-3.5-turbo-1106');
expect(client.getFunctionModelName('gpt-4-1106-preview')).toBe('gpt-4-1106-preview');
expect(client.getFunctionModelName('gpt-4-1106')).toBe('gpt-4-1106');
});
test('should return the corresponding model if input is non-function capable (`0314` models)', () => {
expect(client.getFunctionModelName('gpt-4-0314')).toBe('gpt-4');
expect(client.getFunctionModelName('gpt-4-32k-0314')).toBe('gpt-4');
expect(client.getFunctionModelName('gpt-3.5-turbo-0314')).toBe('gpt-3.5-turbo');
expect(client.getFunctionModelName('gpt-3.5-turbo-16k-0314')).toBe('gpt-3.5-turbo');
});
test('should return "gpt-3.5-turbo" when the input includes "gpt-3.5-turbo"', () => {
expect(client.getFunctionModelName('test gpt-3.5-turbo model')).toBe('gpt-3.5-turbo');
});
test('should return "gpt-4" when the input includes "gpt-4"', () => {
expect(client.getFunctionModelName('testing gpt-4')).toBe('gpt-4');
});
test('should return "gpt-3.5-turbo" for input that does not meet any specific condition', () => {
expect(client.getFunctionModelName('random string')).toBe('gpt-3.5-turbo');
expect(client.getFunctionModelName('')).toBe('gpt-3.5-turbo');
});
});
describe('Azure OpenAI tests specific to Plugins', () => {
// TODO: add more tests for Azure OpenAI integration with Plugins
// let client;
// beforeEach(() => {
// client = new PluginsClient('dummy_api_key');
// });
test('should not call getFunctionModelName when azure options are set', () => {
const spy = jest.spyOn(PluginsClient.prototype, 'getFunctionModelName');
const model = 'gpt-4-turbo';
// note, without the azure change in PR #1766, `getFunctionModelName` is called twice
const testClient = new PluginsClient('dummy_api_key', {
agentOptions: {
model,
agent: 'functions',
},
azure: defaultAzureOptions,
});
expect(spy).not.toHaveBeenCalled();
expect(testClient.agentOptions.model).toBe(model);
spy.mockRestore();
});
});
describe('sendMessage with filtered tools', () => {
let TestAgent;
const apiKey = 'fake-api-key';
const mockTools = [{ name: 'tool1' }, { name: 'tool2' }, { name: 'tool3' }, { name: 'tool4' }];
beforeEach(() => {
TestAgent = new PluginsClient(apiKey, {
tools: mockTools,
modelOptions: {
model: 'gpt-3.5-turbo',
temperature: 0,
max_tokens: 2,
},
agentOptions: {
model: 'gpt-3.5-turbo',
},
});
TestAgent.options.req = {
app: {
locals: {},
},
};
TestAgent.sendMessage = jest.fn().mockImplementation(async () => {
const { filteredTools = [], includedTools = [] } = TestAgent.options.req.app.locals;
if (includedTools.length > 0) {
const tools = TestAgent.options.tools.filter((plugin) =>
includedTools.includes(plugin.name),
);
TestAgent.options.tools = tools;
} else {
const tools = TestAgent.options.tools.filter(
(plugin) => !filteredTools.includes(plugin.name),
);
TestAgent.options.tools = tools;
}
return {
text: 'Mocked response',
tools: TestAgent.options.tools,
};
});
});
test('should filter out tools when filteredTools is provided', async () => {
TestAgent.options.req.app.locals.filteredTools = ['tool1', 'tool3'];
const response = await TestAgent.sendMessage('Test message');
expect(response.tools).toHaveLength(2);
expect(response.tools).toEqual(
expect.arrayContaining([
expect.objectContaining({ name: 'tool2' }),
expect.objectContaining({ name: 'tool4' }),
]),
);
});
test('should only include specified tools when includedTools is provided', async () => {
TestAgent.options.req.app.locals.includedTools = ['tool2', 'tool4'];
const response = await TestAgent.sendMessage('Test message');
expect(response.tools).toHaveLength(2);
expect(response.tools).toEqual(
expect.arrayContaining([
expect.objectContaining({ name: 'tool2' }),
expect.objectContaining({ name: 'tool4' }),
]),
);
});
test('should prioritize includedTools over filteredTools', async () => {
TestAgent.options.req.app.locals.filteredTools = ['tool1', 'tool3'];
TestAgent.options.req.app.locals.includedTools = ['tool1', 'tool2'];
const response = await TestAgent.sendMessage('Test message');
expect(response.tools).toHaveLength(2);
expect(response.tools).toEqual(
expect.arrayContaining([
expect.objectContaining({ name: 'tool1' }),
expect.objectContaining({ name: 'tool2' }),
]),
);
});
test('should not modify tools when no filters are provided', async () => {
const response = await TestAgent.sendMessage('Test message');
expect(response.tools).toHaveLength(4);
expect(response.tools).toEqual(expect.arrayContaining(mockTools));
});
});
});

View File

@@ -107,6 +107,12 @@ const getImageEditPromptDescription = () => {
return process.env.IMAGE_EDIT_OAI_PROMPT_DESCRIPTION || DEFAULT_IMAGE_EDIT_PROMPT_DESCRIPTION;
};
function createAbortHandler() {
return function () {
logger.debug('[ImageGenOAI] Image generation aborted');
};
}
/**
* Creates OpenAI Image tools (generation and editing)
* @param {Object} fields - Configuration fields
@@ -201,10 +207,18 @@ function createOpenAIImageTools(fields = {}) {
}
let resp;
/** @type {AbortSignal} */
let derivedSignal = null;
/** @type {() => void} */
let abortHandler = null;
try {
const derivedSignal = runnableConfig?.signal
? AbortSignal.any([runnableConfig.signal])
: undefined;
if (runnableConfig?.signal) {
derivedSignal = AbortSignal.any([runnableConfig.signal]);
abortHandler = createAbortHandler();
derivedSignal.addEventListener('abort', abortHandler, { once: true });
}
resp = await openai.images.generate(
{
model: 'gpt-image-1',
@@ -228,6 +242,10 @@ function createOpenAIImageTools(fields = {}) {
logAxiosError({ error, message });
return returnValue(`Something went wrong when trying to generate the image. The OpenAI API may be unavailable:
Error Message: ${error.message}`);
} finally {
if (abortHandler && derivedSignal) {
derivedSignal.removeEventListener('abort', abortHandler);
}
}
if (!resp) {
@@ -409,10 +427,17 @@ Error Message: ${error.message}`);
headers['Authorization'] = `Bearer ${apiKey}`;
}
/** @type {AbortSignal} */
let derivedSignal = null;
/** @type {() => void} */
let abortHandler = null;
try {
const derivedSignal = runnableConfig?.signal
? AbortSignal.any([runnableConfig.signal])
: undefined;
if (runnableConfig?.signal) {
derivedSignal = AbortSignal.any([runnableConfig.signal]);
abortHandler = createAbortHandler();
derivedSignal.addEventListener('abort', abortHandler, { once: true });
}
/** @type {import('axios').AxiosRequestConfig} */
const axiosConfig = {
@@ -467,6 +492,10 @@ Error Message: ${error.message}`);
logAxiosError({ error, message });
return returnValue(`Something went wrong when trying to edit the image. The OpenAI API may be unavailable:
Error Message: ${error.message || 'Unknown error'}`);
} finally {
if (abortHandler && derivedSignal) {
derivedSignal.removeEventListener('abort', abortHandler);
}
}
},
{

View File

@@ -1,9 +1,10 @@
const { z } = require('zod');
const axios = require('axios');
const { tool } = require('@langchain/core/tools');
const { logger } = require('@librechat/data-schemas');
const { Tools, EToolResources } = require('librechat-data-provider');
const { generateShortLivedToken } = require('~/server/services/AuthService');
const { getFiles } = require('~/models/File');
const { logger } = require('~/config');
/**
*
@@ -59,7 +60,7 @@ const createFileSearchTool = async ({ req, files, entity_id }) => {
if (files.length === 0) {
return 'No files to search. Instruct the user to add files for the search.';
}
const jwtToken = req.headers.authorization.split(' ')[1];
const jwtToken = generateShortLivedToken(req.user.id);
if (!jwtToken) {
return 'There was an error authenticating the file search request.';
}

View File

@@ -1,7 +1,8 @@
const { logger } = require('@librechat/data-schemas');
const { isEnabled, math } = require('@librechat/api');
const { ViolationTypes } = require('librechat-data-provider');
const { isEnabled, math, removePorts } = require('~/server/utils');
const { deleteAllUserSessions } = require('~/models');
const { removePorts } = require('~/server/utils');
const getLogStores = require('./getLogStores');
const { BAN_VIOLATIONS, BAN_INTERVAL } = process.env ?? {};

View File

@@ -1,7 +1,7 @@
const { Keyv } = require('keyv');
const { isEnabled, math } = require('@librechat/api');
const { CacheKeys, ViolationTypes, Time } = require('librechat-data-provider');
const { logFile, violationFile } = require('./keyvFiles');
const { isEnabled, math } = require('~/server/utils');
const keyvRedis = require('./keyvRedis');
const keyvMongo = require('./keyvMongo');

View File

@@ -4,7 +4,7 @@ const { logger } = require('@librechat/data-schemas');
const { SystemRoles, Tools, actionDelimiter } = require('librechat-data-provider');
const { GLOBAL_PROJECT_NAME, EPHEMERAL_AGENT_ID, mcp_delimiter } =
require('librechat-data-provider').Constants;
const { CONFIG_STORE, STARTUP_CONFIG } = require('librechat-data-provider').CacheKeys;
// Default category value for new agents
const {
getProjectByName,
addAgentIdsToProject,
@@ -12,7 +12,9 @@ const {
removeAgentFromAllProjects,
} = require('./Project');
const { getCachedTools } = require('~/server/services/Config');
const getLogStores = require('~/cache/getLogStores');
// Category values are now imported from shared constants
// Schema fields (category, support_contact, is_promoted) are defined in @librechat/data-schemas
const { getActions } = require('./Action');
const { Agent } = require('~/db/models');
@@ -23,7 +25,7 @@ const { Agent } = require('~/db/models');
* @throws {Error} If the agent creation fails.
*/
const createAgent = async (agentData) => {
const { author, ...versionData } = agentData;
const { author: _author, ...versionData } = agentData;
const timestamp = new Date();
const initialAgentData = {
...agentData,
@@ -34,6 +36,7 @@ const createAgent = async (agentData) => {
updatedAt: timestamp,
},
],
category: agentData.category || 'general',
};
return (await Agent.create(initialAgentData)).toObject();
};
@@ -126,29 +129,7 @@ const loadAgent = async ({ req, agent_id, endpoint, model_parameters }) => {
}
agent.version = agent.versions ? agent.versions.length : 0;
if (agent.author.toString() === req.user.id) {
return agent;
}
if (!agent.projectIds) {
return null;
}
const cache = getLogStores(CONFIG_STORE);
/** @type {TStartupConfig} */
const cachedStartupConfig = await cache.get(STARTUP_CONFIG);
let { instanceProjectId } = cachedStartupConfig ?? {};
if (!instanceProjectId) {
instanceProjectId = (await getProjectByName(GLOBAL_PROJECT_NAME, '_id'))._id.toString();
}
for (const projectObjectId of agent.projectIds) {
const projectId = projectObjectId.toString();
if (projectId === instanceProjectId) {
return agent;
}
}
return agent;
};
/**
@@ -178,7 +159,7 @@ const isDuplicateVersion = (updateData, currentData, versions, actionsHash = nul
'actionsHash', // Exclude actionsHash from direct comparison
];
const { $push, $pull, $addToSet, ...directUpdates } = updateData;
const { $push: _$push, $pull: _$pull, $addToSet: _$addToSet, ...directUpdates } = updateData;
if (Object.keys(directUpdates).length === 0 && !actionsHash) {
return null;
@@ -197,54 +178,116 @@ const isDuplicateVersion = (updateData, currentData, versions, actionsHash = nul
let isMatch = true;
for (const field of importantFields) {
if (!wouldBeVersion[field] && !lastVersion[field]) {
const wouldBeValue = wouldBeVersion[field];
const lastVersionValue = lastVersion[field];
// Skip if both are undefined/null
if (!wouldBeValue && !lastVersionValue) {
continue;
}
if (Array.isArray(wouldBeVersion[field]) && Array.isArray(lastVersion[field])) {
if (wouldBeVersion[field].length !== lastVersion[field].length) {
// Handle arrays
if (Array.isArray(wouldBeValue) || Array.isArray(lastVersionValue)) {
// Normalize: treat undefined/null as empty array for comparison
let wouldBeArr;
if (Array.isArray(wouldBeValue)) {
wouldBeArr = wouldBeValue;
} else if (wouldBeValue == null) {
wouldBeArr = [];
} else {
wouldBeArr = [wouldBeValue];
}
let lastVersionArr;
if (Array.isArray(lastVersionValue)) {
lastVersionArr = lastVersionValue;
} else if (lastVersionValue == null) {
lastVersionArr = [];
} else {
lastVersionArr = [lastVersionValue];
}
if (wouldBeArr.length !== lastVersionArr.length) {
isMatch = false;
break;
}
// Special handling for projectIds (MongoDB ObjectIds)
if (field === 'projectIds') {
const wouldBeIds = wouldBeVersion[field].map((id) => id.toString()).sort();
const versionIds = lastVersion[field].map((id) => id.toString()).sort();
const wouldBeIds = wouldBeArr.map((id) => id.toString()).sort();
const versionIds = lastVersionArr.map((id) => id.toString()).sort();
if (!wouldBeIds.every((id, i) => id === versionIds[i])) {
isMatch = false;
break;
}
}
// Handle arrays of objects like tool_kwargs
else if (typeof wouldBeVersion[field][0] === 'object' && wouldBeVersion[field][0] !== null) {
const sortedWouldBe = [...wouldBeVersion[field]].map((item) => JSON.stringify(item)).sort();
const sortedVersion = [...lastVersion[field]].map((item) => JSON.stringify(item)).sort();
// Handle arrays of objects
else if (
wouldBeArr.length > 0 &&
typeof wouldBeArr[0] === 'object' &&
wouldBeArr[0] !== null
) {
const sortedWouldBe = [...wouldBeArr].map((item) => JSON.stringify(item)).sort();
const sortedVersion = [...lastVersionArr].map((item) => JSON.stringify(item)).sort();
if (!sortedWouldBe.every((item, i) => item === sortedVersion[i])) {
isMatch = false;
break;
}
} else {
const sortedWouldBe = [...wouldBeVersion[field]].sort();
const sortedVersion = [...lastVersion[field]].sort();
const sortedWouldBe = [...wouldBeArr].sort();
const sortedVersion = [...lastVersionArr].sort();
if (!sortedWouldBe.every((item, i) => item === sortedVersion[i])) {
isMatch = false;
break;
}
}
} else if (field === 'model_parameters') {
const wouldBeParams = wouldBeVersion[field] || {};
const lastVersionParams = lastVersion[field] || {};
if (JSON.stringify(wouldBeParams) !== JSON.stringify(lastVersionParams)) {
}
// Handle objects
else if (typeof wouldBeValue === 'object' && wouldBeValue !== null) {
const lastVersionObj =
typeof lastVersionValue === 'object' && lastVersionValue !== null ? lastVersionValue : {};
// For empty objects, normalize the comparison
const wouldBeKeys = Object.keys(wouldBeValue);
const lastVersionKeys = Object.keys(lastVersionObj);
// If both are empty objects, they're equal
if (wouldBeKeys.length === 0 && lastVersionKeys.length === 0) {
continue;
}
// Otherwise do a deep comparison
if (JSON.stringify(wouldBeValue) !== JSON.stringify(lastVersionObj)) {
isMatch = false;
break;
}
}
// Handle primitive values
else {
// For primitives, handle the case where one is undefined and the other is a default value
if (wouldBeValue !== lastVersionValue) {
// Special handling for boolean false vs undefined
if (
typeof wouldBeValue === 'boolean' &&
wouldBeValue === false &&
lastVersionValue === undefined
) {
continue;
}
// Special handling for empty string vs undefined
if (
typeof wouldBeValue === 'string' &&
wouldBeValue === '' &&
lastVersionValue === undefined
) {
continue;
}
isMatch = false;
break;
}
} else if (wouldBeVersion[field] !== lastVersion[field]) {
isMatch = false;
break;
}
}
@@ -273,7 +316,14 @@ const updateAgent = async (searchParameter, updateData, options = {}) => {
const currentAgent = await Agent.findOne(searchParameter);
if (currentAgent) {
const { __v, _id, id, versions, author, ...versionData } = currentAgent.toObject();
const {
__v,
_id,
id: __id,
versions,
author: _author,
...versionData
} = currentAgent.toObject();
const { $push, $pull, $addToSet, ...directUpdates } = updateData;
let actionsHash = null;
@@ -464,8 +514,113 @@ const deleteAgent = async (searchParameter) => {
return agent;
};
/**
* Get agents by accessible IDs with optional cursor-based pagination.
* @param {Object} params - The parameters for getting accessible agents.
* @param {Array} [params.accessibleIds] - Array of agent ObjectIds the user has ACL access to.
* @param {Object} [params.otherParams] - Additional query parameters (including author filter).
* @param {number} [params.limit] - Number of agents to return (max 100). If not provided, returns all agents.
* @param {string} [params.after] - Cursor for pagination - get agents after this cursor. // base64 encoded JSON string with updatedAt and _id.
* @returns {Promise<Object>} A promise that resolves to an object containing the agents data and pagination info.
*/
const getListAgentsByAccess = async ({
accessibleIds = [],
otherParams = {},
limit = null,
after = null,
}) => {
const isPaginated = limit !== null && limit !== undefined;
const normalizedLimit = isPaginated ? Math.min(Math.max(1, parseInt(limit) || 20), 100) : null;
// Build base query combining ACL accessible agents with other filters
const baseQuery = { ...otherParams };
if (accessibleIds.length > 0) {
baseQuery._id = { $in: accessibleIds };
}
// Add cursor condition
if (after) {
try {
const cursor = JSON.parse(Buffer.from(after, 'base64').toString('utf8'));
const { updatedAt, _id } = cursor;
const cursorCondition = {
$or: [
{ updatedAt: { $lt: new Date(updatedAt) } },
{ updatedAt: new Date(updatedAt), _id: { $gt: new mongoose.Types.ObjectId(_id) } },
],
};
// Merge cursor condition with base query
if (Object.keys(baseQuery).length > 0) {
baseQuery.$and = [{ ...baseQuery }, cursorCondition];
// Remove the original conditions from baseQuery to avoid duplication
Object.keys(baseQuery).forEach((key) => {
if (key !== '$and') delete baseQuery[key];
});
} else {
Object.assign(baseQuery, cursorCondition);
}
} catch (error) {
logger.warn('Invalid cursor:', error.message);
}
}
let query = Agent.find(baseQuery, {
id: 1,
_id: 1,
name: 1,
avatar: 1,
author: 1,
projectIds: 1,
description: 1,
updatedAt: 1,
category: 1,
support_contact: 1,
is_promoted: 1,
}).sort({ updatedAt: -1, _id: 1 });
// Only apply limit if pagination is requested
if (isPaginated) {
query = query.limit(normalizedLimit + 1);
}
const agents = await query.lean();
const hasMore = isPaginated ? agents.length > normalizedLimit : false;
const data = (isPaginated ? agents.slice(0, normalizedLimit) : agents).map((agent) => {
if (agent.author) {
agent.author = agent.author.toString();
}
return agent;
});
// Generate next cursor only if paginated
let nextCursor = null;
if (isPaginated && hasMore && data.length > 0) {
const lastAgent = agents[normalizedLimit - 1];
nextCursor = Buffer.from(
JSON.stringify({
updatedAt: lastAgent.updatedAt.toISOString(),
_id: lastAgent._id.toString(),
}),
).toString('base64');
}
return {
object: 'list',
data,
first_id: data.length > 0 ? data[0].id : null,
last_id: data.length > 0 ? data[data.length - 1].id : null,
has_more: hasMore,
after: nextCursor,
};
};
/**
* Get all agents.
* @deprecated Use getListAgentsByAccess for ACL-aware agent listing
* @param {Object} searchParameter - The search parameters to find matching agents.
* @param {string} searchParameter.author - The user ID of the agent's author.
* @returns {Promise<Object>} A promise that resolves to an object containing the agents data and pagination info.
@@ -484,13 +639,15 @@ const getListAgents = async (searchParameter) => {
const agents = (
await Agent.find(query, {
id: 1,
_id: 0,
_id: 1,
name: 1,
avatar: 1,
author: 1,
projectIds: 1,
description: 1,
// @deprecated - isCollaborative replaced by ACL permissions
isCollaborative: 1,
category: 1,
}).lean()
).map((agent) => {
if (agent.author?.toString() !== author) {
@@ -656,6 +813,14 @@ const generateActionMetadataHash = async (actionIds, actions) => {
return hashHex;
};
/**
* Counts the number of promoted agents.
* @returns {Promise<number>} - The count of promoted agents
*/
const countPromotedAgents = async () => {
const count = await Agent.countDocuments({ is_promoted: true });
return count;
};
/**
* Load a default agent based on the endpoint
@@ -673,6 +838,8 @@ module.exports = {
revertAgentVersion,
updateAgentProjects,
addAgentResourceFile,
getListAgentsByAccess,
removeAgentResourceFiles,
generateActionMetadataHash,
countPromotedAgents,
};

View File

@@ -1633,7 +1633,7 @@ describe('models/Agent', () => {
expect(result.version).toBe(1);
});
test('should return null when user is not author and agent has no projectIds', async () => {
test('should return agent even when user is not author (permissions checked at route level)', async () => {
const authorId = new mongoose.Types.ObjectId();
const userId = new mongoose.Types.ObjectId();
const agentId = `agent_${uuidv4()}`;
@@ -1654,7 +1654,11 @@ describe('models/Agent', () => {
model_parameters: { model: 'gpt-4' },
});
expect(result).toBeFalsy();
// With the new permission system, loadAgent returns the agent regardless of permissions
// Permission checks are handled at the route level via middleware
expect(result).toBeTruthy();
expect(result.id).toBe(agentId);
expect(result.name).toBe('Test Agent');
});
test('should handle ephemeral agent with no MCP servers', async () => {
@@ -1762,7 +1766,7 @@ describe('models/Agent', () => {
}
});
test('should handle loadAgent with agent from different project', async () => {
test('should return agent from different project (permissions checked at route level)', async () => {
const authorId = new mongoose.Types.ObjectId();
const userId = new mongoose.Types.ObjectId();
const agentId = `agent_${uuidv4()}`;
@@ -1785,7 +1789,11 @@ describe('models/Agent', () => {
model_parameters: { model: 'gpt-4' },
});
expect(result).toBeFalsy();
// With the new permission system, loadAgent returns the agent regardless of permissions
// Permission checks are handled at the route level via middleware
expect(result).toBeTruthy();
expect(result.id).toBe(agentId);
expect(result.name).toBe('Project Agent');
});
});
});

View File

@@ -1,4 +1,6 @@
const { logger } = require('@librechat/data-schemas');
const { createTempChatExpirationDate } = require('@librechat/api');
const getCustomConfig = require('~/server/services/Config/loadCustomConfig');
const { getMessages, deleteMessages } = require('./Message');
const { Conversation } = require('~/db/models');
@@ -98,10 +100,15 @@ module.exports = {
update.conversationId = newConversationId;
}
if (req.body.isTemporary) {
const expiredAt = new Date();
expiredAt.setDate(expiredAt.getDate() + 30);
update.expiredAt = expiredAt;
if (req?.body?.isTemporary) {
try {
const customConfig = await getCustomConfig();
update.expiredAt = createTempChatExpirationDate(customConfig);
} catch (err) {
logger.error('Error creating temporary chat expiration date:', err);
logger.info(`---\`saveConvo\` context: ${metadata?.context}`);
update.expiredAt = null;
}
} else {
update.expiredAt = null;
}

View File

@@ -1,5 +1,5 @@
const { logger } = require('@librechat/data-schemas');
const { EToolResources } = require('librechat-data-provider');
const { EToolResources, FileContext } = require('librechat-data-provider');
const { File } = require('~/db/models');
/**
@@ -32,19 +32,19 @@ const getFiles = async (filter, _sortOptions, selectFields = { text: 0 }) => {
* @returns {Promise<Array<MongoFile>>} Files that match the criteria
*/
const getToolFilesByIds = async (fileIds, toolResourceSet) => {
if (!fileIds || !fileIds.length) {
if (!fileIds || !fileIds.length || !toolResourceSet?.size) {
return [];
}
try {
const filter = {
file_id: { $in: fileIds },
$or: [],
};
if (toolResourceSet.size) {
filter.$or = [];
if (toolResourceSet.has(EToolResources.ocr)) {
filter.$or.push({ text: { $exists: true, $ne: null }, context: FileContext.agents });
}
if (toolResourceSet.has(EToolResources.file_search)) {
filter.$or.push({ embedded: true });
}

View File

@@ -1,5 +1,7 @@
const { z } = require('zod');
const { logger } = require('@librechat/data-schemas');
const { createTempChatExpirationDate } = require('@librechat/api');
const getCustomConfig = require('~/server/services/Config/loadCustomConfig');
const { Message } = require('~/db/models');
const idSchema = z.string().uuid();
@@ -54,9 +56,14 @@ async function saveMessage(req, params, metadata) {
};
if (req?.body?.isTemporary) {
const expiredAt = new Date();
expiredAt.setDate(expiredAt.getDate() + 30);
update.expiredAt = expiredAt;
try {
const customConfig = await getCustomConfig();
update.expiredAt = createTempChatExpirationDate(customConfig);
} catch (err) {
logger.error('Error creating temporary chat expiration date:', err);
logger.info(`---\`saveMessage\` context: ${metadata?.context}`);
update.expiredAt = null;
}
} else {
update.expiredAt = null;
}

View File

@@ -2,7 +2,6 @@ const {
CacheKeys,
SystemRoles,
roleDefaults,
PermissionTypes,
permissionsSchema,
removeNullishValues,
} = require('librechat-data-provider');

View File

@@ -48,14 +48,14 @@
"@langchain/google-genai": "^0.2.13",
"@langchain/google-vertexai": "^0.2.13",
"@langchain/textsplitters": "^0.1.0",
"@librechat/agents": "^2.4.41",
"@librechat/agents": "^2.4.46",
"@librechat/api": "*",
"@librechat/data-schemas": "*",
"@node-saml/passport-saml": "^5.0.0",
"@microsoft/microsoft-graph-client": "^3.0.7",
"@waylaidwanderer/fetch-event-source": "^3.0.1",
"axios": "^1.8.2",
"bcryptjs": "^2.4.3",
"cohere-ai": "^7.9.1",
"compression": "^1.7.4",
"connect-redis": "^7.1.0",
"cookie": "^0.7.2",

View File

@@ -169,9 +169,6 @@ function disposeClient(client) {
client.isGenerativeModel = null;
}
// Properties specific to OpenAIClient
if (client.ChatGPTClient) {
client.ChatGPTClient = null;
}
if (client.completionsUrl) {
client.completionsUrl = null;
}

View File

@@ -1,17 +1,17 @@
const cookies = require('cookie');
const jwt = require('jsonwebtoken');
const openIdClient = require('openid-client');
const { isEnabled } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const {
registerUser,
resetPassword,
setAuthTokens,
requestPasswordReset,
setOpenIDAuthTokens,
resetPassword,
setAuthTokens,
registerUser,
} = require('~/server/services/AuthService');
const { findUser, getUserById, deleteAllUserSessions, findSession } = require('~/models');
const { getOpenIdConfig } = require('~/strategies');
const { isEnabled } = require('~/server/utils');
const registrationController = async (req, res) => {
try {

View File

@@ -1,3 +1,5 @@
const { sendEvent } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { getResponseSender } = require('librechat-data-provider');
const {
handleAbortError,
@@ -10,9 +12,8 @@ const {
clientRegistry,
requestDataMap,
} = require('~/server/cleanup');
const { sendMessage, createOnProgress } = require('~/server/utils');
const { createOnProgress } = require('~/server/utils');
const { saveMessage } = require('~/models');
const { logger } = require('~/config');
const EditController = async (req, res, next, initializeClient) => {
let {
@@ -198,7 +199,7 @@ const EditController = async (req, res, next, initializeClient) => {
const finalUserMessage = reqDataContext.userMessage;
const finalResponseMessage = { ...response };
sendMessage(res, {
sendEvent(res, {
final: true,
conversation,
title: conversation.title,

View File

@@ -0,0 +1,437 @@
/**
* @import { TUpdateResourcePermissionsRequest, TUpdateResourcePermissionsResponse } from 'librechat-data-provider'
*/
const mongoose = require('mongoose');
const { logger } = require('@librechat/data-schemas');
const {
getAvailableRoles,
ensurePrincipalExists,
getEffectivePermissions,
ensureGroupPrincipalExists,
bulkUpdateResourcePermissions,
} = require('~/server/services/PermissionService');
const { AclEntry } = require('~/db/models');
const {
searchPrincipals: searchLocalPrincipals,
sortPrincipalsByRelevance,
calculateRelevanceScore,
} = require('~/models');
const {
searchEntraIdPrincipals,
entraIdPrincipalFeatureEnabled,
} = require('~/server/services/GraphApiService');
/**
* Generic controller for resource permission endpoints
* Delegates validation and logic to PermissionService
*/
/**
* Bulk update permissions for a resource (grant, update, remove)
* @route PUT /api/{resourceType}/{resourceId}/permissions
* @param {Object} req - Express request object
* @param {Object} req.params - Route parameters
* @param {string} req.params.resourceType - Resource type (e.g., 'agent')
* @param {string} req.params.resourceId - Resource ID
* @param {TUpdateResourcePermissionsRequest} req.body - Request body
* @param {Object} res - Express response object
* @returns {Promise<TUpdateResourcePermissionsResponse>} Updated permissions response
*/
const updateResourcePermissions = async (req, res) => {
try {
const { resourceType, resourceId } = req.params;
/** @type {TUpdateResourcePermissionsRequest} */
const { updated, removed, public: isPublic, publicAccessRoleId } = req.body;
const { id: userId } = req.user;
// Prepare principals for the service call
const updatedPrincipals = [];
const revokedPrincipals = [];
// Add updated principals
if (updated && Array.isArray(updated)) {
updatedPrincipals.push(...updated);
}
// Add public permission if enabled
if (isPublic && publicAccessRoleId) {
updatedPrincipals.push({
type: 'public',
id: null,
accessRoleId: publicAccessRoleId,
});
}
// Prepare authentication context for enhanced group member fetching
const useEntraId = entraIdPrincipalFeatureEnabled(req.user);
const authHeader = req.headers.authorization;
const accessToken =
authHeader && authHeader.startsWith('Bearer ') ? authHeader.substring(7) : null;
const authContext =
useEntraId && accessToken
? {
accessToken,
sub: req.user.openidId,
}
: null;
// Ensure updated principals exist in the database before processing permissions
const validatedPrincipals = [];
for (const principal of updatedPrincipals) {
try {
let principalId;
if (principal.type === 'public') {
principalId = null; // Public principals don't need database records
} else if (principal.type === 'user') {
principalId = await ensurePrincipalExists(principal);
} else if (principal.type === 'group') {
// Pass authContext to enable member fetching for Entra ID groups when available
principalId = await ensureGroupPrincipalExists(principal, authContext);
} else {
logger.error(`Unsupported principal type: ${principal.type}`);
continue; // Skip invalid principal types
}
// Update the principal with the validated ID for ACL operations
validatedPrincipals.push({
...principal,
id: principalId,
});
} catch (error) {
logger.error('Error ensuring principal exists:', {
principal: {
type: principal.type,
id: principal.id,
name: principal.name,
source: principal.source,
},
error: error.message,
});
// Continue with other principals instead of failing the entire operation
continue;
}
}
// Add removed principals
if (removed && Array.isArray(removed)) {
revokedPrincipals.push(...removed);
}
// If public is disabled, add public to revoked list
if (!isPublic) {
revokedPrincipals.push({
type: 'public',
id: null,
});
}
const results = await bulkUpdateResourcePermissions({
resourceType,
resourceId,
updatedPrincipals: validatedPrincipals,
revokedPrincipals,
grantedBy: userId,
});
/** @type {TUpdateResourcePermissionsResponse} */
const response = {
message: 'Permissions updated successfully',
results: {
principals: results.granted,
public: isPublic || false,
publicAccessRoleId: isPublic ? publicAccessRoleId : undefined,
},
};
res.status(200).json(response);
} catch (error) {
logger.error('Error updating resource permissions:', error);
res.status(400).json({
error: 'Failed to update permissions',
details: error.message,
});
}
};
/**
* Get principals with their permission roles for a resource (UI-friendly format)
* Uses efficient aggregation pipeline to join User/Group data in single query
* @route GET /api/permissions/{resourceType}/{resourceId}
*/
const getResourcePermissions = async (req, res) => {
try {
const { resourceType, resourceId } = req.params;
// Use aggregation pipeline for efficient single-query data retrieval
const results = await AclEntry.aggregate([
// Match ACL entries for this resource
{
$match: {
resourceType,
resourceId: mongoose.Types.ObjectId.isValid(resourceId)
? mongoose.Types.ObjectId.createFromHexString(resourceId)
: resourceId,
},
},
// Lookup AccessRole information
{
$lookup: {
from: 'accessroles',
localField: 'roleId',
foreignField: '_id',
as: 'role',
},
},
// Lookup User information (for user principals)
{
$lookup: {
from: 'users',
localField: 'principalId',
foreignField: '_id',
as: 'userInfo',
},
},
// Lookup Group information (for group principals)
{
$lookup: {
from: 'groups',
localField: 'principalId',
foreignField: '_id',
as: 'groupInfo',
},
},
// Project final structure
{
$project: {
principalType: 1,
principalId: 1,
accessRoleId: { $arrayElemAt: ['$role.accessRoleId', 0] },
userInfo: { $arrayElemAt: ['$userInfo', 0] },
groupInfo: { $arrayElemAt: ['$groupInfo', 0] },
},
},
]);
const principals = [];
let publicPermission = null;
// Process aggregation results
for (const result of results) {
if (result.principalType === 'public') {
publicPermission = {
public: true,
publicAccessRoleId: result.accessRoleId,
};
} else if (result.principalType === 'user' && result.userInfo) {
principals.push({
type: 'user',
id: result.userInfo._id.toString(),
name: result.userInfo.name || result.userInfo.username,
email: result.userInfo.email,
avatar: result.userInfo.avatar,
source: !result.userInfo._id ? 'entra' : 'local',
idOnTheSource: result.userInfo.idOnTheSource || result.userInfo._id.toString(),
accessRoleId: result.accessRoleId,
});
} else if (result.principalType === 'group' && result.groupInfo) {
principals.push({
type: 'group',
id: result.groupInfo._id.toString(),
name: result.groupInfo.name,
email: result.groupInfo.email,
description: result.groupInfo.description,
avatar: result.groupInfo.avatar,
source: result.groupInfo.source || 'local',
idOnTheSource: result.groupInfo.idOnTheSource || result.groupInfo._id.toString(),
accessRoleId: result.accessRoleId,
});
}
}
// Return response in format expected by frontend
const response = {
resourceType,
resourceId,
principals,
public: publicPermission?.public || false,
...(publicPermission?.publicAccessRoleId && {
publicAccessRoleId: publicPermission.publicAccessRoleId,
}),
};
res.status(200).json(response);
} catch (error) {
logger.error('Error getting resource permissions principals:', error);
res.status(500).json({
error: 'Failed to get permissions principals',
details: error.message,
});
}
};
/**
* Get available roles for a resource type
* @route GET /api/{resourceType}/roles
*/
const getResourceRoles = async (req, res) => {
try {
const { resourceType } = req.params;
const roles = await getAvailableRoles({ resourceType });
res.status(200).json(
roles.map((role) => ({
accessRoleId: role.accessRoleId,
name: role.name,
description: role.description,
permBits: role.permBits,
})),
);
} catch (error) {
logger.error('Error getting resource roles:', error);
res.status(500).json({
error: 'Failed to get roles',
details: error.message,
});
}
};
/**
* Get user's effective permission bitmask for a resource
* @route GET /api/{resourceType}/{resourceId}/effective
*/
const getUserEffectivePermissions = async (req, res) => {
try {
const { resourceType, resourceId } = req.params;
const { id: userId } = req.user;
const permissionBits = await getEffectivePermissions({
userId,
resourceType,
resourceId,
});
res.status(200).json({
permissionBits,
});
} catch (error) {
logger.error('Error getting user effective permissions:', error);
res.status(500).json({
error: 'Failed to get effective permissions',
details: error.message,
});
}
};
/**
* Search for users and groups to grant permissions
* Supports hybrid local database + Entra ID search when configured
* @route GET /api/permissions/search-principals
*/
const searchPrincipals = async (req, res) => {
try {
const { q: query, limit = 20, type } = req.query;
if (!query || query.trim().length === 0) {
return res.status(400).json({
error: 'Query parameter "q" is required and must not be empty',
});
}
if (query.trim().length < 2) {
return res.status(400).json({
error: 'Query must be at least 2 characters long',
});
}
const searchLimit = Math.min(Math.max(1, parseInt(limit) || 10), 50);
const typeFilter = ['user', 'group'].includes(type) ? type : null;
const localResults = await searchLocalPrincipals(query.trim(), searchLimit, typeFilter);
let allPrincipals = [...localResults];
const useEntraId = entraIdPrincipalFeatureEnabled(req.user);
if (useEntraId && localResults.length < searchLimit) {
try {
const graphTypeMap = {
user: 'users',
group: 'groups',
null: 'all',
};
const authHeader = req.headers.authorization;
const accessToken =
authHeader && authHeader.startsWith('Bearer ') ? authHeader.substring(7) : null;
if (accessToken) {
const graphResults = await searchEntraIdPrincipals(
accessToken,
req.user.openidId,
query.trim(),
graphTypeMap[typeFilter],
searchLimit - localResults.length,
);
const localEmails = new Set(
localResults.map((p) => p.email?.toLowerCase()).filter(Boolean),
);
const localGroupSourceIds = new Set(
localResults.map((p) => p.idOnTheSource).filter(Boolean),
);
for (const principal of graphResults) {
const isDuplicateByEmail =
principal.email && localEmails.has(principal.email.toLowerCase());
const isDuplicateBySourceId =
principal.idOnTheSource && localGroupSourceIds.has(principal.idOnTheSource);
if (!isDuplicateByEmail && !isDuplicateBySourceId) {
allPrincipals.push(principal);
}
}
}
} catch (graphError) {
logger.warn('Graph API search failed, falling back to local results:', graphError.message);
}
}
const scoredResults = allPrincipals.map((item) => ({
...item,
_searchScore: calculateRelevanceScore(item, query.trim()),
}));
allPrincipals = sortPrincipalsByRelevance(scoredResults)
.slice(0, searchLimit)
.map((result) => {
const { _searchScore, ...resultWithoutScore } = result;
return resultWithoutScore;
});
res.status(200).json({
query: query.trim(),
limit: searchLimit,
type: typeFilter,
results: allPrincipals,
count: allPrincipals.length,
sources: {
local: allPrincipals.filter((r) => r.source === 'local').length,
entra: allPrincipals.filter((r) => r.source === 'entra').length,
},
});
} catch (error) {
logger.error('Error searching principals:', error);
res.status(500).json({
error: 'Failed to search principals',
details: error.message,
});
}
};
module.exports = {
updateResourcePermissions,
getResourcePermissions,
getResourceRoles,
getUserEffectivePermissions,
searchPrincipals,
};

View File

@@ -4,11 +4,13 @@ const {
sendEvent,
createRun,
Tokenizer,
checkAccess,
memoryInstructions,
createMemoryProcessor,
} = require('@librechat/api');
const {
Callback,
Providers,
GraphEvents,
formatMessage,
formatAgentMessages,
@@ -31,22 +33,29 @@ const {
} = require('librechat-data-provider');
const { DynamicStructuredTool } = require('@langchain/core/tools');
const { getBufferString, HumanMessage } = require('@langchain/core/messages');
const {
getCustomEndpointConfig,
createGetMCPAuthMap,
checkCapability,
} = require('~/server/services/Config');
const { createGetMCPAuthMap, checkCapability } = require('~/server/services/Config');
const { addCacheControl, createContextHandlers } = require('~/app/clients/prompts');
const { initializeAgent } = require('~/server/services/Endpoints/agents/agent');
const { spendTokens, spendStructuredTokens } = require('~/models/spendTokens');
const { getFormattedMemories, deleteMemory, setMemory } = require('~/models');
const { encodeAndFormat } = require('~/server/services/Files/images/encode');
const initOpenAI = require('~/server/services/Endpoints/openAI/initialize');
const { checkAccess } = require('~/server/middleware/roles/access');
const { getProviderConfig } = require('~/server/services/Endpoints');
const BaseClient = require('~/app/clients/BaseClient');
const { getRoleByName } = require('~/models/Role');
const { loadAgent } = require('~/models/Agent');
const { getMCPManager } = require('~/config');
const omitTitleOptions = new Set([
'stream',
'thinking',
'streaming',
'clientOptions',
'thinkingConfig',
'thinkingBudget',
'includeThoughts',
'maxOutputTokens',
]);
/**
* @param {ServerRequest} req
* @param {Agent} agent
@@ -393,7 +402,12 @@ class AgentClient extends BaseClient {
if (user.personalization?.memories === false) {
return;
}
const hasAccess = await checkAccess(user, PermissionTypes.MEMORIES, [Permissions.USE]);
const hasAccess = await checkAccess({
user,
permissionType: PermissionTypes.MEMORIES,
permissions: [Permissions.USE],
getRoleByName,
});
if (!hasAccess) {
logger.debug(
@@ -677,7 +691,7 @@ class AgentClient extends BaseClient {
hide_sequential_outputs: this.options.agent.hide_sequential_outputs,
user: this.options.req.user,
},
recursionLimit: agentsEConfig?.recursionLimit,
recursionLimit: agentsEConfig?.recursionLimit ?? 25,
signal: abortController.signal,
streamMode: 'values',
version: 'v2',
@@ -983,23 +997,26 @@ class AgentClient extends BaseClient {
throw new Error('Run not initialized');
}
const { handleLLMEnd, collected: collectedMetadata } = createMetadataAggregator();
const endpoint = this.options.agent.endpoint;
const { req, res } = this.options;
const { req, res, agent } = this.options;
const endpoint = agent.endpoint;
/** @type {import('@librechat/agents').ClientOptions} */
let clientOptions = {
maxTokens: 75,
model: agent.model_parameters.model,
};
let endpointConfig = req.app.locals[endpoint];
const { getOptions, overrideProvider, customEndpointConfig } =
await getProviderConfig(endpoint);
/** @type {TEndpoint | undefined} */
const endpointConfig = req.app.locals[endpoint] ?? customEndpointConfig;
if (!endpointConfig) {
try {
endpointConfig = await getCustomEndpointConfig(endpoint);
} catch (err) {
logger.error(
'[api/server/controllers/agents/client.js #titleConvo] Error getting custom endpoint config',
err,
);
}
logger.warn(
'[api/server/controllers/agents/client.js #titleConvo] Error getting endpoint config',
);
}
if (
endpointConfig &&
endpointConfig.titleModel &&
@@ -1007,30 +1024,50 @@ class AgentClient extends BaseClient {
) {
clientOptions.model = endpointConfig.titleModel;
}
const options = await getOptions({
req,
res,
optionsOnly: true,
overrideEndpoint: endpoint,
overrideModel: clientOptions.model,
endpointOption: { model_parameters: clientOptions },
});
let provider = options.provider ?? overrideProvider ?? agent.provider;
if (
endpoint === EModelEndpoint.azureOpenAI &&
clientOptions.model &&
this.options.agent.model_parameters.model !== clientOptions.model
options.llmConfig?.azureOpenAIApiInstanceName == null
) {
clientOptions =
(
await initOpenAI({
req,
res,
optionsOnly: true,
overrideModel: clientOptions.model,
overrideEndpoint: endpoint,
endpointOption: {
model_parameters: clientOptions,
},
})
)?.llmConfig ?? clientOptions;
provider = Providers.OPENAI;
}
if (/\b(o\d)\b/i.test(clientOptions.model) && clientOptions.maxTokens != null) {
/** @type {import('@librechat/agents').ClientOptions} */
clientOptions = { ...options.llmConfig };
if (options.configOptions) {
clientOptions.configuration = options.configOptions;
}
// Ensure maxTokens is set for non-o1 models
if (!/\b(o\d)\b/i.test(clientOptions.model) && !clientOptions.maxTokens) {
clientOptions.maxTokens = 75;
} else if (/\b(o\d)\b/i.test(clientOptions.model) && clientOptions.maxTokens != null) {
delete clientOptions.maxTokens;
}
clientOptions = Object.assign(
Object.fromEntries(
Object.entries(clientOptions).filter(([key]) => !omitTitleOptions.has(key)),
),
);
if (provider === Providers.GOOGLE) {
clientOptions.json = true;
}
try {
const titleResult = await this.run.generateTitle({
provider,
inputText: text,
contentParts: this.contentParts,
clientOptions,
@@ -1048,8 +1085,10 @@ class AgentClient extends BaseClient {
let input_tokens, output_tokens;
if (item.usage) {
input_tokens = item.usage.input_tokens || item.usage.inputTokens;
output_tokens = item.usage.output_tokens || item.usage.outputTokens;
input_tokens =
item.usage.prompt_tokens || item.usage.input_tokens || item.usage.inputTokens;
output_tokens =
item.usage.completion_tokens || item.usage.output_tokens || item.usage.outputTokens;
} else if (item.tokenUsage) {
input_tokens = item.tokenUsage.promptTokens;
output_tokens = item.tokenUsage.completionTokens;

View File

@@ -1,10 +1,10 @@
// errorHandler.js
const { logger } = require('~/config');
const getLogStores = require('~/cache/getLogStores');
const { logger } = require('@librechat/data-schemas');
const { CacheKeys, ViolationTypes } = require('librechat-data-provider');
const { sendResponse } = require('~/server/middleware/error');
const { recordUsage } = require('~/server/services/Threads');
const { getConvo } = require('~/models/Conversation');
const { sendResponse } = require('~/server/utils');
const getLogStores = require('~/cache/getLogStores');
/**
* @typedef {Object} ErrorHandlerContext
@@ -75,7 +75,7 @@ const createErrorHandler = ({ req, res, getContext, originPath = '/assistants/ch
} else if (/Files.*are invalid/.test(error.message)) {
const errorMessage = `Files are invalid, or may not have uploaded yet.${
endpoint === 'azureAssistants'
? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.'
? " If using Azure OpenAI, files are only available in the region of the assistant's model at the time of upload."
: ''
}`;
return sendResponse(req, res, messageData, errorMessage);

View File

@@ -1,106 +0,0 @@
const { HttpsProxyAgent } = require('https-proxy-agent');
const { resolveHeaders } = require('librechat-data-provider');
const { createLLM } = require('~/app/clients/llm');
/**
* Initializes and returns a Language Learning Model (LLM) instance.
*
* @param {Object} options - Configuration options for the LLM.
* @param {string} options.model - The model identifier.
* @param {string} options.modelName - The specific name of the model.
* @param {number} options.temperature - The temperature setting for the model.
* @param {number} options.presence_penalty - The presence penalty for the model.
* @param {number} options.frequency_penalty - The frequency penalty for the model.
* @param {number} options.max_tokens - The maximum number of tokens for the model output.
* @param {boolean} options.streaming - Whether to use streaming for the model output.
* @param {Object} options.context - The context for the conversation.
* @param {number} options.tokenBuffer - The token buffer size.
* @param {number} options.initialMessageCount - The initial message count.
* @param {string} options.conversationId - The ID of the conversation.
* @param {string} options.user - The user identifier.
* @param {string} options.langchainProxy - The langchain proxy URL.
* @param {boolean} options.useOpenRouter - Whether to use OpenRouter.
* @param {Object} options.options - Additional options.
* @param {Object} options.options.headers - Custom headers for the request.
* @param {string} options.options.proxy - Proxy URL.
* @param {Object} options.options.req - The request object.
* @param {Object} options.options.res - The response object.
* @param {boolean} options.options.debug - Whether to enable debug mode.
* @param {string} options.apiKey - The API key for authentication.
* @param {Object} options.azure - Azure-specific configuration.
* @param {Object} options.abortController - The AbortController instance.
* @returns {Object} The initialized LLM instance.
*/
function initializeLLM(options) {
const {
model,
modelName,
temperature,
presence_penalty,
frequency_penalty,
max_tokens,
streaming,
user,
langchainProxy,
useOpenRouter,
options: { headers, proxy },
apiKey,
azure,
} = options;
const modelOptions = {
modelName: modelName || model,
temperature,
presence_penalty,
frequency_penalty,
user,
};
if (max_tokens) {
modelOptions.max_tokens = max_tokens;
}
const configOptions = {};
if (langchainProxy) {
configOptions.basePath = langchainProxy;
}
if (useOpenRouter) {
configOptions.basePath = 'https://openrouter.ai/api/v1';
configOptions.baseOptions = {
headers: {
'HTTP-Referer': 'https://librechat.ai',
'X-Title': 'LibreChat',
},
};
}
if (headers && typeof headers === 'object' && !Array.isArray(headers)) {
configOptions.baseOptions = {
headers: resolveHeaders({
...headers,
...configOptions?.baseOptions?.headers,
}),
};
}
if (proxy) {
configOptions.httpAgent = new HttpsProxyAgent(proxy);
configOptions.httpsAgent = new HttpsProxyAgent(proxy);
}
const llm = createLLM({
modelOptions,
configOptions,
openAIApiKey: apiKey,
azure,
streaming,
});
return llm;
}
module.exports = {
initializeLLM,
};

View File

@@ -1,3 +1,5 @@
const { sendEvent } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { Constants } = require('librechat-data-provider');
const {
handleAbortError,
@@ -5,9 +7,7 @@ const {
cleanupAbortController,
} = require('~/server/middleware');
const { disposeClient, clientRegistry, requestDataMap } = require('~/server/cleanup');
const { sendMessage } = require('~/server/utils');
const { saveMessage } = require('~/models');
const { logger } = require('~/config');
const AgentController = async (req, res, next, initializeClient, addTitle) => {
let {
@@ -206,7 +206,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => {
// Create a new response object with minimal copies
const finalResponse = { ...response };
sendMessage(res, {
sendEvent(res, {
final: true,
conversation,
title: conversation.title,

View File

@@ -1,11 +1,10 @@
const fs = require('fs').promises;
const { nanoid } = require('nanoid');
const { logger } = require('@librechat/data-schemas');
const { logger, PermissionBits } = require('@librechat/data-schemas');
const {
Tools,
Constants,
FileSources,
SystemRoles,
FileSources,
EToolResources,
actionDelimiter,
} = require('librechat-data-provider');
@@ -14,18 +13,24 @@ const {
createAgent,
updateAgent,
deleteAgent,
getListAgents,
getListAgentsByAccess,
countPromotedAgents,
revertAgentVersion,
} = require('~/models/Agent');
const {
grantPermission,
findAccessibleResources,
findPubliclyAccessibleResources,
hasPublicPermission,
} = require('~/server/services/PermissionService');
const { getStrategyFunctions } = require('~/server/services/Files/strategies');
const { resizeAvatar } = require('~/server/services/Files/images/avatar');
const { refreshS3Url } = require('~/server/services/Files/S3/crud');
const { filterFile } = require('~/server/services/Files/process');
const { updateAction, getActions } = require('~/models/Action');
const { getCachedTools } = require('~/server/services/Config');
const { updateAgentProjects } = require('~/models/Agent');
const { getProjectByName } = require('~/models/Project');
const { revertAgentVersion } = require('~/models/Agent');
const { deleteFileByFilter } = require('~/models/File');
const { getCategoriesWithCounts } = require('~/models');
const systemTools = {
[Tools.execute_code]: true,
@@ -69,6 +74,27 @@ const createAgentHandler = async (req, res) => {
agentData.id = `agent_${nanoid()}`;
const agent = await createAgent(agentData);
// Automatically grant owner permissions to the creator
try {
await grantPermission({
principalType: 'user',
principalId: userId,
resourceType: 'agent',
resourceId: agent._id,
accessRoleId: 'agent_owner',
grantedBy: userId,
});
logger.debug(
`[createAgent] Granted owner permissions to user ${userId} for agent ${agent.id}`,
);
} catch (permissionError) {
logger.error(
`[createAgent] Failed to grant owner permissions for agent ${agent.id}:`,
permissionError,
);
}
res.status(201).json(agent);
} catch (error) {
logger.error('[/Agents] Error creating agent', error);
@@ -87,21 +113,14 @@ const createAgentHandler = async (req, res) => {
* @returns {Promise<Agent>} 200 - success response - application/json
* @returns {Error} 404 - Agent not found
*/
const getAgentHandler = async (req, res) => {
const getAgentHandler = async (req, res, expandProperties = false) => {
try {
const id = req.params.id;
const author = req.user.id;
let query = { id, author };
const globalProject = await getProjectByName(Constants.GLOBAL_PROJECT_NAME, ['agentIds']);
if (globalProject && (globalProject.agentIds?.length ?? 0) > 0) {
query = {
$or: [{ id, $in: globalProject.agentIds }, query],
};
}
const agent = await getAgent(query);
// Permissions are validated by middleware before calling this function
// Simply load the agent by ID
const agent = await getAgent({ id });
if (!agent) {
return res.status(404).json({ error: 'Agent not found' });
@@ -118,23 +137,45 @@ const getAgentHandler = async (req, res) => {
}
agent.author = agent.author.toString();
// @deprecated - isCollaborative replaced by ACL permissions
agent.isCollaborative = !!agent.isCollaborative;
// Check if agent is public
const isPublic = await hasPublicPermission({
resourceType: 'agent',
resourceId: agent._id,
requiredPermissions: PermissionBits.VIEW,
});
agent.isPublic = isPublic;
if (agent.author !== author) {
delete agent.author;
}
if (!agent.isCollaborative && agent.author !== author && req.user.role !== SystemRoles.ADMIN) {
if (!expandProperties) {
// VIEW permission: Basic agent info only
return res.status(200).json({
_id: agent._id,
id: agent.id,
name: agent.name,
description: agent.description,
avatar: agent.avatar,
author: agent.author,
provider: agent.provider,
model: agent.model,
projectIds: agent.projectIds,
// @deprecated - isCollaborative replaced by ACL permissions
isCollaborative: agent.isCollaborative,
isPublic: agent.isPublic,
version: agent.version,
// Safe metadata
createdAt: agent.createdAt,
updatedAt: agent.updatedAt,
});
}
// EDIT permission: Full agent details including sensitive configuration
return res.status(200).json(agent);
} catch (error) {
logger.error('[/Agents/:id] Error retrieving agent', error);
@@ -154,42 +195,20 @@ const getAgentHandler = async (req, res) => {
const updateAgentHandler = async (req, res) => {
try {
const id = req.params.id;
const { projectIds, removeProjectIds, ...updateData } = req.body;
const isAdmin = req.user.role === SystemRoles.ADMIN;
const { _id, ...updateData } = req.body;
const existingAgent = await getAgent({ id });
const isAuthor = existingAgent.author.toString() === req.user.id;
if (!existingAgent) {
return res.status(404).json({ error: 'Agent not found' });
}
const hasEditPermission = existingAgent.isCollaborative || isAdmin || isAuthor;
if (!hasEditPermission) {
return res.status(403).json({
error: 'You do not have permission to modify this non-collaborative agent',
});
}
/** @type {boolean} */
const isProjectUpdate = (projectIds?.length ?? 0) > 0 || (removeProjectIds?.length ?? 0) > 0;
let updatedAgent =
Object.keys(updateData).length > 0
? await updateAgent({ id }, updateData, {
updatingUserId: req.user.id,
skipVersioning: isProjectUpdate,
})
: existingAgent;
if (isProjectUpdate) {
updatedAgent = await updateAgentProjects({
user: req.user,
agentId: id,
projectIds,
removeProjectIds,
});
}
if (updatedAgent.author) {
updatedAgent.author = updatedAgent.author.toString();
}
@@ -307,6 +326,26 @@ const duplicateAgentHandler = async (req, res) => {
newAgentData.actions = agentActions;
const newAgent = await createAgent(newAgentData);
// Automatically grant owner permissions to the duplicator
try {
await grantPermission({
principalType: 'user',
principalId: userId,
resourceType: 'agent',
resourceId: newAgent._id,
accessRoleId: 'agent_owner',
grantedBy: userId,
});
logger.debug(
`[duplicateAgent] Granted owner permissions to user ${userId} for duplicated agent ${newAgent.id}`,
);
} catch (permissionError) {
logger.error(
`[duplicateAgent] Failed to grant owner permissions for duplicated agent ${newAgent.id}:`,
permissionError,
);
}
return res.status(201).json({
agent: newAgent,
actions: newActionsList,
@@ -333,7 +372,7 @@ const deleteAgentHandler = async (req, res) => {
if (!agent) {
return res.status(404).json({ error: 'Agent not found' });
}
await deleteAgent({ id, author: req.user.id });
await deleteAgent({ id });
return res.json({ message: 'Agent deleted' });
} catch (error) {
logger.error('[/Agents/:id] Error deleting Agent', error);
@@ -342,7 +381,7 @@ const deleteAgentHandler = async (req, res) => {
};
/**
*
* Lists agents using ACL-aware permissions (ownership + explicit shares).
* @route GET /Agents
* @param {object} req - Express Request
* @param {object} req.query - Request query
@@ -351,9 +390,64 @@ const deleteAgentHandler = async (req, res) => {
*/
const getListAgentsHandler = async (req, res) => {
try {
const data = await getListAgents({
author: req.user.id,
const userId = req.user.id;
const { category, search, limit, cursor, promoted } = req.query;
let requiredPermission = req.query.requiredPermission;
if (typeof requiredPermission === 'string') {
requiredPermission = parseInt(requiredPermission, 10);
if (isNaN(requiredPermission)) {
requiredPermission = PermissionBits.VIEW;
}
} else if (typeof requiredPermission !== 'number') {
requiredPermission = PermissionBits.VIEW;
}
// Base filter
const filter = {};
// Handle category filter - only apply if category is defined
if (category !== undefined && category.trim() !== '') {
filter.category = category;
}
// Handle promoted filter - only from query param
if (promoted === '1') {
filter.is_promoted = true;
} else if (promoted === '0') {
filter.is_promoted = { $ne: true };
}
// Handle search filter
if (search && search.trim() !== '') {
filter.$or = [
{ name: { $regex: search.trim(), $options: 'i' } },
{ description: { $regex: search.trim(), $options: 'i' } },
];
}
// Get agent IDs the user has VIEW access to via ACL
const accessibleIds = await findAccessibleResources({
userId,
resourceType: 'agent',
requiredPermissions: requiredPermission,
});
const publiclyAccessibleIds = await findPubliclyAccessibleResources({
resourceType: 'agent',
requiredPermissions: PermissionBits.VIEW,
});
// Use the new ACL-aware function
const data = await getListAgentsByAccess({
accessibleIds,
otherParams: filter,
limit,
after: cursor,
});
if (data?.data?.length) {
data.data = data.data.map((agent) => {
if (publiclyAccessibleIds.some((id) => id.equals(agent._id))) {
agent.isPublic = true;
}
return agent;
});
}
return res.json(data);
} catch (error) {
logger.error('[/Agents] Error listing Agents', error);
@@ -431,7 +525,7 @@ const uploadAgentAvatarHandler = async (req, res) => {
};
promises.push(
await updateAgent({ id: agent_id, author: req.user.id }, data, {
await updateAgent({ id: agent_id }, data, {
updatingUserId: req.user.id,
}),
);
@@ -511,7 +605,48 @@ const revertAgentVersionHandler = async (req, res) => {
res.status(500).json({ error: error.message });
}
};
/**
* Get all agent categories with counts
*
* @param {Object} _req - Express request object (unused)
* @param {Object} res - Express response object
*/
const getAgentCategories = async (_req, res) => {
try {
const categories = await getCategoriesWithCounts();
const promotedCount = await countPromotedAgents();
const formattedCategories = categories.map((category) => ({
value: category.value,
label: category.label,
count: category.agentCount,
description: category.description,
}));
if (promotedCount > 0) {
formattedCategories.unshift({
value: 'promoted',
label: 'Promoted',
count: promotedCount,
description: 'Our recommended agents',
});
}
formattedCategories.push({
value: 'all',
label: 'All',
description: 'All available agents',
});
res.status(200).json(formattedCategories);
} catch (error) {
logger.error('[/Agents/Marketplace] Error fetching agent categories:', error);
res.status(500).json({
error: 'Failed to fetch agent categories',
userMessage: 'Unable to load categories. Please refresh the page.',
suggestion: 'Try refreshing the page or check your network connection',
});
}
};
module.exports = {
createAgent: createAgentHandler,
getAgent: getAgentHandler,
@@ -521,4 +656,5 @@ module.exports = {
getListAgents: getListAgentsHandler,
uploadAgentAvatar: uploadAgentAvatarHandler,
revertAgentVersion: revertAgentVersionHandler,
getAgentCategories,
};

View File

@@ -1,4 +1,7 @@
const { v4 } = require('uuid');
const { sleep } = require('@librechat/agents');
const { sendEvent } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const {
Time,
Constants,
@@ -19,20 +22,20 @@ const {
addThreadMetadata,
saveAssistantMessage,
} = require('~/server/services/Threads');
const { sendResponse, sendMessage, sleep, countTokens } = require('~/server/utils');
const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService');
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
const { formatMessage, createVisionPrompt } = require('~/app/clients/prompts');
const { createRun, StreamRunManager } = require('~/server/services/Runs');
const { addTitle } = require('~/server/services/Endpoints/assistants');
const { createRunBody } = require('~/server/services/createRunBody');
const { sendResponse } = require('~/server/middleware/error');
const { getTransactions } = require('~/models/Transaction');
const { checkBalance } = require('~/models/balanceMethods');
const { getConvo } = require('~/models/Conversation');
const getLogStores = require('~/cache/getLogStores');
const { countTokens } = require('~/server/utils');
const { getModelMaxTokens } = require('~/utils');
const { getOpenAIClient } = require('./helpers');
const { logger } = require('~/config');
/**
* @route POST /
@@ -471,7 +474,7 @@ const chatV1 = async (req, res) => {
await Promise.all(promises);
const sendInitialResponse = () => {
sendMessage(res, {
sendEvent(res, {
sync: true,
conversationId,
// messages: previousMessages,
@@ -587,7 +590,7 @@ const chatV1 = async (req, res) => {
iconURL: endpointOption.iconURL,
};
sendMessage(res, {
sendEvent(res, {
final: true,
conversation,
requestMessage: {

View File

@@ -1,4 +1,7 @@
const { v4 } = require('uuid');
const { sleep } = require('@librechat/agents');
const { sendEvent } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const {
Time,
Constants,
@@ -22,15 +25,14 @@ const { createErrorHandler } = require('~/server/controllers/assistants/errors')
const validateAuthor = require('~/server/middleware/assistants/validateAuthor');
const { createRun, StreamRunManager } = require('~/server/services/Runs');
const { addTitle } = require('~/server/services/Endpoints/assistants');
const { sendMessage, sleep, countTokens } = require('~/server/utils');
const { createRunBody } = require('~/server/services/createRunBody');
const { getTransactions } = require('~/models/Transaction');
const { checkBalance } = require('~/models/balanceMethods');
const { getConvo } = require('~/models/Conversation');
const getLogStores = require('~/cache/getLogStores');
const { countTokens } = require('~/server/utils');
const { getModelMaxTokens } = require('~/utils');
const { getOpenAIClient } = require('./helpers');
const { logger } = require('~/config');
/**
* @route POST /
@@ -309,7 +311,7 @@ const chatV2 = async (req, res) => {
await Promise.all(promises);
const sendInitialResponse = () => {
sendMessage(res, {
sendEvent(res, {
sync: true,
conversationId,
// messages: previousMessages,
@@ -432,7 +434,7 @@ const chatV2 = async (req, res) => {
iconURL: endpointOption.iconURL,
};
sendMessage(res, {
sendEvent(res, {
final: true,
conversation,
requestMessage: {

View File

@@ -1,10 +1,10 @@
// errorHandler.js
const { sendResponse } = require('~/server/utils');
const { logger } = require('~/config');
const getLogStores = require('~/cache/getLogStores');
const { logger } = require('@librechat/data-schemas');
const { CacheKeys, ViolationTypes, ContentTypes } = require('librechat-data-provider');
const { getConvo } = require('~/models/Conversation');
const { recordUsage, checkMessageGaps } = require('~/server/services/Threads');
const { sendResponse } = require('~/server/middleware/error');
const { getConvo } = require('~/models/Conversation');
const getLogStores = require('~/cache/getLogStores');
/**
* @typedef {Object} ErrorHandlerContext
@@ -78,7 +78,7 @@ const createErrorHandler = ({ req, res, getContext, originPath = '/assistants/ch
} else if (/Files.*are invalid/.test(error.message)) {
const errorMessage = `Files are invalid, or may not have uploaded yet.${
endpoint === 'azureAssistants'
? ' If using Azure OpenAI, files are only available in the region of the assistant\'s model at the time of upload.'
? " If using Azure OpenAI, files are only available in the region of the assistant's model at the time of upload."
: ''
}`;
return sendResponse(req, res, messageData, errorMessage);

View File

@@ -1,5 +1,7 @@
const { nanoid } = require('nanoid');
const { EnvVar } = require('@librechat/agents');
const { checkAccess } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const {
Tools,
AuthType,
@@ -13,9 +15,8 @@ const { processCodeOutput } = require('~/server/services/Files/Code/process');
const { createToolCall, getToolCallsByConvo } = require('~/models/ToolCall');
const { loadAuthValues } = require('~/server/services/Tools/credentials');
const { loadTools } = require('~/app/clients/tools/util');
const { checkAccess } = require('~/server/middleware');
const { getRoleByName } = require('~/models/Role');
const { getMessage } = require('~/models/Message');
const { logger } = require('~/config');
const fieldsMap = {
[Tools.execute_code]: [EnvVar.CODE_API_KEY],
@@ -79,6 +80,7 @@ const verifyToolAuth = async (req, res) => {
throwError: false,
});
} catch (error) {
logger.error('Error loading auth values', error);
res.status(200).json({ authenticated: false, message: AuthType.USER_PROVIDED });
return;
}
@@ -132,7 +134,12 @@ const callTool = async (req, res) => {
logger.debug(`[${toolId}/call] User: ${req.user.id}`);
let hasAccess = true;
if (toolAccessPermType[toolId]) {
hasAccess = await checkAccess(req.user, toolAccessPermType[toolId], [Permissions.USE]);
hasAccess = await checkAccess({
user: req.user,
permissionType: toolAccessPermType[toolId],
permissions: [Permissions.USE],
getRoleByName,
});
}
if (!hasAccess) {
logger.warn(

View File

@@ -118,6 +118,8 @@ const startServer = async () => {
app.use('/api/agents', routes.agents);
app.use('/api/banner', routes.banner);
app.use('/api/memories', routes.memories);
app.use('/api/permissions', routes.accessPermissions);
app.use('/api/tags', routes.tags);
app.use('/api/mcp', routes.mcp);

View File

@@ -1,13 +1,13 @@
// abortMiddleware.js
const { logger } = require('@librechat/data-schemas');
const { countTokens, isEnabled, sendEvent } = require('@librechat/api');
const { isAssistantsEndpoint, ErrorTypes } = require('librechat-data-provider');
const { sendMessage, sendError, countTokens, isEnabled } = require('~/server/utils');
const { truncateText, smartTruncateText } = require('~/app/clients/prompts');
const clearPendingReq = require('~/cache/clearPendingReq');
const { sendError } = require('~/server/middleware/error');
const { spendTokens } = require('~/models/spendTokens');
const abortControllers = require('./abortControllers');
const { saveMessage, getConvo } = require('~/models');
const { abortRun } = require('./abortRun');
const { logger } = require('~/config');
const abortDataMap = new WeakMap();
@@ -101,7 +101,7 @@ async function abortMessage(req, res) {
cleanupAbortController(abortKey);
if (res.headersSent && finalEvent) {
return sendMessage(res, finalEvent);
return sendEvent(res, finalEvent);
}
res.setHeader('Content-Type', 'application/json');
@@ -174,7 +174,7 @@ const createAbortController = (req, res, getAbortData, getReqData) => {
* @param {string} responseMessageId
*/
const onStart = (userMessage, responseMessageId) => {
sendMessage(res, { message: userMessage, created: true });
sendEvent(res, { message: userMessage, created: true });
const abortKey = userMessage?.conversationId ?? req.user.id;
getReqData({ abortKey });

View File

@@ -1,11 +1,11 @@
const { sendEvent } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { CacheKeys, RunStatus, isUUID } = require('librechat-data-provider');
const { initializeClient } = require('~/server/services/Endpoints/assistants');
const { checkMessageGaps, recordUsage } = require('~/server/services/Threads');
const { deleteMessages } = require('~/models/Message');
const { getConvo } = require('~/models/Conversation');
const getLogStores = require('~/cache/getLogStores');
const { sendMessage } = require('~/server/utils');
const { logger } = require('~/config');
const three_minutes = 1000 * 60 * 3;
@@ -34,7 +34,7 @@ async function abortRun(req, res) {
const [thread_id, run_id] = runValues.split(':');
if (!run_id) {
logger.warn('[abortRun] Couldn\'t find run for cancel request', { thread_id });
logger.warn("[abortRun] Couldn't find run for cancel request", { thread_id });
return res.status(204).send({ message: 'Run not found' });
} else if (run_id === 'cancelled') {
logger.warn('[abortRun] Run already cancelled', { thread_id });
@@ -93,7 +93,7 @@ async function abortRun(req, res) {
};
if (res.headersSent && finalEvent) {
return sendMessage(res, finalEvent);
return sendEvent(res, finalEvent);
}
res.json(finalEvent);

View File

@@ -0,0 +1,97 @@
const { logger } = require('@librechat/data-schemas');
const { Constants, isAgentsEndpoint } = require('librechat-data-provider');
const { canAccessResource } = require('./canAccessResource');
const { getAgent } = require('~/models/Agent');
/**
* Agent ID resolver function for agent_id from request body
* Resolves custom agent ID (e.g., "agent_abc123") to MongoDB ObjectId
* This is used specifically for chat routes where agent_id comes from request body
*
* @param {string} agentCustomId - Custom agent ID from request body
* @returns {Promise<Object|null>} Agent document with _id field, or null if not found
*/
const resolveAgentIdFromBody = async (agentCustomId) => {
// Handle ephemeral agents - they don't need permission checks
if (agentCustomId === Constants.EPHEMERAL_AGENT_ID) {
return null; // No permission check needed for ephemeral agents
}
return await getAgent({ id: agentCustomId });
};
/**
* Middleware factory that creates middleware to check agent access permissions from request body.
* This middleware is specifically designed for chat routes where the agent_id comes from req.body
* instead of route parameters.
*
* @param {Object} options - Configuration options
* @param {number} options.requiredPermission - The permission bit required (1=view, 2=edit, 4=delete, 8=share)
* @returns {Function} Express middleware function
*
* @example
* // Basic usage for agent chat (requires VIEW permission)
* router.post('/chat',
* canAccessAgentFromBody({ requiredPermission: PermissionBits.VIEW }),
* buildEndpointOption,
* chatController
* );
*/
const canAccessAgentFromBody = (options) => {
const { requiredPermission } = options;
// Validate required options
if (!requiredPermission || typeof requiredPermission !== 'number') {
throw new Error('canAccessAgentFromBody: requiredPermission is required and must be a number');
}
return async (req, res, next) => {
try {
const { endpoint, agent_id } = req.body;
let agentId = agent_id;
if (!isAgentsEndpoint(endpoint)) {
agentId = Constants.EPHEMERAL_AGENT_ID;
}
if (!agentId) {
return res.status(400).json({
error: 'Bad Request',
message: 'agent_id is required in request body',
});
}
// Skip permission checks for ephemeral agents
if (agentId === Constants.EPHEMERAL_AGENT_ID) {
return next();
}
const agentAccessMiddleware = canAccessResource({
resourceType: 'agent',
requiredPermission,
resourceIdParam: 'agent_id', // This will be ignored since we use custom resolver
idResolver: () => resolveAgentIdFromBody(agentId),
});
const tempReq = {
...req,
params: {
...req.params,
agent_id: agentId,
},
};
return agentAccessMiddleware(tempReq, res, next);
} catch (error) {
logger.error('Failed to validate agent access permissions', error);
return res.status(500).json({
error: 'Internal Server Error',
message: 'Failed to validate agent access permissions',
});
}
};
};
module.exports = {
canAccessAgentFromBody,
};

View File

@@ -0,0 +1,58 @@
const { getAgent } = require('~/models/Agent');
const { canAccessResource } = require('./canAccessResource');
/**
* Agent ID resolver function
* Resolves custom agent ID (e.g., "agent_abc123") to MongoDB ObjectId
*
* @param {string} agentCustomId - Custom agent ID from route parameter
* @returns {Promise<Object|null>} Agent document with _id field, or null if not found
*/
const resolveAgentId = async (agentCustomId) => {
return await getAgent({ id: agentCustomId });
};
/**
* Agent-specific middleware factory that creates middleware to check agent access permissions.
* This middleware extends the generic canAccessResource to handle agent custom ID resolution.
*
* @param {Object} options - Configuration options
* @param {number} options.requiredPermission - The permission bit required (1=view, 2=edit, 4=delete, 8=share)
* @param {string} [options.resourceIdParam='id'] - The name of the route parameter containing the agent custom ID
* @returns {Function} Express middleware function
*
* @example
* // Basic usage for viewing agents
* router.get('/agents/:id',
* canAccessAgentResource({ requiredPermission: 1 }),
* getAgent
* );
*
* @example
* // Custom resource ID parameter and edit permission
* router.patch('/agents/:agent_id',
* canAccessAgentResource({
* requiredPermission: 2,
* resourceIdParam: 'agent_id'
* }),
* updateAgent
* );
*/
const canAccessAgentResource = (options) => {
const { requiredPermission, resourceIdParam = 'id' } = options;
if (!requiredPermission || typeof requiredPermission !== 'number') {
throw new Error('canAccessAgentResource: requiredPermission is required and must be a number');
}
return canAccessResource({
resourceType: 'agent',
requiredPermission,
resourceIdParam,
idResolver: resolveAgentId,
});
};
module.exports = {
canAccessAgentResource,
};

View File

@@ -0,0 +1,384 @@
const mongoose = require('mongoose');
const { MongoMemoryServer } = require('mongodb-memory-server');
const { canAccessAgentResource } = require('./canAccessAgentResource');
const { User, Role, AclEntry } = require('~/db/models');
const { createAgent } = require('~/models/Agent');
describe('canAccessAgentResource middleware', () => {
let mongoServer;
let req, res, next;
let testUser;
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
const mongoUri = mongoServer.getUri();
await mongoose.connect(mongoUri);
});
afterAll(async () => {
await mongoose.disconnect();
await mongoServer.stop();
});
beforeEach(async () => {
await mongoose.connection.dropDatabase();
await Role.create({
name: 'test-role',
permissions: {
AGENTS: {
USE: true,
CREATE: true,
SHARED_GLOBAL: false,
},
},
});
// Create a test user
testUser = await User.create({
email: 'test@example.com',
name: 'Test User',
username: 'testuser',
role: 'test-role',
});
req = {
user: { id: testUser._id.toString(), role: 'test-role' },
params: {},
};
res = {
status: jest.fn().mockReturnThis(),
json: jest.fn(),
};
next = jest.fn();
jest.clearAllMocks();
});
describe('middleware factory', () => {
test('should throw error if requiredPermission is not provided', () => {
expect(() => canAccessAgentResource({})).toThrow(
'canAccessAgentResource: requiredPermission is required and must be a number',
);
});
test('should throw error if requiredPermission is not a number', () => {
expect(() => canAccessAgentResource({ requiredPermission: '1' })).toThrow(
'canAccessAgentResource: requiredPermission is required and must be a number',
);
});
test('should create middleware with default resourceIdParam', () => {
const middleware = canAccessAgentResource({ requiredPermission: 1 });
expect(typeof middleware).toBe('function');
expect(middleware.length).toBe(3); // Express middleware signature
});
test('should create middleware with custom resourceIdParam', () => {
const middleware = canAccessAgentResource({
requiredPermission: 2,
resourceIdParam: 'agent_id',
});
expect(typeof middleware).toBe('function');
expect(middleware.length).toBe(3);
});
});
describe('permission checking with real agents', () => {
test('should allow access when user is the agent author', async () => {
// Create an agent owned by the test user
const agent = await createAgent({
id: `agent_${Date.now()}`,
name: 'Test Agent',
provider: 'openai',
model: 'gpt-4',
author: testUser._id,
});
// Create ACL entry for the author (owner permissions)
await AclEntry.create({
principalType: 'user',
principalId: testUser._id,
principalModel: 'User',
resourceType: 'agent',
resourceId: agent._id,
permBits: 15, // All permissions (1+2+4+8)
grantedBy: testUser._id,
});
req.params.id = agent.id;
const middleware = canAccessAgentResource({ requiredPermission: 1 }); // VIEW permission
await middleware(req, res, next);
expect(next).toHaveBeenCalled();
expect(res.status).not.toHaveBeenCalled();
});
test('should deny access when user is not the author and has no ACL entry', async () => {
// Create an agent owned by a different user
const otherUser = await User.create({
email: 'other@example.com',
name: 'Other User',
username: 'otheruser',
role: 'test-role',
});
const agent = await createAgent({
id: `agent_${Date.now()}`,
name: 'Other User Agent',
provider: 'openai',
model: 'gpt-4',
author: otherUser._id,
});
// Create ACL entry for the other user (owner)
await AclEntry.create({
principalType: 'user',
principalId: otherUser._id,
principalModel: 'User',
resourceType: 'agent',
resourceId: agent._id,
permBits: 15, // All permissions
grantedBy: otherUser._id,
});
req.params.id = agent.id;
const middleware = canAccessAgentResource({ requiredPermission: 1 }); // VIEW permission
await middleware(req, res, next);
expect(next).not.toHaveBeenCalled();
expect(res.status).toHaveBeenCalledWith(403);
expect(res.json).toHaveBeenCalledWith({
error: 'Forbidden',
message: 'Insufficient permissions to access this agent',
});
});
test('should allow access when user has ACL entry with sufficient permissions', async () => {
// Create an agent owned by a different user
const otherUser = await User.create({
email: 'other2@example.com',
name: 'Other User 2',
username: 'otheruser2',
role: 'test-role',
});
const agent = await createAgent({
id: `agent_${Date.now()}`,
name: 'Shared Agent',
provider: 'openai',
model: 'gpt-4',
author: otherUser._id,
});
// Create ACL entry granting view permission to test user
await AclEntry.create({
principalType: 'user',
principalId: testUser._id,
principalModel: 'User',
resourceType: 'agent',
resourceId: agent._id,
permBits: 1, // VIEW permission
grantedBy: otherUser._id,
});
req.params.id = agent.id;
const middleware = canAccessAgentResource({ requiredPermission: 1 }); // VIEW permission
await middleware(req, res, next);
expect(next).toHaveBeenCalled();
expect(res.status).not.toHaveBeenCalled();
});
test('should deny access when ACL permissions are insufficient', async () => {
// Create an agent owned by a different user
const otherUser = await User.create({
email: 'other3@example.com',
name: 'Other User 3',
username: 'otheruser3',
role: 'test-role',
});
const agent = await createAgent({
id: `agent_${Date.now()}`,
name: 'Limited Access Agent',
provider: 'openai',
model: 'gpt-4',
author: otherUser._id,
});
// Create ACL entry granting only view permission
await AclEntry.create({
principalType: 'user',
principalId: testUser._id,
principalModel: 'User',
resourceType: 'agent',
resourceId: agent._id,
permBits: 1, // VIEW permission only
grantedBy: otherUser._id,
});
req.params.id = agent.id;
const middleware = canAccessAgentResource({ requiredPermission: 2 }); // EDIT permission required
await middleware(req, res, next);
expect(next).not.toHaveBeenCalled();
expect(res.status).toHaveBeenCalledWith(403);
expect(res.json).toHaveBeenCalledWith({
error: 'Forbidden',
message: 'Insufficient permissions to access this agent',
});
});
test('should handle non-existent agent', async () => {
req.params.id = 'agent_nonexistent';
const middleware = canAccessAgentResource({ requiredPermission: 1 });
await middleware(req, res, next);
expect(next).not.toHaveBeenCalled();
expect(res.status).toHaveBeenCalledWith(404);
expect(res.json).toHaveBeenCalledWith({
error: 'Not Found',
message: 'agent not found',
});
});
test('should use custom resourceIdParam', async () => {
const agent = await createAgent({
id: `agent_${Date.now()}`,
name: 'Custom Param Agent',
provider: 'openai',
model: 'gpt-4',
author: testUser._id,
});
// Create ACL entry for the author
await AclEntry.create({
principalType: 'user',
principalId: testUser._id,
principalModel: 'User',
resourceType: 'agent',
resourceId: agent._id,
permBits: 15, // All permissions
grantedBy: testUser._id,
});
req.params.agent_id = agent.id; // Using custom param name
const middleware = canAccessAgentResource({
requiredPermission: 1,
resourceIdParam: 'agent_id',
});
await middleware(req, res, next);
expect(next).toHaveBeenCalled();
expect(res.status).not.toHaveBeenCalled();
});
});
describe('permission levels', () => {
let agent;
beforeEach(async () => {
agent = await createAgent({
id: `agent_${Date.now()}`,
name: 'Permission Test Agent',
provider: 'openai',
model: 'gpt-4',
author: testUser._id,
});
// Create ACL entry with all permissions for the owner
await AclEntry.create({
principalType: 'user',
principalId: testUser._id,
principalModel: 'User',
resourceType: 'agent',
resourceId: agent._id,
permBits: 15, // All permissions (1+2+4+8)
grantedBy: testUser._id,
});
req.params.id = agent.id;
});
test('should support view permission (1)', async () => {
const middleware = canAccessAgentResource({ requiredPermission: 1 });
await middleware(req, res, next);
expect(next).toHaveBeenCalled();
});
test('should support edit permission (2)', async () => {
const middleware = canAccessAgentResource({ requiredPermission: 2 });
await middleware(req, res, next);
expect(next).toHaveBeenCalled();
});
test('should support delete permission (4)', async () => {
const middleware = canAccessAgentResource({ requiredPermission: 4 });
await middleware(req, res, next);
expect(next).toHaveBeenCalled();
});
test('should support share permission (8)', async () => {
const middleware = canAccessAgentResource({ requiredPermission: 8 });
await middleware(req, res, next);
expect(next).toHaveBeenCalled();
});
test('should support combined permissions', async () => {
const viewAndEdit = 1 | 2; // 3
const middleware = canAccessAgentResource({ requiredPermission: viewAndEdit });
await middleware(req, res, next);
expect(next).toHaveBeenCalled();
});
});
describe('integration with agent operations', () => {
test('should work with agent CRUD operations', async () => {
const agentId = `agent_${Date.now()}`;
// Create agent
const agent = await createAgent({
id: agentId,
name: 'Integration Test Agent',
provider: 'openai',
model: 'gpt-4',
author: testUser._id,
description: 'Testing integration',
});
// Create ACL entry for the author
await AclEntry.create({
principalType: 'user',
principalId: testUser._id,
principalModel: 'User',
resourceType: 'agent',
resourceId: agent._id,
permBits: 15, // All permissions
grantedBy: testUser._id,
});
req.params.id = agentId;
// Test view access
const viewMiddleware = canAccessAgentResource({ requiredPermission: 1 });
await viewMiddleware(req, res, next);
expect(next).toHaveBeenCalled();
jest.clearAllMocks();
// Update the agent
const { updateAgent } = require('~/models/Agent');
await updateAgent({ id: agentId }, { description: 'Updated description' });
// Test edit access
const editMiddleware = canAccessAgentResource({ requiredPermission: 2 });
await editMiddleware(req, res, next);
expect(next).toHaveBeenCalled();
});
});
});

View File

@@ -0,0 +1,157 @@
const { logger } = require('@librechat/data-schemas');
const { SystemRoles } = require('librechat-data-provider');
const { checkPermission } = require('~/server/services/PermissionService');
/**
* Generic base middleware factory that creates middleware to check resource access permissions.
* This middleware expects MongoDB ObjectIds as resource identifiers for ACL permission checks.
*
* @param {Object} options - Configuration options
* @param {string} options.resourceType - The type of resource (e.g., 'agent', 'file', 'project')
* @param {number} options.requiredPermission - The permission bit required (1=view, 2=edit, 4=delete, 8=share)
* @param {string} [options.resourceIdParam='resourceId'] - The name of the route parameter containing the resource ID
* @param {Function} [options.idResolver] - Optional function to resolve custom IDs to ObjectIds
* @returns {Function} Express middleware function
*
* @example
* // Direct usage with ObjectId (for resources that use MongoDB ObjectId in routes)
* router.get('/prompts/:promptId',
* canAccessResource({ resourceType: 'prompt', requiredPermission: 1 }),
* getPrompt
* );
*
* @example
* // Usage with custom ID resolver (for resources that use custom string IDs)
* router.get('/agents/:id',
* canAccessResource({
* resourceType: 'agent',
* requiredPermission: 1,
* resourceIdParam: 'id',
* idResolver: (customId) => resolveAgentId(customId)
* }),
* getAgent
* );
*/
const canAccessResource = (options) => {
const {
resourceType,
requiredPermission,
resourceIdParam = 'resourceId',
idResolver = null,
} = options;
if (!resourceType || typeof resourceType !== 'string') {
throw new Error('canAccessResource: resourceType is required and must be a string');
}
if (!requiredPermission || typeof requiredPermission !== 'number') {
throw new Error('canAccessResource: requiredPermission is required and must be a number');
}
return async (req, res, next) => {
try {
// Extract resource ID from route parameters
const rawResourceId = req.params[resourceIdParam];
if (!rawResourceId) {
logger.warn(`[canAccessResource] Missing ${resourceIdParam} in route parameters`);
return res.status(400).json({
error: 'Bad Request',
message: `${resourceIdParam} is required`,
});
}
// Check if user is authenticated
if (!req.user || !req.user.id) {
logger.warn(
`[canAccessResource] Unauthenticated request for ${resourceType} ${rawResourceId}`,
);
return res.status(401).json({
error: 'Unauthorized',
message: 'Authentication required',
});
}
// if system admin let through
if (req.user.role === SystemRoles.ADMIN) {
return next();
}
const userId = req.user.id;
let resourceId = rawResourceId;
let resourceInfo = null;
// Resolve custom ID to ObjectId if resolver is provided
if (idResolver) {
logger.debug(
`[canAccessResource] Resolving ${resourceType} custom ID ${rawResourceId} to ObjectId`,
);
const resolutionResult = await idResolver(rawResourceId);
if (!resolutionResult) {
logger.warn(`[canAccessResource] ${resourceType} not found: ${rawResourceId}`);
return res.status(404).json({
error: 'Not Found',
message: `${resourceType} not found`,
});
}
// Handle different resolver return formats
if (typeof resolutionResult === 'string' || resolutionResult._id) {
resourceId = resolutionResult._id || resolutionResult;
resourceInfo = typeof resolutionResult === 'object' ? resolutionResult : null;
} else {
resourceId = resolutionResult;
}
logger.debug(
`[canAccessResource] Resolved ${resourceType} ${rawResourceId} to ObjectId ${resourceId}`,
);
}
// Check permissions using PermissionService with ObjectId
const hasPermission = await checkPermission({
userId,
resourceType,
resourceId,
requiredPermission,
});
if (hasPermission) {
logger.debug(
`[canAccessResource] User ${userId} has permission ${requiredPermission} on ${resourceType} ${rawResourceId} (${resourceId})`,
);
req.resourceAccess = {
resourceType,
resourceId, // MongoDB ObjectId for ACL operations
customResourceId: rawResourceId, // Original ID from route params
permission: requiredPermission,
userId,
...(resourceInfo && { resourceInfo }),
};
return next();
}
logger.warn(
`[canAccessResource] User ${userId} denied access to ${resourceType} ${rawResourceId} ` +
`(required permission: ${requiredPermission})`,
);
return res.status(403).json({
error: 'Forbidden',
message: `Insufficient permissions to access this ${resourceType}`,
});
} catch (error) {
logger.error(`[canAccessResource] Error checking access for ${resourceType}:`, error);
return res.status(500).json({
error: 'Internal Server Error',
message: 'Failed to check resource access permissions',
});
}
};
};
module.exports = {
canAccessResource,
};

View File

@@ -0,0 +1,9 @@
const { canAccessResource } = require('./canAccessResource');
const { canAccessAgentResource } = require('./canAccessAgentResource');
const { canAccessAgentFromBody } = require('./canAccessAgentFromBody');
module.exports = {
canAccessResource,
canAccessAgentResource,
canAccessAgentFromBody,
};

View File

@@ -7,7 +7,6 @@ const {
} = require('librechat-data-provider');
const azureAssistants = require('~/server/services/Endpoints/azureAssistants');
const assistants = require('~/server/services/Endpoints/assistants');
const gptPlugins = require('~/server/services/Endpoints/gptPlugins');
const { processFiles } = require('~/server/services/Files/process');
const anthropic = require('~/server/services/Endpoints/anthropic');
const bedrock = require('~/server/services/Endpoints/bedrock');
@@ -25,7 +24,6 @@ const buildFunction = {
[EModelEndpoint.bedrock]: bedrock.buildOptions,
[EModelEndpoint.azureOpenAI]: openAI.buildOptions,
[EModelEndpoint.anthropic]: anthropic.buildOptions,
[EModelEndpoint.gptPlugins]: gptPlugins.buildOptions,
[EModelEndpoint.assistants]: assistants.buildOptions,
[EModelEndpoint.azureAssistants]: azureAssistants.buildOptions,
};
@@ -60,15 +58,6 @@ async function buildEndpointOption(req, res, next) {
return handleError(res, { text: 'Model spec mismatch' });
}
if (
currentModelSpec.preset.endpoint !== EModelEndpoint.gptPlugins &&
currentModelSpec.preset.tools
) {
return handleError(res, {
text: `Only the "${EModelEndpoint.gptPlugins}" endpoint can have tools defined in the preset`,
});
}
try {
currentModelSpec.preset.spec = spec;
if (currentModelSpec.iconURL != null && currentModelSpec.iconURL !== '') {

View File

@@ -1,6 +1,7 @@
const crypto = require('crypto');
const { sendEvent } = require('@librechat/api');
const { getResponseSender, Constants } = require('librechat-data-provider');
const { sendMessage, sendError } = require('~/server/utils');
const { sendError } = require('~/server/middleware/error');
const { saveMessage } = require('~/models');
/**
@@ -36,7 +37,7 @@ const denyRequest = async (req, res, errorMessage) => {
isCreatedByUser: true,
text,
};
sendMessage(res, { message: userMessage, created: true });
sendEvent(res, { message: userMessage, created: true });
const shouldSaveMessage = _convoId && parentMessageId && parentMessageId !== Constants.NO_PARENT;

View File

@@ -1,31 +1,9 @@
const crypto = require('crypto');
const { logger } = require('@librechat/data-schemas');
const { parseConvo } = require('librechat-data-provider');
const { sendEvent, handleError } = require('@librechat/api');
const { saveMessage, getMessages } = require('~/models/Message');
const { getConvo } = require('~/models/Conversation');
const { logger } = require('~/config');
/**
* Sends error data in Server Sent Events format and ends the response.
* @param {object} res - The server response.
* @param {string} message - The error message.
*/
const handleError = (res, message) => {
res.write(`event: error\ndata: ${JSON.stringify(message)}\n\n`);
res.end();
};
/**
* Sends message data in Server Sent Events format.
* @param {Express.Response} res - - The server response.
* @param {string | Object} message - The message to be sent.
* @param {'message' | 'error' | 'cancel'} event - [Optional] The type of event. Default is 'message'.
*/
const sendMessage = (res, message, event = 'message') => {
if (typeof message === 'string' && message.length === 0) {
return;
}
res.write(`event: ${event}\ndata: ${JSON.stringify(message)}\n\n`);
};
/**
* Processes an error with provided options, saves the error message and sends a corresponding SSE response
@@ -91,7 +69,7 @@ const sendError = async (req, res, options, callback) => {
convo = parseConvo(errorMessage);
}
return sendMessage(res, {
return sendEvent(res, {
final: true,
requestMessage: query?.[0] ? query[0] : requestMessage,
responseMessage: errorMessage,
@@ -120,12 +98,10 @@ const sendResponse = (req, res, data, errorMessage) => {
if (errorMessage) {
return sendError(req, res, { ...data, text: errorMessage });
}
return sendMessage(res, data);
return sendEvent(res, data);
};
module.exports = {
sendResponse,
handleError,
sendMessage,
sendError,
sendResponse,
};

View File

@@ -8,6 +8,7 @@ const concurrentLimiter = require('./concurrentLimiter');
const validateEndpoint = require('./validateEndpoint');
const requireLocalAuth = require('./requireLocalAuth');
const canDeleteAccount = require('./canDeleteAccount');
const accessResources = require('./accessResources');
const setBalanceConfig = require('./setBalanceConfig');
const requireLdapAuth = require('./requireLdapAuth');
const abortMiddleware = require('./abortMiddleware');
@@ -29,6 +30,7 @@ module.exports = {
...validate,
...limiters,
...roles,
...accessResources,
noIndex,
checkBan,
uaParser,

View File

@@ -1,78 +0,0 @@
const { getRoleByName } = require('~/models/Role');
const { logger } = require('~/config');
/**
* Core function to check if a user has one or more required permissions
*
* @param {object} user - The user object
* @param {PermissionTypes} permissionType - The type of permission to check
* @param {Permissions[]} permissions - The list of specific permissions to check
* @param {Record<Permissions, string[]>} [bodyProps] - An optional object where keys are permissions and values are arrays of properties to check
* @param {object} [checkObject] - The object to check properties against
* @returns {Promise<boolean>} Whether the user has the required permissions
*/
const checkAccess = async (user, permissionType, permissions, bodyProps = {}, checkObject = {}) => {
if (!user) {
return false;
}
const role = await getRoleByName(user.role);
if (role && role.permissions && role.permissions[permissionType]) {
const hasAnyPermission = permissions.some((permission) => {
if (role.permissions[permissionType][permission]) {
return true;
}
if (bodyProps[permission] && checkObject) {
return bodyProps[permission].some((prop) =>
Object.prototype.hasOwnProperty.call(checkObject, prop),
);
}
return false;
});
return hasAnyPermission;
}
return false;
};
/**
* Middleware to check if a user has one or more required permissions, optionally based on `req.body` properties.
*
* @param {PermissionTypes} permissionType - The type of permission to check.
* @param {Permissions[]} permissions - The list of specific permissions to check.
* @param {Record<Permissions, string[]>} [bodyProps] - An optional object where keys are permissions and values are arrays of `req.body` properties to check.
* @returns {(req: ServerRequest, res: ServerResponse, next: NextFunction) => Promise<void>} Express middleware function.
*/
const generateCheckAccess = (permissionType, permissions, bodyProps = {}) => {
return async (req, res, next) => {
try {
const hasAccess = await checkAccess(
req.user,
permissionType,
permissions,
bodyProps,
req.body,
);
if (hasAccess) {
return next();
}
logger.warn(
`[${permissionType}] Forbidden: Insufficient permissions for User ${req.user.id}: ${permissions.join(', ')}`,
);
return res.status(403).json({ message: 'Forbidden: Insufficient permissions' });
} catch (error) {
logger.error(error);
return res.status(500).json({ message: `Server error: ${error.message}` });
}
};
};
module.exports = {
checkAccess,
generateCheckAccess,
};

View File

@@ -0,0 +1,357 @@
const mongoose = require('mongoose');
const { MongoMemoryServer } = require('mongodb-memory-server');
const { checkAccess, generateCheckAccess } = require('@librechat/api');
const { PermissionTypes, Permissions } = require('librechat-data-provider');
const { getRoleByName } = require('~/models/Role');
const { Role } = require('~/db/models');
// Mock the logger from @librechat/data-schemas
jest.mock('@librechat/data-schemas', () => ({
...jest.requireActual('@librechat/data-schemas'),
logger: {
warn: jest.fn(),
error: jest.fn(),
info: jest.fn(),
debug: jest.fn(),
},
}));
// Mock the cache to use a simple in-memory implementation
const mockCache = new Map();
jest.mock('~/cache/getLogStores', () => {
return jest.fn(() => ({
get: jest.fn(async (key) => mockCache.get(key)),
set: jest.fn(async (key, value) => mockCache.set(key, value)),
clear: jest.fn(async () => mockCache.clear()),
}));
});
describe('Access Middleware', () => {
let mongoServer;
let req, res, next;
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
const mongoUri = mongoServer.getUri();
await mongoose.connect(mongoUri);
});
afterAll(async () => {
await mongoose.disconnect();
await mongoServer.stop();
});
beforeEach(async () => {
await mongoose.connection.dropDatabase();
mockCache.clear(); // Clear the cache between tests
// Create test roles
await Role.create({
name: 'user',
permissions: {
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: true },
[PermissionTypes.PROMPTS]: {
[Permissions.SHARED_GLOBAL]: false,
[Permissions.USE]: true,
[Permissions.CREATE]: true,
},
[PermissionTypes.MEMORIES]: {
[Permissions.USE]: true,
[Permissions.CREATE]: true,
[Permissions.UPDATE]: true,
[Permissions.READ]: true,
[Permissions.OPT_OUT]: true,
},
[PermissionTypes.AGENTS]: {
[Permissions.USE]: true,
[Permissions.CREATE]: false,
[Permissions.SHARED_GLOBAL]: false,
},
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true },
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: true },
[PermissionTypes.RUN_CODE]: { [Permissions.USE]: true },
[PermissionTypes.WEB_SEARCH]: { [Permissions.USE]: true },
},
});
await Role.create({
name: 'admin',
permissions: {
[PermissionTypes.BOOKMARKS]: { [Permissions.USE]: true },
[PermissionTypes.PROMPTS]: {
[Permissions.SHARED_GLOBAL]: true,
[Permissions.USE]: true,
[Permissions.CREATE]: true,
},
[PermissionTypes.MEMORIES]: {
[Permissions.USE]: true,
[Permissions.CREATE]: true,
[Permissions.UPDATE]: true,
[Permissions.READ]: true,
[Permissions.OPT_OUT]: true,
},
[PermissionTypes.AGENTS]: {
[Permissions.USE]: true,
[Permissions.CREATE]: true,
[Permissions.SHARED_GLOBAL]: true,
},
[PermissionTypes.MULTI_CONVO]: { [Permissions.USE]: true },
[PermissionTypes.TEMPORARY_CHAT]: { [Permissions.USE]: true },
[PermissionTypes.RUN_CODE]: { [Permissions.USE]: true },
[PermissionTypes.WEB_SEARCH]: { [Permissions.USE]: true },
},
});
// Create limited role with no AGENTS permissions
await Role.create({
name: 'limited',
permissions: {
// Explicitly set AGENTS permissions to false
[PermissionTypes.AGENTS]: {
[Permissions.USE]: false,
[Permissions.CREATE]: false,
[Permissions.SHARED_GLOBAL]: false,
},
// Has permissions for other types
[PermissionTypes.PROMPTS]: {
[Permissions.USE]: true,
},
},
});
req = {
user: { id: 'user123', role: 'user' },
body: {},
originalUrl: '/test',
};
res = {
status: jest.fn().mockReturnThis(),
json: jest.fn(),
};
next = jest.fn();
jest.clearAllMocks();
});
describe('checkAccess', () => {
test('should return false if user is not provided', async () => {
const result = await checkAccess({
user: null,
permissionType: PermissionTypes.AGENTS,
permissions: [Permissions.USE],
getRoleByName,
});
expect(result).toBe(false);
});
test('should return true if user has required permission', async () => {
const result = await checkAccess({
req: {},
user: { id: 'user123', role: 'user' },
permissionType: PermissionTypes.AGENTS,
permissions: [Permissions.USE],
getRoleByName,
});
expect(result).toBe(true);
});
test('should return false if user lacks required permission', async () => {
const result = await checkAccess({
req: {},
user: { id: 'user123', role: 'user' },
permissionType: PermissionTypes.AGENTS,
permissions: [Permissions.CREATE],
getRoleByName,
});
expect(result).toBe(false);
});
test('should return true if user has any of multiple permissions', async () => {
const result = await checkAccess({
req: {},
user: { id: 'user123', role: 'user' },
permissionType: PermissionTypes.AGENTS,
permissions: [Permissions.CREATE, Permissions.USE],
getRoleByName,
});
expect(result).toBe(true);
});
test('should check body properties when permission is not directly granted', async () => {
const req = { body: { id: 'agent123' } };
const result = await checkAccess({
req,
user: { id: 'user123', role: 'user' },
permissionType: PermissionTypes.AGENTS,
permissions: [Permissions.UPDATE],
bodyProps: {
[Permissions.UPDATE]: ['id'],
},
checkObject: req.body,
getRoleByName,
});
expect(result).toBe(true);
});
test('should return false if role is not found', async () => {
const result = await checkAccess({
req: {},
user: { id: 'user123', role: 'nonexistent' },
permissionType: PermissionTypes.AGENTS,
permissions: [Permissions.USE],
getRoleByName,
});
expect(result).toBe(false);
});
test('should return false if role has no permissions for the requested type', async () => {
const result = await checkAccess({
req: {},
user: { id: 'user123', role: 'limited' },
permissionType: PermissionTypes.AGENTS,
permissions: [Permissions.USE],
getRoleByName,
});
expect(result).toBe(false);
});
test('should handle admin role with all permissions', async () => {
const createResult = await checkAccess({
req: {},
user: { id: 'admin123', role: 'admin' },
permissionType: PermissionTypes.AGENTS,
permissions: [Permissions.CREATE],
getRoleByName,
});
expect(createResult).toBe(true);
const shareResult = await checkAccess({
req: {},
user: { id: 'admin123', role: 'admin' },
permissionType: PermissionTypes.AGENTS,
permissions: [Permissions.SHARED_GLOBAL],
getRoleByName,
});
expect(shareResult).toBe(true);
});
});
describe('generateCheckAccess', () => {
test('should call next() when user has required permission', async () => {
const middleware = generateCheckAccess({
permissionType: PermissionTypes.AGENTS,
permissions: [Permissions.USE],
getRoleByName,
});
await middleware(req, res, next);
expect(next).toHaveBeenCalled();
expect(res.status).not.toHaveBeenCalled();
});
test('should return 403 when user lacks permission', async () => {
const middleware = generateCheckAccess({
permissionType: PermissionTypes.AGENTS,
permissions: [Permissions.CREATE],
getRoleByName,
});
await middleware(req, res, next);
expect(next).not.toHaveBeenCalled();
expect(res.status).toHaveBeenCalledWith(403);
expect(res.json).toHaveBeenCalledWith({ message: 'Forbidden: Insufficient permissions' });
});
test('should check body properties when configured', async () => {
req.body = { agentId: 'agent123', description: 'test' };
const bodyProps = {
[Permissions.CREATE]: ['agentId'],
};
const middleware = generateCheckAccess({
permissionType: PermissionTypes.AGENTS,
permissions: [Permissions.CREATE],
bodyProps,
getRoleByName,
});
await middleware(req, res, next);
expect(next).toHaveBeenCalled();
expect(res.status).not.toHaveBeenCalled();
});
test('should handle database errors gracefully', async () => {
// Mock getRoleByName to throw an error
const mockGetRoleByName = jest
.fn()
.mockRejectedValue(new Error('Database connection failed'));
const middleware = generateCheckAccess({
permissionType: PermissionTypes.AGENTS,
permissions: [Permissions.USE],
getRoleByName: mockGetRoleByName,
});
await middleware(req, res, next);
expect(next).not.toHaveBeenCalled();
expect(res.status).toHaveBeenCalledWith(500);
expect(res.json).toHaveBeenCalledWith({
message: expect.stringContaining('Server error:'),
});
});
test('should work with multiple permission types', async () => {
req.user.role = 'admin';
const middleware = generateCheckAccess({
permissionType: PermissionTypes.AGENTS,
permissions: [Permissions.USE, Permissions.CREATE, Permissions.SHARED_GLOBAL],
getRoleByName,
});
await middleware(req, res, next);
expect(next).toHaveBeenCalled();
});
test('should handle missing user gracefully', async () => {
req.user = null;
const middleware = generateCheckAccess({
permissionType: PermissionTypes.AGENTS,
permissions: [Permissions.USE],
getRoleByName,
});
await middleware(req, res, next);
expect(next).not.toHaveBeenCalled();
expect(res.status).toHaveBeenCalledWith(403);
expect(res.json).toHaveBeenCalledWith({ message: 'Forbidden: Insufficient permissions' });
});
test('should handle role with no AGENTS permissions', async () => {
await Role.create({
name: 'noaccess',
permissions: {
// Explicitly set AGENTS with all permissions false
[PermissionTypes.AGENTS]: {
[Permissions.USE]: false,
[Permissions.CREATE]: false,
[Permissions.SHARED_GLOBAL]: false,
},
},
});
req.user.role = 'noaccess';
const middleware = generateCheckAccess({
permissionType: PermissionTypes.AGENTS,
permissions: [Permissions.USE],
getRoleByName,
});
await middleware(req, res, next);
expect(next).not.toHaveBeenCalled();
expect(res.status).toHaveBeenCalledWith(403);
expect(res.json).toHaveBeenCalledWith({ message: 'Forbidden: Insufficient permissions' });
});
});
});

View File

@@ -1,8 +1,5 @@
const checkAdmin = require('./admin');
const { checkAccess, generateCheckAccess } = require('./access');
module.exports = {
checkAdmin,
checkAccess,
generateCheckAccess,
};

View File

@@ -0,0 +1,62 @@
const express = require('express');
const { PermissionBits } = require('@librechat/data-schemas');
const {
getUserEffectivePermissions,
updateResourcePermissions,
getResourcePermissions,
getResourceRoles,
searchPrincipals,
} = require('~/server/controllers/PermissionsController');
const { requireJwtAuth, checkBan, uaParser, canAccessResource } = require('~/server/middleware');
const router = express.Router();
// Apply common middleware
router.use(requireJwtAuth);
router.use(checkBan);
router.use(uaParser);
/**
* Generic routes for resource permissions
* Pattern: /api/permissions/{resourceType}/{resourceId}
*/
/**
* GET /api/permissions/search-principals
* Search for users and groups to grant permissions
*/
router.get('/search-principals', searchPrincipals);
/**
* GET /api/permissions/{resourceType}/roles
* Get available roles for a resource type
*/
router.get('/:resourceType/roles', getResourceRoles);
/**
* GET /api/permissions/{resourceType}/{resourceId}
* Get all permissions for a specific resource
*/
router.get('/:resourceType/:resourceId', getResourcePermissions);
/**
* PUT /api/permissions/{resourceType}/{resourceId}
* Bulk update permissions for a specific resource
*/
router.put(
'/:resourceType/:resourceId',
canAccessResource({
resourceType: 'agent',
requiredPermission: PermissionBits.SHARE,
resourceIdParam: 'resourceId',
}),
updateResourcePermissions,
);
/**
* GET /api/permissions/{resourceType}/{resourceId}/effective
* Get user's effective permissions for a specific resource
*/
router.get('/:resourceType/:resourceId/effective', getUserEffectivePermissions);
module.exports = router;

View File

@@ -1,19 +1,27 @@
const express = require('express');
const { nanoid } = require('nanoid');
const { actionDelimiter, SystemRoles, removeNullishValues } = require('librechat-data-provider');
const { generateCheckAccess } = require('@librechat/api');
const { logger, PermissionBits } = require('@librechat/data-schemas');
const {
Permissions,
PermissionTypes,
actionDelimiter,
removeNullishValues,
} = require('librechat-data-provider');
const { encryptMetadata, domainParser } = require('~/server/services/ActionService');
const { updateAction, getActions, deleteAction } = require('~/models/Action');
const { isActionDomainAllowed } = require('~/server/services/domains');
const { canAccessAgentResource } = require('~/server/middleware');
const { getAgent, updateAgent } = require('~/models/Agent');
const { logger } = require('~/config');
const { getRoleByName } = require('~/models/Role');
const router = express.Router();
// If the user has ADMIN role
// then action edition is possible even if not owner of the assistant
const isAdmin = (req) => {
return req.user.role === SystemRoles.ADMIN;
};
const checkAgentCreate = generateCheckAccess({
permissionType: PermissionTypes.AGENTS,
permissions: [Permissions.USE, Permissions.CREATE],
getRoleByName,
});
/**
* Retrieves all user's actions
@@ -23,9 +31,8 @@ const isAdmin = (req) => {
*/
router.get('/', async (req, res) => {
try {
const admin = isAdmin(req);
// If admin, get all actions, otherwise only user's actions
const searchParams = admin ? {} : { user: req.user.id };
// Get all actions for the user (admin permissions handled by middleware if needed)
const searchParams = { user: req.user.id };
res.json(await getActions(searchParams));
} catch (error) {
res.status(500).json({ error: error.message });
@@ -41,106 +48,111 @@ router.get('/', async (req, res) => {
* @param {ActionMetadata} req.body.metadata - Metadata for the action.
* @returns {Object} 200 - success response - application/json
*/
router.post('/:agent_id', async (req, res) => {
try {
const { agent_id } = req.params;
router.post(
'/:agent_id',
canAccessAgentResource({
requiredPermission: PermissionBits.EDIT,
resourceIdParam: 'agent_id',
}),
checkAgentCreate,
async (req, res) => {
try {
const { agent_id } = req.params;
/** @type {{ functions: FunctionTool[], action_id: string, metadata: ActionMetadata }} */
const { functions, action_id: _action_id, metadata: _metadata } = req.body;
if (!functions.length) {
return res.status(400).json({ message: 'No functions provided' });
}
let metadata = await encryptMetadata(removeNullishValues(_metadata, true));
const isDomainAllowed = await isActionDomainAllowed(metadata.domain);
if (!isDomainAllowed) {
return res.status(400).json({ message: 'Domain not allowed' });
}
let { domain } = metadata;
domain = await domainParser(domain, true);
if (!domain) {
return res.status(400).json({ message: 'No domain provided' });
}
const action_id = _action_id ?? nanoid();
const initialPromises = [];
const admin = isAdmin(req);
// If admin, can edit any agent, otherwise only user's agents
const agentQuery = admin ? { id: agent_id } : { id: agent_id, author: req.user.id };
// TODO: share agents
initialPromises.push(getAgent(agentQuery));
if (_action_id) {
initialPromises.push(getActions({ action_id }, true));
}
/** @type {[Agent, [Action|undefined]]} */
const [agent, actions_result] = await Promise.all(initialPromises);
if (!agent) {
return res.status(404).json({ message: 'Agent not found for adding action' });
}
if (actions_result && actions_result.length) {
const action = actions_result[0];
metadata = { ...action.metadata, ...metadata };
}
const { actions: _actions = [], author: agent_author } = agent ?? {};
const actions = [];
for (const action of _actions) {
const [_action_domain, current_action_id] = action.split(actionDelimiter);
if (current_action_id === action_id) {
continue;
/** @type {{ functions: FunctionTool[], action_id: string, metadata: ActionMetadata }} */
const { functions, action_id: _action_id, metadata: _metadata } = req.body;
if (!functions.length) {
return res.status(400).json({ message: 'No functions provided' });
}
actions.push(action);
}
actions.push(`${domain}${actionDelimiter}${action_id}`);
/** @type {string[]}} */
const { tools: _tools = [] } = agent;
const tools = _tools
.filter((tool) => !(tool && (tool.includes(domain) || tool.includes(action_id))))
.concat(functions.map((tool) => `${tool.function.name}${actionDelimiter}${domain}`));
// Force version update since actions are changing
const updatedAgent = await updateAgent(
agentQuery,
{ tools, actions },
{
updatingUserId: req.user.id,
forceVersion: true,
},
);
// Only update user field for new actions
const actionUpdateData = { metadata, agent_id };
if (!actions_result || !actions_result.length) {
// For new actions, use the agent owner's user ID
actionUpdateData.user = agent_author || req.user.id;
}
/** @type {[Action]} */
const updatedAction = await updateAction({ action_id }, actionUpdateData);
const sensitiveFields = ['api_key', 'oauth_client_id', 'oauth_client_secret'];
for (let field of sensitiveFields) {
if (updatedAction.metadata[field]) {
delete updatedAction.metadata[field];
let metadata = await encryptMetadata(removeNullishValues(_metadata, true));
const isDomainAllowed = await isActionDomainAllowed(metadata.domain);
if (!isDomainAllowed) {
return res.status(400).json({ message: 'Domain not allowed' });
}
}
res.json([updatedAgent, updatedAction]);
} catch (error) {
const message = 'Trouble updating the Agent Action';
logger.error(message, error);
res.status(500).json({ message });
}
});
let { domain } = metadata;
domain = await domainParser(domain, true);
if (!domain) {
return res.status(400).json({ message: 'No domain provided' });
}
const action_id = _action_id ?? nanoid();
const initialPromises = [];
// Permissions already validated by middleware - load agent directly
initialPromises.push(getAgent({ id: agent_id }));
if (_action_id) {
initialPromises.push(getActions({ action_id }, true));
}
/** @type {[Agent, [Action|undefined]]} */
const [agent, actions_result] = await Promise.all(initialPromises);
if (!agent) {
return res.status(404).json({ message: 'Agent not found for adding action' });
}
if (actions_result && actions_result.length) {
const action = actions_result[0];
metadata = { ...action.metadata, ...metadata };
}
const { actions: _actions = [], author: agent_author } = agent ?? {};
const actions = [];
for (const action of _actions) {
const [_action_domain, current_action_id] = action.split(actionDelimiter);
if (current_action_id === action_id) {
continue;
}
actions.push(action);
}
actions.push(`${domain}${actionDelimiter}${action_id}`);
/** @type {string[]}} */
const { tools: _tools = [] } = agent;
const tools = _tools
.filter((tool) => !(tool && (tool.includes(domain) || tool.includes(action_id))))
.concat(functions.map((tool) => `${tool.function.name}${actionDelimiter}${domain}`));
// Force version update since actions are changing
const updatedAgent = await updateAgent(
{ id: agent_id },
{ tools, actions },
{
updatingUserId: req.user.id,
forceVersion: true,
},
);
// Only update user field for new actions
const actionUpdateData = { metadata, agent_id };
if (!actions_result || !actions_result.length) {
// For new actions, use the agent owner's user ID
actionUpdateData.user = agent_author || req.user.id;
}
/** @type {[Action]} */
const updatedAction = await updateAction({ action_id }, actionUpdateData);
const sensitiveFields = ['api_key', 'oauth_client_id', 'oauth_client_secret'];
for (let field of sensitiveFields) {
if (updatedAction.metadata[field]) {
delete updatedAction.metadata[field];
}
}
res.json([updatedAgent, updatedAction]);
} catch (error) {
const message = 'Trouble updating the Agent Action';
logger.error(message, error);
res.status(500).json({ message });
}
},
);
/**
* Deletes an action for a specific agent.
@@ -149,52 +161,56 @@ router.post('/:agent_id', async (req, res) => {
* @param {string} req.params.action_id - The ID of the action to delete.
* @returns {Object} 200 - success response - application/json
*/
router.delete('/:agent_id/:action_id', async (req, res) => {
try {
const { agent_id, action_id } = req.params;
const admin = isAdmin(req);
router.delete(
'/:agent_id/:action_id',
canAccessAgentResource({
requiredPermission: PermissionBits.EDIT,
resourceIdParam: 'agent_id',
}),
checkAgentCreate,
async (req, res) => {
try {
const { agent_id, action_id } = req.params;
// If admin, can delete any agent, otherwise only user's agents
const agentQuery = admin ? { id: agent_id } : { id: agent_id, author: req.user.id };
const agent = await getAgent(agentQuery);
if (!agent) {
return res.status(404).json({ message: 'Agent not found for deleting action' });
}
const { tools = [], actions = [] } = agent;
let domain = '';
const updatedActions = actions.filter((action) => {
if (action.includes(action_id)) {
[domain] = action.split(actionDelimiter);
return false;
// Permissions already validated by middleware - load agent directly
const agent = await getAgent({ id: agent_id });
if (!agent) {
return res.status(404).json({ message: 'Agent not found for deleting action' });
}
return true;
});
domain = await domainParser(domain, true);
const { tools = [], actions = [] } = agent;
if (!domain) {
return res.status(400).json({ message: 'No domain provided' });
let domain = '';
const updatedActions = actions.filter((action) => {
if (action.includes(action_id)) {
[domain] = action.split(actionDelimiter);
return false;
}
return true;
});
domain = await domainParser(domain, true);
if (!domain) {
return res.status(400).json({ message: 'No domain provided' });
}
const updatedTools = tools.filter((tool) => !(tool && tool.includes(domain)));
// Force version update since actions are being removed
await updateAgent(
{ id: agent_id },
{ tools: updatedTools, actions: updatedActions },
{ updatingUserId: req.user.id, forceVersion: true },
);
await deleteAction({ action_id });
res.status(200).json({ message: 'Action deleted successfully' });
} catch (error) {
const message = 'Trouble deleting the Agent Action';
logger.error(message, error);
res.status(500).json({ message });
}
const updatedTools = tools.filter((tool) => !(tool && tool.includes(domain)));
// Force version update since actions are being removed
await updateAgent(
agentQuery,
{ tools: updatedTools, actions: updatedActions },
{ updatingUserId: req.user.id, forceVersion: true },
);
// If admin, can delete any action, otherwise only user's actions
const actionQuery = admin ? { action_id } : { action_id, user: req.user.id };
await deleteAction(actionQuery);
res.status(200).json({ message: 'Action deleted successfully' });
} catch (error) {
const message = 'Trouble deleting the Agent Action';
logger.error(message, error);
res.status(500).json({ message });
}
});
},
);
module.exports = router;

View File

@@ -1,24 +1,36 @@
const express = require('express');
const { PermissionBits } = require('@librechat/data-schemas');
const { generateCheckAccess, skipAgentCheck } = require('@librechat/api');
const { PermissionTypes, Permissions } = require('librechat-data-provider');
const {
setHeaders,
moderateText,
// validateModel,
generateCheckAccess,
validateConvoAccess,
buildEndpointOption,
canAccessAgentFromBody,
} = require('~/server/middleware');
const { initializeClient } = require('~/server/services/Endpoints/agents');
const AgentController = require('~/server/controllers/agents/request');
const addTitle = require('~/server/services/Endpoints/agents/title');
const { getRoleByName } = require('~/models/Role');
const router = express.Router();
router.use(moderateText);
const checkAgentAccess = generateCheckAccess(PermissionTypes.AGENTS, [Permissions.USE]);
const checkAgentAccess = generateCheckAccess({
permissionType: PermissionTypes.AGENTS,
permissions: [Permissions.USE],
skipCheck: skipAgentCheck,
getRoleByName,
});
const checkAgentResourceAccess = canAccessAgentFromBody({
requiredPermission: PermissionBits.VIEW,
});
router.use(checkAgentAccess);
router.use(checkAgentResourceAccess);
router.use(validateConvoAccess);
router.use(buildEndpointOption);
router.use(setHeaders);

View File

@@ -37,4 +37,6 @@ if (isEnabled(LIMIT_MESSAGE_USER)) {
chatRouter.use('/', chat);
router.use('/chat', chatRouter);
// Add marketplace routes
module.exports = router;

View File

@@ -1,29 +1,37 @@
const express = require('express');
const { generateCheckAccess } = require('@librechat/api');
const { PermissionBits } = require('@librechat/data-schemas');
const { PermissionTypes, Permissions } = require('librechat-data-provider');
const { requireJwtAuth, generateCheckAccess } = require('~/server/middleware');
const { requireJwtAuth, canAccessAgentResource } = require('~/server/middleware');
const v1 = require('~/server/controllers/agents/v1');
const { getRoleByName } = require('~/models/Role');
const actions = require('./actions');
const tools = require('./tools');
const router = express.Router();
const avatar = express.Router();
const checkAgentAccess = generateCheckAccess(PermissionTypes.AGENTS, [Permissions.USE]);
const checkAgentCreate = generateCheckAccess(PermissionTypes.AGENTS, [
Permissions.USE,
Permissions.CREATE,
]);
const checkAgentAccess = generateCheckAccess({
permissionType: PermissionTypes.AGENTS,
permissions: [Permissions.USE],
getRoleByName,
});
const checkAgentCreate = generateCheckAccess({
permissionType: PermissionTypes.AGENTS,
permissions: [Permissions.USE, Permissions.CREATE],
getRoleByName,
});
const checkGlobalAgentShare = generateCheckAccess(
PermissionTypes.AGENTS,
[Permissions.USE, Permissions.CREATE],
{
const checkGlobalAgentShare = generateCheckAccess({
permissionType: PermissionTypes.AGENTS,
permissions: [Permissions.USE, Permissions.CREATE],
bodyProps: {
[Permissions.SHARED_GLOBAL]: ['projectIds', 'removeProjectIds'],
},
);
getRoleByName,
});
router.use(requireJwtAuth);
router.use(checkAgentAccess);
/**
* Agent actions route.
@@ -37,6 +45,11 @@ router.use('/actions', actions);
*/
router.use('/tools', tools);
/**
* Get all agent categories with counts
* @route GET /agents/marketplace/categories
*/
router.get('/categories', v1.getAgentCategories);
/**
* Creates an agent.
* @route POST /agents
@@ -46,13 +59,38 @@ router.use('/tools', tools);
router.post('/', checkAgentCreate, v1.createAgent);
/**
* Retrieves an agent.
* Retrieves basic agent information (VIEW permission required).
* Returns safe, non-sensitive agent data for viewing purposes.
* @route GET /agents/:id
* @param {string} req.params.id - Agent identifier.
* @returns {Agent} 200 - Success response - application/json
* @returns {Agent} 200 - Basic agent info - application/json
*/
router.get('/:id', checkAgentAccess, v1.getAgent);
router.get(
'/:id',
checkAgentAccess,
canAccessAgentResource({
requiredPermission: PermissionBits.VIEW,
resourceIdParam: 'id',
}),
v1.getAgent,
);
/**
* Retrieves full agent details including sensitive configuration (EDIT permission required).
* Returns complete agent data for editing/configuration purposes.
* @route GET /agents/:id/expanded
* @param {string} req.params.id - Agent identifier.
* @returns {Agent} 200 - Full agent details - application/json
*/
router.get(
'/:id/expanded',
checkAgentAccess,
canAccessAgentResource({
requiredPermission: PermissionBits.EDIT,
resourceIdParam: 'id',
}),
(req, res) => v1.getAgent(req, res, true), // Expanded version
);
/**
* Updates an agent.
* @route PATCH /agents/:id
@@ -60,7 +98,15 @@ router.get('/:id', checkAgentAccess, v1.getAgent);
* @param {AgentUpdateParams} req.body - The agent update parameters.
* @returns {Agent} 200 - Success response - application/json
*/
router.patch('/:id', checkGlobalAgentShare, v1.updateAgent);
router.patch(
'/:id',
checkGlobalAgentShare,
canAccessAgentResource({
requiredPermission: PermissionBits.EDIT,
resourceIdParam: 'id',
}),
v1.updateAgent,
);
/**
* Duplicates an agent.
@@ -68,7 +114,15 @@ router.patch('/:id', checkGlobalAgentShare, v1.updateAgent);
* @param {string} req.params.id - Agent identifier.
* @returns {Agent} 201 - Success response - application/json
*/
router.post('/:id/duplicate', checkAgentCreate, v1.duplicateAgent);
router.post(
'/:id/duplicate',
checkAgentCreate,
canAccessAgentResource({
requiredPermission: PermissionBits.VIEW,
resourceIdParam: 'id',
}),
v1.duplicateAgent,
);
/**
* Deletes an agent.
@@ -76,7 +130,15 @@ router.post('/:id/duplicate', checkAgentCreate, v1.duplicateAgent);
* @param {string} req.params.id - Agent identifier.
* @returns {Agent} 200 - success response - application/json
*/
router.delete('/:id', checkAgentCreate, v1.deleteAgent);
router.delete(
'/:id',
checkAgentCreate,
canAccessAgentResource({
requiredPermission: PermissionBits.DELETE,
resourceIdParam: 'id',
}),
v1.deleteAgent,
);
/**
* Reverts an agent to a previous version.
@@ -103,6 +165,14 @@ router.get('/', checkAgentAccess, v1.getListAgents);
* @param {string} [req.body.metadata] - Optional metadata for the agent's avatar.
* @returns {Object} 200 - success response - application/json
*/
avatar.post('/:agent_id/avatar/', checkAgentAccess, v1.uploadAgentAvatar);
avatar.post(
'/:agent_id/avatar/',
checkAgentAccess,
canAccessAgentResource({
requiredPermission: PermissionBits.EDIT,
resourceIdParam: 'agent_id',
}),
v1.uploadAgentAvatar,
);
module.exports = { v1: router, avatar };

View File

@@ -1,207 +0,0 @@
const express = require('express');
const { getResponseSender } = require('librechat-data-provider');
const {
setHeaders,
moderateText,
validateModel,
handleAbortError,
validateEndpoint,
buildEndpointOption,
createAbortController,
} = require('~/server/middleware');
const { sendMessage, createOnProgress, formatSteps, formatAction } = require('~/server/utils');
const { initializeClient } = require('~/server/services/Endpoints/gptPlugins');
const { saveMessage, updateMessage } = require('~/models');
const { validateTools } = require('~/app');
const { logger } = require('~/config');
const router = express.Router();
router.use(moderateText);
router.post(
'/',
validateEndpoint,
validateModel,
buildEndpointOption,
setHeaders,
async (req, res) => {
let {
text,
generation,
endpointOption,
conversationId,
responseMessageId,
isContinued = false,
parentMessageId = null,
overrideParentMessageId = null,
} = req.body;
logger.debug('[/edit/gptPlugins]', {
text,
generation,
isContinued,
conversationId,
...endpointOption,
});
let userMessage;
let userMessagePromise;
let promptTokens;
const sender = getResponseSender({
...endpointOption,
model: endpointOption.modelOptions.model,
});
const userMessageId = parentMessageId;
const user = req.user.id;
const plugin = {
loading: true,
inputs: [],
latest: null,
outputs: null,
};
const getReqData = (data = {}) => {
for (let key in data) {
if (key === 'userMessage') {
userMessage = data[key];
} else if (key === 'userMessagePromise') {
userMessagePromise = data[key];
} else if (key === 'responseMessageId') {
responseMessageId = data[key];
} else if (key === 'promptTokens') {
promptTokens = data[key];
}
}
};
const {
onProgress: progressCallback,
sendIntermediateMessage,
getPartialText,
} = createOnProgress({
generation,
onProgress: () => {
if (plugin.loading === true) {
plugin.loading = false;
}
},
});
const onChainEnd = (data) => {
let { intermediateSteps: steps } = data;
plugin.outputs = steps && steps[0].action ? formatSteps(steps) : 'An error occurred.';
plugin.loading = false;
saveMessage(
req,
{ ...userMessage, user },
{ context: 'api/server/routes/ask/gptPlugins.js - onChainEnd' },
);
sendIntermediateMessage(res, {
plugin,
parentMessageId: userMessage.messageId,
messageId: responseMessageId,
});
// logger.debug('CHAIN END', plugin.outputs);
};
const getAbortData = () => ({
sender,
conversationId,
userMessagePromise,
messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
plugin: { ...plugin, loading: false },
userMessage,
promptTokens,
});
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
try {
endpointOption.tools = await validateTools(user, endpointOption.tools);
const { client } = await initializeClient({ req, res, endpointOption });
const onAgentAction = (action, start = false) => {
const formattedAction = formatAction(action);
plugin.inputs.push(formattedAction);
plugin.latest = formattedAction.plugin;
if (!start && !client.skipSaveUserMessage) {
saveMessage(
req,
{ ...userMessage, user },
{ context: 'api/server/routes/ask/gptPlugins.js - onAgentAction' },
);
}
sendIntermediateMessage(res, {
plugin,
parentMessageId: userMessage.messageId,
messageId: responseMessageId,
});
// logger.debug('PLUGIN ACTION', formattedAction);
};
let response = await client.sendMessage(text, {
user,
generation,
isContinued,
isEdited: true,
conversationId,
parentMessageId,
responseMessageId,
overrideParentMessageId,
getReqData,
onAgentAction,
onChainEnd,
onStart,
...endpointOption,
progressCallback,
progressOptions: {
res,
plugin,
// parentMessageId: overrideParentMessageId || userMessageId,
},
abortController,
});
if (overrideParentMessageId) {
response.parentMessageId = overrideParentMessageId;
}
logger.debug('[/edit/gptPlugins] CLIENT RESPONSE', response);
const { conversation = {} } = await response.databasePromise;
delete response.databasePromise;
conversation.title =
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
sendMessage(res, {
title: conversation.title,
final: true,
conversation,
requestMessage: userMessage,
responseMessage: response,
});
res.end();
response.plugin = { ...plugin, loading: false };
await updateMessage(
req,
{ ...response, user },
{ context: 'api/server/routes/edit/gptPlugins.js' },
);
} catch (error) {
const partialText = getPartialText();
handleAbortError(res, req, error, {
partialText,
conversationId,
sender,
messageId: responseMessageId,
parentMessageId: userMessageId ?? parentMessageId,
});
}
},
);
module.exports = router;

View File

@@ -3,7 +3,6 @@ const openAI = require('./openAI');
const custom = require('./custom');
const google = require('./google');
const anthropic = require('./anthropic');
const gptPlugins = require('./gptPlugins');
const { isEnabled } = require('~/server/utils');
const { EModelEndpoint } = require('librechat-data-provider');
const {
@@ -39,7 +38,6 @@ if (isEnabled(LIMIT_MESSAGE_USER)) {
router.use(validateConvoAccess);
router.use([`/${EModelEndpoint.azureOpenAI}`, `/${EModelEndpoint.openAI}`], openAI);
router.use(`/${EModelEndpoint.gptPlugins}`, gptPlugins);
router.use(`/${EModelEndpoint.anthropic}`, anthropic);
router.use(`/${EModelEndpoint.google}`, google);
router.use(`/${EModelEndpoint.custom}`, custom);

View File

@@ -283,7 +283,10 @@ router.post('/', async (req, res) => {
message += ': ' + error.message;
}
if (error.message?.includes('Invalid file format')) {
if (
error.message?.includes('Invalid file format') ||
error.message?.includes('No OCR result')
) {
message = error.message;
}

View File

@@ -1,3 +1,4 @@
const accessPermissions = require('./accessPermissions');
const assistants = require('./assistants');
const categories = require('./categories');
const tokenizer = require('./tokenizer');
@@ -28,6 +29,7 @@ const user = require('./user');
const mcp = require('./mcp');
module.exports = {
mcp,
edit,
auth,
keys,
@@ -55,5 +57,5 @@ module.exports = {
assistants,
categories,
staticRoute,
mcp,
accessPermissions,
};

View File

@@ -1,37 +1,43 @@
const express = require('express');
const { Tokenizer } = require('@librechat/api');
const { Tokenizer, generateCheckAccess } = require('@librechat/api');
const { PermissionTypes, Permissions } = require('librechat-data-provider');
const {
getAllUserMemories,
toggleUserMemories,
createMemory,
setMemory,
deleteMemory,
setMemory,
} = require('~/models');
const { requireJwtAuth, generateCheckAccess } = require('~/server/middleware');
const { requireJwtAuth } = require('~/server/middleware');
const { getRoleByName } = require('~/models/Role');
const router = express.Router();
const checkMemoryRead = generateCheckAccess(PermissionTypes.MEMORIES, [
Permissions.USE,
Permissions.READ,
]);
const checkMemoryCreate = generateCheckAccess(PermissionTypes.MEMORIES, [
Permissions.USE,
Permissions.CREATE,
]);
const checkMemoryUpdate = generateCheckAccess(PermissionTypes.MEMORIES, [
Permissions.USE,
Permissions.UPDATE,
]);
const checkMemoryDelete = generateCheckAccess(PermissionTypes.MEMORIES, [
Permissions.USE,
Permissions.UPDATE,
]);
const checkMemoryOptOut = generateCheckAccess(PermissionTypes.MEMORIES, [
Permissions.USE,
Permissions.OPT_OUT,
]);
const checkMemoryRead = generateCheckAccess({
permissionType: PermissionTypes.MEMORIES,
permissions: [Permissions.USE, Permissions.READ],
getRoleByName,
});
const checkMemoryCreate = generateCheckAccess({
permissionType: PermissionTypes.MEMORIES,
permissions: [Permissions.USE, Permissions.CREATE],
getRoleByName,
});
const checkMemoryUpdate = generateCheckAccess({
permissionType: PermissionTypes.MEMORIES,
permissions: [Permissions.USE, Permissions.UPDATE],
getRoleByName,
});
const checkMemoryDelete = generateCheckAccess({
permissionType: PermissionTypes.MEMORIES,
permissions: [Permissions.USE, Permissions.UPDATE],
getRoleByName,
});
const checkMemoryOptOut = generateCheckAccess({
permissionType: PermissionTypes.MEMORIES,
permissions: [Permissions.USE, Permissions.OPT_OUT],
getRoleByName,
});
router.use(requireJwtAuth);

View File

@@ -9,6 +9,7 @@ const {
setBalanceConfig,
checkDomainAllowed,
} = require('~/server/middleware');
const { syncUserEntraGroupMemberships } = require('~/server/services/PermissionService');
const { setAuthTokens, setOpenIDAuthTokens } = require('~/server/services/AuthService');
const { isEnabled } = require('~/server/utils');
const { logger } = require('~/config');
@@ -35,6 +36,7 @@ const oauthHandler = async (req, res) => {
req.user.provider == 'openid' &&
isEnabled(process.env.OPENID_REUSE_TOKENS) === true
) {
await syncUserEntraGroupMemberships(req.user, req.user.tokenset.access_token);
setOpenIDAuthTokens(req.user.tokenset, res);
} else {
await setAuthTokens(req.user._id, res);

View File

@@ -1,5 +1,7 @@
const express = require('express');
const { PermissionTypes, Permissions, SystemRoles } = require('librechat-data-provider');
const { logger } = require('@librechat/data-schemas');
const { generateCheckAccess } = require('@librechat/api');
const { Permissions, SystemRoles, PermissionTypes } = require('librechat-data-provider');
const {
getPrompt,
getPrompts,
@@ -14,24 +16,30 @@ const {
// updatePromptLabels,
makePromptProduction,
} = require('~/models/Prompt');
const { requireJwtAuth, generateCheckAccess } = require('~/server/middleware');
const { logger } = require('~/config');
const { requireJwtAuth } = require('~/server/middleware');
const { getRoleByName } = require('~/models/Role');
const router = express.Router();
const checkPromptAccess = generateCheckAccess(PermissionTypes.PROMPTS, [Permissions.USE]);
const checkPromptCreate = generateCheckAccess(PermissionTypes.PROMPTS, [
Permissions.USE,
Permissions.CREATE,
]);
const checkPromptAccess = generateCheckAccess({
permissionType: PermissionTypes.PROMPTS,
permissions: [Permissions.USE],
getRoleByName,
});
const checkPromptCreate = generateCheckAccess({
permissionType: PermissionTypes.PROMPTS,
permissions: [Permissions.USE, Permissions.CREATE],
getRoleByName,
});
const checkGlobalPromptShare = generateCheckAccess(
PermissionTypes.PROMPTS,
[Permissions.USE, Permissions.CREATE],
{
const checkGlobalPromptShare = generateCheckAccess({
permissionType: PermissionTypes.PROMPTS,
permissions: [Permissions.USE, Permissions.CREATE],
bodyProps: {
[Permissions.SHARED_GLOBAL]: ['projectIds', 'removeProjectIds'],
},
);
getRoleByName,
});
router.use(requireJwtAuth);
router.use(checkPromptAccess);

View File

@@ -1,18 +1,24 @@
const express = require('express');
const { logger } = require('@librechat/data-schemas');
const { generateCheckAccess } = require('@librechat/api');
const { PermissionTypes, Permissions } = require('librechat-data-provider');
const {
getConversationTags,
updateTagsForConversation,
updateConversationTag,
createConversationTag,
deleteConversationTag,
updateTagsForConversation,
getConversationTags,
} = require('~/models/ConversationTag');
const { requireJwtAuth, generateCheckAccess } = require('~/server/middleware');
const { logger } = require('~/config');
const { requireJwtAuth } = require('~/server/middleware');
const { getRoleByName } = require('~/models/Role');
const router = express.Router();
const checkBookmarkAccess = generateCheckAccess(PermissionTypes.BOOKMARKS, [Permissions.USE]);
const checkBookmarkAccess = generateCheckAccess({
permissionType: PermissionTypes.BOOKMARKS,
permissions: [Permissions.USE],
getRoleByName,
});
router.use(requireJwtAuth);
router.use(checkBookmarkAccess);

View File

@@ -1,5 +1,7 @@
jest.mock('~/models', () => ({
initializeRoles: jest.fn(),
seedDefaultRoles: jest.fn(),
ensureDefaultCategories: jest.fn(),
}));
jest.mock('~/models/Role', () => ({
updateAccessPermissions: jest.fn(),

View File

@@ -17,6 +17,7 @@ const {
const { azureAssistantsDefaults, assistantsConfigSetup } = require('./start/assistants');
const { initializeAzureBlobService } = require('./Files/Azure/initialize');
const { initializeFirebase } = require('./Files/Firebase/initialize');
const { seedDefaultRoles, initializeRoles, ensureDefaultCategories } = require('~/models');
const loadCustomConfig = require('./Config/loadCustomConfig');
const handleRateLimits = require('./Config/handleRateLimits');
const { loadDefaultInterface } = require('./start/interface');
@@ -26,7 +27,6 @@ const { processModelSpecs } = require('./start/modelSpecs');
const { initializeS3 } = require('./Files/S3/initialize');
const { loadAndFormatTools } = require('./ToolService');
const { isEnabled } = require('~/server/utils');
const { initializeRoles } = require('~/models');
const { setCachedTools } = require('./Config');
const paths = require('~/config/paths');
@@ -37,6 +37,8 @@ const paths = require('~/config/paths');
*/
const AppService = async (app) => {
await initializeRoles();
await seedDefaultRoles();
await ensureDefaultCategories();
/** @type {TCustomConfig} */
const config = (await loadCustomConfig()) ?? {};
const configDefaults = getConfigDefaults();

View File

@@ -28,6 +28,8 @@ jest.mock('./Files/Firebase/initialize', () => ({
}));
jest.mock('~/models', () => ({
initializeRoles: jest.fn(),
seedDefaultRoles: jest.fn(),
ensureDefaultCategories: jest.fn(),
}));
jest.mock('~/models/Role', () => ({
updateAccessPermissions: jest.fn(),

View File

@@ -1,4 +1,7 @@
const { klona } = require('klona');
const { sleep } = require('@librechat/agents');
const { sendEvent } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const {
StepTypes,
RunStatus,
@@ -11,11 +14,10 @@ const {
} = require('librechat-data-provider');
const { retrieveAndProcessFile } = require('~/server/services/Files/process');
const { processRequiredActions } = require('~/server/services/ToolService');
const { createOnProgress, sendMessage, sleep } = require('~/server/utils');
const { RunManager, waitForRun } = require('~/server/services/Runs');
const { processMessages } = require('~/server/services/Threads');
const { createOnProgress } = require('~/server/utils');
const { TextStream } = require('~/app/clients');
const { logger } = require('~/config');
/**
* Sorts, processes, and flattens messages to a single string.
@@ -64,7 +66,7 @@ async function createOnTextProgress({
};
logger.debug('Content data:', contentData);
sendMessage(openai.res, contentData);
sendEvent(openai.res, contentData);
};
}

View File

@@ -1,4 +1,5 @@
const bcrypt = require('bcryptjs');
const jwt = require('jsonwebtoken');
const { webcrypto } = require('node:crypto');
const { isEnabled } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
@@ -499,6 +500,18 @@ const resendVerificationEmail = async (req) => {
};
}
};
/**
* Generate a short-lived JWT token
* @param {String} userId - The ID of the user
* @param {String} [expireIn='5m'] - The expiration time for the token (default is 5 minutes)
* @returns {String} - The generated JWT token
*/
const generateShortLivedToken = (userId, expireIn = '5m') => {
return jwt.sign({ id: userId }, process.env.JWT_SECRET, {
expiresIn: expireIn,
algorithm: 'HS256',
});
};
module.exports = {
logoutUser,
@@ -506,7 +519,8 @@ module.exports = {
registerUser,
setAuthTokens,
resetPassword,
setOpenIDAuthTokens,
requestPasswordReset,
resendVerificationEmail,
setOpenIDAuthTokens,
generateShortLivedToken,
};

View File

@@ -1,5 +1,6 @@
const { isUserProvided } = require('@librechat/api');
const { EModelEndpoint } = require('librechat-data-provider');
const { isUserProvided, generateConfig } = require('~/server/utils');
const { generateConfig } = require('~/server/utils/handleText');
const {
OPENAI_API_KEY: openAIApiKey,

View File

@@ -40,6 +40,7 @@ async function getBalanceConfig() {
/**
*
* @param {string | EModelEndpoint} endpoint
* @returns {Promise<TEndpoint | undefined>}
*/
const getCustomEndpointConfig = async (endpoint) => {
const customConfig = await getCustomConfig();

View File

@@ -1,18 +1,18 @@
const path = require('path');
const {
CacheKeys,
configSchema,
EImageOutputType,
validateSettingDefinitions,
agentParamSettings,
paramSettings,
} = require('librechat-data-provider');
const getLogStores = require('~/cache/getLogStores');
const loadYaml = require('~/utils/loadYaml');
const { logger } = require('~/config');
const axios = require('axios');
const yaml = require('js-yaml');
const keyBy = require('lodash/keyBy');
const { loadYaml } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const {
CacheKeys,
configSchema,
paramSettings,
EImageOutputType,
agentParamSettings,
validateSettingDefinitions,
} = require('librechat-data-provider');
const getLogStores = require('~/cache/getLogStores');
const projectRoot = path.resolve(__dirname, '..', '..', '..', '..');
const defaultConfigPath = path.resolve(projectRoot, 'librechat.yaml');

View File

@@ -1,6 +1,9 @@
jest.mock('axios');
jest.mock('~/cache/getLogStores');
jest.mock('~/utils/loadYaml');
jest.mock('@librechat/api', () => ({
...jest.requireActual('@librechat/api'),
loadYaml: jest.fn(),
}));
jest.mock('librechat-data-provider', () => {
const actual = jest.requireActual('librechat-data-provider');
return {
@@ -30,11 +33,22 @@ jest.mock('librechat-data-provider', () => {
};
});
jest.mock('@librechat/data-schemas', () => {
return {
logger: {
info: jest.fn(),
warn: jest.fn(),
debug: jest.fn(),
error: jest.fn(),
},
};
});
const axios = require('axios');
const { loadYaml } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const loadCustomConfig = require('./loadCustomConfig');
const getLogStores = require('~/cache/getLogStores');
const loadYaml = require('~/utils/loadYaml');
const { logger } = require('~/config');
describe('loadCustomConfig', () => {
const mockSet = jest.fn();

View File

@@ -11,30 +11,13 @@ const {
replaceSpecialVars,
providerEndpointMap,
} = require('librechat-data-provider');
const initAnthropic = require('~/server/services/Endpoints/anthropic/initialize');
const getBedrockOptions = require('~/server/services/Endpoints/bedrock/options');
const initOpenAI = require('~/server/services/Endpoints/openAI/initialize');
const initCustom = require('~/server/services/Endpoints/custom/initialize');
const initGoogle = require('~/server/services/Endpoints/google/initialize');
const { getProviderConfig } = require('~/server/services/Endpoints');
const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts');
const { getCustomEndpointConfig } = require('~/server/services/Config');
const { processFiles } = require('~/server/services/Files/process');
const { getFiles, getToolFilesByIds } = require('~/models/File');
const { getConvoFiles } = require('~/models/Conversation');
const { getModelMaxTokens } = require('~/utils');
const providerConfigMap = {
[Providers.XAI]: initCustom,
[Providers.OLLAMA]: initCustom,
[Providers.DEEPSEEK]: initCustom,
[Providers.OPENROUTER]: initCustom,
[EModelEndpoint.openAI]: initOpenAI,
[EModelEndpoint.google]: initGoogle,
[EModelEndpoint.azureOpenAI]: initOpenAI,
[EModelEndpoint.anthropic]: initAnthropic,
[EModelEndpoint.bedrock]: getBedrockOptions,
};
/**
* @param {object} params
* @param {ServerRequest} params.req
@@ -114,17 +97,9 @@ const initializeAgent = async ({
})) ?? {};
agent.endpoint = provider;
let getOptions = providerConfigMap[provider];
if (!getOptions && providerConfigMap[provider.toLowerCase()] != null) {
agent.provider = provider.toLowerCase();
getOptions = providerConfigMap[agent.provider];
} else if (!getOptions) {
const customEndpointConfig = await getCustomEndpointConfig(provider);
if (!customEndpointConfig) {
throw new Error(`Provider ${provider} not supported`);
}
getOptions = initCustom;
agent.provider = Providers.OPENAI;
const { getOptions, overrideProvider } = await getProviderConfig(provider);
if (overrideProvider) {
agent.provider = overrideProvider;
}
const _endpointOption =

View File

@@ -23,7 +23,7 @@ const addTitle = async (req, { text, response, client }) => {
let timeoutId;
try {
const timeoutPromise = new Promise((_, reject) => {
timeoutId = setTimeout(() => reject(new Error('Title generation timeout')), 25000);
timeoutId = setTimeout(() => reject(new Error('Title generation timeout')), 45000);
}).catch((error) => {
logger.error('Title error:', error);
});

View File

@@ -41,7 +41,7 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio
{
reverseProxyUrl: ANTHROPIC_REVERSE_PROXY ?? null,
proxy: PROXY ?? null,
modelOptions: endpointOption.model_parameters,
modelOptions: endpointOption?.model_parameters ?? {},
},
clientOptions,
);

View File

@@ -75,6 +75,7 @@ function getLLMConfig(apiKey, options = {}) {
if (options.reverseProxyUrl) {
requestOptions.clientOptions.baseURL = options.reverseProxyUrl;
requestOptions.anthropicApiUrl = options.reverseProxyUrl;
}
return {

View File

@@ -1,11 +1,45 @@
const { anthropicSettings } = require('librechat-data-provider');
const { anthropicSettings, removeNullishValues } = require('librechat-data-provider');
const { getLLMConfig } = require('~/server/services/Endpoints/anthropic/llm');
const { checkPromptCacheSupport, getClaudeHeaders, configureReasoning } = require('./helpers');
jest.mock('https-proxy-agent', () => ({
HttpsProxyAgent: jest.fn().mockImplementation((proxy) => ({ proxy })),
}));
jest.mock('./helpers', () => ({
checkPromptCacheSupport: jest.fn(),
getClaudeHeaders: jest.fn(),
configureReasoning: jest.fn((requestOptions) => requestOptions),
}));
jest.mock('librechat-data-provider', () => ({
anthropicSettings: {
model: { default: 'claude-3-opus-20240229' },
maxOutputTokens: { default: 4096, reset: jest.fn(() => 4096) },
thinking: { default: false },
promptCache: { default: false },
thinkingBudget: { default: null },
},
removeNullishValues: jest.fn((obj) => {
const result = {};
for (const key in obj) {
if (obj[key] !== null && obj[key] !== undefined) {
result[key] = obj[key];
}
}
return result;
}),
}));
describe('getLLMConfig', () => {
beforeEach(() => {
jest.clearAllMocks();
checkPromptCacheSupport.mockReturnValue(false);
getClaudeHeaders.mockReturnValue(undefined);
configureReasoning.mockImplementation((requestOptions) => requestOptions);
anthropicSettings.maxOutputTokens.reset.mockReturnValue(4096);
});
it('should create a basic configuration with default values', () => {
const result = getLLMConfig('test-api-key', { modelOptions: {} });
@@ -36,6 +70,7 @@ describe('getLLMConfig', () => {
});
expect(result.llmConfig.clientOptions).toHaveProperty('baseURL', 'http://reverse-proxy');
expect(result.llmConfig).toHaveProperty('anthropicApiUrl', 'http://reverse-proxy');
});
it('should include topK and topP for non-Claude-3.7 models', () => {
@@ -65,6 +100,11 @@ describe('getLLMConfig', () => {
});
it('should NOT include topK and topP for Claude-3-7 models (hyphen notation)', () => {
configureReasoning.mockImplementation((requestOptions) => {
requestOptions.thinking = { type: 'enabled' };
return requestOptions;
});
const result = getLLMConfig('test-api-key', {
modelOptions: {
model: 'claude-3-7-sonnet',
@@ -78,6 +118,11 @@ describe('getLLMConfig', () => {
});
it('should NOT include topK and topP for Claude-3.7 models (decimal notation)', () => {
configureReasoning.mockImplementation((requestOptions) => {
requestOptions.thinking = { type: 'enabled' };
return requestOptions;
});
const result = getLLMConfig('test-api-key', {
modelOptions: {
model: 'claude-3.7-sonnet',
@@ -154,4 +199,160 @@ describe('getLLMConfig', () => {
expect(result3.llmConfig).toHaveProperty('topK', 10);
expect(result3.llmConfig).toHaveProperty('topP', 0.9);
});
describe('Edge cases', () => {
it('should handle missing apiKey', () => {
const result = getLLMConfig(undefined, { modelOptions: {} });
expect(result.llmConfig).not.toHaveProperty('apiKey');
});
it('should handle empty modelOptions', () => {
expect(() => {
getLLMConfig('test-api-key', {});
}).toThrow("Cannot read properties of undefined (reading 'thinking')");
});
it('should handle no options parameter', () => {
expect(() => {
getLLMConfig('test-api-key');
}).toThrow("Cannot read properties of undefined (reading 'thinking')");
});
it('should handle temperature, stop sequences, and stream settings', () => {
const result = getLLMConfig('test-api-key', {
modelOptions: {
temperature: 0.7,
stop: ['\n\n', 'END'],
stream: false,
},
});
expect(result.llmConfig).toHaveProperty('temperature', 0.7);
expect(result.llmConfig).toHaveProperty('stopSequences', ['\n\n', 'END']);
expect(result.llmConfig).toHaveProperty('stream', false);
});
it('should handle maxOutputTokens when explicitly set to falsy value', () => {
anthropicSettings.maxOutputTokens.reset.mockReturnValue(8192);
const result = getLLMConfig('test-api-key', {
modelOptions: {
model: 'claude-3-opus',
maxOutputTokens: null,
},
});
expect(anthropicSettings.maxOutputTokens.reset).toHaveBeenCalledWith('claude-3-opus');
expect(result.llmConfig).toHaveProperty('maxTokens', 8192);
});
it('should handle both proxy and reverseProxyUrl', () => {
const result = getLLMConfig('test-api-key', {
modelOptions: {},
proxy: 'http://proxy:8080',
reverseProxyUrl: 'https://reverse-proxy.com',
});
expect(result.llmConfig.clientOptions).toHaveProperty('fetchOptions');
expect(result.llmConfig.clientOptions.fetchOptions).toHaveProperty('dispatcher');
expect(result.llmConfig.clientOptions.fetchOptions.dispatcher).toBeDefined();
expect(result.llmConfig.clientOptions.fetchOptions.dispatcher.constructor.name).toBe(
'ProxyAgent',
);
expect(result.llmConfig.clientOptions).toHaveProperty('baseURL', 'https://reverse-proxy.com');
expect(result.llmConfig).toHaveProperty('anthropicApiUrl', 'https://reverse-proxy.com');
});
it('should handle prompt cache with supported model', () => {
checkPromptCacheSupport.mockReturnValue(true);
getClaudeHeaders.mockReturnValue({ 'anthropic-beta': 'prompt-caching-2024-07-31' });
const result = getLLMConfig('test-api-key', {
modelOptions: {
model: 'claude-3-5-sonnet',
promptCache: true,
},
});
expect(checkPromptCacheSupport).toHaveBeenCalledWith('claude-3-5-sonnet');
expect(getClaudeHeaders).toHaveBeenCalledWith('claude-3-5-sonnet', true);
expect(result.llmConfig.clientOptions.defaultHeaders).toEqual({
'anthropic-beta': 'prompt-caching-2024-07-31',
});
});
it('should handle thinking and thinkingBudget options', () => {
configureReasoning.mockImplementation((requestOptions, systemOptions) => {
if (systemOptions.thinking) {
requestOptions.thinking = { type: 'enabled' };
}
if (systemOptions.thinkingBudget) {
requestOptions.thinking = {
...requestOptions.thinking,
budget_tokens: systemOptions.thinkingBudget,
};
}
return requestOptions;
});
getLLMConfig('test-api-key', {
modelOptions: {
model: 'claude-3-7-sonnet',
thinking: true,
thinkingBudget: 5000,
},
});
expect(configureReasoning).toHaveBeenCalledWith(
expect.any(Object),
expect.objectContaining({
thinking: true,
promptCache: false,
thinkingBudget: 5000,
}),
);
});
it('should remove system options from modelOptions', () => {
const modelOptions = {
model: 'claude-3-opus',
thinking: true,
promptCache: true,
thinkingBudget: 1000,
temperature: 0.5,
};
getLLMConfig('test-api-key', { modelOptions });
expect(modelOptions).not.toHaveProperty('thinking');
expect(modelOptions).not.toHaveProperty('promptCache');
expect(modelOptions).not.toHaveProperty('thinkingBudget');
expect(modelOptions).toHaveProperty('temperature', 0.5);
});
it('should handle all nullish values removal', () => {
removeNullishValues.mockImplementation((obj) => {
const cleaned = {};
Object.entries(obj).forEach(([key, value]) => {
if (value !== null && value !== undefined) {
cleaned[key] = value;
}
});
return cleaned;
});
const result = getLLMConfig('test-api-key', {
modelOptions: {
temperature: null,
topP: undefined,
topK: 0,
stop: [],
},
});
expect(result.llmConfig).not.toHaveProperty('temperature');
expect(result.llmConfig).not.toHaveProperty('topP');
expect(result.llmConfig).toHaveProperty('topK', 0);
expect(result.llmConfig).toHaveProperty('stopSequences', []);
});
});
});

View File

@@ -1,12 +1,7 @@
const OpenAI = require('openai');
const { HttpsProxyAgent } = require('https-proxy-agent');
const { constructAzureURL, isUserProvided } = require('@librechat/api');
const {
ErrorTypes,
EModelEndpoint,
resolveHeaders,
mapModelToAzureConfig,
} = require('librechat-data-provider');
const { constructAzureURL, isUserProvided, resolveHeaders } = require('@librechat/api');
const { ErrorTypes, EModelEndpoint, mapModelToAzureConfig } = require('librechat-data-provider');
const {
getUserKeyValues,
getUserKeyExpiry,
@@ -114,11 +109,14 @@ const initializeClient = async ({ req, res, version, endpointOption, initAppClie
apiKey = azureOptions.azureOpenAIApiKey;
opts.defaultQuery = { 'api-version': azureOptions.azureOpenAIApiVersion };
opts.defaultHeaders = resolveHeaders({
...headers,
'api-key': apiKey,
'OpenAI-Beta': `assistants=${version}`,
});
opts.defaultHeaders = resolveHeaders(
{
...headers,
'api-key': apiKey,
'OpenAI-Beta': `assistants=${version}`,
},
req.user,
);
opts.model = azureOptions.azureOpenAIApiDeploymentName;
if (initAppClient) {

View File

@@ -64,7 +64,7 @@ const getOptions = async ({ req, overrideModel, endpointOption }) => {
/** @type {BedrockClientOptions} */
const requestOptions = {
model: overrideModel ?? endpointOption.model,
model: overrideModel ?? endpointOption?.model,
region: BEDROCK_AWS_DEFAULT_REGION,
};
@@ -76,7 +76,7 @@ const getOptions = async ({ req, overrideModel, endpointOption }) => {
const llmConfig = bedrockOutputParser(
bedrockInputParser.parse(
removeNullishValues(Object.assign(requestOptions, endpointOption.model_parameters)),
removeNullishValues(Object.assign(requestOptions, endpointOption?.model_parameters ?? {})),
),
);

View File

@@ -6,7 +6,7 @@ const {
extractEnvVariable,
} = require('librechat-data-provider');
const { Providers } = require('@librechat/agents');
const { getOpenAIConfig, createHandleLLMNewToken } = require('@librechat/api');
const { getOpenAIConfig, createHandleLLMNewToken, resolveHeaders } = require('@librechat/api');
const { getUserKeyValues, checkUserKeyExpiry } = require('~/server/services/UserService');
const { getCustomEndpointConfig } = require('~/server/services/Config');
const { fetchModels } = require('~/server/services/ModelService');
@@ -28,12 +28,7 @@ const initializeClient = async ({ req, res, endpointOption, optionsOnly, overrid
const CUSTOM_API_KEY = extractEnvVariable(endpointConfig.apiKey);
const CUSTOM_BASE_URL = extractEnvVariable(endpointConfig.baseURL);
let resolvedHeaders = {};
if (endpointConfig.headers && typeof endpointConfig.headers === 'object') {
Object.keys(endpointConfig.headers).forEach((key) => {
resolvedHeaders[key] = extractEnvVariable(endpointConfig.headers[key]);
});
}
let resolvedHeaders = resolveHeaders(endpointConfig.headers, req.user);
if (CUSTOM_API_KEY.match(envVarRegex)) {
throw new Error(`Missing API Key for ${endpoint}.`);
@@ -134,7 +129,7 @@ const initializeClient = async ({ req, res, endpointOption, optionsOnly, overrid
};
if (optionsOnly) {
const modelOptions = endpointOption.model_parameters;
const modelOptions = endpointOption?.model_parameters ?? {};
if (endpoint !== Providers.OLLAMA) {
clientOptions = Object.assign(
{

View File

@@ -0,0 +1,93 @@
const initializeClient = require('./initialize');
jest.mock('@librechat/api', () => ({
resolveHeaders: jest.fn(),
getOpenAIConfig: jest.fn(),
createHandleLLMNewToken: jest.fn(),
}));
jest.mock('librechat-data-provider', () => ({
CacheKeys: { TOKEN_CONFIG: 'token_config' },
ErrorTypes: { NO_USER_KEY: 'NO_USER_KEY', NO_BASE_URL: 'NO_BASE_URL' },
envVarRegex: /\$\{([^}]+)\}/,
FetchTokenConfig: {},
extractEnvVariable: jest.fn((value) => value),
}));
jest.mock('@librechat/agents', () => ({
Providers: { OLLAMA: 'ollama' },
}));
jest.mock('~/server/services/UserService', () => ({
getUserKeyValues: jest.fn(),
checkUserKeyExpiry: jest.fn(),
}));
jest.mock('~/server/services/Config', () => ({
getCustomEndpointConfig: jest.fn().mockResolvedValue({
apiKey: 'test-key',
baseURL: 'https://test.com',
headers: { 'x-user': '{{LIBRECHAT_USER_ID}}', 'x-email': '{{LIBRECHAT_USER_EMAIL}}' },
models: { default: ['test-model'] },
}),
}));
jest.mock('~/server/services/ModelService', () => ({
fetchModels: jest.fn(),
}));
jest.mock('~/app/clients/OpenAIClient', () => {
return jest.fn().mockImplementation(() => ({
options: {},
}));
});
jest.mock('~/server/utils', () => ({
isUserProvided: jest.fn().mockReturnValue(false),
}));
jest.mock('~/cache/getLogStores', () =>
jest.fn().mockReturnValue({
get: jest.fn(),
}),
);
describe('custom/initializeClient', () => {
const mockRequest = {
body: { endpoint: 'test-endpoint' },
user: { id: 'user-123', email: 'test@example.com' },
app: { locals: {} },
};
const mockResponse = {};
beforeEach(() => {
jest.clearAllMocks();
});
it('calls resolveHeaders with headers and user', async () => {
const { resolveHeaders } = require('@librechat/api');
await initializeClient({ req: mockRequest, res: mockResponse, optionsOnly: true });
expect(resolveHeaders).toHaveBeenCalledWith(
{ 'x-user': '{{LIBRECHAT_USER_ID}}', 'x-email': '{{LIBRECHAT_USER_EMAIL}}' },
{ id: 'user-123', email: 'test@example.com' },
);
});
it('throws if endpoint config is missing', async () => {
const { getCustomEndpointConfig } = require('~/server/services/Config');
getCustomEndpointConfig.mockResolvedValueOnce(null);
await expect(
initializeClient({ req: mockRequest, res: mockResponse, optionsOnly: true }),
).rejects.toThrow('Config not found for the test-endpoint custom endpoint.');
});
it('throws if user is missing', async () => {
await expect(
initializeClient({
req: { ...mockRequest, user: undefined },
res: mockResponse,
optionsOnly: true,
}),
).rejects.toThrow("Cannot read properties of undefined (reading 'id')");
});
});

View File

@@ -1,7 +1,6 @@
const { getGoogleConfig, isEnabled } = require('@librechat/api');
const { EModelEndpoint, AuthKeys } = require('librechat-data-provider');
const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService');
const { getLLMConfig } = require('~/server/services/Endpoints/google/llm');
const { isEnabled } = require('~/server/utils');
const { GoogleClient } = require('~/app');
const initializeClient = async ({ req, res, endpointOption, overrideModel, optionsOnly }) => {
@@ -18,7 +17,7 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio
let serviceKey = {};
try {
serviceKey = require('~/data/auth.json');
} catch (e) {
} catch (_e) {
// Do nothing
}
@@ -58,14 +57,14 @@ const initializeClient = async ({ req, res, endpointOption, overrideModel, optio
if (optionsOnly) {
clientOptions = Object.assign(
{
modelOptions: endpointOption.model_parameters,
modelOptions: endpointOption?.model_parameters ?? {},
},
clientOptions,
);
if (overrideModel) {
clientOptions.modelOptions.model = overrideModel;
}
return getLLMConfig(credentials, clientOptions);
return getGoogleConfig(credentials, clientOptions);
}
const client = new GoogleClient(credentials, clientOptions);

View File

@@ -1,41 +0,0 @@
const { removeNullishValues } = require('librechat-data-provider');
const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts');
const buildOptions = (endpoint, parsedBody) => {
const {
modelLabel,
chatGptLabel,
promptPrefix,
agentOptions,
tools = [],
iconURL,
greeting,
spec,
maxContextTokens,
artifacts,
...modelOptions
} = parsedBody;
const endpointOption = removeNullishValues({
endpoint,
tools: tools
.map((tool) => tool?.pluginKey ?? tool)
.filter((toolName) => typeof toolName === 'string'),
modelLabel,
chatGptLabel,
promptPrefix,
agentOptions,
iconURL,
greeting,
spec,
maxContextTokens,
modelOptions,
});
if (typeof artifacts === 'string') {
endpointOption.artifactsPrompt = generateArtifactsPrompt({ endpoint, artifacts });
}
return endpointOption;
};
module.exports = buildOptions;

View File

@@ -1,7 +0,0 @@
const buildOptions = require('./build');
const initializeClient = require('./initialize');
module.exports = {
buildOptions,
initializeClient,
};

View File

@@ -1,134 +0,0 @@
const {
EModelEndpoint,
resolveHeaders,
mapModelToAzureConfig,
} = require('librechat-data-provider');
const { isEnabled, isUserProvided, getAzureCredentials } = require('@librechat/api');
const { getUserKeyValues, checkUserKeyExpiry } = require('~/server/services/UserService');
const { PluginsClient } = require('~/app');
const initializeClient = async ({ req, res, endpointOption }) => {
const {
PROXY,
OPENAI_API_KEY,
AZURE_API_KEY,
PLUGINS_USE_AZURE,
OPENAI_REVERSE_PROXY,
AZURE_OPENAI_BASEURL,
OPENAI_SUMMARIZE,
DEBUG_PLUGINS,
} = process.env;
const { key: expiresAt, model: modelName } = req.body;
const contextStrategy = isEnabled(OPENAI_SUMMARIZE) ? 'summarize' : null;
let useAzure = isEnabled(PLUGINS_USE_AZURE);
let endpoint = useAzure ? EModelEndpoint.azureOpenAI : EModelEndpoint.openAI;
/** @type {false | TAzureConfig} */
const azureConfig = req.app.locals[EModelEndpoint.azureOpenAI];
useAzure = useAzure || azureConfig?.plugins;
if (useAzure && endpoint !== EModelEndpoint.azureOpenAI) {
endpoint = EModelEndpoint.azureOpenAI;
}
const credentials = {
[EModelEndpoint.openAI]: OPENAI_API_KEY,
[EModelEndpoint.azureOpenAI]: AZURE_API_KEY,
};
const baseURLOptions = {
[EModelEndpoint.openAI]: OPENAI_REVERSE_PROXY,
[EModelEndpoint.azureOpenAI]: AZURE_OPENAI_BASEURL,
};
const userProvidesKey = isUserProvided(credentials[endpoint]);
const userProvidesURL = isUserProvided(baseURLOptions[endpoint]);
let userValues = null;
if (expiresAt && (userProvidesKey || userProvidesURL)) {
checkUserKeyExpiry(expiresAt, endpoint);
userValues = await getUserKeyValues({ userId: req.user.id, name: endpoint });
}
let apiKey = userProvidesKey ? userValues?.apiKey : credentials[endpoint];
let baseURL = userProvidesURL ? userValues?.baseURL : baseURLOptions[endpoint];
const clientOptions = {
contextStrategy,
debug: isEnabled(DEBUG_PLUGINS),
reverseProxyUrl: baseURL ? baseURL : null,
proxy: PROXY ?? null,
req,
res,
...endpointOption,
};
if (useAzure && azureConfig) {
const { modelGroupMap, groupMap } = azureConfig;
const {
azureOptions,
baseURL,
headers = {},
serverless,
} = mapModelToAzureConfig({
modelName,
modelGroupMap,
groupMap,
});
clientOptions.reverseProxyUrl = baseURL ?? clientOptions.reverseProxyUrl;
clientOptions.headers = resolveHeaders({ ...headers, ...(clientOptions.headers ?? {}) });
clientOptions.titleConvo = azureConfig.titleConvo;
clientOptions.titleModel = azureConfig.titleModel;
clientOptions.titleMethod = azureConfig.titleMethod ?? 'completion';
const azureRate = modelName.includes('gpt-4') ? 30 : 17;
clientOptions.streamRate = azureConfig.streamRate ?? azureRate;
const groupName = modelGroupMap[modelName].group;
clientOptions.addParams = azureConfig.groupMap[groupName].addParams;
clientOptions.dropParams = azureConfig.groupMap[groupName].dropParams;
clientOptions.forcePrompt = azureConfig.groupMap[groupName].forcePrompt;
apiKey = azureOptions.azureOpenAIApiKey;
clientOptions.azure = !serverless && azureOptions;
if (serverless === true) {
clientOptions.defaultQuery = azureOptions.azureOpenAIApiVersion
? { 'api-version': azureOptions.azureOpenAIApiVersion }
: undefined;
clientOptions.headers['api-key'] = apiKey;
}
} else if (useAzure || (apiKey && apiKey.includes('{"azure') && !clientOptions.azure)) {
clientOptions.azure = userProvidesKey ? JSON.parse(userValues.apiKey) : getAzureCredentials();
apiKey = clientOptions.azure.azureOpenAIApiKey;
}
/** @type {undefined | TBaseEndpoint} */
const pluginsConfig = req.app.locals[EModelEndpoint.gptPlugins];
if (!useAzure && pluginsConfig) {
clientOptions.streamRate = pluginsConfig.streamRate;
}
/** @type {undefined | TBaseEndpoint} */
const allConfig = req.app.locals.all;
if (allConfig) {
clientOptions.streamRate = allConfig.streamRate;
}
if (!apiKey) {
throw new Error(`${endpoint} API key not provided. Please provide it again.`);
}
const client = new PluginsClient(apiKey, clientOptions);
return {
client,
azure: clientOptions.azure,
openAIApiKey: apiKey,
};
};
module.exports = initializeClient;

View File

@@ -1,410 +0,0 @@
// gptPlugins/initializeClient.spec.js
jest.mock('~/cache/getLogStores');
const { EModelEndpoint, ErrorTypes, validateAzureGroups } = require('librechat-data-provider');
const { getUserKey, getUserKeyValues } = require('~/server/services/UserService');
const initializeClient = require('./initialize');
const { PluginsClient } = require('~/app');
// Mock getUserKey since it's the only function we want to mock
jest.mock('~/server/services/UserService', () => ({
getUserKey: jest.fn(),
getUserKeyValues: jest.fn(),
checkUserKeyExpiry: jest.requireActual('~/server/services/UserService').checkUserKeyExpiry,
}));
describe('gptPlugins/initializeClient', () => {
// Set up environment variables
const originalEnvironment = process.env;
const app = {
locals: {},
};
const validAzureConfigs = [
{
group: 'librechat-westus',
apiKey: 'WESTUS_API_KEY',
instanceName: 'librechat-westus',
version: '2023-12-01-preview',
models: {
'gpt-4-vision-preview': {
deploymentName: 'gpt-4-vision-preview',
version: '2024-02-15-preview',
},
'gpt-3.5-turbo': {
deploymentName: 'gpt-35-turbo',
},
'gpt-3.5-turbo-1106': {
deploymentName: 'gpt-35-turbo-1106',
},
'gpt-4': {
deploymentName: 'gpt-4',
},
'gpt-4-1106-preview': {
deploymentName: 'gpt-4-1106-preview',
},
},
},
{
group: 'librechat-eastus',
apiKey: 'EASTUS_API_KEY',
instanceName: 'librechat-eastus',
deploymentName: 'gpt-4-turbo',
version: '2024-02-15-preview',
models: {
'gpt-4-turbo': true,
},
baseURL: 'https://eastus.example.com',
additionalHeaders: {
'x-api-key': 'x-api-key-value',
},
},
{
group: 'mistral-inference',
apiKey: 'AZURE_MISTRAL_API_KEY',
baseURL:
'https://Mistral-large-vnpet-serverless.region.inference.ai.azure.com/v1/chat/completions',
serverless: true,
models: {
'mistral-large': true,
},
},
{
group: 'llama-70b-chat',
apiKey: 'AZURE_LLAMA2_70B_API_KEY',
baseURL:
'https://Llama-2-70b-chat-qmvyb-serverless.region.inference.ai.azure.com/v1/chat/completions',
serverless: true,
models: {
'llama-70b-chat': true,
},
},
];
const { modelNames, modelGroupMap, groupMap } = validateAzureGroups(validAzureConfigs);
beforeEach(() => {
jest.resetModules(); // Clears the cache
process.env = { ...originalEnvironment }; // Make a copy
});
afterAll(() => {
process.env = originalEnvironment; // Restore original env vars
});
test('should initialize PluginsClient with OpenAI API key and default options', async () => {
process.env.OPENAI_API_KEY = 'test-openai-api-key';
process.env.PLUGINS_USE_AZURE = 'false';
process.env.DEBUG_PLUGINS = 'false';
process.env.OPENAI_SUMMARIZE = 'false';
const req = {
body: { key: null },
user: { id: '123' },
app,
};
const res = {};
const endpointOption = { modelOptions: { model: 'default-model' } };
const { client, openAIApiKey } = await initializeClient({ req, res, endpointOption });
expect(openAIApiKey).toBe('test-openai-api-key');
expect(client).toBeInstanceOf(PluginsClient);
});
test('should initialize PluginsClient with Azure credentials when PLUGINS_USE_AZURE is true', async () => {
process.env.AZURE_API_KEY = 'test-azure-api-key';
(process.env.AZURE_OPENAI_API_INSTANCE_NAME = 'some-value'),
(process.env.AZURE_OPENAI_API_DEPLOYMENT_NAME = 'some-value'),
(process.env.AZURE_OPENAI_API_VERSION = 'some-value'),
(process.env.AZURE_OPENAI_API_COMPLETIONS_DEPLOYMENT_NAME = 'some-value'),
(process.env.AZURE_OPENAI_API_EMBEDDINGS_DEPLOYMENT_NAME = 'some-value'),
(process.env.PLUGINS_USE_AZURE = 'true');
process.env.DEBUG_PLUGINS = 'false';
process.env.OPENAI_SUMMARIZE = 'false';
const req = {
body: { key: null },
user: { id: '123' },
app,
};
const res = {};
const endpointOption = { modelOptions: { model: 'test-model' } };
const { client, azure } = await initializeClient({ req, res, endpointOption });
expect(azure.azureOpenAIApiKey).toBe('test-azure-api-key');
expect(client).toBeInstanceOf(PluginsClient);
});
test('should use the debug option when DEBUG_PLUGINS is enabled', async () => {
process.env.OPENAI_API_KEY = 'test-openai-api-key';
process.env.DEBUG_PLUGINS = 'true';
const req = {
body: { key: null },
user: { id: '123' },
app,
};
const res = {};
const endpointOption = { modelOptions: { model: 'default-model' } };
const { client } = await initializeClient({ req, res, endpointOption });
expect(client.options.debug).toBe(true);
});
test('should set contextStrategy to summarize when OPENAI_SUMMARIZE is enabled', async () => {
process.env.OPENAI_API_KEY = 'test-openai-api-key';
process.env.OPENAI_SUMMARIZE = 'true';
const req = {
body: { key: null },
user: { id: '123' },
app,
};
const res = {};
const endpointOption = { modelOptions: { model: 'default-model' } };
const { client } = await initializeClient({ req, res, endpointOption });
expect(client.options.contextStrategy).toBe('summarize');
});
// ... additional tests for reverseProxyUrl, proxy, user-provided keys, etc.
test('should throw an error if no API keys are provided in the environment', async () => {
// Clear the environment variables for API keys
delete process.env.OPENAI_API_KEY;
delete process.env.AZURE_API_KEY;
const req = {
body: { key: null },
user: { id: '123' },
app,
};
const res = {};
const endpointOption = { modelOptions: { model: 'default-model' } };
await expect(initializeClient({ req, res, endpointOption })).rejects.toThrow(
`${EModelEndpoint.openAI} API key not provided.`,
);
});
// Additional tests for gptPlugins/initializeClient.spec.js
// ... (previous test setup code)
test('should handle user-provided OpenAI keys and check expiry', async () => {
process.env.OPENAI_API_KEY = 'user_provided';
process.env.PLUGINS_USE_AZURE = 'false';
const futureDate = new Date(Date.now() + 10000).toISOString();
const req = {
body: { key: futureDate },
user: { id: '123' },
app,
};
const res = {};
const endpointOption = { modelOptions: { model: 'default-model' } };
getUserKeyValues.mockResolvedValue({ apiKey: 'test-user-provided-openai-api-key' });
const { openAIApiKey } = await initializeClient({ req, res, endpointOption });
expect(openAIApiKey).toBe('test-user-provided-openai-api-key');
});
test('should handle user-provided Azure keys and check expiry', async () => {
process.env.AZURE_API_KEY = 'user_provided';
process.env.PLUGINS_USE_AZURE = 'true';
const futureDate = new Date(Date.now() + 10000).toISOString();
const req = {
body: { key: futureDate },
user: { id: '123' },
app,
};
const res = {};
const endpointOption = { modelOptions: { model: 'test-model' } };
getUserKeyValues.mockResolvedValue({
apiKey: JSON.stringify({
azureOpenAIApiKey: 'test-user-provided-azure-api-key',
azureOpenAIApiDeploymentName: 'test-deployment',
}),
});
const { azure } = await initializeClient({ req, res, endpointOption });
expect(azure.azureOpenAIApiKey).toBe('test-user-provided-azure-api-key');
});
test('should throw an error if the user-provided key has expired', async () => {
process.env.OPENAI_API_KEY = 'user_provided';
process.env.PLUGINS_USE_AZURE = 'FALSE';
const expiresAt = new Date(Date.now() - 10000).toISOString(); // Expired
const req = {
body: { key: expiresAt },
user: { id: '123' },
app,
};
const res = {};
const endpointOption = { modelOptions: { model: 'default-model' } };
await expect(initializeClient({ req, res, endpointOption })).rejects.toThrow(
/expired_user_key/,
);
});
test('should throw an error if the user-provided Azure key is invalid JSON', async () => {
process.env.AZURE_API_KEY = 'user_provided';
process.env.PLUGINS_USE_AZURE = 'true';
const req = {
body: { key: new Date(Date.now() + 10000).toISOString() },
user: { id: '123' },
app,
};
const res = {};
const endpointOption = { modelOptions: { model: 'default-model' } };
// Simulate an invalid JSON string returned from getUserKey
getUserKey.mockResolvedValue('invalid-json');
getUserKeyValues.mockImplementation(() => {
let userValues = getUserKey();
try {
userValues = JSON.parse(userValues);
} catch (e) {
throw new Error(
JSON.stringify({
type: ErrorTypes.INVALID_USER_KEY,
}),
);
}
return userValues;
});
await expect(initializeClient({ req, res, endpointOption })).rejects.toThrow(
/invalid_user_key/,
);
});
test('should correctly handle the presence of a reverse proxy', async () => {
process.env.OPENAI_REVERSE_PROXY = 'http://reverse.proxy';
process.env.PROXY = 'http://proxy';
process.env.OPENAI_API_KEY = 'test-openai-api-key';
const req = {
body: { key: null },
user: { id: '123' },
app,
};
const res = {};
const endpointOption = { modelOptions: { model: 'default-model' } };
const { client } = await initializeClient({ req, res, endpointOption });
expect(client.options.reverseProxyUrl).toBe('http://reverse.proxy');
expect(client.options.proxy).toBe('http://proxy');
});
test('should throw an error when user-provided values are not valid JSON', async () => {
process.env.OPENAI_API_KEY = 'user_provided';
const req = {
body: { key: new Date(Date.now() + 10000).toISOString(), endpoint: 'openAI' },
user: { id: '123' },
app,
};
const res = {};
const endpointOption = {};
// Mock getUserKey to return a non-JSON string
getUserKey.mockResolvedValue('not-a-json');
getUserKeyValues.mockImplementation(() => {
let userValues = getUserKey();
try {
userValues = JSON.parse(userValues);
} catch (e) {
throw new Error(
JSON.stringify({
type: ErrorTypes.INVALID_USER_KEY,
}),
);
}
return userValues;
});
await expect(initializeClient({ req, res, endpointOption })).rejects.toThrow(
/invalid_user_key/,
);
});
test('should initialize client correctly for Azure OpenAI with valid configuration', async () => {
const req = {
body: {
key: null,
endpoint: EModelEndpoint.gptPlugins,
model: modelNames[0],
},
user: { id: '123' },
app: {
locals: {
[EModelEndpoint.azureOpenAI]: {
plugins: true,
modelNames,
modelGroupMap,
groupMap,
},
},
},
};
const res = {};
const endpointOption = {};
const client = await initializeClient({ req, res, endpointOption });
expect(client.client.options.azure).toBeDefined();
});
test('should initialize client with default options when certain env vars are not set', async () => {
delete process.env.OPENAI_SUMMARIZE;
process.env.OPENAI_API_KEY = 'some-api-key';
const req = {
body: { key: null, endpoint: EModelEndpoint.gptPlugins },
user: { id: '123' },
app,
};
const res = {};
const endpointOption = {};
const client = await initializeClient({ req, res, endpointOption });
expect(client.client.options.contextStrategy).toBe(null);
});
test('should correctly use user-provided apiKey and baseURL when provided', async () => {
process.env.OPENAI_API_KEY = 'user_provided';
process.env.OPENAI_REVERSE_PROXY = 'user_provided';
const req = {
body: {
key: new Date(Date.now() + 10000).toISOString(),
endpoint: 'openAI',
},
user: {
id: '123',
},
app,
};
const res = {};
const endpointOption = {};
getUserKeyValues.mockResolvedValue({
apiKey: 'test',
baseURL: 'https://user-provided-url.com',
});
const result = await initializeClient({ req, res, endpointOption });
expect(result.openAIApiKey).toBe('test');
expect(result.client.options.reverseProxyUrl).toBe('https://user-provided-url.com');
});
});

View File

@@ -0,0 +1,58 @@
const { Providers } = require('@librechat/agents');
const { EModelEndpoint } = require('librechat-data-provider');
const initAnthropic = require('~/server/services/Endpoints/anthropic/initialize');
const getBedrockOptions = require('~/server/services/Endpoints/bedrock/options');
const initOpenAI = require('~/server/services/Endpoints/openAI/initialize');
const initCustom = require('~/server/services/Endpoints/custom/initialize');
const initGoogle = require('~/server/services/Endpoints/google/initialize');
const { getCustomEndpointConfig } = require('~/server/services/Config');
const providerConfigMap = {
[Providers.XAI]: initCustom,
[Providers.OLLAMA]: initCustom,
[Providers.DEEPSEEK]: initCustom,
[Providers.OPENROUTER]: initCustom,
[EModelEndpoint.openAI]: initOpenAI,
[EModelEndpoint.google]: initGoogle,
[EModelEndpoint.azureOpenAI]: initOpenAI,
[EModelEndpoint.anthropic]: initAnthropic,
[EModelEndpoint.bedrock]: getBedrockOptions,
};
/**
* Get the provider configuration and override endpoint based on the provider string
* @param {string} provider - The provider string
* @returns {Promise<{
* getOptions: Function,
* overrideProvider?: string,
* customEndpointConfig?: TEndpoint
* }>}
*/
async function getProviderConfig(provider) {
let getOptions = providerConfigMap[provider];
let overrideProvider;
/** @type {TEndpoint | undefined} */
let customEndpointConfig;
if (!getOptions && providerConfigMap[provider.toLowerCase()] != null) {
overrideProvider = provider.toLowerCase();
getOptions = providerConfigMap[overrideProvider];
} else if (!getOptions) {
customEndpointConfig = await getCustomEndpointConfig(provider);
if (!customEndpointConfig) {
throw new Error(`Provider ${provider} not supported`);
}
getOptions = initCustom;
overrideProvider = Providers.OPENAI;
}
return {
getOptions,
overrideProvider,
customEndpointConfig,
};
}
module.exports = {
getProviderConfig,
};

View File

@@ -1,11 +1,7 @@
const {
ErrorTypes,
EModelEndpoint,
resolveHeaders,
mapModelToAzureConfig,
} = require('librechat-data-provider');
const { ErrorTypes, EModelEndpoint, mapModelToAzureConfig } = require('librechat-data-provider');
const {
isEnabled,
resolveHeaders,
isUserProvided,
getOpenAIConfig,
getAzureCredentials,
@@ -84,7 +80,10 @@ const initializeClient = async ({
});
clientOptions.reverseProxyUrl = baseURL ?? clientOptions.reverseProxyUrl;
clientOptions.headers = resolveHeaders({ ...headers, ...(clientOptions.headers ?? {}) });
clientOptions.headers = resolveHeaders(
{ ...headers, ...(clientOptions.headers ?? {}) },
req.user,
);
clientOptions.titleConvo = azureConfig.titleConvo;
clientOptions.titleModel = azureConfig.titleModel;
@@ -139,7 +138,7 @@ const initializeClient = async ({
}
if (optionsOnly) {
const modelOptions = endpointOption.model_parameters;
const modelOptions = endpointOption?.model_parameters ?? {};
modelOptions.model = modelName;
clientOptions = Object.assign({ modelOptions }, clientOptions);
clientOptions.modelOptions.user = req.user.id;

View File

@@ -1,10 +1,11 @@
const fs = require('fs');
const path = require('path');
const axios = require('axios');
const { logger } = require('@librechat/data-schemas');
const { EModelEndpoint } = require('librechat-data-provider');
const { generateShortLivedToken } = require('~/server/services/AuthService');
const { getBufferMetadata } = require('~/server/utils');
const paths = require('~/config/paths');
const { logger } = require('~/config');
/**
* Saves a file to a specified output path with a new filename.
@@ -206,7 +207,7 @@ const deleteLocalFile = async (req, file) => {
const cleanFilepath = file.filepath.split('?')[0];
if (file.embedded && process.env.RAG_API_URL) {
const jwtToken = req.headers.authorization.split(' ')[1];
const jwtToken = generateShortLivedToken(req.user.id);
axios.delete(`${process.env.RAG_API_URL}/documents`, {
headers: {
Authorization: `Bearer ${jwtToken}`,

View File

@@ -4,6 +4,7 @@ const FormData = require('form-data');
const { logAxiosError } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { FileSources } = require('librechat-data-provider');
const { generateShortLivedToken } = require('~/server/services/AuthService');
/**
* Deletes a file from the vector database. This function takes a file object, constructs the full path, and
@@ -23,7 +24,8 @@ const deleteVectors = async (req, file) => {
return;
}
try {
const jwtToken = req.headers.authorization.split(' ')[1];
const jwtToken = generateShortLivedToken(req.user.id);
return await axios.delete(`${process.env.RAG_API_URL}/documents`, {
headers: {
Authorization: `Bearer ${jwtToken}`,
@@ -70,7 +72,7 @@ async function uploadVectors({ req, file, file_id, entity_id }) {
}
try {
const jwtToken = req.headers.authorization.split(' ')[1];
const jwtToken = generateShortLivedToken(req.user.id);
const formData = new FormData();
formData.append('file_id', file_id);
formData.append('file', fs.createReadStream(file.path));

View File

@@ -55,7 +55,9 @@ const processFiles = async (files, fileIds) => {
}
if (!fileIds) {
return await Promise.all(promises);
const results = await Promise.all(promises);
// Filter out null results from failed updateFileUsage calls
return results.filter((result) => result != null);
}
for (let file_id of fileIds) {
@@ -67,7 +69,9 @@ const processFiles = async (files, fileIds) => {
}
// TODO: calculate token cost when image is first uploaded
return await Promise.all(promises);
const results = await Promise.all(promises);
// Filter out null results from failed updateFileUsage calls
return results.filter((result) => result != null);
};
/**

View File

@@ -0,0 +1,208 @@
// Mock the updateFileUsage function before importing the actual processFiles
jest.mock('~/models/File', () => ({
updateFileUsage: jest.fn(),
}));
// Mock winston and logger configuration to avoid dependency issues
jest.mock('~/config', () => ({
logger: {
info: jest.fn(),
warn: jest.fn(),
debug: jest.fn(),
error: jest.fn(),
},
}));
// Mock all other dependencies that might cause issues
jest.mock('librechat-data-provider', () => ({
isUUID: { parse: jest.fn() },
megabyte: 1024 * 1024,
FileContext: { message_attachment: 'message_attachment' },
FileSources: { local: 'local' },
EModelEndpoint: { assistants: 'assistants' },
EToolResources: { file_search: 'file_search' },
mergeFileConfig: jest.fn(),
removeNullishValues: jest.fn((obj) => obj),
isAssistantsEndpoint: jest.fn(),
}));
jest.mock('~/server/services/Files/images', () => ({
convertImage: jest.fn(),
resizeAndConvert: jest.fn(),
resizeImageBuffer: jest.fn(),
}));
jest.mock('~/server/controllers/assistants/v2', () => ({
addResourceFileId: jest.fn(),
deleteResourceFileId: jest.fn(),
}));
jest.mock('~/models/Agent', () => ({
addAgentResourceFile: jest.fn(),
removeAgentResourceFiles: jest.fn(),
}));
jest.mock('~/server/controllers/assistants/helpers', () => ({
getOpenAIClient: jest.fn(),
}));
jest.mock('~/server/services/Tools/credentials', () => ({
loadAuthValues: jest.fn(),
}));
jest.mock('~/server/services/Config', () => ({
checkCapability: jest.fn(),
}));
jest.mock('~/server/utils/queue', () => ({
LB_QueueAsyncCall: jest.fn(),
}));
jest.mock('./strategies', () => ({
getStrategyFunctions: jest.fn(),
}));
jest.mock('~/server/utils', () => ({
determineFileType: jest.fn(),
}));
// Import the actual processFiles function after all mocks are set up
const { processFiles } = require('./process');
const { updateFileUsage } = require('~/models/File');
describe('processFiles', () => {
beforeEach(() => {
jest.clearAllMocks();
});
describe('null filtering functionality', () => {
it('should filter out null results from updateFileUsage when files do not exist', async () => {
const mockFiles = [
{ file_id: 'existing-file-1' },
{ file_id: 'non-existent-file' },
{ file_id: 'existing-file-2' },
];
// Mock updateFileUsage to return null for non-existent files
updateFileUsage.mockImplementation(({ file_id }) => {
if (file_id === 'non-existent-file') {
return Promise.resolve(null); // Simulate file not found in the database
}
return Promise.resolve({ file_id, usage: 1 });
});
const result = await processFiles(mockFiles);
expect(updateFileUsage).toHaveBeenCalledTimes(3);
expect(result).toEqual([
{ file_id: 'existing-file-1', usage: 1 },
{ file_id: 'existing-file-2', usage: 1 },
]);
// Critical test - ensure no null values in result
expect(result).not.toContain(null);
expect(result).not.toContain(undefined);
expect(result.length).toBe(2); // Only valid files should be returned
});
it('should return empty array when all updateFileUsage calls return null', async () => {
const mockFiles = [{ file_id: 'non-existent-1' }, { file_id: 'non-existent-2' }];
// All updateFileUsage calls return null
updateFileUsage.mockResolvedValue(null);
const result = await processFiles(mockFiles);
expect(updateFileUsage).toHaveBeenCalledTimes(2);
expect(result).toEqual([]);
expect(result).not.toContain(null);
expect(result.length).toBe(0);
});
it('should work correctly when all files exist', async () => {
const mockFiles = [{ file_id: 'file-1' }, { file_id: 'file-2' }];
updateFileUsage.mockImplementation(({ file_id }) => {
return Promise.resolve({ file_id, usage: 1 });
});
const result = await processFiles(mockFiles);
expect(result).toEqual([
{ file_id: 'file-1', usage: 1 },
{ file_id: 'file-2', usage: 1 },
]);
expect(result).not.toContain(null);
expect(result.length).toBe(2);
});
it('should handle fileIds parameter and filter nulls correctly', async () => {
const mockFiles = [{ file_id: 'file-1' }];
const mockFileIds = ['file-2', 'non-existent-file'];
updateFileUsage.mockImplementation(({ file_id }) => {
if (file_id === 'non-existent-file') {
return Promise.resolve(null);
}
return Promise.resolve({ file_id, usage: 1 });
});
const result = await processFiles(mockFiles, mockFileIds);
expect(result).toEqual([
{ file_id: 'file-1', usage: 1 },
{ file_id: 'file-2', usage: 1 },
]);
expect(result).not.toContain(null);
expect(result).not.toContain(undefined);
expect(result.length).toBe(2);
});
it('should handle duplicate file_ids correctly', async () => {
const mockFiles = [
{ file_id: 'duplicate-file' },
{ file_id: 'duplicate-file' }, // Duplicate should be ignored
{ file_id: 'unique-file' },
];
updateFileUsage.mockImplementation(({ file_id }) => {
return Promise.resolve({ file_id, usage: 1 });
});
const result = await processFiles(mockFiles);
// Should only call updateFileUsage twice (duplicate ignored)
expect(updateFileUsage).toHaveBeenCalledTimes(2);
expect(result).toEqual([
{ file_id: 'duplicate-file', usage: 1 },
{ file_id: 'unique-file', usage: 1 },
]);
expect(result.length).toBe(2);
});
});
describe('edge cases', () => {
it('should handle empty files array', async () => {
const result = await processFiles([]);
expect(result).toEqual([]);
expect(updateFileUsage).not.toHaveBeenCalled();
});
it('should handle mixed null and undefined returns from updateFileUsage', async () => {
const mockFiles = [{ file_id: 'file-1' }, { file_id: 'file-2' }, { file_id: 'file-3' }];
updateFileUsage.mockImplementation(({ file_id }) => {
if (file_id === 'file-1') return Promise.resolve(null);
if (file_id === 'file-2') return Promise.resolve(undefined);
return Promise.resolve({ file_id, usage: 1 });
});
const result = await processFiles(mockFiles);
expect(result).toEqual([{ file_id: 'file-3', usage: 1 }]);
expect(result).not.toContain(null);
expect(result).not.toContain(undefined);
expect(result.length).toBe(1);
});
});
});

View File

@@ -0,0 +1,525 @@
const client = require('openid-client');
const { isEnabled } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const { CacheKeys } = require('librechat-data-provider');
const { Client } = require('@microsoft/microsoft-graph-client');
const { getOpenIdConfig } = require('~/strategies/openidStrategy');
const getLogStores = require('~/cache/getLogStores');
/**
* @import { TPrincipalSearchResult, TGraphPerson, TGraphUser, TGraphGroup, TGraphPeopleResponse, TGraphUsersResponse, TGraphGroupsResponse } from 'librechat-data-provider'
*/
/**
* Checks if Entra ID principal search feature is enabled based on environment variables and user authentication
* @param {Object} user - User object from request
* @param {string} user.provider - Authentication provider
* @param {string} user.openidId - OpenID subject identifier
* @returns {boolean} True if Entra ID principal search is enabled and user is authenticated via OpenID
*/
const entraIdPrincipalFeatureEnabled = (user) => {
return (
isEnabled(process.env.USE_ENTRA_ID_FOR_PEOPLE_SEARCH) &&
isEnabled(process.env.OPENID_REUSE_TOKENS) &&
user?.provider === 'openid' &&
user?.openidId
);
};
/**
* Creates a Microsoft Graph client with on-behalf-of token exchange
* @param {string} accessToken - OpenID Connect access token from user
* @param {string} sub - Subject identifier from token claims
* @returns {Promise<Client>} Authenticated Graph API client
*/
const createGraphClient = async (accessToken, sub) => {
try {
// Reason: Use existing OpenID configuration and token exchange pattern from openidStrategy.js
const openidConfig = getOpenIdConfig();
const exchangedToken = await exchangeTokenForGraphAccess(openidConfig, accessToken, sub);
const graphClient = Client.init({
authProvider: (done) => {
done(null, exchangedToken);
},
});
return graphClient;
} catch (error) {
logger.error('[createGraphClient] Error creating Graph client:', error);
throw error;
}
};
/**
* Exchange OpenID token for Graph API access using on-behalf-of flow
* Similar to exchangeAccessTokenIfNeeded in openidStrategy.js but for Graph scopes
* @param {Configuration} config - OpenID configuration
* @param {string} accessToken - Original access token
* @param {string} sub - Subject identifier
* @returns {Promise<string>} Graph API access token
*/
const exchangeTokenForGraphAccess = async (config, accessToken, sub) => {
try {
const tokensCache = getLogStores(CacheKeys.OPENID_EXCHANGED_TOKENS);
const cacheKey = `${sub}:graph`;
const cachedToken = await tokensCache.get(cacheKey);
if (cachedToken) {
return cachedToken.access_token;
}
const graphScopes = process.env.OPENID_GRAPH_SCOPES || 'User.Read,People.Read,Group.Read.All';
const scopeString = graphScopes
.split(',')
.map((scope) => `https://graph.microsoft.com/${scope}`)
.join(' ');
const grantResponse = await client.genericGrantRequest(
config,
'urn:ietf:params:oauth:grant-type:jwt-bearer',
{
scope: scopeString,
assertion: accessToken,
requested_token_use: 'on_behalf_of',
},
);
await tokensCache.set(
cacheKey,
{
access_token: grantResponse.access_token,
},
grantResponse.expires_in * 1000,
);
return grantResponse.access_token;
} catch (error) {
logger.error('[exchangeTokenForGraphAccess] Token exchange failed:', error);
throw error;
}
};
/**
* Search for principals (people and groups) using Microsoft Graph API
* Uses searchContacts first, then searchUsers and searchGroups to fill remaining slots
* @param {string} accessToken - OpenID Connect access token
* @param {string} sub - Subject identifier
* @param {string} query - Search query string
* @param {string} type - Type filter ('users', 'groups', or 'all')
* @param {number} limit - Maximum number of results
* @returns {Promise<TPrincipalSearchResult[]>} Array of principal search results
*/
const searchEntraIdPrincipals = async (accessToken, sub, query, type = 'all', limit = 10) => {
try {
if (!query || query.trim().length < 2) {
return [];
}
const graphClient = await createGraphClient(accessToken, sub);
let allResults = [];
if (type === 'users' || type === 'all') {
const contactResults = await searchContacts(graphClient, query, limit);
allResults.push(...contactResults);
}
if (allResults.length >= limit) {
return allResults.slice(0, limit);
}
if (type === 'users') {
const userResults = await searchUsers(graphClient, query, limit);
allResults.push(...userResults);
} else if (type === 'groups') {
const groupResults = await searchGroups(graphClient, query, limit);
allResults.push(...groupResults);
} else if (type === 'all') {
const [userResults, groupResults] = await Promise.all([
searchUsers(graphClient, query, limit),
searchGroups(graphClient, query, limit),
]);
allResults.push(...userResults, ...groupResults);
}
const seenIds = new Set();
const uniqueResults = allResults.filter((result) => {
if (seenIds.has(result.idOnTheSource)) {
return false;
}
seenIds.add(result.idOnTheSource);
return true;
});
return uniqueResults.slice(0, limit);
} catch (error) {
logger.error('[searchEntraIdPrincipals] Error searching principals:', error);
return [];
}
};
/**
* Get current user's Entra ID group memberships from Microsoft Graph
* Uses /me/memberOf endpoint to get groups the user is a member of
* @param {string} accessToken - OpenID Connect access token
* @param {string} sub - Subject identifier
* @returns {Promise<Array<string>>} Array of group ID strings (GUIDs)
*/
const getUserEntraGroups = async (accessToken, sub) => {
try {
const graphClient = await createGraphClient(accessToken, sub);
const groupsResponse = await graphClient.api('/me/memberOf').select('id').get();
return (groupsResponse.value || []).map((group) => group.id);
} catch (error) {
logger.error('[getUserEntraGroups] Error fetching user groups:', error);
return [];
}
};
/**
* Get current user's owned Entra ID groups from Microsoft Graph
* Uses /me/ownedObjects/microsoft.graph.group endpoint to get groups the user owns
* @param {string} accessToken - OpenID Connect access token
* @param {string} sub - Subject identifier
* @returns {Promise<Array<string>>} Array of group ID strings (GUIDs)
*/
const getUserOwnedEntraGroups = async (accessToken, sub) => {
try {
const graphClient = await createGraphClient(accessToken, sub);
const groupsResponse = await graphClient
.api('/me/ownedObjects/microsoft.graph.group')
.select('id')
.get();
return (groupsResponse.value || []).map((group) => group.id);
} catch (error) {
logger.error('[getUserOwnedEntraGroups] Error fetching user owned groups:', error);
return [];
}
};
/**
* Get group members from Microsoft Graph API
* Recursively fetches all members using pagination (@odata.nextLink)
* @param {string} accessToken - OpenID Connect access token
* @param {string} sub - Subject identifier
* @param {string} groupId - Entra ID group object ID
* @returns {Promise<Array>} Array of member IDs (idOnTheSource values)
*/
const getGroupMembers = async (accessToken, sub, groupId) => {
try {
const graphClient = await createGraphClient(accessToken, sub);
const allMembers = [];
let nextLink = `/groups/${groupId}/members`;
while (nextLink) {
const membersResponse = await graphClient.api(nextLink).select('id').top(999).get();
const members = membersResponse.value || [];
allMembers.push(...members.map((member) => member.id));
nextLink = membersResponse['@odata.nextLink']
? membersResponse['@odata.nextLink'].split('/v1.0')[1]
: null;
}
return allMembers;
} catch (error) {
logger.error('[getGroupMembers] Error fetching group members:', error);
return [];
}
};
/**
* Get group owners from Microsoft Graph API
* Recursively fetches all owners using pagination (@odata.nextLink)
* @param {string} accessToken - OpenID Connect access token
* @param {string} sub - Subject identifier
* @param {string} groupId - Entra ID group object ID
* @returns {Promise<Array>} Array of owner IDs (idOnTheSource values)
*/
const getGroupOwners = async (accessToken, sub, groupId) => {
try {
const graphClient = await createGraphClient(accessToken, sub);
const allOwners = [];
let nextLink = `/groups/${groupId}/owners`;
while (nextLink) {
const ownersResponse = await graphClient.api(nextLink).select('id').top(999).get();
const owners = ownersResponse.value || [];
allOwners.push(...owners.map((member) => member.id));
nextLink = ownersResponse['@odata.nextLink']
? ownersResponse['@odata.nextLink'].split('/v1.0')[1]
: null;
}
return allOwners;
} catch (error) {
logger.error('[getGroupOwners] Error fetching group owners:', error);
return [];
}
};
/**
* Search for contacts (users only) using Microsoft Graph /me/people endpoint
* Returns mapped TPrincipalSearchResult objects for users only
* @param {Client} graphClient - Authenticated Microsoft Graph client
* @param {string} query - Search query string
* @param {number} limit - Maximum number of results (default: 10)
* @returns {Promise<TPrincipalSearchResult[]>} Array of mapped user contact results
*/
const searchContacts = async (graphClient, query, limit = 10) => {
try {
if (!query || query.trim().length < 2) {
return [];
}
if (
process.env.OPENID_GRAPH_SCOPES &&
!process.env.OPENID_GRAPH_SCOPES.toLowerCase().includes('people.read')
) {
logger.warn('[searchContacts] People.Read scope is not enabled, skipping contact search');
return [];
}
// Reason: Search only for OrganizationUser (person) type, not groups
const filter = "personType/subclass eq 'OrganizationUser'";
let apiCall = graphClient
.api('/me/people')
.search(`"${query}"`)
.select(
'id,displayName,givenName,surname,userPrincipalName,jobTitle,department,companyName,scoredEmailAddresses,personType,phones',
)
.header('ConsistencyLevel', 'eventual')
.filter(filter)
.top(limit);
const contactsResponse = await apiCall.get();
return (contactsResponse.value || []).map(mapContactToTPrincipalSearchResult);
} catch (error) {
logger.error('[searchContacts] Error searching contacts:', error);
return [];
}
};
/**
* Search for users using Microsoft Graph /users endpoint
* Returns mapped TPrincipalSearchResult objects
* @param {Client} graphClient - Authenticated Microsoft Graph client
* @param {string} query - Search query string
* @param {number} limit - Maximum number of results (default: 10)
* @returns {Promise<TPrincipalSearchResult[]>} Array of mapped user results
*/
const searchUsers = async (graphClient, query, limit = 10) => {
try {
if (!query || query.trim().length < 2) {
return [];
}
// Reason: Search users by display name, email, and user principal name
const usersResponse = await graphClient
.api('/users')
.search(
`"displayName:${query}" OR "userPrincipalName:${query}" OR "mail:${query}" OR "givenName:${query}" OR "surname:${query}"`,
)
.select(
'id,displayName,givenName,surname,userPrincipalName,jobTitle,department,companyName,mail,phones',
)
.header('ConsistencyLevel', 'eventual')
.top(limit)
.get();
return (usersResponse.value || []).map(mapUserToTPrincipalSearchResult);
} catch (error) {
logger.error('[searchUsers] Error searching users:', error);
return [];
}
};
/**
* Search for groups using Microsoft Graph /groups endpoint
* Returns mapped TPrincipalSearchResult objects, includes all group types
* @param {Client} graphClient - Authenticated Microsoft Graph client
* @param {string} query - Search query string
* @param {number} limit - Maximum number of results (default: 10)
* @returns {Promise<TPrincipalSearchResult[]>} Array of mapped group results
*/
const searchGroups = async (graphClient, query, limit = 10) => {
try {
if (!query || query.trim().length < 2) {
return [];
}
// Reason: Search all groups by display name and email without filtering group types
const groupsResponse = await graphClient
.api('/groups')
.search(`"displayName:${query}" OR "mail:${query}" OR "mailNickname:${query}"`)
.select('id,displayName,mail,mailNickname,description,groupTypes,resourceProvisioningOptions')
.header('ConsistencyLevel', 'eventual')
.top(limit)
.get();
return (groupsResponse.value || []).map(mapGroupToTPrincipalSearchResult);
} catch (error) {
logger.error('[searchGroups] Error searching groups:', error);
return [];
}
};
/**
* Test Graph API connectivity and permissions
* @param {string} accessToken - OpenID Connect access token
* @param {string} sub - Subject identifier
* @returns {Promise<Object>} Test results with available permissions
*/
const testGraphApiAccess = async (accessToken, sub) => {
try {
const graphClient = await createGraphClient(accessToken, sub);
const results = {
userAccess: false,
peopleAccess: false,
groupsAccess: false,
usersEndpointAccess: false,
groupsEndpointAccess: false,
errors: [],
};
// Test User.Read permission
try {
await graphClient.api('/me').select('id,displayName').get();
results.userAccess = true;
} catch (error) {
results.errors.push(`User.Read: ${error.message}`);
}
// Test People.Read permission with OrganizationUser filter
try {
await graphClient
.api('/me/people')
.filter("personType/subclass eq 'OrganizationUser'")
.top(1)
.get();
results.peopleAccess = true;
} catch (error) {
results.errors.push(`People.Read (OrganizationUser): ${error.message}`);
}
// Test People.Read permission with UnifiedGroup filter
try {
await graphClient
.api('/me/people')
.filter("personType/subclass eq 'UnifiedGroup'")
.top(1)
.get();
results.groupsAccess = true;
} catch (error) {
results.errors.push(`People.Read (UnifiedGroup): ${error.message}`);
}
// Test /users endpoint access (requires User.Read.All or similar)
try {
await graphClient
.api('/users')
.search('"displayName:test"')
.select('id,displayName,userPrincipalName')
.top(1)
.get();
results.usersEndpointAccess = true;
} catch (error) {
results.errors.push(`Users endpoint: ${error.message}`);
}
// Test /groups endpoint access (requires Group.Read.All or similar)
try {
await graphClient
.api('/groups')
.search('"displayName:test"')
.select('id,displayName,mail')
.top(1)
.get();
results.groupsEndpointAccess = true;
} catch (error) {
results.errors.push(`Groups endpoint: ${error.message}`);
}
return results;
} catch (error) {
logger.error('[testGraphApiAccess] Error testing Graph API access:', error);
return {
userAccess: false,
peopleAccess: false,
groupsAccess: false,
usersEndpointAccess: false,
groupsEndpointAccess: false,
errors: [error.message],
};
}
};
/**
* Map Graph API user object to TPrincipalSearchResult format
* @param {TGraphUser} user - Raw user object from Graph API
* @returns {TPrincipalSearchResult} Mapped user result
*/
const mapUserToTPrincipalSearchResult = (user) => {
return {
id: null,
type: 'user',
name: user.displayName,
email: user.mail || user.userPrincipalName,
username: user.userPrincipalName,
source: 'entra',
idOnTheSource: user.id,
};
};
/**
* Map Graph API group object to TPrincipalSearchResult format
* @param {TGraphGroup} group - Raw group object from Graph API
* @returns {TPrincipalSearchResult} Mapped group result
*/
const mapGroupToTPrincipalSearchResult = (group) => {
return {
id: null,
type: 'group',
name: group.displayName,
email: group.mail || group.userPrincipalName,
description: group.description,
source: 'entra',
idOnTheSource: group.id,
};
};
/**
* Map Graph API /me/people contact object to TPrincipalSearchResult format
* Handles both user and group contacts from the people endpoint
* @param {TGraphPerson} contact - Raw contact object from Graph API /me/people
* @returns {TPrincipalSearchResult} Mapped contact result
*/
const mapContactToTPrincipalSearchResult = (contact) => {
const isGroup = contact.personType?.class === 'Group';
const primaryEmail = contact.scoredEmailAddresses?.[0]?.address;
return {
id: null,
type: isGroup ? 'group' : 'user',
name: contact.displayName,
email: primaryEmail,
username: !isGroup ? contact.userPrincipalName : undefined,
source: 'entra',
idOnTheSource: contact.id,
};
};
module.exports = {
getGroupMembers,
getGroupOwners,
createGraphClient,
getUserEntraGroups,
getUserOwnedEntraGroups,
testGraphApiAccess,
searchEntraIdPrincipals,
exchangeTokenForGraphAccess,
entraIdPrincipalFeatureEnabled,
};

View File

@@ -0,0 +1,720 @@
jest.mock('@microsoft/microsoft-graph-client');
jest.mock('~/strategies/openidStrategy');
jest.mock('~/cache/getLogStores');
jest.mock('@librechat/data-schemas', () => ({
...jest.requireActual('@librechat/data-schemas'),
logger: {
error: jest.fn(),
debug: jest.fn(),
},
}));
jest.mock('~/config', () => ({
logger: {
error: jest.fn(),
debug: jest.fn(),
},
createAxiosInstance: jest.fn(() => ({
create: jest.fn(),
defaults: {},
})),
}));
jest.mock('~/utils', () => ({
logAxiosError: jest.fn(),
}));
jest.mock('~/server/services/Config', () => ({}));
jest.mock('~/server/services/Files/strategies', () => ({
getStrategyFunctions: jest.fn(),
}));
const mongoose = require('mongoose');
const client = require('openid-client');
const { MongoMemoryServer } = require('mongodb-memory-server');
const { Client } = require('@microsoft/microsoft-graph-client');
const { getOpenIdConfig } = require('~/strategies/openidStrategy');
const getLogStores = require('~/cache/getLogStores');
const GraphApiService = require('./GraphApiService');
describe('GraphApiService', () => {
let mongoServer;
let mockGraphClient;
let mockTokensCache;
let mockOpenIdConfig;
beforeAll(async () => {
mongoServer = await MongoMemoryServer.create();
const mongoUri = mongoServer.getUri();
await mongoose.connect(mongoUri);
});
afterAll(async () => {
await mongoose.disconnect();
await mongoServer.stop();
});
afterEach(() => {
// Clean up environment variables
delete process.env.OPENID_GRAPH_SCOPES;
});
beforeEach(async () => {
jest.clearAllMocks();
await mongoose.connection.dropDatabase();
// Set up environment variable for People.Read scope
process.env.OPENID_GRAPH_SCOPES = 'User.Read,People.Read,Group.Read.All';
// Mock Graph client
mockGraphClient = {
api: jest.fn().mockReturnThis(),
search: jest.fn().mockReturnThis(),
filter: jest.fn().mockReturnThis(),
select: jest.fn().mockReturnThis(),
header: jest.fn().mockReturnThis(),
top: jest.fn().mockReturnThis(),
get: jest.fn(),
};
Client.init.mockReturnValue(mockGraphClient);
// Mock tokens cache
mockTokensCache = {
get: jest.fn(),
set: jest.fn(),
};
getLogStores.mockReturnValue(mockTokensCache);
// Mock OpenID config
mockOpenIdConfig = {
client_id: 'test-client-id',
issuer: 'https://test-issuer.com',
};
getOpenIdConfig.mockReturnValue(mockOpenIdConfig);
// Mock openid-client (using the existing jest mock configuration)
if (client.genericGrantRequest) {
client.genericGrantRequest.mockResolvedValue({
access_token: 'mocked-graph-token',
expires_in: 3600,
});
}
});
describe('Dependency Contract Tests', () => {
it('should fail if getOpenIdConfig interface changes', () => {
// Reason: Ensure getOpenIdConfig returns expected structure
const config = getOpenIdConfig();
expect(config).toBeDefined();
expect(typeof config).toBe('object');
// Add specific property checks that GraphApiService depends on
expect(config).toHaveProperty('client_id');
expect(config).toHaveProperty('issuer');
// Ensure the function is callable
expect(typeof getOpenIdConfig).toBe('function');
});
it('should fail if openid-client.genericGrantRequest interface changes', () => {
// Reason: Ensure client.genericGrantRequest maintains expected signature
if (client.genericGrantRequest) {
expect(typeof client.genericGrantRequest).toBe('function');
// Test that it accepts the expected parameters
const mockCall = client.genericGrantRequest(
mockOpenIdConfig,
'urn:ietf:params:oauth:grant-type:jwt-bearer',
{
scope: 'test-scope',
assertion: 'test-token',
requested_token_use: 'on_behalf_of',
},
);
expect(mockCall).toBeDefined();
}
});
it('should fail if Microsoft Graph Client interface changes', () => {
// Reason: Ensure Graph Client maintains expected fluent API
expect(typeof Client.init).toBe('function');
const client = Client.init({ authProvider: jest.fn() });
expect(client).toHaveProperty('api');
expect(typeof client.api).toBe('function');
});
});
describe('createGraphClient', () => {
it('should create graph client with exchanged token', async () => {
const accessToken = 'test-access-token';
const sub = 'test-user-id';
const result = await GraphApiService.createGraphClient(accessToken, sub);
expect(getOpenIdConfig).toHaveBeenCalled();
expect(Client.init).toHaveBeenCalledWith({
authProvider: expect.any(Function),
});
expect(result).toBe(mockGraphClient);
});
it('should handle token exchange errors gracefully', async () => {
if (client.genericGrantRequest) {
client.genericGrantRequest.mockRejectedValue(new Error('Token exchange failed'));
}
await expect(GraphApiService.createGraphClient('invalid-token', 'test-user')).rejects.toThrow(
'Token exchange failed',
);
});
});
describe('exchangeTokenForGraphAccess', () => {
it('should return cached token if available', async () => {
const cachedToken = { access_token: 'cached-token' };
mockTokensCache.get.mockResolvedValue(cachedToken);
const result = await GraphApiService.exchangeTokenForGraphAccess(
mockOpenIdConfig,
'test-token',
'test-user',
);
expect(result).toBe('cached-token');
expect(mockTokensCache.get).toHaveBeenCalledWith('test-user:graph');
if (client.genericGrantRequest) {
expect(client.genericGrantRequest).not.toHaveBeenCalled();
}
});
it('should exchange token and cache result', async () => {
mockTokensCache.get.mockResolvedValue(null);
const result = await GraphApiService.exchangeTokenForGraphAccess(
mockOpenIdConfig,
'test-token',
'test-user',
);
if (client.genericGrantRequest) {
expect(client.genericGrantRequest).toHaveBeenCalledWith(
mockOpenIdConfig,
'urn:ietf:params:oauth:grant-type:jwt-bearer',
{
scope:
'https://graph.microsoft.com/User.Read https://graph.microsoft.com/People.Read https://graph.microsoft.com/Group.Read.All',
assertion: 'test-token',
requested_token_use: 'on_behalf_of',
},
);
}
expect(mockTokensCache.set).toHaveBeenCalledWith(
'test-user:graph',
{ access_token: 'mocked-graph-token' },
3600000,
);
expect(result).toBe('mocked-graph-token');
});
it('should use custom scopes from environment', async () => {
const originalEnv = process.env.OPENID_GRAPH_SCOPES;
process.env.OPENID_GRAPH_SCOPES = 'Custom.Read,Custom.Write';
mockTokensCache.get.mockResolvedValue(null);
await GraphApiService.exchangeTokenForGraphAccess(
mockOpenIdConfig,
'test-token',
'test-user',
);
if (client.genericGrantRequest) {
expect(client.genericGrantRequest).toHaveBeenCalledWith(
mockOpenIdConfig,
'urn:ietf:params:oauth:grant-type:jwt-bearer',
{
scope:
'https://graph.microsoft.com/Custom.Read https://graph.microsoft.com/Custom.Write',
assertion: 'test-token',
requested_token_use: 'on_behalf_of',
},
);
}
process.env.OPENID_GRAPH_SCOPES = originalEnv;
});
});
describe('searchEntraIdPrincipals', () => {
// Mock data used by multiple tests
const mockContactsResponse = {
value: [
{
id: 'contact-user-1',
displayName: 'John Doe',
userPrincipalName: 'john@company.com',
mail: 'john@company.com',
personType: { class: 'Person', subclass: 'OrganizationUser' },
scoredEmailAddresses: [{ address: 'john@company.com', relevanceScore: 0.9 }],
},
{
id: 'contact-group-1',
displayName: 'Marketing Team',
mail: 'marketing@company.com',
personType: { class: 'Group', subclass: 'UnifiedGroup' },
scoredEmailAddresses: [{ address: 'marketing@company.com', relevanceScore: 0.8 }],
},
],
};
const mockUsersResponse = {
value: [
{
id: 'dir-user-1',
displayName: 'Jane Smith',
userPrincipalName: 'jane@company.com',
mail: 'jane@company.com',
},
],
};
const mockGroupsResponse = {
value: [
{
id: 'dir-group-1',
displayName: 'Development Team',
mail: 'dev@company.com',
},
],
};
beforeEach(() => {
// Reset mock call history for each test
jest.clearAllMocks();
// Re-apply the Client.init mock after clearAllMocks
Client.init.mockReturnValue(mockGraphClient);
// Re-apply openid-client mock
if (client.genericGrantRequest) {
client.genericGrantRequest.mockResolvedValue({
access_token: 'mocked-graph-token',
expires_in: 3600,
});
}
// Re-apply cache mock
mockTokensCache.get.mockResolvedValue(null); // Force token exchange
mockTokensCache.set.mockResolvedValue();
getLogStores.mockReturnValue(mockTokensCache);
getOpenIdConfig.mockReturnValue(mockOpenIdConfig);
});
it('should return empty results for short queries', async () => {
const result = await GraphApiService.searchEntraIdPrincipals('token', 'user', 'a', 'all', 10);
expect(result).toEqual([]);
expect(mockGraphClient.api).not.toHaveBeenCalled();
});
it('should search contacts first and additional users for users type', async () => {
// Mock responses for this specific test
const contactsFilteredResponse = {
value: [
{
id: 'contact-user-1',
displayName: 'John Doe',
userPrincipalName: 'john@company.com',
mail: 'john@company.com',
personType: { class: 'Person', subclass: 'OrganizationUser' },
scoredEmailAddresses: [{ address: 'john@company.com', relevanceScore: 0.9 }],
},
],
};
mockGraphClient.get
.mockResolvedValueOnce(contactsFilteredResponse) // contacts call
.mockResolvedValueOnce(mockUsersResponse); // users call
const result = await GraphApiService.searchEntraIdPrincipals(
'token',
'user',
'john',
'users',
10,
);
// Should call contacts first with user filter
expect(mockGraphClient.api).toHaveBeenCalledWith('/me/people');
expect(mockGraphClient.search).toHaveBeenCalledWith('"john"');
expect(mockGraphClient.filter).toHaveBeenCalledWith(
"personType/subclass eq 'OrganizationUser'",
);
// Should call users endpoint for additional results
expect(mockGraphClient.api).toHaveBeenCalledWith('/users');
expect(mockGraphClient.search).toHaveBeenCalledWith(
'"displayName:john" OR "userPrincipalName:john" OR "mail:john" OR "givenName:john" OR "surname:john"',
);
// Should return TPrincipalSearchResult array
expect(Array.isArray(result)).toBe(true);
expect(result).toHaveLength(2); // 1 from contacts + 1 from users
expect(result[0]).toMatchObject({
id: null,
type: 'user',
name: 'John Doe',
email: 'john@company.com',
source: 'entra',
idOnTheSource: 'contact-user-1',
});
});
it('should search groups endpoint only for groups type', async () => {
// Mock responses for this specific test - only groups endpoint called
mockGraphClient.get.mockResolvedValueOnce(mockGroupsResponse); // only groups call
const result = await GraphApiService.searchEntraIdPrincipals(
'token',
'user',
'team',
'groups',
10,
);
// Should NOT call contacts for groups type
expect(mockGraphClient.api).not.toHaveBeenCalledWith('/me/people');
// Should call groups endpoint only
expect(mockGraphClient.api).toHaveBeenCalledWith('/groups');
expect(mockGraphClient.search).toHaveBeenCalledWith(
'"displayName:team" OR "mail:team" OR "mailNickname:team"',
);
expect(Array.isArray(result)).toBe(true);
expect(result).toHaveLength(1); // 1 from groups only
});
it('should search all endpoints for all type', async () => {
// Mock responses for this specific test
mockGraphClient.get
.mockResolvedValueOnce(mockContactsResponse) // contacts call (both user and group)
.mockResolvedValueOnce(mockUsersResponse) // users call
.mockResolvedValueOnce(mockGroupsResponse); // groups call
const result = await GraphApiService.searchEntraIdPrincipals(
'token',
'user',
'test',
'all',
10,
);
// Should call contacts first with user filter
expect(mockGraphClient.api).toHaveBeenCalledWith('/me/people');
expect(mockGraphClient.search).toHaveBeenCalledWith('"test"');
expect(mockGraphClient.filter).toHaveBeenCalledWith(
"personType/subclass eq 'OrganizationUser'",
);
// Should call both users and groups endpoints
expect(mockGraphClient.api).toHaveBeenCalledWith('/users');
expect(mockGraphClient.api).toHaveBeenCalledWith('/groups');
expect(Array.isArray(result)).toBe(true);
expect(result).toHaveLength(4); // 2 from contacts + 1 from users + 1 from groups
});
it('should early exit if contacts reach limit', async () => {
// Mock contacts to return exactly the limit
const limitedContactsResponse = {
value: Array(10).fill({
id: 'contact-1',
displayName: 'Contact User',
mail: 'contact@company.com',
personType: { class: 'Person', subclass: 'OrganizationUser' },
}),
};
mockGraphClient.get.mockResolvedValueOnce(limitedContactsResponse);
const result = await GraphApiService.searchEntraIdPrincipals(
'token',
'user',
'test',
'all',
10,
);
// Should call contacts first
expect(mockGraphClient.api).toHaveBeenCalledWith('/me/people');
expect(mockGraphClient.search).toHaveBeenCalledWith('"test"');
// Should not call users endpoint since limit was reached
expect(mockGraphClient.api).not.toHaveBeenCalledWith('/users');
expect(result).toHaveLength(10);
});
it('should deduplicate results based on idOnTheSource', async () => {
// Mock responses with duplicate IDs
const duplicateContactsResponse = {
value: [
{
id: 'duplicate-id',
displayName: 'John Doe',
mail: 'john@company.com',
personType: { class: 'Person', subclass: 'OrganizationUser' },
},
],
};
const duplicateUsersResponse = {
value: [
{
id: 'duplicate-id', // Same ID as contact
displayName: 'John Doe',
mail: 'john@company.com',
},
],
};
mockGraphClient.get
.mockResolvedValueOnce(duplicateContactsResponse)
.mockResolvedValueOnce(duplicateUsersResponse);
const result = await GraphApiService.searchEntraIdPrincipals(
'token',
'user',
'john',
'users',
10,
);
// Should only return one result despite duplicate IDs
expect(result).toHaveLength(1);
expect(result[0].idOnTheSource).toBe('duplicate-id');
});
it('should handle Graph API errors gracefully', async () => {
mockGraphClient.get.mockRejectedValue(new Error('Graph API error'));
const result = await GraphApiService.searchEntraIdPrincipals(
'token',
'user',
'test',
'all',
10,
);
expect(result).toEqual([]);
});
});
describe('getUserEntraGroups', () => {
it('should fetch user groups from memberOf endpoint', async () => {
const mockGroupsResponse = {
value: [
{
id: 'group-1',
},
{
id: 'group-2',
},
],
};
mockGraphClient.get.mockResolvedValue(mockGroupsResponse);
const result = await GraphApiService.getUserEntraGroups('token', 'user');
expect(mockGraphClient.api).toHaveBeenCalledWith('/me/memberOf');
expect(mockGraphClient.select).toHaveBeenCalledWith('id');
expect(result).toHaveLength(2);
expect(result).toEqual(['group-1', 'group-2']);
});
it('should return empty array on error', async () => {
mockGraphClient.get.mockRejectedValue(new Error('API error'));
const result = await GraphApiService.getUserEntraGroups('token', 'user');
expect(result).toEqual([]);
});
it('should handle empty response', async () => {
const mockGroupsResponse = {
value: [],
};
mockGraphClient.get.mockResolvedValue(mockGroupsResponse);
const result = await GraphApiService.getUserEntraGroups('token', 'user');
expect(result).toEqual([]);
});
it('should handle missing value property', async () => {
mockGraphClient.get.mockResolvedValue({});
const result = await GraphApiService.getUserEntraGroups('token', 'user');
expect(result).toEqual([]);
});
});
describe('testGraphApiAccess', () => {
beforeEach(() => {
jest.clearAllMocks();
});
it('should test all permissions and return success results', async () => {
// Mock successful responses for all tests
mockGraphClient.get
.mockResolvedValueOnce({ id: 'user-123', displayName: 'Test User' }) // /me test
.mockResolvedValueOnce({ value: [] }) // people OrganizationUser test
.mockResolvedValueOnce({ value: [] }) // people UnifiedGroup test
.mockResolvedValueOnce({ value: [] }) // /users endpoint test
.mockResolvedValueOnce({ value: [] }); // /groups endpoint test
const result = await GraphApiService.testGraphApiAccess('token', 'user');
expect(result).toEqual({
userAccess: true,
peopleAccess: true,
groupsAccess: true,
usersEndpointAccess: true,
groupsEndpointAccess: true,
errors: [],
});
// Verify all endpoints were tested
expect(mockGraphClient.api).toHaveBeenCalledWith('/me');
expect(mockGraphClient.api).toHaveBeenCalledWith('/me/people');
expect(mockGraphClient.api).toHaveBeenCalledWith('/users');
expect(mockGraphClient.api).toHaveBeenCalledWith('/groups');
expect(mockGraphClient.filter).toHaveBeenCalledWith(
"personType/subclass eq 'OrganizationUser'",
);
expect(mockGraphClient.filter).toHaveBeenCalledWith("personType/subclass eq 'UnifiedGroup'");
expect(mockGraphClient.search).toHaveBeenCalledWith('"displayName:test"');
});
it('should handle partial failures and record errors', async () => {
// Mock mixed success/failure responses
mockGraphClient.get
.mockResolvedValueOnce({ id: 'user-123', displayName: 'Test User' }) // /me success
.mockRejectedValueOnce(new Error('People access denied')) // people OrganizationUser fail
.mockResolvedValueOnce({ value: [] }) // people UnifiedGroup success
.mockRejectedValueOnce(new Error('Users endpoint access denied')) // /users fail
.mockResolvedValueOnce({ value: [] }); // /groups success
const result = await GraphApiService.testGraphApiAccess('token', 'user');
expect(result).toEqual({
userAccess: true,
peopleAccess: false,
groupsAccess: true,
usersEndpointAccess: false,
groupsEndpointAccess: true,
errors: [
'People.Read (OrganizationUser): People access denied',
'Users endpoint: Users endpoint access denied',
],
});
});
it('should handle complete Graph client creation failure', async () => {
// Mock token exchange failure to test error handling
if (client.genericGrantRequest) {
client.genericGrantRequest.mockRejectedValue(new Error('Token exchange failed'));
}
const result = await GraphApiService.testGraphApiAccess('invalid-token', 'user');
expect(result).toEqual({
userAccess: false,
peopleAccess: false,
groupsAccess: false,
usersEndpointAccess: false,
groupsEndpointAccess: false,
errors: ['Token exchange failed'],
});
});
it('should record all permission errors', async () => {
// Mock all requests to fail
mockGraphClient.get
.mockRejectedValueOnce(new Error('User.Read denied'))
.mockRejectedValueOnce(new Error('People.Read OrganizationUser denied'))
.mockRejectedValueOnce(new Error('People.Read UnifiedGroup denied'))
.mockRejectedValueOnce(new Error('Users directory access denied'))
.mockRejectedValueOnce(new Error('Groups directory access denied'));
const result = await GraphApiService.testGraphApiAccess('token', 'user');
expect(result).toEqual({
userAccess: false,
peopleAccess: false,
groupsAccess: false,
usersEndpointAccess: false,
groupsEndpointAccess: false,
errors: [
'User.Read: User.Read denied',
'People.Read (OrganizationUser): People.Read OrganizationUser denied',
'People.Read (UnifiedGroup): People.Read UnifiedGroup denied',
'Users endpoint: Users directory access denied',
'Groups endpoint: Groups directory access denied',
],
});
});
it('should test new endpoints with correct search patterns', async () => {
// Mock successful responses for endpoint testing
mockGraphClient.get
.mockResolvedValueOnce({ id: 'user-123', displayName: 'Test User' }) // /me
.mockResolvedValueOnce({ value: [] }) // people OrganizationUser
.mockResolvedValueOnce({ value: [] }) // people UnifiedGroup
.mockResolvedValueOnce({ value: [] }) // /users
.mockResolvedValueOnce({ value: [] }); // /groups
await GraphApiService.testGraphApiAccess('token', 'user');
// Verify /users endpoint test
expect(mockGraphClient.api).toHaveBeenCalledWith('/users');
expect(mockGraphClient.search).toHaveBeenCalledWith('"displayName:test"');
expect(mockGraphClient.select).toHaveBeenCalledWith('id,displayName,userPrincipalName');
// Verify /groups endpoint test
expect(mockGraphClient.api).toHaveBeenCalledWith('/groups');
expect(mockGraphClient.select).toHaveBeenCalledWith('id,displayName,mail');
});
it('should handle endpoint-specific permission failures', async () => {
// Mock specific endpoint failures
mockGraphClient.get
.mockResolvedValueOnce({ id: 'user-123', displayName: 'Test User' }) // /me success
.mockResolvedValueOnce({ value: [] }) // people OrganizationUser success
.mockResolvedValueOnce({ value: [] }) // people UnifiedGroup success
.mockRejectedValueOnce(new Error('Insufficient privileges')) // /users fail (User.Read.All needed)
.mockRejectedValueOnce(new Error('Access denied to groups')); // /groups fail (Group.Read.All needed)
const result = await GraphApiService.testGraphApiAccess('token', 'user');
expect(result).toEqual({
userAccess: true,
peopleAccess: true,
groupsAccess: true,
usersEndpointAccess: false,
groupsEndpointAccess: false,
errors: [
'Users endpoint: Insufficient privileges',
'Groups endpoint: Access denied to groups',
],
});
});
});
});

View File

@@ -0,0 +1,721 @@
const mongoose = require('mongoose');
const { getTransactionSupport, logger } = require('@librechat/data-schemas');
const { isEnabled } = require('~/server/utils');
const {
entraIdPrincipalFeatureEnabled,
getUserEntraGroups,
getUserOwnedEntraGroups,
getGroupMembers,
getGroupOwners,
} = require('~/server/services/GraphApiService');
const {
findGroupByExternalId,
findRoleByIdentifier,
getUserPrincipals,
createGroup,
createUser,
updateUser,
findUser,
grantPermission: grantPermissionACL,
findAccessibleResources: findAccessibleResourcesACL,
hasPermission,
getEffectivePermissions: getEffectivePermissionsACL,
findEntriesByPrincipalsAndResource,
} = require('~/models');
const { AclEntry, AccessRole, Group } = require('~/db/models');
/** @type {boolean|null} */
let transactionSupportCache = null;
/**
* @import { TPrincipal } from 'librechat-data-provider'
*/
/**
* Grant a permission to a principal for a resource using a role
* @param {Object} params - Parameters for granting role-based permission
* @param {string} params.principalType - 'user', 'group', or 'public'
* @param {string|mongoose.Types.ObjectId|null} params.principalId - The ID of the principal (null for 'public')
* @param {string} params.resourceType - Type of resource (e.g., 'agent')
* @param {string|mongoose.Types.ObjectId} params.resourceId - The ID of the resource
* @param {string} params.accessRoleId - The ID of the role (e.g., 'agent_viewer', 'agent_editor')
* @param {string|mongoose.Types.ObjectId} params.grantedBy - User ID granting the permission
* @param {mongoose.ClientSession} [params.session] - Optional MongoDB session for transactions
* @returns {Promise<Object>} The created or updated ACL entry
*/
const grantPermission = async ({
principalType,
principalId,
resourceType,
resourceId,
accessRoleId,
grantedBy,
session,
}) => {
try {
if (!['user', 'group', 'public'].includes(principalType)) {
throw new Error(`Invalid principal type: ${principalType}`);
}
if (principalType !== 'public' && !principalId) {
throw new Error('Principal ID is required for user and group principals');
}
if (principalId && !mongoose.Types.ObjectId.isValid(principalId)) {
throw new Error(`Invalid principal ID: ${principalId}`);
}
if (!resourceId || !mongoose.Types.ObjectId.isValid(resourceId)) {
throw new Error(`Invalid resource ID: ${resourceId}`);
}
// Get the role to determine permission bits
const role = await findRoleByIdentifier(accessRoleId);
if (!role) {
throw new Error(`Role ${accessRoleId} not found`);
}
// Ensure the role is for the correct resource type
if (role.resourceType !== resourceType) {
throw new Error(
`Role ${accessRoleId} is for ${role.resourceType} resources, not ${resourceType}`,
);
}
return await grantPermissionACL(
principalType,
principalId,
resourceType,
resourceId,
role.permBits,
grantedBy,
session,
role._id,
);
} catch (error) {
logger.error(`[PermissionService.grantPermission] Error: ${error.message}`);
throw error;
}
};
/**
* Check if a user has specific permission bits on a resource
* @param {Object} params - Parameters for checking permissions
* @param {string|mongoose.Types.ObjectId} params.userId - The ID of the user
* @param {string} params.resourceType - Type of resource (e.g., 'agent')
* @param {string|mongoose.Types.ObjectId} params.resourceId - The ID of the resource
* @param {number} params.requiredPermissions - The permission bits required (e.g., 1 for VIEW, 3 for VIEW+EDIT)
* @returns {Promise<boolean>} Whether the user has the required permission bits
*/
const checkPermission = async ({ userId, resourceType, resourceId, requiredPermission }) => {
try {
if (typeof requiredPermission !== 'number' || requiredPermission < 1) {
throw new Error('requiredPermission must be a positive number');
}
// Get all principals for the user (user + groups + public)
const principals = await getUserPrincipals(userId);
if (principals.length === 0) {
return false;
}
return await hasPermission(principals, resourceType, resourceId, requiredPermission);
} catch (error) {
logger.error(`[PermissionService.checkPermission] Error: ${error.message}`);
// Re-throw validation errors
if (error.message.includes('requiredPermission must be')) {
throw error;
}
return false;
}
};
/**
* Get effective permission bitmask for a user on a resource
* @param {Object} params - Parameters for getting effective permissions
* @param {string|mongoose.Types.ObjectId} params.userId - The ID of the user
* @param {string} params.resourceType - Type of resource (e.g., 'agent')
* @param {string|mongoose.Types.ObjectId} params.resourceId - The ID of the resource
* @returns {Promise<number>} Effective permission bitmask
*/
const getEffectivePermissions = async ({ userId, resourceType, resourceId }) => {
try {
// Get all principals for the user (user + groups + public)
const principals = await getUserPrincipals(userId);
if (principals.length === 0) {
return 0;
}
return await getEffectivePermissionsACL(principals, resourceType, resourceId);
} catch (error) {
logger.error(`[PermissionService.getEffectivePermissions] Error: ${error.message}`);
return 0;
}
};
/**
* Find all resources of a specific type that a user has access to with specific permission bits
* @param {Object} params - Parameters for finding accessible resources
* @param {string|mongoose.Types.ObjectId} params.userId - The ID of the user
* @param {string} params.resourceType - Type of resource (e.g., 'agent')
* @param {number} params.requiredPermissions - The minimum permission bits required (e.g., 1 for VIEW, 3 for VIEW+EDIT)
* @returns {Promise<Array>} Array of resource IDs
*/
const findAccessibleResources = async ({ userId, resourceType, requiredPermissions }) => {
try {
if (typeof requiredPermissions !== 'number' || requiredPermissions < 1) {
throw new Error('requiredPermissions must be a positive number');
}
// Get all principals for the user (user + groups + public)
const principalsList = await getUserPrincipals(userId);
if (principalsList.length === 0) {
return [];
}
return await findAccessibleResourcesACL(principalsList, resourceType, requiredPermissions);
} catch (error) {
logger.error(`[PermissionService.findAccessibleResources] Error: ${error.message}`);
// Re-throw validation errors
if (error.message.includes('requiredPermissions must be')) {
throw error;
}
return [];
}
};
/**
* Find all publicly accessible resources of a specific type
* @param {Object} params - Parameters for finding publicly accessible resources
* @param {string} params.resourceType - Type of resource (e.g., 'agent')
* @param {number} params.requiredPermissions - The minimum permission bits required (e.g., 1 for VIEW, 3 for VIEW+EDIT)
* @returns {Promise<Array>} Array of resource IDs
*/
const findPubliclyAccessibleResources = async ({ resourceType, requiredPermissions }) => {
try {
if (typeof requiredPermissions !== 'number' || requiredPermissions < 1) {
throw new Error('requiredPermissions must be a positive number');
}
// Find all public ACL entries where the public principal has at least the required permission bits
const entries = await AclEntry.find({
principalType: 'public',
resourceType,
permBits: { $bitsAllSet: requiredPermissions },
}).distinct('resourceId');
return entries;
} catch (error) {
logger.error(`[PermissionService.findPubliclyAccessibleResources] Error: ${error.message}`);
// Re-throw validation errors
if (error.message.includes('requiredPermissions must be')) {
throw error;
}
return [];
}
};
/**
* Get available roles for a resource type
* @param {Object} params - Parameters for getting available roles
* @param {string} params.resourceType - Type of resource (e.g., 'agent')
* @returns {Promise<Array>} Array of role definitions
*/
const getAvailableRoles = async ({ resourceType }) => {
try {
return await AccessRole.find({ resourceType }).lean();
} catch (error) {
logger.error(`[PermissionService.getAvailableRoles] Error: ${error.message}`);
return [];
}
};
/**
* Ensures a principal exists in the database based on TPrincipal data
* Creates user if it doesn't exist locally (for Entra ID users)
* @param {Object} principal - TPrincipal object from frontend
* @param {string} principal.type - 'user', 'group', or 'public'
* @param {string} [principal.id] - Local database ID (null for Entra ID principals not yet synced)
* @param {string} principal.name - Display name
* @param {string} [principal.email] - Email address
* @param {string} [principal.source] - 'local' or 'entra'
* @param {string} [principal.idOnTheSource] - Entra ID object ID for external principals
* @returns {Promise<string|null>} Returns the principalId for database operations, null for public
*/
const ensurePrincipalExists = async function (principal) {
if (principal.type === 'public') {
return null;
}
if (principal.id) {
return principal.id;
}
if (principal.type === 'user' && principal.source === 'entra') {
if (!principal.email || !principal.idOnTheSource) {
throw new Error('Entra ID user principals must have email and idOnTheSource');
}
let existingUser = await findUser({ idOnTheSource: principal.idOnTheSource });
if (!existingUser) {
existingUser = await findUser({ email: principal.email.toLowerCase() });
}
if (existingUser) {
if (!existingUser.idOnTheSource && principal.idOnTheSource) {
await updateUser(existingUser._id, {
idOnTheSource: principal.idOnTheSource,
provider: 'openid',
});
}
return existingUser._id.toString();
}
const userData = {
name: principal.name,
email: principal.email.toLowerCase(),
emailVerified: false,
provider: 'openid',
idOnTheSource: principal.idOnTheSource,
};
const userId = await createUser(userData, true, false);
return userId.toString();
}
if (principal.type === 'group') {
throw new Error('Group principals should be handled by group-specific methods');
}
throw new Error(`Unsupported principal type: ${principal.type}`);
};
/**
* Ensures a group principal exists in the database based on TPrincipal data
* Creates group if it doesn't exist locally (for Entra ID groups)
* For Entra ID groups, always synchronizes member IDs when authentication context is provided
* @param {Object} principal - TPrincipal object from frontend
* @param {string} principal.type - Must be 'group'
* @param {string} [principal.id] - Local database ID (null for Entra ID principals not yet synced)
* @param {string} principal.name - Display name
* @param {string} [principal.email] - Email address
* @param {string} [principal.description] - Group description
* @param {string} [principal.source] - 'local' or 'entra'
* @param {string} [principal.idOnTheSource] - Entra ID object ID for external principals
* @param {Object} [authContext] - Optional authentication context for fetching member data
* @param {string} [authContext.accessToken] - Access token for Graph API calls
* @param {string} [authContext.sub] - Subject identifier
* @returns {Promise<string>} Returns the groupId for database operations
*/
const ensureGroupPrincipalExists = async function (principal, authContext = null) {
if (principal.type !== 'group') {
throw new Error(`Invalid principal type: ${principal.type}. Expected 'group'`);
}
if (principal.source === 'entra') {
if (!principal.name || !principal.idOnTheSource) {
throw new Error('Entra ID group principals must have name and idOnTheSource');
}
let memberIds = [];
if (authContext && authContext.accessToken && authContext.sub) {
try {
memberIds = await getGroupMembers(
authContext.accessToken,
authContext.sub,
principal.idOnTheSource,
);
// Include group owners as members if feature is enabled
if (isEnabled(process.env.ENTRA_ID_INCLUDE_OWNERS_AS_MEMBERS)) {
const ownerIds = await getGroupOwners(
authContext.accessToken,
authContext.sub,
principal.idOnTheSource,
);
if (ownerIds && ownerIds.length > 0) {
memberIds.push(...ownerIds);
// Remove duplicates
memberIds = [...new Set(memberIds)];
}
}
} catch (error) {
logger.error('Failed to fetch group members from Graph API:', error);
}
}
let existingGroup = await findGroupByExternalId(principal.idOnTheSource, 'entra');
if (!existingGroup && principal.email) {
existingGroup = await Group.findOne({ email: principal.email.toLowerCase() }).lean();
}
if (existingGroup) {
const updateData = {};
let needsUpdate = false;
if (!existingGroup.idOnTheSource && principal.idOnTheSource) {
updateData.idOnTheSource = principal.idOnTheSource;
updateData.source = 'entra';
needsUpdate = true;
}
if (principal.description && existingGroup.description !== principal.description) {
updateData.description = principal.description;
needsUpdate = true;
}
if (principal.email && existingGroup.email !== principal.email.toLowerCase()) {
updateData.email = principal.email.toLowerCase();
needsUpdate = true;
}
if (authContext && authContext.accessToken && authContext.sub) {
updateData.memberIds = memberIds;
needsUpdate = true;
}
if (needsUpdate) {
await Group.findByIdAndUpdate(existingGroup._id, { $set: updateData }, { new: true });
}
return existingGroup._id.toString();
}
const groupData = {
name: principal.name,
source: 'entra',
idOnTheSource: principal.idOnTheSource,
memberIds: memberIds, // Store idOnTheSource values of group members (empty if no auth context)
};
if (principal.email) {
groupData.email = principal.email.toLowerCase();
}
if (principal.description) {
groupData.description = principal.description;
}
const newGroup = await createGroup(groupData);
return newGroup._id.toString();
}
if (principal.id && authContext == null) {
return principal.id;
}
throw new Error(`Unsupported group principal source: ${principal.source}`);
};
/**
* Synchronize user's Entra ID group memberships on sign-in
* Gets user's group IDs from GraphAPI and updates memberships only for existing groups in database
* Optionally includes groups the user owns if ENTRA_ID_INCLUDE_OWNERS_AS_MEMBERS is enabled
* @param {Object} user - User object with authentication context
* @param {string} user.openidId - User's OpenID subject identifier
* @param {string} user.idOnTheSource - User's Entra ID (oid from token claims)
* @param {string} user.provider - Authentication provider ('openid')
* @param {string} accessToken - Access token for Graph API calls
* @param {mongoose.ClientSession} [session] - Optional MongoDB session for transactions
* @returns {Promise<void>}
*/
const syncUserEntraGroupMemberships = async (user, accessToken, session = null) => {
try {
if (!entraIdPrincipalFeatureEnabled(user) || !accessToken || !user.idOnTheSource) {
return;
}
const memberGroupIds = await getUserEntraGroups(accessToken, user.openidId);
let allGroupIds = [...(memberGroupIds || [])];
// Include owned groups if feature is enabled
if (isEnabled(process.env.ENTRA_ID_INCLUDE_OWNERS_AS_MEMBERS)) {
const ownedGroupIds = await getUserOwnedEntraGroups(accessToken, user.openidId);
if (ownedGroupIds && ownedGroupIds.length > 0) {
allGroupIds.push(...ownedGroupIds);
// Remove duplicates
allGroupIds = [...new Set(allGroupIds)];
}
}
if (!allGroupIds || allGroupIds.length === 0) {
return;
}
const sessionOptions = session ? { session } : {};
await Group.updateMany(
{
idOnTheSource: { $in: allGroupIds },
source: 'entra',
memberIds: { $ne: user.idOnTheSource },
},
{ $addToSet: { memberIds: user.idOnTheSource } },
sessionOptions,
);
await Group.updateMany(
{
source: 'entra',
memberIds: user.idOnTheSource,
idOnTheSource: { $nin: allGroupIds },
},
{ $pull: { memberIds: user.idOnTheSource } },
sessionOptions,
);
} catch (error) {
logger.error(`[PermissionService.syncUserEntraGroupMemberships] Error syncing groups:`, error);
}
};
/**
* Check if public has a specific permission on a resource
* @param {Object} params - Parameters for checking public permission
* @param {string} params.resourceType - Type of resource (e.g., 'agent')
* @param {string|mongoose.Types.ObjectId} params.resourceId - The ID of the resource
* @param {number} params.requiredPermissions - The permission bits required (e.g., 1 for VIEW, 3 for VIEW+EDIT)
* @returns {Promise<boolean>} Whether public has the required permission bits
*/
const hasPublicPermission = async ({ resourceType, resourceId, requiredPermissions }) => {
try {
if (typeof requiredPermissions !== 'number' || requiredPermissions < 1) {
throw new Error('requiredPermissions must be a positive number');
}
// Use public principal to check permissions
const publicPrincipal = [{ principalType: 'public' }];
const entries = await findEntriesByPrincipalsAndResource(
publicPrincipal,
resourceType,
resourceId,
);
// Check if any entry has the required permission bits
return entries.some((entry) => (entry.permBits & requiredPermissions) === requiredPermissions);
} catch (error) {
logger.error(`[PermissionService.hasPublicPermission] Error: ${error.message}`);
// Re-throw validation errors
if (error.message.includes('requiredPermissions must be')) {
throw error;
}
return false;
}
};
/**
* Bulk update permissions for a resource (grant, update, revoke)
* Efficiently handles multiple permission changes in a single transaction
*
* @param {Object} params - Parameters for bulk permission update
* @param {string} params.resourceType - Type of resource (e.g., 'agent')
* @param {string|mongoose.Types.ObjectId} params.resourceId - The ID of the resource
* @param {Array<TPrincipal>} params.updatedPrincipals - Array of principals to grant/update permissions for
* @param {Array<TPrincipal>} params.revokedPrincipals - Array of principals to revoke permissions from
* @param {string|mongoose.Types.ObjectId} params.grantedBy - User ID making the changes
* @param {mongoose.ClientSession} [params.session] - Optional MongoDB session for transactions
* @returns {Promise<Object>} Results object with granted, updated, revoked arrays and error details
*/
const bulkUpdateResourcePermissions = async ({
resourceType,
resourceId,
updatedPrincipals = [],
revokedPrincipals = [],
grantedBy,
session,
}) => {
const supportsTransactions = await getTransactionSupport(mongoose, transactionSupportCache);
transactionSupportCache = supportsTransactions;
let localSession = session;
let shouldEndSession = false;
try {
if (!Array.isArray(updatedPrincipals)) {
throw new Error('updatedPrincipals must be an array');
}
if (!Array.isArray(revokedPrincipals)) {
throw new Error('revokedPrincipals must be an array');
}
if (!resourceId || !mongoose.Types.ObjectId.isValid(resourceId)) {
throw new Error(`Invalid resource ID: ${resourceId}`);
}
if (!localSession && supportsTransactions) {
localSession = await mongoose.startSession();
localSession.startTransaction();
shouldEndSession = true;
}
const sessionOptions = localSession ? { session: localSession } : {};
const roles = await AccessRole.find({ resourceType }).lean();
const rolesMap = new Map();
roles.forEach((role) => {
rolesMap.set(role.accessRoleId, role);
});
const results = {
granted: [],
updated: [],
revoked: [],
errors: [],
};
const bulkWrites = [];
for (const principal of updatedPrincipals) {
try {
if (!principal.accessRoleId) {
results.errors.push({
principal,
error: 'accessRoleId is required for updated principals',
});
continue;
}
const role = rolesMap.get(principal.accessRoleId);
if (!role) {
results.errors.push({
principal,
error: `Role ${principal.accessRoleId} not found`,
});
continue;
}
const query = {
principalType: principal.type,
resourceType,
resourceId,
};
if (principal.type !== 'public') {
query.principalId = principal.id;
}
const update = {
$set: {
permBits: role.permBits,
roleId: role._id,
grantedBy,
grantedAt: new Date(),
},
$setOnInsert: {
principalType: principal.type,
resourceType,
resourceId,
...(principal.type !== 'public' && {
principalId: principal.id,
principalModel: principal.type === 'user' ? 'User' : 'Group',
}),
},
};
bulkWrites.push({
updateOne: {
filter: query,
update: update,
upsert: true,
},
});
results.granted.push({
type: principal.type,
id: principal.id,
name: principal.name,
email: principal.email,
source: principal.source,
avatar: principal.avatar,
description: principal.description,
idOnTheSource: principal.idOnTheSource,
accessRoleId: principal.accessRoleId,
memberCount: principal.memberCount,
memberIds: principal.memberIds,
});
} catch (error) {
results.errors.push({
principal,
error: error.message,
});
}
}
if (bulkWrites.length > 0) {
await AclEntry.bulkWrite(bulkWrites, sessionOptions);
}
const deleteQueries = [];
for (const principal of revokedPrincipals) {
try {
const query = {
principalType: principal.type,
resourceType,
resourceId,
};
if (principal.type !== 'public') {
query.principalId = principal.id;
}
deleteQueries.push(query);
results.revoked.push({
type: principal.type,
id: principal.id,
name: principal.name,
email: principal.email,
source: principal.source,
avatar: principal.avatar,
description: principal.description,
idOnTheSource: principal.idOnTheSource,
memberCount: principal.memberCount,
});
} catch (error) {
results.errors.push({
principal,
error: error.message,
});
}
}
if (deleteQueries.length > 0) {
await AclEntry.deleteMany(
{
$or: deleteQueries,
},
sessionOptions,
);
}
if (shouldEndSession && supportsTransactions) {
await localSession.commitTransaction();
}
return results;
} catch (error) {
if (shouldEndSession && supportsTransactions) {
await localSession.abortTransaction();
}
logger.error(`[PermissionService.bulkUpdateResourcePermissions] Error: ${error.message}`);
throw error;
} finally {
if (shouldEndSession && localSession) {
localSession.endSession();
}
}
};
module.exports = {
grantPermission,
checkPermission,
getEffectivePermissions,
findAccessibleResources,
findPubliclyAccessibleResources,
hasPublicPermission,
getAvailableRoles,
bulkUpdateResourcePermissions,
ensurePrincipalExists,
ensureGroupPrincipalExists,
syncUserEntraGroupMemberships,
};

File diff suppressed because it is too large Load Diff

View File

@@ -1,3 +1,6 @@
const { sleep } = require('@librechat/agents');
const { sendEvent } = require('@librechat/api');
const { logger } = require('@librechat/data-schemas');
const {
Constants,
StepTypes,
@@ -8,9 +11,8 @@ const {
} = require('librechat-data-provider');
const { retrieveAndProcessFile } = require('~/server/services/Files/process');
const { processRequiredActions } = require('~/server/services/ToolService');
const { createOnProgress, sendMessage, sleep } = require('~/server/utils');
const { processMessages } = require('~/server/services/Threads');
const { logger } = require('~/config');
const { createOnProgress } = require('~/server/utils');
/**
* Implements the StreamRunManager functionality for managing the streaming
@@ -126,7 +128,7 @@ class StreamRunManager {
conversationId: this.finalMessage.conversationId,
};
sendMessage(this.res, contentData);
sendEvent(this.res, contentData);
}
/* <------------------ Misc. Helpers ------------------> */
@@ -302,7 +304,7 @@ class StreamRunManager {
for (const d of delta[key]) {
if (typeof d === 'object' && !Object.prototype.hasOwnProperty.call(d, 'index')) {
logger.warn('Expected an object with an \'index\' for array updates but got:', d);
logger.warn("Expected an object with an 'index' for array updates but got:", d);
continue;
}

View File

@@ -1,9 +1,9 @@
const { logger } = require('@librechat/data-schemas');
const { CacheKeys, processMCPEnv } = require('librechat-data-provider');
const { CacheKeys } = require('librechat-data-provider');
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
const { getMCPManager, getFlowStateManager } = require('~/config');
const { getCachedTools, setCachedTools } = require('./Config');
const { getLogStores } = require('~/cache');
const { findToken, updateToken, createToken, deleteTokens } = require('~/models');
/**
* Initialize MCP servers
@@ -30,7 +30,6 @@ async function initializeMCP(app) {
createToken,
deleteTokens,
},
processMCPEnv,
});
delete app.locals.mcpConfig;

View File

@@ -41,6 +41,7 @@ async function loadDefaultInterface(config, configDefaults, roleName = SystemRol
sidePanel: interfaceConfig?.sidePanel ?? defaults.sidePanel,
privacyPolicy: interfaceConfig?.privacyPolicy ?? defaults.privacyPolicy,
termsOfService: interfaceConfig?.termsOfService ?? defaults.termsOfService,
mcpServers: interfaceConfig?.mcpServers ?? defaults.mcpServers,
bookmarks: interfaceConfig?.bookmarks ?? defaults.bookmarks,
memories: shouldDisableMemories ? false : (interfaceConfig?.memories ?? defaults.memories),
prompts: interfaceConfig?.prompts ?? defaults.prompts,

View File

@@ -7,9 +7,9 @@ const {
defaultAssistantsVersion,
defaultAgentCapabilities,
} = require('librechat-data-provider');
const { sendEvent } = require('@librechat/api');
const { Providers } = require('@librechat/agents');
const partialRight = require('lodash/partialRight');
const { sendMessage } = require('./streamResponse');
/** Helper function to escape special characters in regex
* @param {string} string - The string to escape.
@@ -37,7 +37,7 @@ const createOnProgress = (
basePayload.text = basePayload.text + chunk;
const payload = Object.assign({}, basePayload, rest);
sendMessage(res, payload);
sendEvent(res, payload);
if (_onProgress) {
_onProgress(payload);
}
@@ -50,7 +50,7 @@ const createOnProgress = (
const sendIntermediateMessage = (res, payload, extraTokens = '') => {
basePayload.text = basePayload.text + extraTokens;
const message = Object.assign({}, basePayload, payload);
sendMessage(res, message);
sendEvent(res, message);
if (i === 0) {
basePayload.initial = false;
}

View File

@@ -1,11 +1,9 @@
const streamResponse = require('./streamResponse');
const removePorts = require('./removePorts');
const countTokens = require('./countTokens');
const handleText = require('./handleText');
const sendEmail = require('./sendEmail');
const queue = require('./queue');
const files = require('./files');
const math = require('./math');
/**
* Check if email configuration is set
@@ -28,7 +26,6 @@ function checkEmailConfig() {
}
module.exports = {
...streamResponse,
checkEmailConfig,
...handleText,
countTokens,
@@ -36,5 +33,4 @@ module.exports = {
sendEmail,
...files,
...queue,
math,
};

Some files were not shown because too many files have changed in this diff Show More