From 434289fe929f3e7412980b74baa32d365776255f Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Tue, 1 Jul 2025 15:43:10 -0400 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=80=20feat:=20Save=20&=20Submit=20Mess?= =?UTF-8?q?age=20Content=20Parts=20(#8171)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🐛 fix: Enhance provider validation and error handling in getProviderConfig function * WIP: edit text part * refactor: Allow updating of both TEXT and THINK content types in message updates * WIP: first pass, save & submit * chore: remove legacy generation user message field * feat: merge edited content * fix: update placeholder and description for bedrock setting * fix: remove unsupported warning message for AI resubmission --- api/app/clients/BaseClient.js | 85 +++++++++++++++++-- api/server/controllers/agents/request.js | 11 ++- api/server/routes/messages.js | 7 +- api/server/services/Endpoints/index.js | 17 ++++ client/src/common/types.ts | 5 ++ .../Chat/Messages/Content/ContentParts.tsx | 13 ++- .../Messages/Content/Parts/EditTextPart.tsx | 39 ++++----- client/src/hooks/Chat/useChatFunctions.ts | 68 +++++++-------- client/src/hooks/SSE/useStepHandler.ts | 59 +++++++++++-- client/src/locales/en/translation.json | 3 +- packages/data-provider/src/createPayload.ts | 2 + .../data-provider/src/parameterSettings.ts | 4 +- packages/data-provider/src/schemas.ts | 1 + packages/data-provider/src/types.ts | 10 +++ 14 files changed, 240 insertions(+), 84 deletions(-) diff --git a/api/app/clients/BaseClient.js b/api/app/clients/BaseClient.js index c8f4228f1..0598f0da2 100644 --- a/api/app/clients/BaseClient.js +++ b/api/app/clients/BaseClient.js @@ -13,7 +13,6 @@ const { const { getMessages, saveMessage, updateMessage, saveConvo, getConvo } = require('~/models'); const { checkBalance } = require('~/models/balanceMethods'); const { truncateToolCallOutputs } = require('./prompts'); -const { addSpaceIfNeeded } = require('~/server/utils'); const { getFiles } = require('~/models/File'); const TextStream = require('./TextStream'); const { logger } = require('~/config'); @@ -572,7 +571,7 @@ class BaseClient { }); } - const { generation = '' } = opts; + const { editedContent } = opts; // It's not necessary to push to currentMessages // depending on subclass implementation of handling messages @@ -587,11 +586,21 @@ class BaseClient { isCreatedByUser: false, model: this.modelOptions?.model ?? this.model, sender: this.sender, - text: generation, }; this.currentMessages.push(userMessage, latestMessage); - } else { - latestMessage.text = generation; + } else if (editedContent != null) { + // Handle editedContent for content parts + if (editedContent && latestMessage.content && Array.isArray(latestMessage.content)) { + const { index, text, type } = editedContent; + if (index >= 0 && index < latestMessage.content.length) { + const contentPart = latestMessage.content[index]; + if (type === ContentTypes.THINK && contentPart.type === ContentTypes.THINK) { + contentPart[ContentTypes.THINK] = text; + } else if (type === ContentTypes.TEXT && contentPart.type === ContentTypes.TEXT) { + contentPart[ContentTypes.TEXT] = text; + } + } + } } this.continued = true; } else { @@ -672,16 +681,32 @@ class BaseClient { }; if (typeof completion === 'string') { - responseMessage.text = addSpaceIfNeeded(generation) + completion; + responseMessage.text = completion; } else if ( Array.isArray(completion) && (this.clientName === EModelEndpoint.agents || isParamEndpoint(this.options.endpoint, this.options.endpointType)) ) { responseMessage.text = ''; - responseMessage.content = completion; + + if (!opts.editedContent || this.currentMessages.length === 0) { + responseMessage.content = completion; + } else { + const latestMessage = this.currentMessages[this.currentMessages.length - 1]; + if (!latestMessage?.content) { + responseMessage.content = completion; + } else { + const existingContent = [...latestMessage.content]; + const { type: editedType } = opts.editedContent; + responseMessage.content = this.mergeEditedContent( + existingContent, + completion, + editedType, + ); + } + } } else if (Array.isArray(completion)) { - responseMessage.text = addSpaceIfNeeded(generation) + completion.join(''); + responseMessage.text = completion.join(''); } if ( @@ -1095,6 +1120,50 @@ class BaseClient { return numTokens; } + /** + * Merges completion content with existing content when editing TEXT or THINK types + * @param {Array} existingContent - The existing content array + * @param {Array} newCompletion - The new completion content + * @param {string} editedType - The type of content being edited + * @returns {Array} The merged content array + */ + mergeEditedContent(existingContent, newCompletion, editedType) { + if (!newCompletion.length) { + return existingContent.concat(newCompletion); + } + + if (editedType !== ContentTypes.TEXT && editedType !== ContentTypes.THINK) { + return existingContent.concat(newCompletion); + } + + const lastIndex = existingContent.length - 1; + const lastExisting = existingContent[lastIndex]; + const firstNew = newCompletion[0]; + + if (lastExisting?.type !== firstNew?.type || firstNew?.type !== editedType) { + return existingContent.concat(newCompletion); + } + + const mergedContent = [...existingContent]; + if (editedType === ContentTypes.TEXT) { + mergedContent[lastIndex] = { + ...mergedContent[lastIndex], + [ContentTypes.TEXT]: + (mergedContent[lastIndex][ContentTypes.TEXT] || '') + (firstNew[ContentTypes.TEXT] || ''), + }; + } else { + mergedContent[lastIndex] = { + ...mergedContent[lastIndex], + [ContentTypes.THINK]: + (mergedContent[lastIndex][ContentTypes.THINK] || '') + + (firstNew[ContentTypes.THINK] || ''), + }; + } + + // Add remaining completion items + return mergedContent.concat(newCompletion.slice(1)); + } + async sendPayload(payload, opts = {}) { if (opts && typeof opts === 'object') { this.setOptions(opts); diff --git a/api/server/controllers/agents/request.js b/api/server/controllers/agents/request.js index 5d55991e1..2c8e424b5 100644 --- a/api/server/controllers/agents/request.js +++ b/api/server/controllers/agents/request.js @@ -14,8 +14,11 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => { text, endpointOption, conversationId, + isContinued = false, + editedContent = null, parentMessageId = null, overrideParentMessageId = null, + responseMessageId: editedResponseMessageId = null, } = req.body; let sender; @@ -67,7 +70,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => { handler(); } } catch (e) { - // Ignore cleanup errors + logger.error('[AgentController] Error in cleanup handler', e); } } } @@ -155,7 +158,7 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => { try { res.removeListener('close', closeHandler); } catch (e) { - // Ignore + logger.error('[AgentController] Error removing close listener', e); } }); @@ -163,10 +166,14 @@ const AgentController = async (req, res, next, initializeClient, addTitle) => { user: userId, onStart, getReqData, + isContinued, + editedContent, conversationId, parentMessageId, abortController, overrideParentMessageId, + isEdited: !!editedContent, + responseMessageId: editedResponseMessageId, progressOptions: { res, }, diff --git a/api/server/routes/messages.js b/api/server/routes/messages.js index 356dd2509..0a277a1bd 100644 --- a/api/server/routes/messages.js +++ b/api/server/routes/messages.js @@ -235,12 +235,13 @@ router.put('/:conversationId/:messageId', validateMessageReq, async (req, res) = return res.status(400).json({ error: 'Content part not found' }); } - if (updatedContent[index].type !== ContentTypes.TEXT) { + const currentPartType = updatedContent[index].type; + if (currentPartType !== ContentTypes.TEXT && currentPartType !== ContentTypes.THINK) { return res.status(400).json({ error: 'Cannot update non-text content' }); } - const oldText = updatedContent[index].text; - updatedContent[index] = { type: ContentTypes.TEXT, text }; + const oldText = updatedContent[index][currentPartType]; + updatedContent[index] = { type: currentPartType, [currentPartType]: text }; let tokenCount = message.tokenCount; if (tokenCount !== undefined) { diff --git a/api/server/services/Endpoints/index.js b/api/server/services/Endpoints/index.js index b6e398366..817178941 100644 --- a/api/server/services/Endpoints/index.js +++ b/api/server/services/Endpoints/index.js @@ -7,6 +7,16 @@ const initCustom = require('~/server/services/Endpoints/custom/initialize'); const initGoogle = require('~/server/services/Endpoints/google/initialize'); const { getCustomEndpointConfig } = require('~/server/services/Config'); +/** Check if the provider is a known custom provider + * @param {string | undefined} [provider] - The provider string + * @returns {boolean} - True if the provider is a known custom provider, false otherwise + */ +function isKnownCustomProvider(provider) { + return [Providers.XAI, Providers.OLLAMA, Providers.DEEPSEEK, Providers.OPENROUTER].includes( + provider || '', + ); +} + const providerConfigMap = { [Providers.XAI]: initCustom, [Providers.OLLAMA]: initCustom, @@ -46,6 +56,13 @@ async function getProviderConfig(provider) { overrideProvider = Providers.OPENAI; } + if (isKnownCustomProvider(overrideProvider)) { + customEndpointConfig = await getCustomEndpointConfig(provider); + if (!customEndpointConfig) { + throw new Error(`Provider ${provider} not supported`); + } + } + return { getOptions, overrideProvider, diff --git a/client/src/common/types.ts b/client/src/common/types.ts index c7f2d6788..9349b7695 100644 --- a/client/src/common/types.ts +++ b/client/src/common/types.ts @@ -336,6 +336,11 @@ export type TAskProps = { export type TOptions = { editedMessageId?: string | null; editedText?: string | null; + editedContent?: { + index: number; + text: string; + type: 'text' | 'think'; + }; isRegenerate?: boolean; isContinued?: boolean; isEdited?: boolean; diff --git a/client/src/components/Chat/Messages/Content/ContentParts.tsx b/client/src/components/Chat/Messages/Content/ContentParts.tsx index 0a1b4616a..49f6be255 100644 --- a/client/src/components/Chat/Messages/Content/ContentParts.tsx +++ b/client/src/components/Chat/Messages/Content/ContentParts.tsx @@ -81,14 +81,23 @@ const ContentParts = memo( return ( <> {content.map((part, idx) => { - if (part?.type !== ContentTypes.TEXT || typeof part.text !== 'string') { + if (!part) { + return null; + } + const isTextPart = + part?.type === ContentTypes.TEXT || + typeof (part as unknown as Agents.MessageContentText)?.text !== 'string'; + const isThinkPart = + part?.type === ContentTypes.THINK || + typeof (part as unknown as Agents.ReasoningDeltaUpdate)?.think !== 'string'; + if (!isTextPart && !isThinkPart) { return null; } return ( & { +}: Omit & { index: number; messageId: string; + part: Agents.MessageContentText | Agents.ReasoningDeltaUpdate; }) => { const localize = useLocalize(); const { addedIndex } = useAddedChatContext(); - const { getMessages, setMessages, conversation } = useChatContext(); + const { ask, getMessages, setMessages, conversation } = useChatContext(); const [latestMultiMessage, setLatestMultiMessage] = useRecoilState( store.latestMessageFamily(addedIndex), ); @@ -34,15 +36,16 @@ const EditTextPart = ({ [getMessages, messageId], ); + const chatDirection = useRecoilValue(store.chatDirection); + const textAreaRef = useRef(null); const updateMessageContentMutation = useUpdateMessageContentMutation(conversationId ?? ''); - const chatDirection = useRecoilValue(store.chatDirection).toLowerCase(); - const isRTL = chatDirection === 'rtl'; + const isRTL = chatDirection?.toLowerCase() === 'rtl'; const { register, handleSubmit, setValue } = useForm({ defaultValues: { - text: text ?? '', + text: (ContentTypes.THINK in part ? part.think : part.text) || '', }, }); @@ -55,15 +58,7 @@ const EditTextPart = ({ } }, []); - /* - const resubmitMessage = () => { - showToast({ - status: 'warning', - message: localize('com_warning_resubmit_unsupported'), - }); - - // const resubmitMessage = (data: { text: string }) => { - // Not supported by AWS Bedrock + const resubmitMessage = (data: { text: string }) => { const messages = getMessages(); const parentMessage = messages?.find((msg) => msg.messageId === message?.parentMessageId); @@ -73,17 +68,19 @@ const EditTextPart = ({ ask( { ...parentMessage }, { - editedText: data.text, + editedContent: { + index, + text: data.text, + type: part.type, + }, editedMessageId: messageId, isRegenerate: true, isEdited: true, }, ); - setSiblingIdx((siblingIdx ?? 0) - 1); enterEdit(true); }; - */ const updateMessage = (data: { text: string }) => { const messages = getMessages(); @@ -167,13 +164,13 @@ const EditTextPart = ({ />
- {/* */} +