From 431fc6284f24beb6f504e4240c3406a9503358d5 Mon Sep 17 00:00:00 2001 From: Danny Avila <110412045+danny-avila@users.noreply.github.com> Date: Sat, 30 Dec 2023 14:34:32 -0500 Subject: [PATCH] =?UTF-8?q?=F0=9F=9B=A0=EF=B8=8F=20fix:=20Minor=20Fixes=20?= =?UTF-8?q?in=20`Message`,=20`Ask/EditController`,=20`OpenAIClient`,=20and?= =?UTF-8?q?=20`countTokens`=20(#1463)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(Message): avoid overwriting unprovided properties * fix(OpenAIClient): return intermediateReply on user abort * fix(AskController): do not send/save final message if abort was triggered * fix(countTokens): avoid fetching remote registry and exclusively use cl100k_base or p50k_base weights for token counting * refactor(Message/messageSchema): rely on messageSchema for default values when saving messages * fix(EditController): do not send/save final message if abort was triggered * fix(config/helpers): fix module resolution error --- api/app/clients/OpenAIClient.js | 2 +- api/models/Message.js | 14 +++++++------- api/models/schema/messageSchema.js | 1 + api/server/controllers/AskController.js | 21 ++++++++++++--------- api/server/controllers/EditController.js | 20 +++++++++++--------- api/server/utils/countTokens.js | 7 +++---- config/helpers.js | 3 ++- 7 files changed, 37 insertions(+), 31 deletions(-) diff --git a/api/app/clients/OpenAIClient.js b/api/app/clients/OpenAIClient.js index ce39311f3..f0dbc366b 100644 --- a/api/app/clients/OpenAIClient.js +++ b/api/app/clients/OpenAIClient.js @@ -847,7 +847,7 @@ ${convo} err?.message?.includes('abort') || (err instanceof OpenAI.APIError && err?.message?.includes('abort')) ) { - return ''; + return intermediateReply; } if ( err?.message?.includes( diff --git a/api/models/Message.js b/api/models/Message.js index 7cb9bdc37..7accf9285 100644 --- a/api/models/Message.js +++ b/api/models/Message.js @@ -15,16 +15,16 @@ module.exports = { parentMessageId, sender, text, - isCreatedByUser = false, + isCreatedByUser, error, unfinished, files, - isEdited = false, - finish_reason = null, - tokenCount = null, - plugin = null, - plugins = null, - model = null, + isEdited, + finish_reason, + tokenCount, + plugin, + plugins, + model, }) { try { const validConvoId = idSchema.safeParse(conversationId); diff --git a/api/models/schema/messageSchema.js b/api/models/schema/messageSchema.js index bcc28cd23..33d799544 100644 --- a/api/models/schema/messageSchema.js +++ b/api/models/schema/messageSchema.js @@ -21,6 +21,7 @@ const messageSchema = mongoose.Schema( }, model: { type: String, + default: null, }, conversationSignature: { type: String, diff --git a/api/server/controllers/AskController.js b/api/server/controllers/AskController.js index ffaa10938..78933feeb 100644 --- a/api/server/controllers/AskController.js +++ b/api/server/controllers/AskController.js @@ -118,16 +118,19 @@ const AskController = async (req, res, next, initializeClient, addTitle) => { delete userMessage.image_urls; } - sendMessage(res, { - title: await getConvoTitle(user, conversationId), - final: true, - conversation: await getConvo(user, conversationId), - requestMessage: userMessage, - responseMessage: response, - }); - res.end(); + if (!abortController.signal.aborted) { + sendMessage(res, { + title: await getConvoTitle(user, conversationId), + final: true, + conversation: await getConvo(user, conversationId), + requestMessage: userMessage, + responseMessage: response, + }); + res.end(); + + await saveMessage({ ...response, user }); + } - await saveMessage({ ...response, user }); await saveMessage(userMessage); if (addTitle && parentMessageId === '00000000-0000-0000-0000-000000000000' && newConvo) { diff --git a/api/server/controllers/EditController.js b/api/server/controllers/EditController.js index ecc146126..72ee58026 100644 --- a/api/server/controllers/EditController.js +++ b/api/server/controllers/EditController.js @@ -112,16 +112,18 @@ const EditController = async (req, res, next, initializeClient) => { response = { ...response, ...metadata }; } - await saveMessage({ ...response, user }); + if (!abortController.signal.aborted) { + sendMessage(res, { + title: await getConvoTitle(user, conversationId), + final: true, + conversation: await getConvo(user, conversationId), + requestMessage: userMessage, + responseMessage: response, + }); + res.end(); - sendMessage(res, { - title: await getConvoTitle(user, conversationId), - final: true, - conversation: await getConvo(user, conversationId), - requestMessage: userMessage, - responseMessage: response, - }); - res.end(); + await saveMessage({ ...response, user }); + } } catch (error) { const partialText = getPartialText(); handleAbortError(res, req, error, { diff --git a/api/server/utils/countTokens.js b/api/server/utils/countTokens.js index 9c8c98e76..34c070aa8 100644 --- a/api/server/utils/countTokens.js +++ b/api/server/utils/countTokens.js @@ -1,13 +1,12 @@ -const { load } = require('tiktoken/load'); const { Tiktoken } = require('tiktoken/lite'); -const registry = require('tiktoken/registry.json'); -const models = require('tiktoken/model_to_encoding.json'); +const p50k_base = require('tiktoken/encoders/p50k_base.json'); +const cl100k_base = require('tiktoken/encoders/cl100k_base.json'); const logger = require('~/config/winston'); const countTokens = async (text = '', modelName = 'gpt-3.5-turbo') => { let encoder = null; try { - const model = await load(registry[models[modelName]]); + const model = modelName.includes('text-davinci-003') ? p50k_base : cl100k_base; encoder = new Tiktoken(model.bpe_ranks, model.special_tokens, model.pat_str); const tokens = encoder.encode(text); encoder.free(); diff --git a/config/helpers.js b/config/helpers.js index a86d562eb..2b634612d 100644 --- a/config/helpers.js +++ b/config/helpers.js @@ -6,7 +6,8 @@ const fs = require('fs'); const path = require('path'); const readline = require('readline'); const { execSync } = require('child_process'); -const { connectDb } = require('@librechat/backend/lib/db'); +require('module-alias')({ base: path.resolve(__dirname, '..', 'api') }); +const connectDb = require('~/lib/db/connectDb'); const askQuestion = (query) => { const rl = readline.createInterface({