diff --git a/.env.example b/.env.example index 103728fe0..5f2d15742 100644 --- a/.env.example +++ b/.env.example @@ -111,6 +111,23 @@ ANTHROPIC_API_KEY=user_provided BINGAI_TOKEN=user_provided # BINGAI_HOST=https://cn.bing.com +#=================# +# AWS Bedrock # +#=================# + +# BEDROCK_AWS_DEFAULT_REGION=us-east-1 # A default region must be provided +# BEDROCK_AWS_ACCESS_KEY_ID=someAccessKey +# BEDROCK_AWS_SECRET_ACCESS_KEY=someSecretAccessKey + +# Note: This example list is not meant to be exhaustive. If omitted, all known, supported model IDs will be included for you. +# BEDROCK_AWS_MODELS=anthropic.claude-3-5-sonnet-20240620-v1:0,meta.llama3-1-8b-instruct-v1:0 + +# See all Bedrock model IDs here: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html#model-ids-arns + +# Notes on specific models: +# 'ai21.j2-mid-v1', # Not supported, as it doesn't support streaming +# 'ai21.j2-ultra-v1', # Not supported, as it doesn't support conversation history + #============# # Google # #============# diff --git a/api/app/clients/BaseClient.js b/api/app/clients/BaseClient.js index 764038806..972716ee0 100644 --- a/api/app/clients/BaseClient.js +++ b/api/app/clients/BaseClient.js @@ -2,6 +2,7 @@ const crypto = require('crypto'); const fetch = require('node-fetch'); const { supportsBalanceCheck, + isAgentsEndpoint, ErrorTypes, Constants, CacheKeys, @@ -66,6 +67,17 @@ class BaseClient { throw new Error('Subclasses attempted to call summarizeMessages without implementing it'); } + /** + * @returns {string} + */ + getResponseModel() { + if (isAgentsEndpoint(this.options.endpoint) && this.options.agent && this.options.agent.id) { + return this.options.agent.id; + } + + return this.modelOptions.model; + } + /** * Abstract method to get the token count for a message. Subclasses must implement this method. * @param {TMessage} responseMessage @@ -217,6 +229,7 @@ class BaseClient { userMessage, conversationId, responseMessageId, + sender: this.sender, }); } @@ -557,7 +570,7 @@ class BaseClient { parentMessageId: userMessage.messageId, isCreatedByUser: false, isEdited, - model: this.modelOptions.model, + model: this.getResponseModel(), sender: this.sender, promptTokens, iconURL: this.options.iconURL, diff --git a/api/app/clients/prompts/formatMessages.js b/api/app/clients/prompts/formatMessages.js index 87d5ba7a1..a5ed05c64 100644 --- a/api/app/clients/prompts/formatMessages.js +++ b/api/app/clients/prompts/formatMessages.js @@ -170,7 +170,15 @@ const formatAgentMessages = (payload) => { } // Note: `tool_calls` list is defined when constructed by `AIMessage` class, and outputs should be excluded from it - const { output, ...tool_call } = part.tool_call; + const { output, args: _args, ...tool_call } = part.tool_call; + // TODO: investigate; args as dictionary may need to be provider-or-tool-specific + let args = _args; + try { + args = JSON.parse(args); + } catch (e) { + // failed to parse, leave as is + } + tool_call.args = args; lastAIMessage.tool_calls.push(tool_call); // Add the corresponding ToolMessage diff --git a/api/app/clients/specs/BaseClient.test.js b/api/app/clients/specs/BaseClient.test.js index e3cc1515c..0fdc6ce16 100644 --- a/api/app/clients/specs/BaseClient.test.js +++ b/api/app/clients/specs/BaseClient.test.js @@ -565,11 +565,13 @@ describe('BaseClient', () => { const getReqData = jest.fn(); const opts = { getReqData }; const response = await TestClient.sendMessage('Hello, world!', opts); - expect(getReqData).toHaveBeenCalledWith({ - userMessage: expect.objectContaining({ text: 'Hello, world!' }), - conversationId: response.conversationId, - responseMessageId: response.messageId, - }); + expect(getReqData).toHaveBeenCalledWith( + expect.objectContaining({ + userMessage: expect.objectContaining({ text: 'Hello, world!' }), + conversationId: response.conversationId, + responseMessageId: response.messageId, + }), + ); }); test('onStart is called with the correct arguments', async () => { diff --git a/api/models/Action.js b/api/models/Action.js index 7971f3e61..299b3bf20 100644 --- a/api/models/Action.js +++ b/api/models/Action.js @@ -5,17 +5,16 @@ const Action = mongoose.model('action', actionSchema); /** * Update an action with new data without overwriting existing properties, - * or create a new action if it doesn't exist, within a transaction session if provided. + * or create a new action if it doesn't exist. * * @param {Object} searchParams - The search parameters to find the action to update. * @param {string} searchParams.action_id - The ID of the action to update. * @param {string} searchParams.user - The user ID of the action's author. * @param {Object} updateData - An object containing the properties to update. - * @param {mongoose.ClientSession} [session] - The transaction session to use. * @returns {Promise} The updated or newly created action document as a plain object. */ -const updateAction = async (searchParams, updateData, session = null) => { - const options = { new: true, upsert: true, session }; +const updateAction = async (searchParams, updateData) => { + const options = { new: true, upsert: true }; return await Action.findOneAndUpdate(searchParams, updateData, options).lean(); }; @@ -49,31 +48,27 @@ const getActions = async (searchParams, includeSensitive = false) => { }; /** - * Deletes an action by params, within a transaction session if provided. + * Deletes an action by params. * * @param {Object} searchParams - The search parameters to find the action to delete. * @param {string} searchParams.action_id - The ID of the action to delete. * @param {string} searchParams.user - The user ID of the action's author. - * @param {mongoose.ClientSession} [session] - The transaction session to use (optional). * @returns {Promise} A promise that resolves to the deleted action document as a plain object, or null if no document was found. */ -const deleteAction = async (searchParams, session = null) => { - const options = session ? { session } : {}; - return await Action.findOneAndDelete(searchParams, options).lean(); +const deleteAction = async (searchParams) => { + return await Action.findOneAndDelete(searchParams).lean(); }; /** - * Deletes actions by params, within a transaction session if provided. + * Deletes actions by params. * * @param {Object} searchParams - The search parameters to find the actions to delete. * @param {string} searchParams.action_id - The ID of the action(s) to delete. * @param {string} searchParams.user - The user ID of the action's author. - * @param {mongoose.ClientSession} [session] - The transaction session to use (optional). * @returns {Promise} A promise that resolves to the number of deleted action documents. */ -const deleteActions = async (searchParams, session = null) => { - const options = session ? { session } : {}; - const result = await Action.deleteMany(searchParams, options); +const deleteActions = async (searchParams) => { + const result = await Action.deleteMany(searchParams); return result.deletedCount; }; diff --git a/api/models/Agent.js b/api/models/Agent.js index 1ee783b10..2112a4499 100644 --- a/api/models/Agent.js +++ b/api/models/Agent.js @@ -1,4 +1,11 @@ const mongoose = require('mongoose'); +const { GLOBAL_PROJECT_NAME } = require('librechat-data-provider').Constants; +const { + getProjectByName, + addAgentIdsToProject, + removeAgentIdsFromProject, + removeAgentFromAllProjects, +} = require('./Project'); const agentSchema = require('./schema/agent'); const Agent = mongoose.model('agent', agentSchema); @@ -24,18 +31,17 @@ const createAgent = async (agentData) => { const getAgent = async (searchParameter) => await Agent.findOne(searchParameter).lean(); /** - * Update an agent with new data without overwriting existing properties, - * or create a new agent if it doesn't exist, within a transaction session if provided. + * Update an agent with new data without overwriting existing + * properties, or create a new agent if it doesn't exist. * * @param {Object} searchParameter - The search parameters to find the agent to update. * @param {string} searchParameter.id - The ID of the agent to update. - * @param {string} searchParameter.author - The user ID of the agent's author. + * @param {string} [searchParameter.author] - The user ID of the agent's author. * @param {Object} updateData - An object containing the properties to update. - * @param {mongoose.ClientSession} [session] - The transaction session to use (optional). * @returns {Promise} The updated or newly created agent document as a plain object. */ -const updateAgent = async (searchParameter, updateData, session = null) => { - const options = { new: true, upsert: true, session }; +const updateAgent = async (searchParameter, updateData) => { + const options = { new: true, upsert: true }; return await Agent.findOneAndUpdate(searchParameter, updateData, options).lean(); }; @@ -44,11 +50,15 @@ const updateAgent = async (searchParameter, updateData, session = null) => { * * @param {Object} searchParameter - The search parameters to find the agent to delete. * @param {string} searchParameter.id - The ID of the agent to delete. - * @param {string} searchParameter.author - The user ID of the agent's author. + * @param {string} [searchParameter.author] - The user ID of the agent's author. * @returns {Promise} Resolves when the agent has been successfully deleted. */ const deleteAgent = async (searchParameter) => { - return await Agent.findOneAndDelete(searchParameter); + const agent = await Agent.findOneAndDelete(searchParameter); + if (agent) { + await removeAgentFromAllProjects(agent.id); + } + return agent; }; /** @@ -58,11 +68,24 @@ const deleteAgent = async (searchParameter) => { * @returns {Promise} A promise that resolves to an object containing the agents data and pagination info. */ const getListAgents = async (searchParameter) => { - const agents = await Agent.find(searchParameter, { + const { author, ...otherParams } = searchParameter; + + let query = Object.assign({ author }, otherParams); + + const globalProject = await getProjectByName(GLOBAL_PROJECT_NAME, ['agentIds']); + if (globalProject && (globalProject.agentIds?.length ?? 0) > 0) { + const globalQuery = { id: { $in: globalProject.agentIds }, ...otherParams }; + delete globalQuery.author; + query = { $or: [globalQuery, query] }; + } + + const agents = await Agent.find(query, { id: 1, name: 1, avatar: 1, + projectIds: 1, }).lean(); + const hasMore = agents.length > 0; const firstId = agents.length > 0 ? agents[0].id : null; const lastId = agents.length > 0 ? agents[agents.length - 1].id : null; @@ -75,10 +98,45 @@ const getListAgents = async (searchParameter) => { }; }; +/** + * Updates the projects associated with an agent, adding and removing project IDs as specified. + * This function also updates the corresponding projects to include or exclude the agent ID. + * + * @param {string} agentId - The ID of the agent to update. + * @param {string[]} [projectIds] - Array of project IDs to add to the agent. + * @param {string[]} [removeProjectIds] - Array of project IDs to remove from the agent. + * @returns {Promise} The updated agent document. + * @throws {Error} If there's an error updating the agent or projects. + */ +const updateAgentProjects = async (agentId, projectIds, removeProjectIds) => { + const updateOps = {}; + + if (removeProjectIds && removeProjectIds.length > 0) { + for (const projectId of removeProjectIds) { + await removeAgentIdsFromProject(projectId, [agentId]); + } + updateOps.$pull = { projectIds: { $in: removeProjectIds } }; + } + + if (projectIds && projectIds.length > 0) { + for (const projectId of projectIds) { + await addAgentIdsToProject(projectId, [agentId]); + } + updateOps.$addToSet = { projectIds: { $each: projectIds } }; + } + + if (Object.keys(updateOps).length === 0) { + return await getAgent({ id: agentId }); + } + + return await updateAgent({ id: agentId }, updateOps); +}; + module.exports = { createAgent, getAgent, updateAgent, deleteAgent, getListAgents, + updateAgentProjects, }; diff --git a/api/models/Assistant.js b/api/models/Assistant.js index 2c98287a8..d0e73ad4e 100644 --- a/api/models/Assistant.js +++ b/api/models/Assistant.js @@ -5,17 +5,16 @@ const Assistant = mongoose.model('assistant', assistantSchema); /** * Update an assistant with new data without overwriting existing properties, - * or create a new assistant if it doesn't exist, within a transaction session if provided. + * or create a new assistant if it doesn't exist. * * @param {Object} searchParams - The search parameters to find the assistant to update. * @param {string} searchParams.assistant_id - The ID of the assistant to update. * @param {string} searchParams.user - The user ID of the assistant's author. * @param {Object} updateData - An object containing the properties to update. - * @param {mongoose.ClientSession} [session] - The transaction session to use (optional). * @returns {Promise} The updated or newly created assistant document as a plain object. */ -const updateAssistantDoc = async (searchParams, updateData, session = null) => { - const options = { new: true, upsert: true, session }; +const updateAssistantDoc = async (searchParams, updateData) => { + const options = { new: true, upsert: true }; return await Assistant.findOneAndUpdate(searchParams, updateData, options).lean(); }; diff --git a/api/models/Project.js b/api/models/Project.js index e982e34b5..17ef3093a 100644 --- a/api/models/Project.js +++ b/api/models/Project.js @@ -1,4 +1,5 @@ const { model } = require('mongoose'); +const { GLOBAL_PROJECT_NAME } = require('librechat-data-provider').Constants; const projectSchema = require('~/models/schema/projectSchema'); const Project = model('Project', projectSchema); @@ -33,7 +34,7 @@ const getProjectByName = async function (projectName, fieldsToSelect = null) { const update = { $setOnInsert: { name: projectName } }; const options = { new: true, - upsert: projectName === 'instance', + upsert: projectName === GLOBAL_PROJECT_NAME, lean: true, select: fieldsToSelect, }; @@ -81,10 +82,55 @@ const removeGroupFromAllProjects = async (promptGroupId) => { await Project.updateMany({}, { $pull: { promptGroupIds: promptGroupId } }); }; +/** + * Add an array of agent IDs to a project's agentIds array, ensuring uniqueness. + * + * @param {string} projectId - The ID of the project to update. + * @param {string[]} agentIds - The array of agent IDs to add to the project. + * @returns {Promise} The updated project document. + */ +const addAgentIdsToProject = async function (projectId, agentIds) { + return await Project.findByIdAndUpdate( + projectId, + { $addToSet: { agentIds: { $each: agentIds } } }, + { new: true }, + ); +}; + +/** + * Remove an array of agent IDs from a project's agentIds array. + * + * @param {string} projectId - The ID of the project to update. + * @param {string[]} agentIds - The array of agent IDs to remove from the project. + * @returns {Promise} The updated project document. + */ +const removeAgentIdsFromProject = async function (projectId, agentIds) { + return await Project.findByIdAndUpdate( + projectId, + { $pull: { agentIds: { $in: agentIds } } }, + { new: true }, + ); +}; + +/** + * Remove an agent ID from all projects. + * + * @param {string} agentId - The ID of the agent to remove from projects. + * @returns {Promise} + */ +const removeAgentFromAllProjects = async (agentId) => { + await Project.updateMany({}, { $pull: { agentIds: agentId } }); +}; + module.exports = { getProjectById, getProjectByName, + /* prompts */ addGroupIdsToProject, removeGroupIdsFromProject, removeGroupFromAllProjects, + /* agents */ + addAgentIdsToProject, + removeAgentIdsFromProject, + removeAgentFromAllProjects, }; diff --git a/api/models/Prompt.js b/api/models/Prompt.js index 56dcd7857..548589b4d 100644 --- a/api/models/Prompt.js +++ b/api/models/Prompt.js @@ -1,5 +1,5 @@ const { ObjectId } = require('mongodb'); -const { SystemRoles, SystemCategories } = require('librechat-data-provider'); +const { SystemRoles, SystemCategories, Constants } = require('librechat-data-provider'); const { getProjectByName, addGroupIdsToProject, @@ -123,7 +123,7 @@ const getAllPromptGroups = async (req, filter) => { let combinedQuery = query; if (searchShared) { - const project = await getProjectByName('instance', 'promptGroupIds'); + const project = await getProjectByName(Constants.GLOBAL_PROJECT_NAME, 'promptGroupIds'); if (project && project.promptGroupIds.length > 0) { const projectQuery = { _id: { $in: project.promptGroupIds }, ...query }; delete projectQuery.author; @@ -177,7 +177,7 @@ const getPromptGroups = async (req, filter) => { if (searchShared) { // const projects = req.user.projects || []; // TODO: handle multiple projects - const project = await getProjectByName('instance', 'promptGroupIds'); + const project = await getProjectByName(Constants.GLOBAL_PROJECT_NAME, 'promptGroupIds'); if (project && project.promptGroupIds.length > 0) { const projectQuery = { _id: { $in: project.promptGroupIds }, ...query }; delete projectQuery.author; diff --git a/api/models/Role.js b/api/models/Role.js index d21efee3b..62ec8dfe2 100644 --- a/api/models/Role.js +++ b/api/models/Role.js @@ -4,6 +4,7 @@ const { roleDefaults, PermissionTypes, removeNullishValues, + agentPermissionsSchema, promptPermissionsSchema, bookmarkPermissionsSchema, } = require('librechat-data-provider'); @@ -71,6 +72,7 @@ const updateRoleByName = async function (roleName, updates) { }; const permissionSchemas = { + [PermissionTypes.AGENTS]: agentPermissionsSchema, [PermissionTypes.PROMPTS]: promptPermissionsSchema, [PermissionTypes.BOOKMARKS]: bookmarkPermissionsSchema, }; @@ -130,6 +132,7 @@ async function updateAccessPermissions(roleName, permissionsUpdate) { /** * Initialize default roles in the system. * Creates the default roles (ADMIN, USER) if they don't exist in the database. + * Updates existing roles with new permission types if they're missing. * * @returns {Promise} */ @@ -137,14 +140,27 @@ const initializeRoles = async function () { const defaultRoles = [SystemRoles.ADMIN, SystemRoles.USER]; for (const roleName of defaultRoles) { - let role = await Role.findOne({ name: roleName }).select('name').lean(); + let role = await Role.findOne({ name: roleName }); + if (!role) { + // Create new role if it doesn't exist role = new Role(roleDefaults[roleName]); - await role.save(); + } else { + // Add missing permission types + let isUpdated = false; + for (const permType of Object.values(PermissionTypes)) { + if (!role[permType]) { + role[permType] = roleDefaults[roleName][permType]; + isUpdated = true; + } + } + if (isUpdated) { + await role.save(); + } } + await role.save(); } }; - module.exports = { getRoleByName, initializeRoles, diff --git a/api/models/Role.spec.js b/api/models/Role.spec.js index c183b9d1c..753df77e6 100644 --- a/api/models/Role.spec.js +++ b/api/models/Role.spec.js @@ -1,9 +1,14 @@ const mongoose = require('mongoose'); const { MongoMemoryServer } = require('mongodb-memory-server'); -const { SystemRoles, PermissionTypes } = require('librechat-data-provider'); -const Role = require('~/models/schema/roleSchema'); -const { updateAccessPermissions } = require('~/models/Role'); +const { + SystemRoles, + PermissionTypes, + roleDefaults, + Permissions, +} = require('librechat-data-provider'); +const { updateAccessPermissions, initializeRoles } = require('~/models/Role'); const getLogStores = require('~/cache/getLogStores'); +const Role = require('~/models/schema/roleSchema'); // Mock the cache jest.mock('~/cache/getLogStores', () => { @@ -195,3 +200,117 @@ describe('updateAccessPermissions', () => { }); }); }); + +describe('initializeRoles', () => { + beforeEach(async () => { + await Role.deleteMany({}); + }); + + it('should create default roles if they do not exist', async () => { + await initializeRoles(); + + const adminRole = await Role.findOne({ name: SystemRoles.ADMIN }).lean(); + const userRole = await Role.findOne({ name: SystemRoles.USER }).lean(); + + expect(adminRole).toBeTruthy(); + expect(userRole).toBeTruthy(); + + // Check if all permission types exist + Object.values(PermissionTypes).forEach((permType) => { + expect(adminRole[permType]).toBeDefined(); + expect(userRole[permType]).toBeDefined(); + }); + + // Check if permissions match defaults (example for ADMIN role) + expect(adminRole[PermissionTypes.PROMPTS].SHARED_GLOBAL).toBe(true); + expect(adminRole[PermissionTypes.BOOKMARKS].USE).toBe(true); + expect(adminRole[PermissionTypes.AGENTS].CREATE).toBe(true); + }); + + it('should not modify existing permissions for existing roles', async () => { + const customUserRole = { + name: SystemRoles.USER, + [PermissionTypes.PROMPTS]: { + [Permissions.USE]: false, + [Permissions.CREATE]: true, + [Permissions.SHARED_GLOBAL]: true, + }, + [PermissionTypes.BOOKMARKS]: { + [Permissions.USE]: false, + }, + }; + + await new Role(customUserRole).save(); + + await initializeRoles(); + + const userRole = await Role.findOne({ name: SystemRoles.USER }).lean(); + + expect(userRole[PermissionTypes.PROMPTS]).toEqual(customUserRole[PermissionTypes.PROMPTS]); + expect(userRole[PermissionTypes.BOOKMARKS]).toEqual(customUserRole[PermissionTypes.BOOKMARKS]); + expect(userRole[PermissionTypes.AGENTS]).toBeDefined(); + }); + + it('should add new permission types to existing roles', async () => { + const partialUserRole = { + name: SystemRoles.USER, + [PermissionTypes.PROMPTS]: roleDefaults[SystemRoles.USER][PermissionTypes.PROMPTS], + [PermissionTypes.BOOKMARKS]: roleDefaults[SystemRoles.USER][PermissionTypes.BOOKMARKS], + }; + + await new Role(partialUserRole).save(); + + await initializeRoles(); + + const userRole = await Role.findOne({ name: SystemRoles.USER }).lean(); + + expect(userRole[PermissionTypes.AGENTS]).toBeDefined(); + expect(userRole[PermissionTypes.AGENTS].CREATE).toBeDefined(); + expect(userRole[PermissionTypes.AGENTS].USE).toBeDefined(); + expect(userRole[PermissionTypes.AGENTS].SHARED_GLOBAL).toBeDefined(); + }); + + it('should handle multiple runs without duplicating or modifying data', async () => { + await initializeRoles(); + await initializeRoles(); + + const adminRoles = await Role.find({ name: SystemRoles.ADMIN }); + const userRoles = await Role.find({ name: SystemRoles.USER }); + + expect(adminRoles).toHaveLength(1); + expect(userRoles).toHaveLength(1); + + const adminRole = adminRoles[0].toObject(); + const userRole = userRoles[0].toObject(); + + // Check if all permission types exist + Object.values(PermissionTypes).forEach((permType) => { + expect(adminRole[permType]).toBeDefined(); + expect(userRole[permType]).toBeDefined(); + }); + }); + + it('should update roles with missing permission types from roleDefaults', async () => { + const partialAdminRole = { + name: SystemRoles.ADMIN, + [PermissionTypes.PROMPTS]: { + [Permissions.USE]: false, + [Permissions.CREATE]: false, + [Permissions.SHARED_GLOBAL]: false, + }, + [PermissionTypes.BOOKMARKS]: roleDefaults[SystemRoles.ADMIN][PermissionTypes.BOOKMARKS], + }; + + await new Role(partialAdminRole).save(); + + await initializeRoles(); + + const adminRole = await Role.findOne({ name: SystemRoles.ADMIN }).lean(); + + expect(adminRole[PermissionTypes.PROMPTS]).toEqual(partialAdminRole[PermissionTypes.PROMPTS]); + expect(adminRole[PermissionTypes.AGENTS]).toBeDefined(); + expect(adminRole[PermissionTypes.AGENTS].CREATE).toBeDefined(); + expect(adminRole[PermissionTypes.AGENTS].USE).toBeDefined(); + expect(adminRole[PermissionTypes.AGENTS].SHARED_GLOBAL).toBeDefined(); + }); +}); diff --git a/api/models/schema/agent.js b/api/models/schema/agent.js index 97f052791..819398ee7 100644 --- a/api/models/schema/agent.js +++ b/api/models/schema/agent.js @@ -57,6 +57,11 @@ const agentSchema = mongoose.Schema( ref: 'User', required: true, }, + projectIds: { + type: [mongoose.Schema.Types.ObjectId], + ref: 'Project', + index: true, + }, }, { timestamps: true, diff --git a/api/models/schema/defaults.js b/api/models/schema/defaults.js index 4a99a6837..6dced3af8 100644 --- a/api/models/schema/defaults.js +++ b/api/models/schema/defaults.js @@ -13,6 +13,11 @@ const conversationPreset = { type: String, required: false, }, + // for bedrock only + region: { + type: String, + required: false, + }, // for azureOpenAI, openAI only chatGptLabel: { type: String, @@ -78,6 +83,9 @@ const conversationPreset = { promptCache: { type: Boolean, }, + system: { + type: String, + }, // files resendFiles: { type: Boolean, diff --git a/api/models/schema/projectSchema.js b/api/models/schema/projectSchema.js index 0e27c6a8f..dfa68a06c 100644 --- a/api/models/schema/projectSchema.js +++ b/api/models/schema/projectSchema.js @@ -21,6 +21,11 @@ const projectSchema = new Schema( ref: 'PromptGroup', default: [], }, + agentIds: { + type: [String], + ref: 'Agent', + default: [], + }, }, { timestamps: true, diff --git a/api/models/schema/roleSchema.js b/api/models/schema/roleSchema.js index ebd1d0bc4..b0cbeb8c2 100644 --- a/api/models/schema/roleSchema.js +++ b/api/models/schema/roleSchema.js @@ -28,6 +28,20 @@ const roleSchema = new mongoose.Schema({ default: true, }, }, + [PermissionTypes.AGENTS]: { + [Permissions.SHARED_GLOBAL]: { + type: Boolean, + default: false, + }, + [Permissions.USE]: { + type: Boolean, + default: true, + }, + [Permissions.CREATE]: { + type: Boolean, + default: true, + }, + }, }); const Role = mongoose.model('Role', roleSchema); diff --git a/api/models/tx.js b/api/models/tx.js index 1b515cca2..2de31f1b3 100644 --- a/api/models/tx.js +++ b/api/models/tx.js @@ -3,38 +3,28 @@ const defaultRate = 6; /** AWS Bedrock pricing */ const bedrockValues = { - 'anthropic.claude-3-haiku-20240307-v1:0': { prompt: 0.25, completion: 1.25 }, - 'anthropic.claude-3-sonnet-20240229-v1:0': { prompt: 3.0, completion: 15.0 }, - 'anthropic.claude-3-opus-20240229-v1:0': { prompt: 15.0, completion: 75.0 }, - 'anthropic.claude-3-5-sonnet-20240620-v1:0': { prompt: 3.0, completion: 15.0 }, - 'anthropic.claude-v2:1': { prompt: 8.0, completion: 24.0 }, - 'anthropic.claude-instant-v1': { prompt: 0.8, completion: 2.4 }, - 'meta.llama2-13b-chat-v1': { prompt: 0.75, completion: 1.0 }, - 'meta.llama2-70b-chat-v1': { prompt: 1.95, completion: 2.56 }, - 'meta.llama3-8b-instruct-v1:0': { prompt: 0.3, completion: 0.6 }, - 'meta.llama3-70b-instruct-v1:0': { prompt: 2.65, completion: 3.5 }, - 'meta.llama3-1-8b-instruct-v1:0': { prompt: 0.3, completion: 0.6 }, - 'meta.llama3-1-70b-instruct-v1:0': { prompt: 2.65, completion: 3.5 }, - 'meta.llama3-1-405b-instruct-v1:0': { prompt: 5.32, completion: 16.0 }, - 'mistral.mistral-7b-instruct-v0:2': { prompt: 0.15, completion: 0.2 }, - 'mistral.mistral-small-2402-v1:0': { prompt: 0.15, completion: 0.2 }, - 'mistral.mixtral-8x7b-instruct-v0:1': { prompt: 0.45, completion: 0.7 }, - 'mistral.mistral-large-2402-v1:0': { prompt: 4.0, completion: 12.0 }, - 'mistral.mistral-large-2407-v1:0': { prompt: 3.0, completion: 9.0 }, - 'cohere.command-text-v14': { prompt: 1.5, completion: 2.0 }, - 'cohere.command-light-text-v14': { prompt: 0.3, completion: 0.6 }, - 'cohere.command-r-v1:0': { prompt: 0.5, completion: 1.5 }, - 'cohere.command-r-plus-v1:0': { prompt: 3.0, completion: 15.0 }, + 'llama2-13b': { prompt: 0.75, completion: 1.0 }, + 'llama2-70b': { prompt: 1.95, completion: 2.56 }, + 'llama3-8b': { prompt: 0.3, completion: 0.6 }, + 'llama3-70b': { prompt: 2.65, completion: 3.5 }, + 'llama3-1-8b': { prompt: 0.3, completion: 0.6 }, + 'llama3-1-70b': { prompt: 2.65, completion: 3.5 }, + 'llama3-1-405b': { prompt: 5.32, completion: 16.0 }, + 'mistral-7b': { prompt: 0.15, completion: 0.2 }, + 'mistral-small': { prompt: 0.15, completion: 0.2 }, + 'mixtral-8x7b': { prompt: 0.45, completion: 0.7 }, + 'mistral-large-2402': { prompt: 4.0, completion: 12.0 }, + 'mistral-large-2407': { prompt: 3.0, completion: 9.0 }, + 'command-text': { prompt: 1.5, completion: 2.0 }, + 'command-light': { prompt: 0.3, completion: 0.6 }, 'ai21.j2-mid-v1': { prompt: 12.5, completion: 12.5 }, 'ai21.j2-ultra-v1': { prompt: 18.8, completion: 18.8 }, + 'ai21.jamba-instruct-v1:0': { prompt: 0.5, completion: 0.7 }, 'amazon.titan-text-lite-v1': { prompt: 0.15, completion: 0.2 }, 'amazon.titan-text-express-v1': { prompt: 0.2, completion: 0.6 }, + 'amazon.titan-text-premier-v1:0': { prompt: 0.5, completion: 1.5 }, }; -for (const [key, value] of Object.entries(bedrockValues)) { - bedrockValues[`bedrock/${key}`] = value; -} - /** * Mapping of model token sizes to their respective multipliers for prompt and completion. * The rates are 1 USD per 1M tokens. @@ -59,6 +49,7 @@ const tokenValues = Object.assign( 'claude-3-haiku': { prompt: 0.25, completion: 1.25 }, 'claude-2.1': { prompt: 8, completion: 24 }, 'claude-2': { prompt: 8, completion: 24 }, + 'claude-instant': { prompt: 0.8, completion: 2.4 }, 'claude-': { prompt: 0.8, completion: 2.4 }, 'command-r-plus': { prompt: 3, completion: 15 }, 'command-r': { prompt: 0.5, completion: 1.5 }, diff --git a/api/models/tx.spec.js b/api/models/tx.spec.js index f0a118c01..c8a8b335e 100644 --- a/api/models/tx.spec.js +++ b/api/models/tx.spec.js @@ -1,3 +1,4 @@ +const { EModelEndpoint } = require('librechat-data-provider'); const { defaultRate, tokenValues, @@ -224,34 +225,18 @@ describe('AWS Bedrock Model Tests', () => { it('should return the correct prompt multipliers for all models', () => { const results = awsModels.map((model) => { - const multiplier = getMultiplier({ valueKey: model, tokenType: 'prompt' }); - return multiplier === tokenValues[model].prompt; + const valueKey = getValueKey(model, EModelEndpoint.bedrock); + const multiplier = getMultiplier({ valueKey, tokenType: 'prompt' }); + return tokenValues[valueKey].prompt && multiplier === tokenValues[valueKey].prompt; }); expect(results.every(Boolean)).toBe(true); }); it('should return the correct completion multipliers for all models', () => { const results = awsModels.map((model) => { - const multiplier = getMultiplier({ valueKey: model, tokenType: 'completion' }); - return multiplier === tokenValues[model].completion; - }); - expect(results.every(Boolean)).toBe(true); - }); - - it('should return the correct prompt multipliers for all models with Bedrock prefix', () => { - const results = awsModels.map((model) => { - const modelName = `bedrock/${model}`; - const multiplier = getMultiplier({ valueKey: modelName, tokenType: 'prompt' }); - return multiplier === tokenValues[model].prompt; - }); - expect(results.every(Boolean)).toBe(true); - }); - - it('should return the correct completion multipliers for all models with Bedrock prefix', () => { - const results = awsModels.map((model) => { - const modelName = `bedrock/${model}`; - const multiplier = getMultiplier({ valueKey: modelName, tokenType: 'completion' }); - return multiplier === tokenValues[model].completion; + const valueKey = getValueKey(model, EModelEndpoint.bedrock); + const multiplier = getMultiplier({ valueKey, tokenType: 'completion' }); + return tokenValues[valueKey].completion && multiplier === tokenValues[valueKey].completion; }); expect(results.every(Boolean)).toBe(true); }); diff --git a/api/package.json b/api/package.json index 43d8609a8..75df20b2d 100644 --- a/api/package.json +++ b/api/package.json @@ -43,7 +43,7 @@ "@langchain/core": "^0.2.18", "@langchain/google-genai": "^0.0.11", "@langchain/google-vertexai": "^0.0.17", - "@librechat/agents": "^1.4.1", + "@librechat/agents": "^1.5.2", "axios": "^1.3.4", "bcryptjs": "^2.4.3", "cheerio": "^1.0.0-rc.12", diff --git a/api/server/controllers/AskController.js b/api/server/controllers/AskController.js index ce6e0fb17..dd5c8d657 100644 --- a/api/server/controllers/AskController.js +++ b/api/server/controllers/AskController.js @@ -123,11 +123,6 @@ const AskController = async (req, res, next, initializeClient, addTitle) => { }; let response = await client.sendMessage(text, messageOptions); - - if (overrideParentMessageId) { - response.parentMessageId = overrideParentMessageId; - } - response.endpoint = endpointOption.endpoint; const { conversation = {} } = await client.responsePromise; diff --git a/api/server/controllers/EndpointController.js b/api/server/controllers/EndpointController.js index d80ea6b14..1e716870c 100644 --- a/api/server/controllers/EndpointController.js +++ b/api/server/controllers/EndpointController.js @@ -44,6 +44,14 @@ async function endpointController(req, res) { }; } + if (mergedConfig[EModelEndpoint.bedrock] && req.app.locals?.[EModelEndpoint.bedrock]) { + const { availableRegions } = req.app.locals[EModelEndpoint.bedrock]; + mergedConfig[EModelEndpoint.bedrock] = { + ...mergedConfig[EModelEndpoint.bedrock], + availableRegions, + }; + } + const endpointsConfig = orderEndpointsConfig(mergedConfig); await cache.set(CacheKeys.ENDPOINT_CONFIG, endpointsConfig); diff --git a/api/server/controllers/agents/callbacks.js b/api/server/controllers/agents/callbacks.js index 9649f56a5..f6c1972b4 100644 --- a/api/server/controllers/agents/callbacks.js +++ b/api/server/controllers/agents/callbacks.js @@ -1,7 +1,10 @@ const { GraphEvents, ToolEndHandler, ChatModelStreamHandler } = require('@librechat/agents'); +/** @typedef {import('@librechat/agents').Graph} Graph */ /** @typedef {import('@librechat/agents').EventHandler} EventHandler */ +/** @typedef {import('@librechat/agents').ModelEndData} ModelEndData */ /** @typedef {import('@librechat/agents').ChatModelStreamHandler} ChatModelStreamHandler */ +/** @typedef {import('@librechat/agents').ContentAggregatorResult['aggregateContent']} ContentAggregator */ /** @typedef {import('@librechat/agents').GraphEvents} GraphEvents */ /** @@ -18,18 +21,55 @@ const sendEvent = (res, event) => { res.write(`event: message\ndata: ${JSON.stringify(event)}\n\n`); }; +class ModelEndHandler { + /** + * @param {Array} collectedUsage + */ + constructor(collectedUsage) { + if (!Array.isArray(collectedUsage)) { + throw new Error('collectedUsage must be an array'); + } + this.collectedUsage = collectedUsage; + } + + /** + * @param {string} event + * @param {ModelEndData | undefined} data + * @param {Record | undefined} metadata + * @param {Graph} graph + * @returns + */ + handle(event, data, metadata, graph) { + if (!graph || !metadata) { + console.warn(`Graph or metadata not found in ${event} event`); + return; + } + + const usage = data?.output?.usage_metadata; + + if (usage) { + this.collectedUsage.push(usage); + } + } +} + /** * Get default handlers for stream events. - * @param {{ res?: ServerResponse }} options - The options object. + * @param {Object} options - The options object. + * @param {ServerResponse} options.res - The options object. + * @param {ContentAggregator} options.aggregateContent - The options object. + * @param {Array} options.collectedUsage - The list of collected usage metadata. * @returns {Record} The default handlers. * @throws {Error} If the request is not found. */ -function getDefaultHandlers({ res }) { - if (!res) { - throw new Error('Request not found'); +function getDefaultHandlers({ res, aggregateContent, collectedUsage }) { + if (!res || !aggregateContent) { + throw new Error( + `[getDefaultHandlers] Missing required options: res: ${!res}, aggregateContent: ${!aggregateContent}`, + ); } const handlers = { - // [GraphEvents.CHAT_MODEL_END]: new ModelEndHandler(), + [GraphEvents.CHAT_MODEL_END]: new ModelEndHandler(collectedUsage), [GraphEvents.TOOL_END]: new ToolEndHandler(), [GraphEvents.CHAT_MODEL_STREAM]: new ChatModelStreamHandler(), [GraphEvents.ON_RUN_STEP]: { @@ -40,6 +80,7 @@ function getDefaultHandlers({ res }) { */ handle: (event, data) => { sendEvent(res, { event, data }); + aggregateContent({ event, data }); }, }, [GraphEvents.ON_RUN_STEP_DELTA]: { @@ -50,6 +91,7 @@ function getDefaultHandlers({ res }) { */ handle: (event, data) => { sendEvent(res, { event, data }); + aggregateContent({ event, data }); }, }, [GraphEvents.ON_RUN_STEP_COMPLETED]: { @@ -60,6 +102,7 @@ function getDefaultHandlers({ res }) { */ handle: (event, data) => { sendEvent(res, { event, data }); + aggregateContent({ event, data }); }, }, [GraphEvents.ON_MESSAGE_DELTA]: { @@ -70,6 +113,7 @@ function getDefaultHandlers({ res }) { */ handle: (event, data) => { sendEvent(res, { event, data }); + aggregateContent({ event, data }); }, }, }; diff --git a/api/server/controllers/agents/client.js b/api/server/controllers/agents/client.js index 82e6a6f48..137068ddd 100644 --- a/api/server/controllers/agents/client.js +++ b/api/server/controllers/agents/client.js @@ -7,9 +7,11 @@ // validateVisionModel, // mapModelToAzureConfig, // } = require('librechat-data-provider'); -const { Callback } = require('@librechat/agents'); +const { Callback, createMetadataAggregator } = require('@librechat/agents'); const { + Constants, EModelEndpoint, + bedrockOutputParser, providerEndpointMap, removeNullishValues, } = require('librechat-data-provider'); @@ -23,15 +25,27 @@ const { formatAgentMessages, createContextHandlers, } = require('~/app/clients/prompts'); +const { encodeAndFormat } = require('~/server/services/Files/images/encode'); const Tokenizer = require('~/server/services/Tokenizer'); +const { spendTokens } = require('~/models/spendTokens'); const BaseClient = require('~/app/clients/BaseClient'); // const { sleep } = require('~/server/utils'); const { createRun } = require('./run'); const { logger } = require('~/config'); +/** @typedef {import('@librechat/agents').MessageContentComplex} MessageContentComplex */ + +// const providerSchemas = { +// [EModelEndpoint.bedrock]: true, +// }; + +const providerParsers = { + [EModelEndpoint.bedrock]: bedrockOutputParser, +}; + class AgentClient extends BaseClient { constructor(options = {}) { - super(options); + super(null, options); /** @type {'discard' | 'summarize'} */ this.contextStrategy = 'discard'; @@ -39,11 +53,31 @@ class AgentClient extends BaseClient { /** @deprecated @type {true} - Is a Chat Completion Request */ this.isChatCompletion = true; - const { maxContextTokens, modelOptions = {}, ...clientOptions } = options; + /** @type {AgentRun} */ + this.run; + + const { + maxContextTokens, + modelOptions = {}, + contentParts, + collectedUsage, + ...clientOptions + } = options; this.modelOptions = modelOptions; this.maxContextTokens = maxContextTokens; - this.options = Object.assign({ endpoint: EModelEndpoint.agents }, clientOptions); + /** @type {MessageContentComplex[]} */ + this.contentParts = contentParts; + /** @type {Array} */ + this.collectedUsage = collectedUsage; + this.options = Object.assign({ endpoint: options.endpoint }, clientOptions); + } + + /** + * Returns the aggregated content parts for the current run. + * @returns {MessageContentComplex[]} */ + getContentParts() { + return this.contentParts; } setOptions(options) { @@ -112,9 +146,27 @@ class AgentClient extends BaseClient { } getSaveOptions() { + const parseOptions = providerParsers[this.options.endpoint]; + let runOptions = + this.options.endpoint === EModelEndpoint.agents + ? { + model: undefined, + // TODO: + // would need to be override settings; otherwise, model needs to be undefined + // model: this.override.model, + // instructions: this.override.instructions, + // additional_instructions: this.override.additional_instructions, + } + : {}; + + if (parseOptions) { + runOptions = parseOptions(this.modelOptions); + } + return removeNullishValues( Object.assign( { + endpoint: this.options.endpoint, agent_id: this.options.agent.id, modelLabel: this.options.modelLabel, maxContextTokens: this.options.maxContextTokens, @@ -122,15 +174,8 @@ class AgentClient extends BaseClient { imageDetail: this.options.imageDetail, spec: this.options.spec, }, - this.modelOptions, - { - model: undefined, - // TODO: - // would need to be override settings; otherwise, model needs to be undefined - // model: this.override.model, - // instructions: this.override.instructions, - // additional_instructions: this.override.additional_instructions, - }, + // TODO: PARSE OPTIONS BY PROVIDER, MAY CONTAIN SENSITIVE DATA + runOptions, ), ); } @@ -142,6 +187,16 @@ class AgentClient extends BaseClient { }; } + async addImageURLs(message, attachments) { + const { files, image_urls } = await encodeAndFormat( + this.options.req, + attachments, + this.options.agent.provider, + ); + message.image_urls = image_urls.length ? image_urls : undefined; + return files; + } + async buildMessages( messages, parentMessageId, @@ -270,25 +325,34 @@ class AgentClient extends BaseClient { /** @type {sendCompletion} */ async sendCompletion(payload, opts = {}) { this.modelOptions.user = this.user; - return await this.chatCompletion({ + await this.chatCompletion({ payload, onProgress: opts.onProgress, abortController: opts.abortController, }); + return this.contentParts; } - // async recordTokenUsage({ promptTokens, completionTokens, context = 'message' }) { - // await spendTokens( - // { - // context, - // model: this.modelOptions.model, - // conversationId: this.conversationId, - // user: this.user ?? this.options.req.user?.id, - // endpointTokenConfig: this.options.endpointTokenConfig, - // }, - // { promptTokens, completionTokens }, - // ); - // } + /** + * @param {Object} params + * @param {string} [params.model] + * @param {string} [params.context='message'] + * @param {UsageMetadata[]} [params.collectedUsage=this.collectedUsage] + */ + async recordCollectedUsage({ model, context = 'message', collectedUsage = this.collectedUsage }) { + for (const usage of collectedUsage) { + await spendTokens( + { + context, + model: model ?? this.modelOptions.model, + conversationId: this.conversationId, + user: this.user ?? this.options.req.user?.id, + endpointTokenConfig: this.options.endpointTokenConfig, + }, + { promptTokens: usage.input_tokens, completionTokens: usage.output_tokens }, + ); + } + } async chatCompletion({ payload, abortController = null }) { try { @@ -398,9 +462,8 @@ class AgentClient extends BaseClient { // }); // } - // const streamRate = this.options.streamRate ?? Constants.DEFAULT_STREAM_RATE; - const run = await createRun({ + req: this.options.req, agent: this.options.agent, tools: this.options.tools, toolMap: this.options.toolMap, @@ -415,6 +478,7 @@ class AgentClient extends BaseClient { thread_id: this.conversationId, }, run_id: this.responseMessageId, + signal: abortController.signal, streamMode: 'values', version: 'v2', }; @@ -423,8 +487,10 @@ class AgentClient extends BaseClient { throw new Error('Failed to create run'); } + this.run = run; + const messages = formatAgentMessages(payload); - const runMessages = await run.processStream({ messages }, config, { + await run.processStream({ messages }, config, { [Callback.TOOL_ERROR]: (graph, error, toolId) => { logger.error( '[api/server/controllers/agents/client.js #chatCompletion] Tool Error', @@ -433,14 +499,94 @@ class AgentClient extends BaseClient { ); }, }); - // console.dir(runMessages, { depth: null }); - return runMessages; + this.recordCollectedUsage({ context: 'message' }).catch((err) => { + logger.error( + '[api/server/controllers/agents/client.js #chatCompletion] Error recording collected usage', + err, + ); + }); } catch (err) { - logger.error( - '[api/server/controllers/agents/client.js #chatCompletion] Unhandled error type', + if (!abortController.signal.aborted) { + logger.error( + '[api/server/controllers/agents/client.js #sendCompletion] Unhandled error type', + err, + ); + throw err; + } + + logger.warn( + '[api/server/controllers/agents/client.js #sendCompletion] Operation aborted', err, ); - throw err; + } + } + + /** + * + * @param {Object} params + * @param {string} params.text + * @param {string} params.conversationId + */ + async titleConvo({ text }) { + if (!this.run) { + throw new Error('Run not initialized'); + } + const { handleLLMEnd, collected: collectedMetadata } = createMetadataAggregator(); + const clientOptions = {}; + const providerConfig = this.options.req.app.locals[this.options.agent.provider]; + if ( + providerConfig && + providerConfig.titleModel && + providerConfig.titleModel !== Constants.CURRENT_MODEL + ) { + clientOptions.model = providerConfig.titleModel; + } + try { + const titleResult = await this.run.generateTitle({ + inputText: text, + contentParts: this.contentParts, + clientOptions, + chainOptions: { + callbacks: [ + { + handleLLMEnd, + }, + ], + }, + }); + + const collectedUsage = collectedMetadata.map((item) => { + let input_tokens, output_tokens; + + if (item.usage) { + input_tokens = item.usage.input_tokens || item.usage.inputTokens; + output_tokens = item.usage.output_tokens || item.usage.outputTokens; + } else if (item.tokenUsage) { + input_tokens = item.tokenUsage.promptTokens; + output_tokens = item.tokenUsage.completionTokens; + } + + return { + input_tokens: input_tokens, + output_tokens: output_tokens, + }; + }); + + this.recordCollectedUsage({ + model: clientOptions.model, + context: 'title', + collectedUsage, + }).catch((err) => { + logger.error( + '[api/server/controllers/agents/client.js #titleConvo] Error recording collected usage', + err, + ); + }); + + return titleResult.title; + } catch (err) { + logger.error('[api/server/controllers/agents/client.js #titleConvo] Error', err); + return; } } diff --git a/api/server/controllers/agents/demo.js b/api/server/controllers/agents/demo.js deleted file mode 100644 index c90745ba8..000000000 --- a/api/server/controllers/agents/demo.js +++ /dev/null @@ -1,44 +0,0 @@ -// Import the necessary modules -const path = require('path'); -const base = path.resolve(__dirname, '..', '..', '..', '..', 'api'); -console.log(base); -//api/server/controllers/agents/demo.js -require('module-alias')({ base }); -const connectDb = require('~/lib/db/connectDb'); -const AgentClient = require('./client'); - -// Define the user and message options -const user = 'user123'; -const parentMessageId = 'pmid123'; -const conversationId = 'cid456'; -const maxContextTokens = 200000; -const req = { - user: { id: user }, -}; -const progressOptions = { - res: {}, -}; - -// Define the message options -const messageOptions = { - user, - parentMessageId, - conversationId, - progressOptions, -}; - -async function main() { - await connectDb(); - const client = new AgentClient({ req, maxContextTokens }); - - const text = 'Hello, this is a test message.'; - - try { - let response = await client.sendMessage(text, messageOptions); - console.log('Response:', response); - } catch (error) { - console.error('Error sending message:', error); - } -} - -main(); diff --git a/api/server/controllers/agents/request.js b/api/server/controllers/agents/request.js index 648020597..2006d4e6e 100644 --- a/api/server/controllers/agents/request.js +++ b/api/server/controllers/agents/request.js @@ -1,4 +1,4 @@ -const { Constants, getResponseSender } = require('librechat-data-provider'); +const { Constants } = require('librechat-data-provider'); const { createAbortController, handleAbortError } = require('~/server/middleware'); const { sendMessage } = require('~/server/utils'); const { saveMessage } = require('~/models'); @@ -9,22 +9,17 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => { text, endpointOption, conversationId, - modelDisplayLabel, parentMessageId = null, overrideParentMessageId = null, } = req.body; + let sender; let userMessage; - let userMessagePromise; let promptTokens; let userMessageId; let responseMessageId; + let userMessagePromise; - const sender = getResponseSender({ - ...endpointOption, - model: endpointOption.modelOptions.model, - modelDisplayLabel, - }); const newConvo = !conversationId; const user = req.user.id; @@ -39,6 +34,8 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => { responseMessageId = data[key]; } else if (key === 'promptTokens') { promptTokens = data[key]; + } else if (key === 'sender') { + sender = data[key]; } else if (!conversationId && key === 'conversationId') { conversationId = data[key]; } @@ -46,6 +43,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => { }; try { + /** @type {{ client: TAgentClient }} */ const { client } = await initializeClient({ req, res, endpointOption }); const getAbortData = () => ({ @@ -54,8 +52,8 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => { promptTokens, conversationId, userMessagePromise, - // text: getPartialText(), messageId: responseMessageId, + content: client.getContentParts(), parentMessageId: overrideParentMessageId ?? userMessageId, }); @@ -90,11 +88,6 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => { }; let response = await client.sendMessage(text, messageOptions); - - if (overrideParentMessageId) { - response.parentMessageId = overrideParentMessageId; - } - response.endpoint = endpointOption.endpoint; const { conversation = {} } = await client.responsePromise; @@ -103,7 +96,6 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => { if (client.options.attachments) { userMessage.files = client.options.attachments; - conversation.model = endpointOption.modelOptions.model; delete userMessage.image_urls; } diff --git a/api/server/controllers/agents/run.js b/api/server/controllers/agents/run.js index d30d43bd9..5aeefa122 100644 --- a/api/server/controllers/agents/run.js +++ b/api/server/controllers/agents/run.js @@ -1,4 +1,4 @@ -const { Run } = require('@librechat/agents'); +const { Run, Providers } = require('@librechat/agents'); const { providerEndpointMap } = require('librechat-data-provider'); /** @@ -14,11 +14,12 @@ const { providerEndpointMap } = require('librechat-data-provider'); * Creates a new Run instance with custom handlers and configuration. * * @param {Object} options - The options for creating the Run instance. + * @param {ServerRequest} [options.req] - The server request. + * @param {string | undefined} [options.runId] - Optional run ID; otherwise, a new run ID will be generated. * @param {Agent} options.agent - The agent for this run. * @param {StructuredTool[] | undefined} [options.tools] - The tools to use in the run. * @param {Record | undefined} [options.toolMap] - The tool map for the run. * @param {Record | undefined} [options.customHandlers] - Custom event handlers. - * @param {string | undefined} [options.runId] - Optional run ID; otherwise, a new run ID will be generated. * @param {ClientOptions} [options.modelOptions] - Optional model to use; if not provided, it will use the default from modelMap. * @param {boolean} [options.streaming=true] - Whether to use streaming. * @param {boolean} [options.streamUsage=true] - Whether to stream usage information. @@ -43,15 +44,22 @@ async function createRun({ modelOptions, ); + const graphConfig = { + runId, + llmConfig, + tools, + toolMap, + instructions: agent.instructions, + additional_instructions: agent.additional_instructions, + }; + + // TEMPORARY FOR TESTING + if (agent.provider === Providers.ANTHROPIC) { + graphConfig.streamBuffer = 2000; + } + return Run.create({ - graphConfig: { - runId, - llmConfig, - tools, - toolMap, - instructions: agent.instructions, - additional_instructions: agent.additional_instructions, - }, + graphConfig, customHandlers, }); } diff --git a/api/server/controllers/agents/v1.js b/api/server/controllers/agents/v1.js index 2a9911c54..65e37f261 100644 --- a/api/server/controllers/agents/v1.js +++ b/api/server/controllers/agents/v1.js @@ -1,5 +1,5 @@ const { nanoid } = require('nanoid'); -const { FileContext } = require('librechat-data-provider'); +const { FileContext, Constants } = require('librechat-data-provider'); const { getAgent, createAgent, @@ -9,6 +9,8 @@ const { } = require('~/models/Agent'); const { getStrategyFunctions } = require('~/server/services/Files/strategies'); const { uploadImageBuffer } = require('~/server/services/Files/process'); +const { getProjectByName } = require('~/models/Project'); +const { updateAgentProjects } = require('~/models/Agent'); const { deleteFileByFilter } = require('~/models/File'); const { logger } = require('~/config'); @@ -53,16 +55,35 @@ const createAgentHandler = async (req, res) => { * @param {object} req - Express Request * @param {object} req.params - Request params * @param {string} req.params.id - Agent identifier. - * @returns {Agent} 200 - success response - application/json + * @param {object} req.user - Authenticated user information + * @param {string} req.user.id - User ID + * @returns {Promise} 200 - success response - application/json * @returns {Error} 404 - Agent not found */ const getAgentHandler = async (req, res) => { try { const id = req.params.id; - const agent = await getAgent({ id }); + const author = req.user.id; + + let query = { id, author }; + + const globalProject = await getProjectByName(Constants.GLOBAL_PROJECT_NAME, ['agentIds']); + if (globalProject && (globalProject.agentIds?.length ?? 0) > 0) { + query = { + $or: [{ id, $in: globalProject.agentIds }, query], + }; + } + + const agent = await getAgent(query); + if (!agent) { return res.status(404).json({ error: 'Agent not found' }); } + + if (agent.author !== author) { + delete agent.author; + } + return res.status(200).json(agent); } catch (error) { logger.error('[/Agents/:id] Error retrieving agent', error); @@ -82,7 +103,17 @@ const getAgentHandler = async (req, res) => { const updateAgentHandler = async (req, res) => { try { const id = req.params.id; - const updatedAgent = await updateAgent({ id, author: req.user.id }, req.body); + const { projectIds, removeProjectIds, ...updateData } = req.body; + + let updatedAgent; + if (Object.keys(updateData).length > 0) { + updatedAgent = await updateAgent({ id, author: req.user.id }, updateData); + } + + if (projectIds || removeProjectIds) { + updatedAgent = await updateAgentProjects(id, projectIds, removeProjectIds); + } + return res.json(updatedAgent); } catch (error) { logger.error('[/Agents/:id] Error updating Agent', error); @@ -119,13 +150,13 @@ const deleteAgentHandler = async (req, res) => { * @param {object} req - Express Request * @param {object} req.query - Request query * @param {string} [req.query.user] - The user ID of the agent's author. - * @returns {AgentListResponse} 200 - success response - application/json + * @returns {Promise} 200 - success response - application/json */ const getListAgentsHandler = async (req, res) => { try { - const { user } = req.query; - const filter = user ? { author: user } : {}; - const data = await getListAgents(filter); + const data = await getListAgents({ + author: req.user.id, + }); return res.json(data); } catch (error) { logger.error('[/Agents] Error listing Agents', error); diff --git a/api/server/index.js b/api/server/index.js index 3fa577830..47ce354f2 100644 --- a/api/server/index.js +++ b/api/server/index.js @@ -106,6 +106,7 @@ const startServer = async () => { app.use('/api/share', routes.share); app.use('/api/roles', routes.roles); app.use('/api/agents', routes.agents); + app.use('/api/bedrock', routes.bedrock); app.use('/api/tags', routes.tags); diff --git a/api/server/middleware/abortMiddleware.js b/api/server/middleware/abortMiddleware.js index e855c0cb6..7fb84a307 100644 --- a/api/server/middleware/abortMiddleware.js +++ b/api/server/middleware/abortMiddleware.js @@ -107,7 +107,7 @@ const createAbortController = (req, res, getAbortData, getReqData) => { finish_reason: 'incomplete', endpoint: endpointOption.endpoint, iconURL: endpointOption.iconURL, - model: endpointOption.modelOptions.model, + model: endpointOption.modelOptions?.model ?? endpointOption.model_parameters?.model, unfinished: false, error: false, isCreatedByUser: false, diff --git a/api/server/middleware/buildEndpointOption.js b/api/server/middleware/buildEndpointOption.js index 83e06d77c..2b4ba4017 100644 --- a/api/server/middleware/buildEndpointOption.js +++ b/api/server/middleware/buildEndpointOption.js @@ -5,6 +5,7 @@ const assistants = require('~/server/services/Endpoints/assistants'); const gptPlugins = require('~/server/services/Endpoints/gptPlugins'); const { processFiles } = require('~/server/services/Files/process'); const anthropic = require('~/server/services/Endpoints/anthropic'); +const bedrock = require('~/server/services/Endpoints/bedrock'); const openAI = require('~/server/services/Endpoints/openAI'); const agents = require('~/server/services/Endpoints/agents'); const custom = require('~/server/services/Endpoints/custom'); @@ -17,6 +18,7 @@ const buildFunction = { [EModelEndpoint.google]: google.buildOptions, [EModelEndpoint.custom]: custom.buildOptions, [EModelEndpoint.agents]: agents.buildOptions, + [EModelEndpoint.bedrock]: bedrock.buildOptions, [EModelEndpoint.azureOpenAI]: openAI.buildOptions, [EModelEndpoint.anthropic]: anthropic.buildOptions, [EModelEndpoint.gptPlugins]: gptPlugins.buildOptions, diff --git a/api/server/routes/agents/actions.js b/api/server/routes/agents/actions.js index e79f749fc..dde3293b4 100644 --- a/api/server/routes/agents/actions.js +++ b/api/server/routes/agents/actions.js @@ -41,7 +41,7 @@ router.post('/:agent_id', async (req, res) => { return res.status(400).json({ message: 'No functions provided' }); } - let metadata = encryptMetadata(_metadata); + let metadata = await encryptMetadata(_metadata); let { domain } = metadata; domain = await domainParser(req, domain, true); diff --git a/api/server/routes/agents/v1.js b/api/server/routes/agents/v1.js index 1001873fe..d3a3005bd 100644 --- a/api/server/routes/agents/v1.js +++ b/api/server/routes/agents/v1.js @@ -1,11 +1,30 @@ const multer = require('multer'); const express = require('express'); +const { PermissionTypes, Permissions } = require('librechat-data-provider'); +const { requireJwtAuth, generateCheckAccess } = require('~/server/middleware'); const v1 = require('~/server/controllers/agents/v1'); const actions = require('./actions'); const upload = multer(); const router = express.Router(); +const checkAgentAccess = generateCheckAccess(PermissionTypes.AGENTS, [Permissions.USE]); +const checkAgentCreate = generateCheckAccess(PermissionTypes.AGENTS, [ + Permissions.USE, + Permissions.CREATE, +]); + +const checkGlobalAgentShare = generateCheckAccess( + PermissionTypes.AGENTS, + [Permissions.USE, Permissions.CREATE], + { + [Permissions.SHARED_GLOBAL]: ['projectIds', 'removeProjectIds'], + }, +); + +router.use(requireJwtAuth); +router.use(checkAgentAccess); + /** * Agent actions route. * @route GET|POST /agents/actions @@ -27,7 +46,7 @@ router.use('/tools', (req, res) => { * @param {AgentCreateParams} req.body - The agent creation parameters. * @returns {Agent} 201 - Success response - application/json */ -router.post('/', v1.createAgent); +router.post('/', checkAgentCreate, v1.createAgent); /** * Retrieves an agent. @@ -35,7 +54,7 @@ router.post('/', v1.createAgent); * @param {string} req.params.id - Agent identifier. * @returns {Agent} 200 - Success response - application/json */ -router.get('/:id', v1.getAgent); +router.get('/:id', checkAgentAccess, v1.getAgent); /** * Updates an agent. @@ -44,7 +63,7 @@ router.get('/:id', v1.getAgent); * @param {AgentUpdateParams} req.body - The agent update parameters. * @returns {Agent} 200 - Success response - application/json */ -router.patch('/:id', v1.updateAgent); +router.patch('/:id', checkGlobalAgentShare, v1.updateAgent); /** * Deletes an agent. @@ -52,7 +71,7 @@ router.patch('/:id', v1.updateAgent); * @param {string} req.params.id - Agent identifier. * @returns {Agent} 200 - success response - application/json */ -router.delete('/:id', v1.deleteAgent); +router.delete('/:id', checkAgentCreate, v1.deleteAgent); /** * Returns a list of agents. @@ -60,9 +79,7 @@ router.delete('/:id', v1.deleteAgent); * @param {AgentListParams} req.query - The agent list parameters for pagination and sorting. * @returns {AgentListResponse} 200 - success response - application/json */ -router.get('/', v1.getListAgents); - -// TODO: handle private agents +router.get('/', checkAgentAccess, v1.getListAgents); /** * Uploads and updates an avatar for a specific agent. @@ -72,6 +89,6 @@ router.get('/', v1.getListAgents); * @param {string} [req.body.metadata] - Optional metadata for the agent's avatar. * @returns {Object} 200 - success response - application/json */ -router.post('/avatar/:agent_id', upload.single('file'), v1.uploadAgentAvatar); +router.post('/avatar/:agent_id', checkAgentAccess, upload.single('file'), v1.uploadAgentAvatar); module.exports = router; diff --git a/api/server/routes/bedrock/chat.js b/api/server/routes/bedrock/chat.js new file mode 100644 index 000000000..605a01271 --- /dev/null +++ b/api/server/routes/bedrock/chat.js @@ -0,0 +1,36 @@ +const express = require('express'); + +const router = express.Router(); +const { + setHeaders, + handleAbort, + // validateModel, + // validateEndpoint, + buildEndpointOption, +} = require('~/server/middleware'); +const { initializeClient } = require('~/server/services/Endpoints/bedrock'); +const AgentController = require('~/server/controllers/agents/request'); +const addTitle = require('~/server/services/Endpoints/bedrock/title'); + +router.post('/abort', handleAbort()); + +/** + * @route POST / + * @desc Chat with an assistant + * @access Public + * @param {express.Request} req - The request object, containing the request data. + * @param {express.Response} res - The response object, used to send back a response. + * @returns {void} + */ +router.post( + '/', + // validateModel, + // validateEndpoint, + buildEndpointOption, + setHeaders, + async (req, res, next) => { + await AgentController(req, res, next, initializeClient, addTitle); + }, +); + +module.exports = router; diff --git a/api/server/routes/bedrock/index.js b/api/server/routes/bedrock/index.js new file mode 100644 index 000000000..b1a9efec4 --- /dev/null +++ b/api/server/routes/bedrock/index.js @@ -0,0 +1,19 @@ +const express = require('express'); +const router = express.Router(); +const { + uaParser, + checkBan, + requireJwtAuth, + // concurrentLimiter, + // messageIpLimiter, + // messageUserLimiter, +} = require('~/server/middleware'); + +const chat = require('./chat'); + +router.use(requireJwtAuth); +router.use(checkBan); +router.use(uaParser); +router.use('/chat', chat); + +module.exports = router; diff --git a/api/server/routes/config.js b/api/server/routes/config.js index 3fc90c14b..f6669169a 100644 --- a/api/server/routes/config.js +++ b/api/server/routes/config.js @@ -1,5 +1,5 @@ const express = require('express'); -const { CacheKeys, defaultSocialLogins } = require('librechat-data-provider'); +const { CacheKeys, defaultSocialLogins, Constants } = require('librechat-data-provider'); const { getLdapConfig } = require('~/server/services/Config/ldap'); const { getProjectByName } = require('~/models/Project'); const { isEnabled } = require('~/server/utils'); @@ -32,7 +32,7 @@ router.get('/', async function (req, res) { return today.getMonth() === 1 && today.getDate() === 11; }; - const instanceProject = await getProjectByName('instance', '_id'); + const instanceProject = await getProjectByName(Constants.GLOBAL_PROJECT_NAME, '_id'); const ldap = getLdapConfig(); diff --git a/api/server/routes/index.js b/api/server/routes/index.js index 90ba5c73a..3790aacd2 100644 --- a/api/server/routes/index.js +++ b/api/server/routes/index.js @@ -8,6 +8,7 @@ const presets = require('./presets'); const prompts = require('./prompts'); const balance = require('./balance'); const plugins = require('./plugins'); +const bedrock = require('./bedrock'); const search = require('./search'); const models = require('./models'); const convos = require('./convos'); @@ -36,6 +37,7 @@ module.exports = { files, share, agents, + bedrock, convos, search, prompts, diff --git a/api/server/routes/prompts.js b/api/server/routes/prompts.js index 5a6dcafcb..54128d3b3 100644 --- a/api/server/routes/prompts.js +++ b/api/server/routes/prompts.js @@ -24,6 +24,7 @@ const checkPromptCreate = generateCheckAccess(PermissionTypes.PROMPTS, [ Permissions.USE, Permissions.CREATE, ]); + const checkGlobalPromptShare = generateCheckAccess( PermissionTypes.PROMPTS, [Permissions.USE, Permissions.CREATE], diff --git a/api/server/services/ActionService.js b/api/server/services/ActionService.js index 04a9b9829..da69548b4 100644 --- a/api/server/services/ActionService.js +++ b/api/server/services/ActionService.js @@ -165,7 +165,7 @@ async function createActionTool({ action, requestBuilder, zodSchema, name, descr * Encrypts sensitive metadata values for an action. * * @param {ActionMetadata} metadata - The action metadata to encrypt. - * @returns {ActionMetadata} The updated action metadata with encrypted values. + * @returns {Promise} The updated action metadata with encrypted values. */ async function encryptMetadata(metadata) { const encryptedMetadata = { ...metadata }; diff --git a/api/server/services/AppService.js b/api/server/services/AppService.js index eae83bc6e..19a9fc91a 100644 --- a/api/server/services/AppService.js +++ b/api/server/services/AppService.js @@ -94,18 +94,19 @@ const AppService = async (app) => { ); } - if (endpoints?.[EModelEndpoint.openAI]) { - endpointLocals[EModelEndpoint.openAI] = endpoints[EModelEndpoint.openAI]; - } - if (endpoints?.[EModelEndpoint.google]) { - endpointLocals[EModelEndpoint.google] = endpoints[EModelEndpoint.google]; - } - if (endpoints?.[EModelEndpoint.anthropic]) { - endpointLocals[EModelEndpoint.anthropic] = endpoints[EModelEndpoint.anthropic]; - } - if (endpoints?.[EModelEndpoint.gptPlugins]) { - endpointLocals[EModelEndpoint.gptPlugins] = endpoints[EModelEndpoint.gptPlugins]; - } + const endpointKeys = [ + EModelEndpoint.openAI, + EModelEndpoint.google, + EModelEndpoint.bedrock, + EModelEndpoint.anthropic, + EModelEndpoint.gptPlugins, + ]; + + endpointKeys.forEach((key) => { + if (endpoints?.[key]) { + endpointLocals[key] = endpoints[key]; + } + }); app.locals = { ...defaultLocals, diff --git a/api/server/services/Config/EndpointService.js b/api/server/services/Config/EndpointService.js index b2f82f383..485c99f37 100644 --- a/api/server/services/Config/EndpointService.js +++ b/api/server/services/Config/EndpointService.js @@ -45,6 +45,7 @@ module.exports = { AZURE_ASSISTANTS_BASE_URL, EModelEndpoint.azureAssistants, ), + [EModelEndpoint.bedrock]: generateConfig(process.env.BEDROCK_AWS_SECRET_ACCESS_KEY), /* key will be part of separate config */ [EModelEndpoint.agents]: generateConfig(process.env.I_AM_A_TEAPOT), }, diff --git a/api/server/services/Config/loadDefaultEConfig.js b/api/server/services/Config/loadDefaultEConfig.js index df331d92f..c11ddbe9d 100644 --- a/api/server/services/Config/loadDefaultEConfig.js +++ b/api/server/services/Config/loadDefaultEConfig.js @@ -9,22 +9,13 @@ const { config } = require('./EndpointService'); */ async function loadDefaultEndpointsConfig(req) { const { google, gptPlugins } = await loadAsyncEndpoints(req); - const { - openAI, - agents, - assistants, - azureAssistants, - bingAI, - anthropic, - azureOpenAI, - chatGPTBrowser, - } = config; + const { assistants, azureAssistants, bingAI, azureOpenAI, chatGPTBrowser } = config; const enabledEndpoints = getEnabledEndpoints(); const endpointConfig = { - [EModelEndpoint.openAI]: openAI, - [EModelEndpoint.agents]: agents, + [EModelEndpoint.openAI]: config[EModelEndpoint.openAI], + [EModelEndpoint.agents]: config[EModelEndpoint.agents], [EModelEndpoint.assistants]: assistants, [EModelEndpoint.azureAssistants]: azureAssistants, [EModelEndpoint.azureOpenAI]: azureOpenAI, @@ -32,7 +23,8 @@ async function loadDefaultEndpointsConfig(req) { [EModelEndpoint.bingAI]: bingAI, [EModelEndpoint.chatGPTBrowser]: chatGPTBrowser, [EModelEndpoint.gptPlugins]: gptPlugins, - [EModelEndpoint.anthropic]: anthropic, + [EModelEndpoint.anthropic]: config[EModelEndpoint.anthropic], + [EModelEndpoint.bedrock]: config[EModelEndpoint.bedrock], }; const orderedAndFilteredEndpoints = enabledEndpoints.reduce((config, key, index) => { diff --git a/api/server/services/Config/loadDefaultModels.js b/api/server/services/Config/loadDefaultModels.js index e06b73c0c..464e84d44 100644 --- a/api/server/services/Config/loadDefaultModels.js +++ b/api/server/services/Config/loadDefaultModels.js @@ -3,6 +3,7 @@ const { useAzurePlugins } = require('~/server/services/Config/EndpointService'). const { getOpenAIModels, getGoogleModels, + getBedrockModels, getAnthropicModels, getChatGPTBrowserModels, } = require('~/server/services/ModelService'); @@ -38,6 +39,7 @@ async function loadDefaultModels(req) { [EModelEndpoint.chatGPTBrowser]: chatGPTBrowser, [EModelEndpoint.assistants]: assistants, [EModelEndpoint.azureAssistants]: azureAssistants, + [EModelEndpoint.bedrock]: getBedrockModels(), }; } diff --git a/api/server/services/Endpoints/agents/build.js b/api/server/services/Endpoints/agents/build.js index 256901057..d04dee9a0 100644 --- a/api/server/services/Endpoints/agents/build.js +++ b/api/server/services/Endpoints/agents/build.js @@ -2,7 +2,7 @@ const { getAgent } = require('~/models/Agent'); const { logger } = require('~/config'); const buildOptions = (req, endpoint, parsedBody) => { - const { agent_id, instructions, spec, ...rest } = parsedBody; + const { agent_id, instructions, spec, ...model_parameters } = parsedBody; const agentPromise = getAgent({ id: agent_id, @@ -19,9 +19,7 @@ const buildOptions = (req, endpoint, parsedBody) => { agent_id, instructions, spec, - modelOptions: { - ...rest, - }, + model_parameters, }; return endpointOption; diff --git a/api/server/services/Endpoints/agents/initialize.js b/api/server/services/Endpoints/agents/initialize.js index 8627775ce..a079e2145 100644 --- a/api/server/services/Endpoints/agents/initialize.js +++ b/api/server/services/Endpoints/agents/initialize.js @@ -11,7 +11,12 @@ const { z } = require('zod'); const { tool } = require('@langchain/core/tools'); -const { EModelEndpoint, providerEndpointMap } = require('librechat-data-provider'); +const { createContentAggregator } = require('@librechat/agents'); +const { + EModelEndpoint, + providerEndpointMap, + getResponseSender, +} = require('librechat-data-provider'); const { getDefaultHandlers } = require('~/server/controllers/agents/callbacks'); // for testing purposes // const createTavilySearchTool = require('~/app/clients/tools/structured/TavilySearch'); @@ -53,7 +58,8 @@ const initializeClient = async ({ req, res, endpointOption }) => { } // TODO: use endpointOption to determine options/modelOptions - const eventHandlers = getDefaultHandlers({ res }); + const { contentParts, aggregateContent } = createContentAggregator(); + const eventHandlers = getDefaultHandlers({ res, aggregateContent }); // const tools = [createTavilySearchTool()]; // const tools = [_getWeather]; @@ -90,7 +96,7 @@ const initializeClient = async ({ req, res, endpointOption }) => { } // TODO: pass-in override settings that are specific to current run - endpointOption.modelOptions.model = agent.model; + endpointOption.model_parameters.model = agent.model; const options = await getOptions({ req, res, @@ -101,13 +107,21 @@ const initializeClient = async ({ req, res, endpointOption }) => { }); modelOptions = Object.assign(modelOptions, options.llmConfig); + const sender = getResponseSender({ + ...endpointOption, + model: endpointOption.model_parameters.model, + }); + const client = new AgentClient({ req, agent, tools, + sender, toolMap, + contentParts, modelOptions, eventHandlers, + endpoint: EModelEndpoint.agents, configOptions: options.configOptions, maxContextTokens: agent.max_context_tokens ?? diff --git a/api/server/services/Endpoints/anthropic/addTitle.js b/api/server/services/Endpoints/anthropic/addTitle.js index b69c04de6..5c477632d 100644 --- a/api/server/services/Endpoints/anthropic/addTitle.js +++ b/api/server/services/Endpoints/anthropic/addTitle.js @@ -23,7 +23,7 @@ const addTitle = async (req, { text, response, client }) => { const title = await client.titleConvo({ text, - responseText: response?.text, + responseText: response?.text ?? '', conversationId: response.conversationId, }); await titleCache.set(key, title, 120000); diff --git a/api/server/services/Endpoints/bedrock/build.js b/api/server/services/Endpoints/bedrock/build.js new file mode 100644 index 000000000..d6fb0636a --- /dev/null +++ b/api/server/services/Endpoints/bedrock/build.js @@ -0,0 +1,44 @@ +const { removeNullishValues, bedrockInputParser } = require('librechat-data-provider'); +const generateArtifactsPrompt = require('~/app/clients/prompts/artifacts'); +const { logger } = require('~/config'); + +const buildOptions = (endpoint, parsedBody) => { + const { + modelLabel: name, + promptPrefix, + maxContextTokens, + resendFiles = true, + imageDetail, + iconURL, + greeting, + spec, + artifacts, + ...model_parameters + } = parsedBody; + let parsedParams = model_parameters; + try { + parsedParams = bedrockInputParser.parse(model_parameters); + } catch (error) { + logger.warn('Failed to parse bedrock input', error); + } + const endpointOption = removeNullishValues({ + endpoint, + name, + resendFiles, + imageDetail, + iconURL, + greeting, + spec, + promptPrefix, + maxContextTokens, + model_parameters: parsedParams, + }); + + if (typeof artifacts === 'string') { + endpointOption.artifactsPrompt = generateArtifactsPrompt({ endpoint, artifacts }); + } + + return endpointOption; +}; + +module.exports = { buildOptions }; diff --git a/api/server/services/Endpoints/bedrock/index.js b/api/server/services/Endpoints/bedrock/index.js new file mode 100644 index 000000000..8989f7df8 --- /dev/null +++ b/api/server/services/Endpoints/bedrock/index.js @@ -0,0 +1,7 @@ +const build = require('./build'); +const initialize = require('./initialize'); + +module.exports = { + ...build, + ...initialize, +}; diff --git a/api/server/services/Endpoints/bedrock/initialize.js b/api/server/services/Endpoints/bedrock/initialize.js new file mode 100644 index 000000000..db90d5fc8 --- /dev/null +++ b/api/server/services/Endpoints/bedrock/initialize.js @@ -0,0 +1,72 @@ +const { createContentAggregator } = require('@librechat/agents'); +const { + EModelEndpoint, + providerEndpointMap, + getResponseSender, +} = require('librechat-data-provider'); +const { getDefaultHandlers } = require('~/server/controllers/agents/callbacks'); +// const { loadAgentTools } = require('~/server/services/ToolService'); +const getOptions = require('~/server/services/Endpoints/bedrock/options'); +const AgentClient = require('~/server/controllers/agents/client'); +const { getModelMaxTokens } = require('~/utils'); + +const initializeClient = async ({ req, res, endpointOption }) => { + if (!endpointOption) { + throw new Error('Endpoint option not provided'); + } + + /** @type {Array} */ + const collectedUsage = []; + const { contentParts, aggregateContent } = createContentAggregator(); + const eventHandlers = getDefaultHandlers({ res, aggregateContent, collectedUsage }); + + // const tools = [createTavilySearchTool()]; + + /** @type {Agent} */ + const agent = { + id: EModelEndpoint.bedrock, + name: endpointOption.name, + instructions: endpointOption.promptPrefix, + provider: EModelEndpoint.bedrock, + model: endpointOption.model_parameters.model, + model_parameters: endpointOption.model_parameters, + }; + + let modelOptions = { model: agent.model }; + + // TODO: pass-in override settings that are specific to current run + const options = await getOptions({ + req, + res, + endpointOption, + }); + + modelOptions = Object.assign(modelOptions, options.llmConfig); + const maxContextTokens = + agent.max_context_tokens ?? + getModelMaxTokens(modelOptions.model, providerEndpointMap[agent.provider]); + + const sender = getResponseSender({ + ...endpointOption, + model: endpointOption.model_parameters.model, + }); + + const client = new AgentClient({ + req, + agent, + sender, + // tools, + // toolMap, + modelOptions, + contentParts, + eventHandlers, + collectedUsage, + maxContextTokens, + endpoint: EModelEndpoint.bedrock, + configOptions: options.configOptions, + attachments: endpointOption.attachments, + }); + return { client }; +}; + +module.exports = { initializeClient }; diff --git a/api/server/services/Endpoints/bedrock/options.js b/api/server/services/Endpoints/bedrock/options.js new file mode 100644 index 000000000..0839d033c --- /dev/null +++ b/api/server/services/Endpoints/bedrock/options.js @@ -0,0 +1,90 @@ +const { HttpsProxyAgent } = require('https-proxy-agent'); +const { + EModelEndpoint, + Constants, + AuthType, + removeNullishValues, +} = require('librechat-data-provider'); +const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService'); +const { sleep } = require('~/server/utils'); + +const getOptions = async ({ req, endpointOption }) => { + const { + BEDROCK_AWS_SECRET_ACCESS_KEY, + BEDROCK_AWS_ACCESS_KEY_ID, + BEDROCK_REVERSE_PROXY, + BEDROCK_AWS_DEFAULT_REGION, + PROXY, + } = process.env; + const expiresAt = req.body.key; + const isUserProvided = BEDROCK_AWS_SECRET_ACCESS_KEY === AuthType.USER_PROVIDED; + + const credentials = isUserProvided + ? await getUserKey({ userId: req.user.id, name: EModelEndpoint.bedrock }) + : { + accessKeyId: BEDROCK_AWS_ACCESS_KEY_ID, + secretAccessKey: BEDROCK_AWS_SECRET_ACCESS_KEY, + }; + + if (!credentials) { + throw new Error('Bedrock credentials not provided. Please provide them again.'); + } + + if (expiresAt && isUserProvided) { + checkUserKeyExpiry(expiresAt, EModelEndpoint.bedrock); + } + + /** @type {number} */ + let streamRate = Constants.DEFAULT_STREAM_RATE; + + /** @type {undefined | TBaseEndpoint} */ + const bedrockConfig = req.app.locals[EModelEndpoint.bedrock]; + + if (bedrockConfig && bedrockConfig.streamRate) { + streamRate = bedrockConfig.streamRate; + } + + /** @type {undefined | TBaseEndpoint} */ + const allConfig = req.app.locals.all; + if (allConfig && allConfig.streamRate) { + streamRate = allConfig.streamRate; + } + + /** @type {import('@librechat/agents').BedrockConverseClientOptions} */ + const requestOptions = Object.assign( + { + credentials, + model: endpointOption.model, + region: BEDROCK_AWS_DEFAULT_REGION, + streaming: true, + streamUsage: true, + callbacks: [ + { + handleLLMNewToken: async () => { + if (!streamRate) { + return; + } + await sleep(streamRate); + }, + }, + ], + }, + endpointOption.model_parameters, + ); + + const configOptions = {}; + if (PROXY) { + configOptions.httpAgent = new HttpsProxyAgent(PROXY); + } + + if (BEDROCK_REVERSE_PROXY) { + configOptions.endpointHost = BEDROCK_REVERSE_PROXY; + } + + return { + llmConfig: removeNullishValues(requestOptions), + configOptions, + }; +}; + +module.exports = getOptions; diff --git a/api/server/services/Endpoints/bedrock/title.js b/api/server/services/Endpoints/bedrock/title.js new file mode 100644 index 000000000..520b9f78c --- /dev/null +++ b/api/server/services/Endpoints/bedrock/title.js @@ -0,0 +1,40 @@ +const { CacheKeys } = require('librechat-data-provider'); +const getLogStores = require('~/cache/getLogStores'); +const { isEnabled } = require('~/server/utils'); +const { saveConvo } = require('~/models'); + +const addTitle = async (req, { text, response, client }) => { + const { TITLE_CONVO = true } = process.env ?? {}; + if (!isEnabled(TITLE_CONVO)) { + return; + } + + if (client.options.titleConvo === false) { + return; + } + + // If the request was aborted, don't generate the title. + if (client.abortController.signal.aborted) { + return; + } + + const titleCache = getLogStores(CacheKeys.GEN_TITLE); + const key = `${req.user.id}-${response.conversationId}`; + + const title = await client.titleConvo({ + text, + responseText: response?.text ?? '', + conversationId: response.conversationId, + }); + await titleCache.set(key, title, 120000); + await saveConvo( + req, + { + conversationId: response.conversationId, + title, + }, + { context: 'api/server/services/Endpoints/bedrock/title.js' }, + ); +}; + +module.exports = addTitle; diff --git a/api/server/services/Endpoints/google/addTitle.js b/api/server/services/Endpoints/google/addTitle.js index 14eafe841..f21d12321 100644 --- a/api/server/services/Endpoints/google/addTitle.js +++ b/api/server/services/Endpoints/google/addTitle.js @@ -49,7 +49,7 @@ const addTitle = async (req, { text, response, client }) => { const title = await titleClient.titleConvo({ text, - responseText: response?.text, + responseText: response?.text ?? '', conversationId: response.conversationId, }); await titleCache.set(key, title, 120000); diff --git a/api/server/services/Endpoints/openAI/addTitle.js b/api/server/services/Endpoints/openAI/addTitle.js index af886dd22..35291c5e3 100644 --- a/api/server/services/Endpoints/openAI/addTitle.js +++ b/api/server/services/Endpoints/openAI/addTitle.js @@ -23,7 +23,7 @@ const addTitle = async (req, { text, response, client }) => { const title = await client.titleConvo({ text, - responseText: response?.text, + responseText: response?.text ?? '', conversationId: response.conversationId, }); await titleCache.set(key, title, 120000); diff --git a/api/server/services/Files/images/encode.js b/api/server/services/Files/images/encode.js index 4edb0bd56..05c9fc1d3 100644 --- a/api/server/services/Files/images/encode.js +++ b/api/server/services/Files/images/encode.js @@ -23,7 +23,13 @@ async function fetchImageToBase64(url) { } } -const base64Only = new Set([EModelEndpoint.google, EModelEndpoint.anthropic, 'Ollama', 'ollama']); +const base64Only = new Set([ + EModelEndpoint.google, + EModelEndpoint.anthropic, + 'Ollama', + 'ollama', + EModelEndpoint.bedrock, +]); /** * Encodes and formats the given files. diff --git a/api/server/services/ModelService.js b/api/server/services/ModelService.js index b6ca6e4f4..7d2a3ae9e 100644 --- a/api/server/services/ModelService.js +++ b/api/server/services/ModelService.js @@ -5,6 +5,21 @@ const { extractBaseURL, inputSchema, processModelData, logAxiosError } = require const { OllamaClient } = require('~/app/clients/OllamaClient'); const getLogStores = require('~/cache/getLogStores'); +/** + * Splits a string by commas and trims each resulting value. + * @param {string} input - The input string to split. + * @returns {string[]} An array of trimmed values. + */ +const splitAndTrim = (input) => { + if (!input || typeof input !== 'string') { + return []; + } + return input + .split(',') + .map((item) => item.trim()) + .filter(Boolean); +}; + const { openAIApiKey, userProvidedOpenAI } = require('./Config/EndpointService').config; /** @@ -194,7 +209,7 @@ const getOpenAIModels = async (opts) => { } if (process.env[key]) { - models = String(process.env[key]).split(','); + models = splitAndTrim(process.env[key]); return models; } @@ -208,7 +223,7 @@ const getOpenAIModels = async (opts) => { const getChatGPTBrowserModels = () => { let models = ['text-davinci-002-render-sha', 'gpt-4']; if (process.env.CHATGPT_MODELS) { - models = String(process.env.CHATGPT_MODELS).split(','); + models = splitAndTrim(process.env.CHATGPT_MODELS); } return models; @@ -217,7 +232,7 @@ const getChatGPTBrowserModels = () => { const getAnthropicModels = () => { let models = defaultModels[EModelEndpoint.anthropic]; if (process.env.ANTHROPIC_MODELS) { - models = String(process.env.ANTHROPIC_MODELS).split(','); + models = splitAndTrim(process.env.ANTHROPIC_MODELS); } return models; @@ -226,7 +241,16 @@ const getAnthropicModels = () => { const getGoogleModels = () => { let models = defaultModels[EModelEndpoint.google]; if (process.env.GOOGLE_MODELS) { - models = String(process.env.GOOGLE_MODELS).split(','); + models = splitAndTrim(process.env.GOOGLE_MODELS); + } + + return models; +}; + +const getBedrockModels = () => { + let models = defaultModels[EModelEndpoint.bedrock]; + if (process.env.BEDROCK_AWS_MODELS) { + models = splitAndTrim(process.env.BEDROCK_AWS_MODELS); } return models; @@ -234,7 +258,9 @@ const getGoogleModels = () => { module.exports = { fetchModels, + splitAndTrim, getOpenAIModels, + getBedrockModels, getChatGPTBrowserModels, getAnthropicModels, getGoogleModels, diff --git a/api/server/services/ModelService.spec.js b/api/server/services/ModelService.spec.js index fc7c8b107..4e4647ee3 100644 --- a/api/server/services/ModelService.spec.js +++ b/api/server/services/ModelService.spec.js @@ -1,7 +1,16 @@ const axios = require('axios'); +const { EModelEndpoint, defaultModels } = require('librechat-data-provider'); const { logger } = require('~/config'); -const { fetchModels, getOpenAIModels } = require('./ModelService'); +const { + fetchModels, + splitAndTrim, + getOpenAIModels, + getGoogleModels, + getBedrockModels, + getAnthropicModels, +} = require('./ModelService'); + jest.mock('~/utils', () => { const originalUtils = jest.requireActual('~/utils'); return { @@ -329,3 +338,71 @@ describe('fetchModels with Ollama specific logic', () => { ); }); }); + +describe('splitAndTrim', () => { + it('should split a string by commas and trim each value', () => { + const input = ' model1, model2 , model3,model4 '; + const expected = ['model1', 'model2', 'model3', 'model4']; + expect(splitAndTrim(input)).toEqual(expected); + }); + + it('should return an empty array for empty input', () => { + expect(splitAndTrim('')).toEqual([]); + }); + + it('should return an empty array for null input', () => { + expect(splitAndTrim(null)).toEqual([]); + }); + + it('should return an empty array for undefined input', () => { + expect(splitAndTrim(undefined)).toEqual([]); + }); + + it('should filter out empty values after trimming', () => { + const input = 'model1,, ,model2,'; + const expected = ['model1', 'model2']; + expect(splitAndTrim(input)).toEqual(expected); + }); +}); + +describe('getAnthropicModels', () => { + it('returns default models when ANTHROPIC_MODELS is not set', () => { + delete process.env.ANTHROPIC_MODELS; + const models = getAnthropicModels(); + expect(models).toEqual(defaultModels[EModelEndpoint.anthropic]); + }); + + it('returns models from ANTHROPIC_MODELS when set', () => { + process.env.ANTHROPIC_MODELS = 'claude-1, claude-2 '; + const models = getAnthropicModels(); + expect(models).toEqual(['claude-1', 'claude-2']); + }); +}); + +describe('getGoogleModels', () => { + it('returns default models when GOOGLE_MODELS is not set', () => { + delete process.env.GOOGLE_MODELS; + const models = getGoogleModels(); + expect(models).toEqual(defaultModels[EModelEndpoint.google]); + }); + + it('returns models from GOOGLE_MODELS when set', () => { + process.env.GOOGLE_MODELS = 'gemini-pro, bard '; + const models = getGoogleModels(); + expect(models).toEqual(['gemini-pro', 'bard']); + }); +}); + +describe('getBedrockModels', () => { + it('returns default models when BEDROCK_AWS_MODELS is not set', () => { + delete process.env.BEDROCK_AWS_MODELS; + const models = getBedrockModels(); + expect(models).toEqual(defaultModels[EModelEndpoint.bedrock]); + }); + + it('returns models from BEDROCK_AWS_MODELS when set', () => { + process.env.BEDROCK_AWS_MODELS = 'anthropic.claude-v2, ai21.j2-ultra '; + const models = getBedrockModels(); + expect(models).toEqual(['anthropic.claude-v2', 'ai21.j2-ultra']); + }); +}); diff --git a/api/typedefs.js b/api/typedefs.js index 6591d192b..b2548acf1 100644 --- a/api/typedefs.js +++ b/api/typedefs.js @@ -20,12 +20,30 @@ * @memberof typedefs */ +/** + * @exports AgentRun + * @typedef {import('@librechat/agents').Run} AgentRun + * @memberof typedefs + */ + +/** + * @exports IState + * @typedef {import('@librechat/agents').IState} IState + * @memberof typedefs + */ + /** * @exports ClientCallbacks * @typedef {import('@librechat/agents').ClientCallbacks} ClientCallbacks * @memberof typedefs */ +/** + * @exports BedrockClientOptions + * @typedef {import('@librechat/agents').BedrockConverseClientOptions} BedrockClientOptions + * @memberof typedefs + */ + /** * @exports StreamEventData * @typedef {import('@librechat/agents').StreamEventData} StreamEventData @@ -38,6 +56,12 @@ * @memberof typedefs */ +/** + * @exports UsageMetadata + * @typedef {import('@langchain/core/messages').UsageMetadata} UsageMetadata + * @memberof typedefs + */ + /** * @exports Ollama * @typedef {import('ollama').Ollama} Ollama @@ -893,6 +917,12 @@ * @memberof typedefs */ +/** + * @exports TAgentClient + * @typedef {import('./server/controllers/agents/client')} TAgentClient + * @memberof typedefs + */ + /** * @exports ImportBatchBuilder * @typedef {import('./server/utils/import/importBatchBuilder.js').ImportBatchBuilder} ImportBatchBuilder diff --git a/api/utils/tokens.js b/api/utils/tokens.js index 83246c5b7..ec248fe92 100644 --- a/api/utils/tokens.js +++ b/api/utils/tokens.js @@ -7,13 +7,13 @@ const openAIModels = { 'gpt-4-32k': 32758, // -10 from max 'gpt-4-32k-0314': 32758, // -10 from max 'gpt-4-32k-0613': 32758, // -10 from max - 'gpt-4-1106': 127990, // -10 from max - 'gpt-4-0125': 127990, // -10 from max - 'gpt-4o': 127990, // -10 from max - 'gpt-4o-mini': 127990, // -10 from max - 'gpt-4o-2024-08-06': 127990, // -10 from max - 'gpt-4-turbo': 127990, // -10 from max - 'gpt-4-vision': 127990, // -10 from max + 'gpt-4-1106': 127500, // -500 from max + 'gpt-4-0125': 127500, // -500 from max + 'gpt-4o': 127500, // -500 from max + 'gpt-4o-mini': 127500, // -500 from max + 'gpt-4o-2024-08-06': 127500, // -500 from max + 'gpt-4-turbo': 127500, // -500 from max + 'gpt-4-vision': 127500, // -500 from max 'gpt-3.5-turbo': 16375, // -10 from max 'gpt-3.5-turbo-0613': 4092, // -5 from max 'gpt-3.5-turbo-0301': 4092, // -5 from max @@ -21,9 +21,15 @@ const openAIModels = { 'gpt-3.5-turbo-16k-0613': 16375, // -10 from max 'gpt-3.5-turbo-1106': 16375, // -10 from max 'gpt-3.5-turbo-0125': 16375, // -10 from max +}; + +const mistralModels = { 'mistral-': 31990, // -10 from max - llama3: 8187, // -5 from max - 'llama-3': 8187, // -5 from max + 'mistral-7b': 31990, // -10 from max + 'mistral-small': 31990, // -10 from max + 'mixtral-8x7b': 31990, // -10 from max + 'mistral-large-2402': 127500, + 'mistral-large-2407': 127500, }; const cohereModels = { @@ -54,6 +60,7 @@ const googleModels = { const anthropicModels = { 'claude-': 100000, + 'claude-instant': 100000, 'claude-2': 100000, 'claude-2.1': 200000, 'claude-3-haiku': 200000, @@ -63,7 +70,38 @@ const anthropicModels = { 'claude-3.5-sonnet': 200000, }; -const aggregateModels = { ...openAIModels, ...googleModels, ...anthropicModels, ...cohereModels }; +const metaModels = { + 'llama2-13b': 4000, + 'llama2-70b': 4000, + 'llama3-8b': 8000, + 'llama3-70b': 8000, + 'llama3-1-8b': 127500, + 'llama3-1-70b': 127500, + 'llama3-1-405b': 127500, +}; + +const ai21Models = { + 'ai21.j2-mid-v1': 8182, // -10 from max + 'ai21.j2-ultra-v1': 8182, // -10 from max + 'ai21.jamba-instruct-v1:0': 255500, // -500 from max +}; + +const amazonModels = { + 'amazon.titan-text-lite-v1': 4000, + 'amazon.titan-text-express-v1': 8000, + 'amazon.titan-text-premier-v1:0': 31500, // -500 from max +}; + +const bedrockModels = { + ...anthropicModels, + ...mistralModels, + ...cohereModels, + ...metaModels, + ...ai21Models, + ...amazonModels, +}; + +const aggregateModels = { ...openAIModels, ...googleModels, ...bedrockModels }; const maxTokensMap = { [EModelEndpoint.azureOpenAI]: openAIModels, @@ -72,6 +110,7 @@ const maxTokensMap = { [EModelEndpoint.custom]: aggregateModels, [EModelEndpoint.google]: googleModels, [EModelEndpoint.anthropic]: anthropicModels, + [EModelEndpoint.bedrock]: bedrockModels, }; /** diff --git a/api/utils/tokens.spec.js b/api/utils/tokens.spec.js index afcd4b217..e76e01a56 100644 --- a/api/utils/tokens.spec.js +++ b/api/utils/tokens.spec.js @@ -20,18 +20,6 @@ describe('getModelMaxTokens', () => { ); }); - test('should return correct tokens for LLama 3 models', () => { - expect(getModelMaxTokens('meta-llama/llama-3-8b')).toBe( - maxTokensMap[EModelEndpoint.openAI]['llama-3'], - ); - expect(getModelMaxTokens('meta-llama/llama-3-8b')).toBe( - maxTokensMap[EModelEndpoint.openAI]['llama3'], - ); - expect(getModelMaxTokens('llama-3-500b')).toBe(maxTokensMap[EModelEndpoint.openAI]['llama-3']); - expect(getModelMaxTokens('llama3-70b')).toBe(maxTokensMap[EModelEndpoint.openAI]['llama3']); - expect(getModelMaxTokens('llama3:latest')).toBe(maxTokensMap[EModelEndpoint.openAI]['llama3']); - }); - test('should return undefined for no match', () => { expect(getModelMaxTokens('unknown-model')).toBeUndefined(); }); diff --git a/client/src/common/agents-types.ts b/client/src/common/agents-types.ts index eaf64f4c6..07633a68d 100644 --- a/client/src/common/agents-types.ts +++ b/client/src/common/agents-types.ts @@ -1,8 +1,8 @@ import { Capabilities } from 'librechat-data-provider'; import type { Agent, AgentProvider, AgentModelParameters } from 'librechat-data-provider'; -import type { Option, ExtendedFile } from './types'; +import type { OptionWithIcon, ExtendedFile } from './types'; -export type TAgentOption = Option & +export type TAgentOption = OptionWithIcon & Agent & { files?: Array<[string, ExtendedFile]>; code_files?: Array<[string, ExtendedFile]>; @@ -23,5 +23,5 @@ export type AgentForm = { model: string | null; model_parameters: AgentModelParameters; tools?: string[]; - provider?: AgentProvider | Option; + provider?: AgentProvider | OptionWithIcon; } & AgentCapabilities; diff --git a/client/src/components/Chat/Input/HeaderOptions.tsx b/client/src/components/Chat/Input/HeaderOptions.tsx index ae28ae162..652e97128 100644 --- a/client/src/components/Chat/Input/HeaderOptions.tsx +++ b/client/src/components/Chat/Input/HeaderOptions.tsx @@ -2,7 +2,7 @@ import { useRecoilState } from 'recoil'; import { Settings2 } from 'lucide-react'; import { Root, Anchor } from '@radix-ui/react-popover'; import { useState, useEffect, useMemo } from 'react'; -import { tPresetUpdateSchema, EModelEndpoint } from 'librechat-data-provider'; +import { tPresetUpdateSchema, EModelEndpoint, paramEndpoints } from 'librechat-data-provider'; import type { TPreset, TInterfaceConfig } from 'librechat-data-provider'; import { EndpointSettings, SaveAsPresetDialog, AlternativeSettings } from '~/components/Endpoints'; import { ModelSelect } from '~/components/Input/ModelSelect'; @@ -12,7 +12,6 @@ import PopoverButtons from './PopoverButtons'; import { useSetIndexOptions } from '~/hooks'; import { useChatContext } from '~/Providers'; import { Button } from '~/components/ui'; -import { cn, cardStyle } from '~/utils/'; import store from '~/store'; export default function HeaderOptions({ @@ -29,10 +28,10 @@ export default function HeaderOptions({ useChatContext(); const { setOption } = useSetIndexOptions(); - const { endpoint, conversationId, jailbreak } = conversation ?? {}; + const { endpoint, conversationId, jailbreak = false } = conversation ?? {}; const altConditions: { [key: string]: boolean } = { - bingAI: !!(latestMessage && conversation?.jailbreak && endpoint === 'bingAI'), + bingAI: !!(latestMessage && jailbreak && endpoint === 'bingAI'), }; const altSettings: { [key: string]: () => void } = { @@ -74,7 +73,7 @@ export default function HeaderOptions({
- {interfaceConfig?.modelSelect && ( + {interfaceConfig?.modelSelect === true && ( )} - {!noSettings[endpoint] && interfaceConfig?.parameters && ( + {!noSettings[endpoint] && + interfaceConfig?.parameters === true && + !paramEndpoints.has(endpoint) && (
- {interfaceConfig?.parameters && ( + {interfaceConfig?.parameters === true && !paramEndpoints.has(endpoint) && ( } closePopover={() => setShowPopover(false)} > @@ -114,7 +115,7 @@ export default function HeaderOptions({
)} - {interfaceConfig?.presets && ( + {interfaceConfig?.presets === true && ( )} - {interfaceConfig?.parameters && ( + {interfaceConfig?.parameters === true && ( ; }; +const Bedrock = ({ className = '' }: IconMapProps) => { + return ; +}; + export const icons = { [EModelEndpoint.azureOpenAI]: AzureMinimalIcon, [EModelEndpoint.openAI]: GPTIcon, @@ -64,5 +69,6 @@ export const icons = { [EModelEndpoint.assistants]: AssistantAvatar, [EModelEndpoint.azureAssistants]: AssistantAvatar, [EModelEndpoint.agents]: AgentAvatar, + [EModelEndpoint.bedrock]: Bedrock, unknown: UnknownIcon, }; diff --git a/client/src/components/Chat/Messages/Content/Container.tsx b/client/src/components/Chat/Messages/Content/Container.tsx index cbd085e30..ecc40d6cd 100644 --- a/client/src/components/Chat/Messages/Content/Container.tsx +++ b/client/src/components/Chat/Messages/Content/Container.tsx @@ -1,12 +1,12 @@ import { TMessage } from 'librechat-data-provider'; import Files from './Files'; -const Container = ({ children, message }: { children: React.ReactNode; message: TMessage }) => ( +const Container = ({ children, message }: { children: React.ReactNode; message?: TMessage }) => (
- {message.isCreatedByUser && } + {message?.isCreatedByUser === true && } {children}
); diff --git a/client/src/components/Chat/Messages/Content/ContentParts.tsx b/client/src/components/Chat/Messages/Content/ContentParts.tsx index 3227bce07..d73377241 100644 --- a/client/src/components/Chat/Messages/Content/ContentParts.tsx +++ b/client/src/components/Chat/Messages/Content/ContentParts.tsx @@ -1,51 +1,34 @@ -import { Suspense } from 'react'; +import { memo } from 'react'; import type { TMessageContentParts } from 'librechat-data-provider'; -import { UnfinishedMessage } from './MessageContent'; -import { DelayedRender } from '~/components/ui'; import Part from './Part'; -const ContentParts = ({ - error, - unfinished, - isSubmitting, - isLast, - content, - ...props -}: // eslint-disable-next-line @typescript-eslint/no-explicit-any -any) => { - if (error) { - // return ; - } else { - const { message } = props; - const { messageId } = message; +type ContentPartsProps = { + content: Array; + messageId: string; + isCreatedByUser: boolean; + isLast: boolean; + isSubmitting: boolean; +}; +const ContentParts = memo( + ({ content, messageId, isCreatedByUser, isLast, isSubmitting }: ContentPartsProps) => { return ( <> {content - .filter((part: TMessageContentParts | undefined) => part) - .map((part: TMessageContentParts | undefined, idx: number) => { - const showCursor = idx === content.length - 1 && isLast; - return ( - - ); - })} - {/* Temporarily remove this */} - {/* {!isSubmitting && unfinished && ( - - - - - - )} */} + .filter((part) => part) + .map((part, idx) => ( + + ))} ); - } -}; + }, +); export default ContentParts; diff --git a/client/src/components/Chat/Messages/Content/Files.tsx b/client/src/components/Chat/Messages/Content/Files.tsx index beff81b58..09801d92c 100644 --- a/client/src/components/Chat/Messages/Content/Files.tsx +++ b/client/src/components/Chat/Messages/Content/Files.tsx @@ -3,7 +3,7 @@ import type { TFile, TMessage } from 'librechat-data-provider'; import FileContainer from '~/components/Chat/Input/Files/FileContainer'; import Image from './Image'; -const Files = ({ message }: { message: TMessage }) => { +const Files = ({ message }: { message?: TMessage }) => { const imageFiles = useMemo(() => { return message?.files?.filter((file) => file.type?.startsWith('image/')) || []; }, [message?.files]); @@ -20,7 +20,7 @@ const Files = ({ message }: { message: TMessage }) => { imageFiles.map((file) => ( ) => { +}: Pick & { + message?: TMessage; +}) => { const localize = useLocalize(); if (text === 'Error connecting to server, try refreshing the page.') { console.log('error message', message); diff --git a/client/src/components/Chat/Messages/Content/Part.tsx b/client/src/components/Chat/Messages/Content/Part.tsx index 7fcc9dbe4..ca0469cbf 100644 --- a/client/src/components/Chat/Messages/Content/Part.tsx +++ b/client/src/components/Chat/Messages/Content/Part.tsx @@ -4,81 +4,47 @@ import { imageGenTools, isImageVisionTool, } from 'librechat-data-provider'; -import { useMemo } from 'react'; -import type { TMessageContentParts, TMessage } from 'librechat-data-provider'; -import type { TDisplayProps } from '~/common'; +import { memo } from 'react'; +import type { TMessageContentParts } from 'librechat-data-provider'; import { ErrorMessage } from './MessageContent'; -import { useChatContext } from '~/Providers'; import RetrievalCall from './RetrievalCall'; import CodeAnalyze from './CodeAnalyze'; import Container from './Container'; import ToolCall from './ToolCall'; -import Markdown from './Markdown'; import ImageGen from './ImageGen'; -import { cn } from '~/utils'; +import Text from './Parts/Text'; import Image from './Image'; -// Display Message Component -const DisplayMessage = ({ text, isCreatedByUser = false, message, showCursor }: TDisplayProps) => { - const { isSubmitting, latestMessage } = useChatContext(); - const showCursorState = useMemo( - () => showCursor === true && isSubmitting, - [showCursor, isSubmitting], - ); - const isLatestMessage = useMemo( - () => message.messageId === latestMessage?.messageId, - [message.messageId, latestMessage?.messageId], - ); - - // Note: for testing purposes - // isSubmitting && isLatestMessage && logger.log('message_stream', { text, isCreatedByUser, isSubmitting, showCursorState }); - - return ( -
- {!isCreatedByUser ? ( - - ) : ( - <>{text} - )} -
- ); -}; - -export default function Part({ - part, - showCursor, - isSubmitting, - message, -}: { - part: TMessageContentParts | undefined; +type PartProps = { + part?: TMessageContentParts; isSubmitting: boolean; showCursor: boolean; - message: TMessage; -}) { + messageId: string; + isCreatedByUser: boolean; +}; + +const Part = memo(({ part, isSubmitting, showCursor, messageId, isCreatedByUser }: PartProps) => { if (!part) { return null; } if (part.type === ContentTypes.ERROR) { - return ; + return ; } else if (part.type === ContentTypes.TEXT) { const text = typeof part.text === 'string' ? part.text : part.text.value; + if (typeof text !== 'string') { return null; } + if (part.tool_call_ids != null && !text) { + return null; + } return ( - - + @@ -93,7 +59,7 @@ export default function Part({ if ('args' in toolCall && (!toolCall.type || toolCall.type === ToolCallTypes.TOOL_CALL)) { return ( - + @@ -174,4 +140,6 @@ export default function Part({ } return null; -} +}); + +export default Part; diff --git a/client/src/components/Chat/Messages/Content/Parts/Text.tsx b/client/src/components/Chat/Messages/Content/Parts/Text.tsx new file mode 100644 index 000000000..7d0b386c8 --- /dev/null +++ b/client/src/components/Chat/Messages/Content/Parts/Text.tsx @@ -0,0 +1,39 @@ +import { memo, useMemo } from 'react'; +import { useChatContext } from '~/Providers'; +import Markdown from '~/components/Chat/Messages/Content/Markdown'; +import { cn } from '~/utils'; + +type TextPartProps = { + text: string; + isCreatedByUser: boolean; + messageId: string; + showCursor: boolean; +}; + +const TextPart = memo(({ text, isCreatedByUser, messageId, showCursor }: TextPartProps) => { + const { isSubmitting, latestMessage } = useChatContext(); + const showCursorState = useMemo(() => showCursor && isSubmitting, [showCursor, isSubmitting]); + const isLatestMessage = useMemo( + () => messageId === latestMessage?.messageId, + [messageId, latestMessage?.messageId], + ); + + return ( +
+ {!isCreatedByUser ? ( + + ) : ( + <>{text} + )} +
+ ); +}); + +export default TextPart; diff --git a/client/src/components/Chat/Messages/MessageParts.tsx b/client/src/components/Chat/Messages/MessageParts.tsx index 37794c44f..52a2dbe2d 100644 --- a/client/src/components/Chat/Messages/MessageParts.tsx +++ b/client/src/components/Chat/Messages/MessageParts.tsx @@ -1,4 +1,5 @@ import { useRecoilValue } from 'recoil'; +import type { TMessageContentParts } from 'librechat-data-provider'; import type { TMessageProps } from '~/common'; import Icon from '~/components/Chat/Messages/MessageIcon'; import { useMessageHelpers, useLocalize } from '~/hooks'; @@ -17,7 +18,6 @@ export default function Message(props: TMessageProps) { props; const { - ask, edit, index, agent, @@ -33,7 +33,7 @@ export default function Message(props: TMessageProps) { regenerateMessage, } = useMessageHelpers(props); const fontSize = useRecoilValue(store.fontSize); - const { content, children, messageId = null, isCreatedByUser, error, unfinished } = message ?? {}; + const { children, messageId = null, isCreatedByUser } = message ?? {}; if (!message) { return null; @@ -82,24 +82,11 @@ export default function Message(props: TMessageProps) {
} + messageId={message.messageId} + isCreatedByUser={message.isCreatedByUser} isLast={isLast} - content={content ?? []} - message={message} - messageId={messageId} - enterEdit={enterEdit} - error={!!(error ?? false)} isSubmitting={isSubmitting} - unfinished={unfinished ?? false} - isCreatedByUser={isCreatedByUser ?? true} - siblingIdx={siblingIdx ?? 0} - setSiblingIdx={ - setSiblingIdx ?? - (() => { - return; - }) - } />
diff --git a/client/src/components/Chat/Messages/MultiMessage.tsx b/client/src/components/Chat/Messages/MultiMessage.tsx index 0e3204f39..eb2d2e925 100644 --- a/client/src/components/Chat/Messages/MultiMessage.tsx +++ b/client/src/components/Chat/Messages/MultiMessage.tsx @@ -1,10 +1,14 @@ import { useRecoilState } from 'recoil'; import { useEffect, useCallback } from 'react'; +import { isAssistantsEndpoint } from 'librechat-data-provider'; +import type { TMessage } from 'librechat-data-provider'; import type { TMessageProps } from '~/common'; // eslint-disable-next-line import/no-cycle -import Message from './Message'; +import MessageContent from '~/components/Messages/MessageContent'; // eslint-disable-next-line import/no-cycle import MessageParts from './MessageParts'; +// eslint-disable-next-line import/no-cycle +import Message from './Message'; import store from '~/store'; export default function MultiMessage({ @@ -30,22 +34,22 @@ export default function MultiMessage({ }, [messagesTree?.length]); useEffect(() => { - if (messagesTree?.length && siblingIdx >= messagesTree?.length) { + if (messagesTree?.length && siblingIdx >= messagesTree.length) { setSiblingIdx(0); } }, [siblingIdx, messagesTree?.length, setSiblingIdx]); - if (!(messagesTree && messagesTree?.length)) { + if (!(messagesTree && messagesTree.length)) { return null; } - const message = messagesTree[messagesTree.length - siblingIdx - 1]; + const message = messagesTree[messagesTree.length - siblingIdx - 1] as TMessage | undefined; if (!message) { return null; } - if (message.content) { + if (isAssistantsEndpoint(message.endpoint) && message.content) { return ( ); + } else if (message.content) { + return ( + + ); } return ( diff --git a/client/src/components/Endpoints/MessageEndpointIcon.tsx b/client/src/components/Endpoints/MessageEndpointIcon.tsx index fc084bbb5..ad204a7d4 100644 --- a/client/src/components/Endpoints/MessageEndpointIcon.tsx +++ b/client/src/components/Endpoints/MessageEndpointIcon.tsx @@ -1,4 +1,4 @@ -import { EModelEndpoint, isAssistantsEndpoint } from 'librechat-data-provider'; +import { EModelEndpoint, isAssistantsEndpoint, alternateName } from 'librechat-data-provider'; import UnknownIcon from '~/components/Chat/Menus/Endpoints/UnknownIcon'; import { BrainCircuit } from 'lucide-react'; import { @@ -7,6 +7,7 @@ import { PaLMIcon, CodeyIcon, GeminiIcon, + BedrockIcon, AssistantIcon, AnthropicIcon, AzureMinimalIcon, @@ -16,11 +17,31 @@ import { import { IconProps } from '~/common'; import { cn } from '~/utils'; +function getGoogleIcon(model: string | null | undefined, size: number) { + if (model?.toLowerCase().includes('code') === true) { + return ; + } else if (model?.toLowerCase().includes('gemini') === true) { + return ; + } else { + return ; + } +} + +function getGoogleModelName(model: string | null | undefined) { + if (model?.toLowerCase().includes('code') === true) { + return 'Codey'; + } else if (model?.toLowerCase().includes('gemini') === true) { + return 'Gemini'; + } else { + return 'PaLM2'; + } +} + const MessageEndpointIcon: React.FC = (props) => { const { error, button, - iconURL, + iconURL = '', endpoint, jailbreak, size = 30, @@ -30,7 +51,7 @@ const MessageEndpointIcon: React.FC = (props) => { } = props; const assistantsIcon = { - icon: props.iconURL ? ( + icon: iconURL ? (
= (props) => { > {assistantName} @@ -59,7 +80,7 @@ const MessageEndpointIcon: React.FC = (props) => { }; const agentsIcon = { - icon: props.iconURL ? ( + icon: iconURL ? (
= (props) => { > {agentName} @@ -104,42 +125,38 @@ const MessageEndpointIcon: React.FC = (props) => { }, [EModelEndpoint.gptPlugins]: { icon: , - bg: `rgba(69, 89, 164, ${button ? 0.75 : 1})`, + bg: `rgba(69, 89, 164, ${button === true ? 0.75 : 1})`, name: 'Plugins', }, [EModelEndpoint.google]: { - icon: model?.toLowerCase()?.includes('code') ? ( - - ) : model?.toLowerCase()?.includes('gemini') ? ( - - ) : ( - - ), - name: model?.toLowerCase()?.includes('code') - ? 'Codey' - : model?.toLowerCase()?.includes('gemini') - ? 'Gemini' - : 'PaLM2', + icon: getGoogleIcon(model, size), + name: getGoogleModelName(model), }, [EModelEndpoint.anthropic]: { icon: , bg: '#d09a74', name: 'Claude', }, + [EModelEndpoint.bedrock]: { + icon: , + bg: '#268672', + name: alternateName[EModelEndpoint.bedrock], + }, [EModelEndpoint.bingAI]: { - icon: jailbreak ? ( - Bing Icon - ) : ( - Sydney Icon - ), - name: jailbreak ? 'Sydney' : 'BingAI', + icon: + jailbreak === true ? ( + Bing Icon + ) : ( + Sydney Icon + ), + name: jailbreak === true ? 'Sydney' : 'BingAI', }, [EModelEndpoint.chatGPTBrowser]: { icon: , bg: typeof model === 'string' && model.toLowerCase().includes('gpt-4') ? '#AB68FF' - : `rgba(0, 163, 255, ${button ? 0.75 : 1})`, + : `rgba(0, 163, 255, ${button === true ? 0.75 : 1})`, name: 'ChatGPT', }, [EModelEndpoint.custom]: { @@ -152,7 +169,7 @@ const MessageEndpointIcon: React.FC = (props) => {
= (props) => { }} className={cn( 'relative flex h-9 w-9 items-center justify-center rounded-sm p-1 text-white', - props.className || '', + props.className ?? '', )} > {icon} - {error && ( + {error === true && ( ! diff --git a/client/src/components/Endpoints/MinimalIcon.tsx b/client/src/components/Endpoints/MinimalIcon.tsx index 80ab657a8..1f008e263 100644 --- a/client/src/components/Endpoints/MinimalIcon.tsx +++ b/client/src/components/Endpoints/MinimalIcon.tsx @@ -1,4 +1,4 @@ -import { EModelEndpoint } from 'librechat-data-provider'; +import { EModelEndpoint, alternateName } from 'librechat-data-provider'; import { BrainCircuit } from 'lucide-react'; import UnknownIcon from '~/components/Chat/Menus/Endpoints/UnknownIcon'; import { @@ -10,6 +10,7 @@ import { GoogleMinimalIcon, CustomMinimalIcon, AnthropicIcon, + BedrockIcon, Sparkles, } from '~/components/svg'; import { cn } from '~/utils'; @@ -27,17 +28,17 @@ const MinimalIcon: React.FC = (props) => { const endpointIcons = { [EModelEndpoint.azureOpenAI]: { icon: , - name: props.chatGptLabel || 'ChatGPT', + name: props.chatGptLabel ?? 'ChatGPT', }, [EModelEndpoint.openAI]: { icon: , - name: props.chatGptLabel || 'ChatGPT', + name: props.chatGptLabel ?? 'ChatGPT', }, [EModelEndpoint.gptPlugins]: { icon: , name: 'Plugins' }, - [EModelEndpoint.google]: { icon: , name: props.modelLabel || 'Google' }, + [EModelEndpoint.google]: { icon: , name: props.modelLabel ?? 'Google' }, [EModelEndpoint.anthropic]: { icon: , - name: props.modelLabel || 'Claude', + name: props.modelLabel ?? 'Claude', }, [EModelEndpoint.custom]: { icon: , @@ -47,7 +48,14 @@ const MinimalIcon: React.FC = (props) => { [EModelEndpoint.chatGPTBrowser]: { icon: , name: 'ChatGPT' }, [EModelEndpoint.assistants]: { icon: , name: 'Assistant' }, [EModelEndpoint.azureAssistants]: { icon: , name: 'Assistant' }, - [EModelEndpoint.agents]: { icon: , name: 'Agent' }, + [EModelEndpoint.agents]: { + icon: , + name: props.modelLabel ?? alternateName[EModelEndpoint.agents], + }, + [EModelEndpoint.bedrock]: { + icon: , + name: props.modelLabel ?? alternateName[EModelEndpoint.bedrock], + }, default: { icon: ( = (props) => { }} className={cn( 'relative flex items-center justify-center rounded-sm text-black dark:text-white', - props.className || '', + props.className ?? '', )} > {icon} - {error && ( + {error === true && ( ! diff --git a/client/src/components/Endpoints/SaveAsPresetDialog.tsx b/client/src/components/Endpoints/SaveAsPresetDialog.tsx index 46abdb980..f9a290968 100644 --- a/client/src/components/Endpoints/SaveAsPresetDialog.tsx +++ b/client/src/components/Endpoints/SaveAsPresetDialog.tsx @@ -2,14 +2,14 @@ import React, { useEffect, useState } from 'react'; import { useCreatePresetMutation } from 'librechat-data-provider/react-query'; import type { TEditPresetProps } from '~/common'; import { cn, removeFocusOutlines, cleanupPreset, defaultTextProps } from '~/utils/'; -import DialogTemplate from '~/components/ui/DialogTemplate'; -import { Dialog, Input, Label } from '~/components/ui/'; +import OGDialogTemplate from '~/components/ui/OGDialogTemplate'; +import { OGDialog, Input, Label } from '~/components/ui/'; import { NotificationSeverity } from '~/common'; import { useToastContext } from '~/Providers'; import { useLocalize } from '~/hooks'; const SaveAsPresetDialog = ({ open, onOpenChange, preset }: TEditPresetProps) => { - const [title, setTitle] = useState(preset.title || 'My Preset'); + const [title, setTitle] = useState(preset.title ?? 'My Preset'); const createPresetMutation = useCreatePresetMutation(); const { showToast } = useToastContext(); const localize = useLocalize(); @@ -22,15 +22,15 @@ const SaveAsPresetDialog = ({ open, onOpenChange, preset }: TEditPresetProps) => }, }); - const toastTitle = _preset.title - ? `\`${_preset.title}\`` - : localize('com_endpoint_preset_title'); + const toastTitle = + _preset.title ?? '' ? `\`${_preset.title}\`` : localize('com_endpoint_preset_title'); createPresetMutation.mutate(_preset, { onSuccess: () => { showToast({ message: `${toastTitle} ${localize('com_endpoint_preset_saved')}`, }); + onOpenChange(false); // Close the dialog on success }, onError: () => { showToast({ @@ -42,27 +42,38 @@ const SaveAsPresetDialog = ({ open, onOpenChange, preset }: TEditPresetProps) => }; useEffect(() => { - setTitle(preset.title || localize('com_endpoint_my_preset')); + setTitle(preset.title ?? localize('com_endpoint_my_preset')); // eslint-disable-next-line react-hooks/exhaustive-deps }, [open]); + // Handle Enter key press + const handleKeyDown = (event: React.KeyboardEvent) => { + if (event.key === 'Enter') { + event.preventDefault(); + submitPreset(); + } + }; + return ( - - +
-
+ ); }; diff --git a/client/src/components/Endpoints/Settings/Bedrock.tsx b/client/src/components/Endpoints/Settings/Bedrock.tsx new file mode 100644 index 000000000..1e88bf326 --- /dev/null +++ b/client/src/components/Endpoints/Settings/Bedrock.tsx @@ -0,0 +1,59 @@ +import { useMemo } from 'react'; +import { getSettingsKeys } from 'librechat-data-provider'; +import type { SettingDefinition } from 'librechat-data-provider'; +import type { TModelSelectProps } from '~/common'; +import { componentMapping } from '~/components/SidePanel/Parameters/components'; +import { presetSettings } from '~/components/SidePanel/Parameters/settings'; + +export default function BedrockSettings({ + conversation, + setOption, + models, + readonly, +}: TModelSelectProps) { + const parameters = useMemo(() => { + const [combinedKey, endpointKey] = getSettingsKeys( + conversation?.endpoint ?? '', + conversation?.model ?? '', + ); + return presetSettings[combinedKey] ?? presetSettings[endpointKey]; + }, [conversation]); + + if (!parameters) { + return null; + } + + const renderComponent = (setting: SettingDefinition) => { + const Component = componentMapping[setting.component]; + const { key, default: defaultValue, ...rest } = setting; + + const props = { + key, + settingKey: key, + defaultValue, + ...rest, + readonly, + setOption, + conversation, + }; + + if (key === 'model') { + return ; + } + + return ; + }; + + return ( +
+
+
+ {parameters.col1.map(renderComponent)} +
+
+ {parameters.col2.map(renderComponent)} +
+
+
+ ); +} diff --git a/client/src/components/Endpoints/Settings/OpenAI.tsx b/client/src/components/Endpoints/Settings/OpenAI.tsx index a95e7823c..506276358 100644 --- a/client/src/components/Endpoints/Settings/OpenAI.tsx +++ b/client/src/components/Endpoints/Settings/OpenAI.tsx @@ -18,8 +18,7 @@ import { HoverCardTrigger, } from '~/components/ui'; import { cn, defaultTextProps, optionText, removeFocusOutlines, removeFocusRings } from '~/utils'; -import OptionHoverAlt from '~/components/SidePanel/Parameters/OptionHover'; -import { DynamicTags } from '~/components/SidePanel/Parameters'; +import { OptionHoverAlt, DynamicTags } from '~/components/SidePanel/Parameters'; import { useLocalize, useDebouncedInput } from '~/hooks'; import OptionHover from './OptionHover'; import { ESide } from '~/common'; diff --git a/client/src/components/Endpoints/Settings/index.ts b/client/src/components/Endpoints/Settings/index.ts index 7d525d8a5..92af1a60b 100644 --- a/client/src/components/Endpoints/Settings/index.ts +++ b/client/src/components/Endpoints/Settings/index.ts @@ -1,5 +1,6 @@ export { default as Advanced } from './Advanced'; export { default as AssistantsSettings } from './Assistants'; +export { default as BedrockSettings } from './Bedrock'; export { default as OpenAISettings } from './OpenAI'; export { default as BingAISettings } from './BingAI'; export { default as GoogleSettings } from './Google'; diff --git a/client/src/components/Endpoints/Settings/settings.ts b/client/src/components/Endpoints/Settings/settings.ts index bcbaab8e8..d58917df8 100644 --- a/client/src/components/Endpoints/Settings/settings.ts +++ b/client/src/components/Endpoints/Settings/settings.ts @@ -4,6 +4,7 @@ import type { TModelSelectProps } from '~/common'; import { GoogleSettings, PluginSettings } from './MultiView'; import AssistantsSettings from './Assistants'; import AnthropicSettings from './Anthropic'; +import BedrockSettings from './Bedrock'; import BingAISettings from './BingAI'; import OpenAISettings from './OpenAI'; @@ -16,6 +17,7 @@ const settings: { [key: string]: FC } = { [EModelEndpoint.azureOpenAI]: OpenAISettings, [EModelEndpoint.bingAI]: BingAISettings, [EModelEndpoint.anthropic]: AnthropicSettings, + [EModelEndpoint.bedrock]: BedrockSettings, }; export const getSettings = () => { diff --git a/client/src/components/Input/ModelSelect/options.ts b/client/src/components/Input/ModelSelect/options.ts index 0159f0782..24c820485 100644 --- a/client/src/components/Input/ModelSelect/options.ts +++ b/client/src/components/Input/ModelSelect/options.ts @@ -12,6 +12,7 @@ import PluginsByIndex from './PluginsByIndex'; export const options: { [key: string]: FC } = { [EModelEndpoint.openAI]: OpenAI, [EModelEndpoint.custom]: OpenAI, + [EModelEndpoint.bedrock]: OpenAI, [EModelEndpoint.azureOpenAI]: OpenAI, [EModelEndpoint.bingAI]: BingAI, [EModelEndpoint.google]: Google, diff --git a/client/src/components/Messages/ContentRender.tsx b/client/src/components/Messages/ContentRender.tsx new file mode 100644 index 000000000..f341efbfd --- /dev/null +++ b/client/src/components/Messages/ContentRender.tsx @@ -0,0 +1,170 @@ +import { useRecoilValue } from 'recoil'; +import { useCallback, useMemo, memo } from 'react'; +import type { TMessage, TMessageContentParts } from 'librechat-data-provider'; +import type { TMessageProps } from '~/common'; +import ContentParts from '~/components/Chat/Messages/Content/ContentParts'; +import PlaceholderRow from '~/components/Chat/Messages/ui/PlaceholderRow'; +import SiblingSwitch from '~/components/Chat/Messages/SiblingSwitch'; +import HoverButtons from '~/components/Chat/Messages/HoverButtons'; +import Icon from '~/components/Chat/Messages/MessageIcon'; +import SubRow from '~/components/Chat/Messages/SubRow'; +import { useMessageActions } from '~/hooks'; +import { cn, logger } from '~/utils'; +import store from '~/store'; + +type ContentRenderProps = { + message?: TMessage; + isCard?: boolean; + isMultiMessage?: boolean; + isSubmittingFamily?: boolean; +} & Pick< + TMessageProps, + 'currentEditId' | 'setCurrentEditId' | 'siblingIdx' | 'setSiblingIdx' | 'siblingCount' +>; + +const ContentRender = memo( + ({ + isCard, + siblingIdx, + siblingCount, + message: msg, + setSiblingIdx, + currentEditId, + isMultiMessage, + setCurrentEditId, + isSubmittingFamily, + }: ContentRenderProps) => { + const { + // ask, + edit, + index, + agent, + assistant, + enterEdit, + conversation, + messageLabel, + isSubmitting, + latestMessage, + handleContinue, + copyToClipboard, + setLatestMessage, + regenerateMessage, + } = useMessageActions({ + message: msg, + currentEditId, + isMultiMessage, + setCurrentEditId, + }); + + const fontSize = useRecoilValue(store.fontSize); + const handleRegenerateMessage = useCallback(() => regenerateMessage(), [regenerateMessage]); + // const { isCreatedByUser, error, unfinished } = msg ?? {}; + const isLast = useMemo( + () => + !(msg?.children?.length ?? 0) && (msg?.depth === latestMessage?.depth || msg?.depth === -1), + [msg?.children, msg?.depth, latestMessage?.depth], + ); + + if (!msg) { + return null; + } + + const isLatestMessage = msg.messageId === latestMessage?.messageId; + const showCardRender = isLast && !(isSubmittingFamily === true) && isCard === true; + const isLatestCard = isCard === true && !(isSubmittingFamily === true) && isLatestMessage; + const clickHandler = + showCardRender && !isLatestMessage + ? () => { + logger.log(`Message Card click: Setting ${msg.messageId} as latest message`); + logger.dir(msg); + setLatestMessage(msg); + } + : undefined; + + return ( +
{ + if ((e.key === 'Enter' || e.key === ' ') && clickHandler) { + clickHandler(); + } + }} + role={showCardRender ? 'button' : undefined} + tabIndex={showCardRender ? 0 : undefined} + > + {isLatestCard === true && ( +
+ )} +
+
+
+
+ +
+
+
+
+
+

{messageLabel}

+
+
+ } + messageId={msg.messageId} + isCreatedByUser={msg.isCreatedByUser} + isLast={isLast} + isSubmitting={isSubmitting} + /> +
+
+ {!(msg.children?.length ?? 0) && (isSubmittingFamily === true || isSubmitting) ? ( + + ) : ( + + + + + )} +
+
+ ); + }, +); + +export default ContentRender; diff --git a/client/src/components/Messages/MessageContent.tsx b/client/src/components/Messages/MessageContent.tsx new file mode 100644 index 000000000..e472b1e8d --- /dev/null +++ b/client/src/components/Messages/MessageContent.tsx @@ -0,0 +1,82 @@ +import React from 'react'; +import { useMessageProcess } from '~/hooks'; +import type { TMessageProps } from '~/common'; +// eslint-disable-next-line import/no-cycle +import MultiMessage from '~/components/Chat/Messages/MultiMessage'; +import ContentRender from './ContentRender'; + +const MessageContainer = React.memo( + ({ + handleScroll, + children, + }: { + handleScroll: (event?: unknown) => void; + children: React.ReactNode; + }) => { + return ( +
+ {children} +
+ ); + }, +); + +export default function MessageContent(props: TMessageProps) { + const { + showSibling, + conversation, + handleScroll, + siblingMessage, + latestMultiMessage, + isSubmittingFamily, + } = useMessageProcess({ message: props.message }); + const { message, currentEditId, setCurrentEditId } = props; + + if (!message || typeof message !== 'object') { + return null; + } + + const { children, messageId = null } = message; + + return ( + <> + + {showSibling ? ( +
+
+ + +
+
+ ) : ( +
+ +
+ )} +
+ + + ); +} diff --git a/client/src/components/Prompts/Groups/ListCard.tsx b/client/src/components/Prompts/Groups/ListCard.tsx index 2730bdc27..c17faed8b 100644 --- a/client/src/components/Prompts/Groups/ListCard.tsx +++ b/client/src/components/Prompts/Groups/ListCard.tsx @@ -1,3 +1,4 @@ +import React from 'react'; import CategoryIcon from '~/components/Prompts/Groups/CategoryIcon'; export default function ListCard({ @@ -10,19 +11,33 @@ export default function ListCard({ category: string; name: string; snippet: string; - onClick?: React.MouseEventHandler; + onClick?: React.MouseEventHandler; children?: React.ReactNode; }) { + const handleKeyDown = (event: React.KeyboardEvent) => { + if (event.key === 'Enter' || event.key === ' ') { + event.preventDefault(); + onClick?.(event as unknown as React.MouseEvent); + } + }; + return ( - +
); } diff --git a/client/src/components/Prompts/SharePrompt.tsx b/client/src/components/Prompts/SharePrompt.tsx index 696b6a949..9a12553c4 100644 --- a/client/src/components/Prompts/SharePrompt.tsx +++ b/client/src/components/Prompts/SharePrompt.tsx @@ -31,7 +31,7 @@ const SharePrompt = ({ group, disabled }: { group?: TPromptGroup; disabled: bool const { data: startupConfig = {} as TStartupConfig, isFetching } = useGetStartupConfig(); const { instanceProjectId } = startupConfig; const groupIsGlobal = useMemo( - () => !!group?.projectIds?.includes(instanceProjectId), + () => !!(group?.projectIds ?? []).includes(instanceProjectId), [group, instanceProjectId], ); @@ -57,7 +57,8 @@ const SharePrompt = ({ group, disabled }: { group?: TPromptGroup; disabled: bool } const onSubmit = (data: FormValues) => { - if (!group._id || !instanceProjectId) { + const groupId = group._id ?? ''; + if (!groupId || !instanceProjectId) { return; } @@ -70,7 +71,7 @@ const SharePrompt = ({ group, disabled }: { group?: TPromptGroup; disabled: bool } updateGroup.mutate({ - id: group._id, + id: groupId, payload, }); }; @@ -87,24 +88,38 @@ const SharePrompt = ({ group, disabled }: { group?: TPromptGroup; disabled: bool - + {localize('com_ui_share_var', `"${group.name}"`)}
- +
+ + +
)} /> diff --git a/client/src/components/SidePanel/Agents/AgentConfig.tsx b/client/src/components/SidePanel/Agents/AgentConfig.tsx index e8395ba2b..a35aa8e2e 100644 --- a/client/src/components/SidePanel/Agents/AgentConfig.tsx +++ b/client/src/components/SidePanel/Agents/AgentConfig.tsx @@ -11,9 +11,10 @@ import Action from '~/components/SidePanel/Builder/Action'; import { useLocalize } from '~/hooks'; import { ToolSelectDialog } from '~/components/Tools'; import { useToastContext } from '~/Providers'; -import ContextButton from './ContextButton'; import { Spinner } from '~/components/svg'; +import DeleteButton from './DeleteButton'; import AgentAvatar from './AgentAvatar'; +import ShareAgent from './ShareAgent'; import AgentTool from './AgentTool'; import { Panel } from '~/common'; @@ -57,14 +58,14 @@ export default function AgentConfig({ () => agentsConfig?.capabilities?.includes(Capabilities.actions), [agentsConfig], ); - const retrievalEnabled = useMemo( - () => agentsConfig?.capabilities?.includes(Capabilities.retrieval), - [agentsConfig], - ); - const codeEnabled = useMemo( - () => agentsConfig?.capabilities?.includes(Capabilities.code_interpreter), - [agentsConfig], - ); + // const retrievalEnabled = useMemo( + // () => agentsConfig?.capabilities?.includes(Capabilities.retrieval), + // [agentsConfig], + // ); + // const codeEnabled = useMemo( + // () => agentsConfig?.capabilities?.includes(Capabilities.code_interpreter), + // [agentsConfig], + // ); /* Mutations */ const update = useUpdateAgentMutation({ @@ -190,7 +191,7 @@ export default function AgentConfig({ name="id" control={control} render={({ field }) => ( -

+

{field.value}

)} @@ -221,12 +222,11 @@ export default function AgentConfig({ {/* Instructions */}
( <>