From a2fd975cd5dac6e59adc46a384ccd01b71fa6026 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Fri, 28 Jun 2024 08:44:47 -0400 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=A4=20refactor:=20Optimize=20Request?= =?UTF-8?q?=20Lifecycle=20Speeds=20(#3222)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor: optimize backend operations for client requests * fix: message styling * refactor: Improve handleKeyUp logic in StreamRunManager.js and handleText.js * refactor: Improve handleKeyUp logic in StreamRunManager.js and handleText.js * fix: clear new convo messages on clear all convos * fix: forgot to pass userId to getConvo * refactor: update getPartialText to send basePayload.text --- api/app/clients/BaseClient.js | 27 ++++++++------ api/app/clients/PluginsClient.js | 9 +++-- api/models/Conversation.js | 2 +- api/models/Message.js | 15 +++----- api/server/controllers/AskController.js | 8 +++-- api/server/controllers/EditController.js | 8 +++-- api/server/middleware/abortMiddleware.js | 19 +++++++--- api/server/middleware/validateMessageReq.js | 2 +- api/server/routes/ask/gptPlugins.js | 14 ++++++-- api/server/routes/edit/gptPlugins.js | 14 ++++++-- api/server/services/Runs/StreamRunManager.js | 2 +- api/server/utils/handleText.js | 36 ++++++++++--------- api/typedefs.js | 7 ++++ .../Chat/Messages/ui/MessageRender.tsx | 2 +- .../Chat/Messages/ui/PlaceholderRow.tsx | 5 +-- client/src/components/Nav/ClearConvos.tsx | 6 ++-- .../hooks/Conversations/useConversation.ts | 4 +++ 17 files changed, 115 insertions(+), 65 deletions(-) diff --git a/api/app/clients/BaseClient.js b/api/app/clients/BaseClient.js index d335272e9..c7b4f977c 100644 --- a/api/app/clients/BaseClient.js +++ b/api/app/clients/BaseClient.js @@ -1,7 +1,7 @@ const crypto = require('crypto'); const fetch = require('node-fetch'); const { supportsBalanceCheck, Constants } = require('librechat-data-provider'); -const { getConvo, getMessages, saveMessage, updateMessage, saveConvo } = require('~/models'); +const { getMessages, saveMessage, updateMessage, saveConvo } = require('~/models'); const { addSpaceIfNeeded, isEnabled } = require('~/server/utils'); const checkBalance = require('~/models/checkBalance'); const { getFiles } = require('~/models/File'); @@ -23,6 +23,10 @@ class BaseClient { this.skipSaveConvo = false; /** @type {boolean} */ this.skipSaveUserMessage = false; + /** @type {ClientDatabaseSavePromise} */ + this.userMessagePromise; + /** @type {ClientDatabaseSavePromise} */ + this.responsePromise; } setOptions() { @@ -481,7 +485,12 @@ class BaseClient { } if (!isEdited && !this.skipSaveUserMessage) { - await this.saveMessageToDatabase(userMessage, saveOptions, user); + this.userMessagePromise = this.saveMessageToDatabase(userMessage, saveOptions, user); + if (typeof opts?.getReqData === 'function') { + opts.getReqData({ + userMessagePromise: this.userMessagePromise, + }); + } } if ( @@ -530,15 +539,11 @@ class BaseClient { const completionTokens = this.getTokenCount(completion); await this.recordTokenUsage({ promptTokens, completionTokens }); } - await this.saveMessageToDatabase(responseMessage, saveOptions, user); + this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user); delete responseMessage.tokenCount; return responseMessage; } - async getConversation(conversationId, user = null) { - return await getConvo(user, conversationId); - } - async loadHistory(conversationId, parentMessageId = null) { logger.debug('[BaseClient] Loading history:', { conversationId, parentMessageId }); @@ -593,7 +598,7 @@ class BaseClient { * @param {string | null} user */ async saveMessageToDatabase(message, endpointOptions, user = null) { - await saveMessage({ + const savedMessage = await saveMessage({ ...message, endpoint: this.options.endpoint, unfinished: false, @@ -601,14 +606,16 @@ class BaseClient { }); if (this.skipSaveConvo) { - return; + return { message: savedMessage }; } - await saveConvo(user, { + const conversation = await saveConvo(user, { conversationId: message.conversationId, endpoint: this.options.endpoint, endpointType: this.options.endpointType, ...endpointOptions, }); + + return { message: savedMessage, conversation }; } async updateMessageInDatabase(message) { diff --git a/api/app/clients/PluginsClient.js b/api/app/clients/PluginsClient.js index 123890dfb..86931c449 100644 --- a/api/app/clients/PluginsClient.js +++ b/api/app/clients/PluginsClient.js @@ -238,7 +238,7 @@ class PluginsClient extends OpenAIClient { await this.recordTokenUsage(responseMessage); } - await this.saveMessageToDatabase(responseMessage, saveOptions, user); + this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user); delete responseMessage.tokenCount; return { ...responseMessage, ...result }; } @@ -303,7 +303,12 @@ class PluginsClient extends OpenAIClient { } if (!this.skipSaveUserMessage) { - await this.saveMessageToDatabase(userMessage, saveOptions, user); + this.userMessagePromise = this.saveMessageToDatabase(userMessage, saveOptions, user); + if (typeof opts?.getReqData === 'function') { + opts.getReqData({ + userMessagePromise: this.userMessagePromise, + }); + } } if (isEnabled(process.env.CHECK_BALANCE)) { diff --git a/api/models/Conversation.js b/api/models/Conversation.js index 1cd1b0aa9..c554c9da7 100644 --- a/api/models/Conversation.js +++ b/api/models/Conversation.js @@ -30,7 +30,7 @@ module.exports = { return await Conversation.findOneAndUpdate({ conversationId: conversationId, user }, update, { new: true, upsert: true, - }); + }).lean(); } catch (error) { logger.error('[saveConvo] Error saving conversation', error); return { message: 'Error saving conversation' }; diff --git a/api/models/Message.js b/api/models/Message.js index f86849fe9..c04bb3c7e 100644 --- a/api/models/Message.js +++ b/api/models/Message.js @@ -57,18 +57,11 @@ module.exports = { if (files) { update.files = files; } - // may also need to update the conversation here - await Message.findOneAndUpdate({ messageId }, update, { upsert: true, new: true }); - return { - messageId, - conversationId, - parentMessageId, - sender, - text, - isCreatedByUser, - tokenCount, - }; + return await Message.findOneAndUpdate({ messageId }, update, { + upsert: true, + new: true, + }).lean(); } catch (err) { logger.error('Error saving message:', err); throw new Error('Failed to save message.'); diff --git a/api/server/controllers/AskController.js b/api/server/controllers/AskController.js index f6da23692..81b6f9396 100644 --- a/api/server/controllers/AskController.js +++ b/api/server/controllers/AskController.js @@ -2,7 +2,7 @@ const throttle = require('lodash/throttle'); const { getResponseSender, Constants, EModelEndpoint } = require('librechat-data-provider'); const { createAbortController, handleAbortError } = require('~/server/middleware'); const { sendMessage, createOnProgress } = require('~/server/utils'); -const { saveMessage, getConvo } = require('~/models'); +const { saveMessage } = require('~/models'); const { logger } = require('~/config'); const AskController = async (req, res, next, initializeClient, addTitle) => { @@ -18,6 +18,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => { logger.debug('[AskController]', { text, conversationId, ...endpointOption }); let userMessage; + let userMessagePromise; let promptTokens; let userMessageId; let responseMessageId; @@ -34,6 +35,8 @@ const AskController = async (req, res, next, initializeClient, addTitle) => { if (key === 'userMessage') { userMessage = data[key]; userMessageId = data[key].messageId; + } else if (key === 'userMessagePromise') { + userMessagePromise = data[key]; } else if (key === 'responseMessageId') { responseMessageId = data[key]; } else if (key === 'promptTokens') { @@ -74,6 +77,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => { const getAbortData = () => ({ sender, conversationId, + userMessagePromise, messageId: responseMessageId, parentMessageId: overrideParentMessageId ?? userMessageId, text: getPartialText(), @@ -121,7 +125,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => { response.endpoint = endpointOption.endpoint; - const conversation = await getConvo(user, conversationId); + const { conversation = {} } = await client.responsePromise; conversation.title = conversation && !conversation.title ? null : conversation?.title || 'New Chat'; diff --git a/api/server/controllers/EditController.js b/api/server/controllers/EditController.js index 5a2d71d1b..269a80ed0 100644 --- a/api/server/controllers/EditController.js +++ b/api/server/controllers/EditController.js @@ -2,7 +2,7 @@ const throttle = require('lodash/throttle'); const { getResponseSender, EModelEndpoint } = require('librechat-data-provider'); const { createAbortController, handleAbortError } = require('~/server/middleware'); const { sendMessage, createOnProgress } = require('~/server/utils'); -const { saveMessage, getConvo } = require('~/models'); +const { saveMessage } = require('~/models'); const { logger } = require('~/config'); const EditController = async (req, res, next, initializeClient) => { @@ -27,6 +27,7 @@ const EditController = async (req, res, next, initializeClient) => { }); let userMessage; + let userMessagePromise; let promptTokens; const sender = getResponseSender({ ...endpointOption, @@ -40,6 +41,8 @@ const EditController = async (req, res, next, initializeClient) => { for (let key in data) { if (key === 'userMessage') { userMessage = data[key]; + } else if (key === 'userMessagePromise') { + userMessagePromise = data[key]; } else if (key === 'responseMessageId') { responseMessageId = data[key]; } else if (key === 'promptTokens') { @@ -73,6 +76,7 @@ const EditController = async (req, res, next, initializeClient) => { const getAbortData = () => ({ conversationId, + userMessagePromise, messageId: responseMessageId, sender, parentMessageId: overrideParentMessageId ?? userMessageId, @@ -120,7 +124,7 @@ const EditController = async (req, res, next, initializeClient) => { }, }); - const conversation = await getConvo(user, conversationId); + const { conversation = {} } = await client.responsePromise; conversation.title = conversation && !conversation.title ? null : conversation?.title || 'New Chat'; diff --git a/api/server/middleware/abortMiddleware.js b/api/server/middleware/abortMiddleware.js index f0eabddd7..21f67d4db 100644 --- a/api/server/middleware/abortMiddleware.js +++ b/api/server/middleware/abortMiddleware.js @@ -1,9 +1,9 @@ const { isAssistantsEndpoint } = require('librechat-data-provider'); const { sendMessage, sendError, countTokens, isEnabled } = require('~/server/utils'); const { truncateText, smartTruncateText } = require('~/app/clients/prompts'); -const { saveMessage, getConvo, getConvoTitle } = require('~/models'); const clearPendingReq = require('~/cache/clearPendingReq'); const abortControllers = require('./abortControllers'); +const { saveMessage, getConvo } = require('~/models'); const spendTokens = require('~/models/spendTokens'); const { abortRun } = require('./abortRun'); const { logger } = require('~/config'); @@ -90,7 +90,8 @@ const createAbortController = (req, res, getAbortData, getReqData) => { abortController.abortCompletion = async function () { abortController.abort(); - const { conversationId, userMessage, promptTokens, ...responseData } = getAbortData(); + const { conversationId, userMessage, userMessagePromise, promptTokens, ...responseData } = + getAbortData(); const completionTokens = await countTokens(responseData?.text ?? ''); const user = req.user.id; @@ -114,10 +115,20 @@ const createAbortController = (req, res, getAbortData, getReqData) => { saveMessage({ ...responseMessage, user }); + let conversation; + if (userMessagePromise) { + const resolved = await userMessagePromise; + conversation = resolved?.conversation; + } + + if (!conversation) { + conversation = await getConvo(req.user.id, conversationId); + } + return { - title: await getConvoTitle(user, conversationId), + title: conversation && !conversation.title ? null : conversation?.title || 'New Chat', final: true, - conversation: await getConvo(user, conversationId), + conversation, requestMessage: userMessage, responseMessage: responseMessage, }; diff --git a/api/server/middleware/validateMessageReq.js b/api/server/middleware/validateMessageReq.js index 7492c8fd4..430444a17 100644 --- a/api/server/middleware/validateMessageReq.js +++ b/api/server/middleware/validateMessageReq.js @@ -1,4 +1,4 @@ -const { getConvo } = require('../../models'); +const { getConvo } = require('~/models'); // Middleware to validate conversationId and user relationship const validateMessageReq = async (req, res, next) => { diff --git a/api/server/routes/ask/gptPlugins.js b/api/server/routes/ask/gptPlugins.js index 66f15da0f..6b95ceaba 100644 --- a/api/server/routes/ask/gptPlugins.js +++ b/api/server/routes/ask/gptPlugins.js @@ -2,9 +2,9 @@ const express = require('express'); const throttle = require('lodash/throttle'); const { getResponseSender, Constants } = require('librechat-data-provider'); const { initializeClient } = require('~/server/services/Endpoints/gptPlugins'); -const { saveMessage, getConvoTitle, getConvo } = require('~/models'); const { sendMessage, createOnProgress } = require('~/server/utils'); const { addTitle } = require('~/server/services/Endpoints/openAI'); +const { saveMessage } = require('~/models'); const { handleAbort, createAbortController, @@ -41,6 +41,7 @@ router.post( logger.debug('[/ask/gptPlugins]', { text, conversationId, ...endpointOption }); let userMessage; + let userMessagePromise; let promptTokens; let userMessageId; let responseMessageId; @@ -58,6 +59,8 @@ router.post( if (key === 'userMessage') { userMessage = data[key]; userMessageId = data[key].messageId; + } else if (key === 'userMessagePromise') { + userMessagePromise = data[key]; } else if (key === 'responseMessageId') { responseMessageId = data[key]; } else if (key === 'promptTokens') { @@ -151,6 +154,7 @@ router.post( const getAbortData = () => ({ sender, conversationId, + userMessagePromise, messageId: responseMessageId, parentMessageId: overrideParentMessageId ?? userMessageId, text: getPartialText(), @@ -207,10 +211,14 @@ router.post( response.plugins = plugins.map((p) => ({ ...p, loading: false })); await saveMessage({ ...response, user }); + const { conversation = {} } = await client.responsePromise; + conversation.title = + conversation && !conversation.title ? null : conversation?.title || 'New Chat'; + sendMessage(res, { - title: await getConvoTitle(user, conversationId), + title: conversation.title, final: true, - conversation: await getConvo(user, conversationId), + conversation, requestMessage: userMessage, responseMessage: response, }); diff --git a/api/server/routes/edit/gptPlugins.js b/api/server/routes/edit/gptPlugins.js index 6fc2e4b1f..67c5aa2b6 100644 --- a/api/server/routes/edit/gptPlugins.js +++ b/api/server/routes/edit/gptPlugins.js @@ -13,7 +13,7 @@ const { } = require('~/server/middleware'); const { sendMessage, createOnProgress, formatSteps, formatAction } = require('~/server/utils'); const { initializeClient } = require('~/server/services/Endpoints/gptPlugins'); -const { saveMessage, getConvoTitle, getConvo } = require('~/models'); +const { saveMessage } = require('~/models'); const { validateTools } = require('~/app'); const { logger } = require('~/config'); @@ -49,6 +49,7 @@ router.post( }); let userMessage; + let userMessagePromise; let promptTokens; const sender = getResponseSender({ ...endpointOption, @@ -68,6 +69,8 @@ router.post( for (let key in data) { if (key === 'userMessage') { userMessage = data[key]; + } else if (key === 'userMessagePromise') { + userMessagePromise = data[key]; } else if (key === 'responseMessageId') { responseMessageId = data[key]; } else if (key === 'promptTokens') { @@ -119,6 +122,7 @@ router.post( const getAbortData = () => ({ sender, conversationId, + userMessagePromise, messageId: responseMessageId, parentMessageId: overrideParentMessageId ?? userMessageId, text: getPartialText(), @@ -179,10 +183,14 @@ router.post( response.plugin = { ...plugin, loading: false }; await saveMessage({ ...response, user }); + const { conversation = {} } = await client.responsePromise; + conversation.title = + conversation && !conversation.title ? null : conversation?.title || 'New Chat'; + sendMessage(res, { - title: await getConvoTitle(user, conversationId), + title: conversation.title, final: true, - conversation: await getConvo(user, conversationId), + conversation, requestMessage: userMessage, responseMessage: response, }); diff --git a/api/server/services/Runs/StreamRunManager.js b/api/server/services/Runs/StreamRunManager.js index f19c73d73..01c97c0f7 100644 --- a/api/server/services/Runs/StreamRunManager.js +++ b/api/server/services/Runs/StreamRunManager.js @@ -427,7 +427,7 @@ class StreamRunManager { const toolCallDelta = toolCall[toolCall.type]; const progressCallback = this.progressCallbacks.get(stepKey); - await progressCallback(toolCallDelta); + progressCallback(toolCallDelta); } } diff --git a/api/server/utils/handleText.js b/api/server/utils/handleText.js index 70dc16b93..f07a30f0a 100644 --- a/api/server/utils/handleText.js +++ b/api/server/utils/handleText.js @@ -12,33 +12,35 @@ const citationRegex = /\[\^\d+?\^]/g; const addSpaceIfNeeded = (text) => (text.length > 0 && !text.endsWith(' ') ? text + ' ' : text); +const base = { message: true, initial: true }; const createOnProgress = ({ generation = '', onProgress: _onProgress }) => { let i = 0; let tokens = addSpaceIfNeeded(generation); - const progressCallback = async (partial, { res, text, bing = false, ...rest }) => { + const basePayload = Object.assign({}, base, { text: tokens || '' }); + + const progressCallback = (partial, { res, text, ...rest }) => { let chunk = partial === text ? '' : partial; - tokens += chunk; - tokens = tokens.replaceAll('[DONE]', ''); + basePayload.text = basePayload.text + chunk; - if (bing) { - tokens = citeText(tokens, true); + const payload = Object.assign({}, basePayload, rest); + sendMessage(res, payload); + if (_onProgress) { + _onProgress(payload); + } + if (i === 0) { + basePayload.initial = false; } - - const payload = { text: tokens, message: true, initial: i === 0, ...rest }; - sendMessage(res, { ...payload, text: tokens }); - _onProgress && _onProgress(payload); i++; }; const sendIntermediateMessage = (res, payload, extraTokens = '') => { - tokens += extraTokens; - sendMessage(res, { - text: tokens?.length === 0 ? '' : tokens, - message: true, - initial: i === 0, - ...payload, - }); + basePayload.text = basePayload.text + extraTokens; + const message = Object.assign({}, basePayload, payload); + sendMessage(res, message); + if (i === 0) { + basePayload.initial = false; + } i++; }; @@ -47,7 +49,7 @@ const createOnProgress = ({ generation = '', onProgress: _onProgress }) => { }; const getPartialText = () => { - return tokens; + return basePayload.text; }; return { onProgress, getPartialText, sendIntermediateMessage }; diff --git a/api/typedefs.js b/api/typedefs.js index cdb2c531f..ecf78c137 100644 --- a/api/typedefs.js +++ b/api/typedefs.js @@ -1442,3 +1442,10 @@ * @typedef {import('librechat-data-provider').TForkConvoRequest} TForkConvoRequest * @memberof typedefs */ + +/** Clients */ + +/** + * @typedef {Promise<{ message: TMessage, conversation: TConversation }> | undefined} ClientDatabaseSavePromise + * @memberof typedefs + */ diff --git a/client/src/components/Chat/Messages/ui/MessageRender.tsx b/client/src/components/Chat/Messages/ui/MessageRender.tsx index c78066c53..aa248a118 100644 --- a/client/src/components/Chat/Messages/ui/MessageRender.tsx +++ b/client/src/components/Chat/Messages/ui/MessageRender.tsx @@ -122,7 +122,7 @@ const MessageRender = React.memo( {!msg?.children?.length && (isSubmittingFamily || isSubmitting) ? ( - + ) : ( { +const PlaceholderRow = memo(({ isCard }: { isCard?: boolean }) => { if (!isCard) { return null; } - if (!isLast) { - return null; - } return
; }); diff --git a/client/src/components/Nav/ClearConvos.tsx b/client/src/components/Nav/ClearConvos.tsx index 438361aa5..310899f35 100644 --- a/client/src/components/Nav/ClearConvos.tsx +++ b/client/src/components/Nav/ClearConvos.tsx @@ -1,9 +1,9 @@ import { useState } from 'react'; -import { Dialog } from '~/components/ui/'; -import DialogTemplate from '~/components/ui/DialogTemplate'; -import { ClearChatsButton } from './SettingsTabs/'; import { useClearConversationsMutation } from 'librechat-data-provider/react-query'; import { useLocalize, useConversation, useConversations } from '~/hooks'; +import DialogTemplate from '~/components/ui/DialogTemplate'; +import { ClearChatsButton } from './SettingsTabs'; +import { Dialog } from '~/components/ui'; const ClearConvos = ({ open, onOpenChange }) => { const { newConversation } = useConversation(); diff --git a/client/src/hooks/Conversations/useConversation.ts b/client/src/hooks/Conversations/useConversation.ts index 7f58a136f..ada584e00 100644 --- a/client/src/hooks/Conversations/useConversation.ts +++ b/client/src/hooks/Conversations/useConversation.ts @@ -1,5 +1,7 @@ import { useCallback } from 'react'; import { useNavigate } from 'react-router-dom'; +import { QueryKeys } from 'librechat-data-provider'; +import { useQueryClient } from '@tanstack/react-query'; import { useSetRecoilState, useResetRecoilState, useRecoilCallback } from 'recoil'; import { useGetEndpointsQuery, useGetModelsQuery } from 'librechat-data-provider/react-query'; import type { @@ -15,6 +17,7 @@ import store from '~/store'; const useConversation = () => { const navigate = useNavigate(); + const queryClient = useQueryClient(); const setConversation = useSetRecoilState(store.conversation); const resetLatestMessage = useResetRecoilState(store.latestMessage); const setMessages = useSetRecoilState(store.messages); @@ -59,6 +62,7 @@ const useConversation = () => { resetLatestMessage(); if (conversation.conversationId === 'new' && !modelsData) { + queryClient.invalidateQueries([QueryKeys.messages, 'new']); navigate('/c/new'); } },