diff --git a/api/app/clients/BaseClient.js b/api/app/clients/BaseClient.js index e90d9fefb..08c3b10a6 100644 --- a/api/app/clients/BaseClient.js +++ b/api/app/clients/BaseClient.js @@ -515,7 +515,7 @@ class BaseClient { } async saveMessageToDatabase(message, endpointOptions, user = null) { - await saveMessage({ ...message, unfinished: false, cancelled: false }); + await saveMessage({ ...message, user, unfinished: false, cancelled: false }); await saveConvo(user, { conversationId: message.conversationId, endpoint: this.options.endpoint, diff --git a/api/models/Conversation.js b/api/models/Conversation.js index 9d56ee48c..c946a28af 100644 --- a/api/models/Conversation.js +++ b/api/models/Conversation.js @@ -108,6 +108,23 @@ module.exports = { return { message: 'Error getting conversation title' }; } }, + /** + * Asynchronously deletes conversations and associated messages for a given user and filter. + * + * @async + * @function + * @param {string|ObjectId} user - The user's ID. + * @param {Object} filter - Additional filter criteria for the conversations to be deleted. + * @returns {Promise<{ n: number, ok: number, deletedCount: number, messages: { n: number, ok: number, deletedCount: number } }>} + * An object containing the count of deleted conversations and associated messages. + * @throws {Error} Throws an error if there's an issue with the database operations. + * + * @example + * const user = 'someUserId'; + * const filter = { someField: 'someValue' }; + * const result = await deleteConvos(user, filter); + * console.log(result); // { n: 5, ok: 1, deletedCount: 5, messages: { n: 10, ok: 1, deletedCount: 10 } } + */ deleteConvos: async (user, filter) => { let toRemove = await Conversation.find({ ...filter, user }).select('conversationId'); const ids = toRemove.map((instance) => instance.conversationId); diff --git a/api/models/Message.js b/api/models/Message.js index adcdd9e56..a3380a8b0 100644 --- a/api/models/Message.js +++ b/api/models/Message.js @@ -7,6 +7,7 @@ module.exports = { Message, async saveMessage({ + user, messageId, newMessageId, conversationId, @@ -33,6 +34,7 @@ module.exports = { await Message.findOneAndUpdate( { messageId }, { + user, messageId: newMessageId || messageId, conversationId, parentMessageId, diff --git a/api/models/schema/convoSchema.js b/api/models/schema/convoSchema.js index e21ae0aa6..1ea928f25 100644 --- a/api/models/schema/convoSchema.js +++ b/api/models/schema/convoSchema.js @@ -17,6 +17,7 @@ const convoSchema = mongoose.Schema( }, user: { type: String, + index: true, default: null, }, messages: [{ type: mongoose.Schema.Types.ObjectId, ref: 'Message' }], diff --git a/api/models/schema/messageSchema.js b/api/models/schema/messageSchema.js index 7d4211dea..267fa26ab 100644 --- a/api/models/schema/messageSchema.js +++ b/api/models/schema/messageSchema.js @@ -14,6 +14,11 @@ const messageSchema = mongoose.Schema( required: true, meiliIndex: true, }, + user: { + type: String, + index: true, + default: null, + }, model: { type: String, }, diff --git a/api/server/middleware/abortMiddleware.js b/api/server/middleware/abortMiddleware.js index 80cf26ba4..68ee9d15e 100644 --- a/api/server/middleware/abortMiddleware.js +++ b/api/server/middleware/abortMiddleware.js @@ -54,7 +54,7 @@ const createAbortController = (req, res, getAbortData) => { isCreatedByUser: false, }; - saveMessage(responseMessage); + saveMessage({ ...responseMessage, user: req.user.id }); return { title: await getConvoTitle(req.user.id, conversationId), @@ -80,6 +80,7 @@ const handleAbortError = async (res, req, error, data) => { parentMessageId, text: error.message, shouldSaveMessage: true, + user: req.user.id, }; const callback = async () => { if (abortControllers.has(conversationId)) { diff --git a/api/server/middleware/denyRequest.js b/api/server/middleware/denyRequest.js index 64ca86c63..1f44e2974 100644 --- a/api/server/middleware/denyRequest.js +++ b/api/server/middleware/denyRequest.js @@ -42,7 +42,7 @@ const denyRequest = async (req, res, errorMessage) => { _convoId && parentMessageId && parentMessageId !== '00000000-0000-0000-0000-000000000000'; if (shouldSaveMessage) { - await saveMessage(userMessage); + await saveMessage({ ...userMessage, user: req.user.id }); } return await sendError(res, { @@ -52,6 +52,7 @@ const denyRequest = async (req, res, errorMessage) => { parentMessageId: userMessage.messageId, text: responseText, shouldSaveMessage, + user: req.user.id, }); }; diff --git a/api/server/routes/ask/anthropic.js b/api/server/routes/ask/anthropic.js index 673fd185d..e0fb9720e 100644 --- a/api/server/routes/ask/anthropic.js +++ b/api/server/routes/ask/anthropic.js @@ -30,6 +30,7 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, let responseMessageId; let lastSavedTimestamp = 0; let saveDelay = 100; + const user = req.user.id; const getIds = (data) => { userMessage = data.userMessage; @@ -55,6 +56,7 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, unfinished: true, cancelled: false, error: false, + user, }); } @@ -80,7 +82,7 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, let response = await client.sendMessage(text, { getIds, // debug: true, - user: req.user.id, + user, conversationId, parentMessageId, overrideParentMessageId, @@ -98,18 +100,18 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, response.parentMessageId = overrideParentMessageId; } - await saveConvo(req.user.id, { + await saveConvo(user, { ...endpointOption, ...endpointOption.modelOptions, conversationId, endpoint: 'anthropic', }); - await saveMessage(response); + await saveMessage({ ...response, user }); sendMessage(res, { - title: await getConvoTitle(req.user.id, conversationId), + title: await getConvoTitle(user, conversationId), final: true, - conversation: await getConvo(req.user.id, conversationId), + conversation: await getConvo(user, conversationId), requestMessage: userMessage, responseMessage: response, }); diff --git a/api/server/routes/ask/askChatGPTBrowser.js b/api/server/routes/ask/askChatGPTBrowser.js index c3a9d56f1..04772a74a 100644 --- a/api/server/routes/ask/askChatGPTBrowser.js +++ b/api/server/routes/ask/askChatGPTBrowser.js @@ -48,7 +48,7 @@ router.post('/', setHeaders, async (req, res) => { }); if (!overrideParentMessageId) { - await saveMessage(userMessage); + await saveMessage({ ...userMessage, user: req.user.id }); await saveConvo(req.user.id, { ...userMessage, ...endpointOption, @@ -80,7 +80,7 @@ const ask = async ({ res, }) => { let { text, parentMessageId: userParentMessageId, messageId: userMessageId } = userMessage; - const userId = req.user.id; + const user = req.user.id; let responseMessageId = crypto.randomUUID(); let getPartialMessage = null; try { @@ -100,6 +100,7 @@ const ask = async ({ cancelled: false, error: false, isCreatedByUser: false, + user, }); } }, @@ -114,7 +115,7 @@ const ask = async ({ conversationId, ...endpointOption, abortController, - userId, + userId: user, onProgress: progressCallback.call(null, { res, text }), onEventMessage: (eventMessage) => { let data = null; @@ -157,7 +158,7 @@ const ask = async ({ isCreatedByUser: false, }; - await saveMessage(responseMessage); + await saveMessage({ ...responseMessage, user }); responseMessage.messageId = newResponseMessageId; // STEP2 update the conversation @@ -181,7 +182,7 @@ const ask = async ({ } } - await saveConvo(req.user.id, conversationUpdate); + await saveConvo(user, conversationUpdate); conversationId = newConversationId; // STEP3 update the user message @@ -192,6 +193,7 @@ const ask = async ({ if (!overrideParentMessageId) { await saveMessage({ ...userMessage, + user, messageId: userMessageId, newMessageId: newUserMassageId, }); @@ -199,9 +201,9 @@ const ask = async ({ userMessageId = newUserMassageId; sendMessage(res, { - title: await getConvoTitle(req.user.id, conversationId), + title: await getConvoTitle(user, conversationId), final: true, - conversation: await getConvo(req.user.id, conversationId), + conversation: await getConvo(user, conversationId), requestMessage: userMessage, responseMessage: responseMessage, }); @@ -210,7 +212,7 @@ const ask = async ({ if (userParentMessageId == '00000000-0000-0000-0000-000000000000') { // const title = await titleConvo({ endpoint: endpointOption?.endpoint, text, response: responseMessage }); const title = await response.details.title; - await saveConvo(req.user.id, { + await saveConvo(user, { conversationId: conversationId, title, }); @@ -227,7 +229,7 @@ const ask = async ({ isCreatedByUser: false, text: `${getPartialMessage() ?? ''}\n\nError message: "${error.message}"`, }; - await saveMessage(errorMessage); + await saveMessage({ ...errorMessage, user }); handleError(res, errorMessage); } }; diff --git a/api/server/routes/ask/bingAI.js b/api/server/routes/ask/bingAI.js index f3047c285..740d8cc04 100644 --- a/api/server/routes/ask/bingAI.js +++ b/api/server/routes/ask/bingAI.js @@ -67,7 +67,7 @@ router.post('/', setHeaders, async (req, res) => { }); if (!overrideParentMessageId) { - await saveMessage(userMessage); + await saveMessage({ ...userMessage, user: req.user.id }); await saveConvo(req.user.id, { ...userMessage, ...endpointOption, @@ -100,6 +100,7 @@ const ask = async ({ res, }) => { let { text, parentMessageId: userParentMessageId, messageId: userMessageId } = userMessage; + const user = req.user.id; let responseMessageId = crypto.randomUUID(); const model = endpointOption?.jailbreak ? 'Sydney' : 'BingAI'; @@ -125,6 +126,7 @@ const ask = async ({ cancelled: false, error: false, isCreatedByUser: false, + user, }); } }, @@ -132,14 +134,14 @@ const ask = async ({ const abortController = new AbortController(); let bingConversationId = null; if (!isNewConversation) { - const convo = await getConvo(req.user.id, conversationId); + const convo = await getConvo(user, conversationId); bingConversationId = convo.bingConversationId; } try { let response = await askBing({ text, - userId: req.user.id, + userId: user, parentMessageId: userParentMessageId, conversationId: bingConversationId ?? conversationId, ...endpointOption, @@ -194,7 +196,7 @@ const ask = async ({ isCreatedByUser: false, }; - await saveMessage(responseMessage); + await saveMessage({ ...responseMessage, user }); responseMessage.messageId = newResponseMessageId; let conversationUpdate = { @@ -213,13 +215,14 @@ const ask = async ({ conversationUpdate.invocationId = response.invocationId; } - await saveConvo(req.user.id, conversationUpdate); + await saveConvo(user, conversationUpdate); userMessage.messageId = newUserMessageId; // If response has parentMessageId, the fake userMessage.messageId should be updated to the real one. if (!overrideParentMessageId) { await saveMessage({ ...userMessage, + user, messageId: userMessageId, newMessageId: newUserMessageId, }); @@ -227,9 +230,9 @@ const ask = async ({ userMessageId = newUserMessageId; sendMessage(res, { - title: await getConvoTitle(req.user.id, conversationId), + title: await getConvoTitle(user, conversationId), final: true, - conversation: await getConvo(req.user.id, conversationId), + conversation: await getConvo(user, conversationId), requestMessage: userMessage, responseMessage: responseMessage, }); @@ -241,7 +244,7 @@ const ask = async ({ response: responseMessage, }); - await saveConvo(req.user.id, { + await saveConvo(user, { conversationId: conversationId, title, }); @@ -263,12 +266,12 @@ const ask = async ({ isCreatedByUser: false, }; - saveMessage(responseMessage); + saveMessage({ ...responseMessage, user }); return { - title: await getConvoTitle(req.user.id, conversationId), + title: await getConvoTitle(user, conversationId), final: true, - conversation: await getConvo(req.user.id, conversationId), + conversation: await getConvo(user, conversationId), requestMessage: userMessage, responseMessage: responseMessage, }; @@ -286,7 +289,7 @@ const ask = async ({ model, isCreatedByUser: false, }; - await saveMessage(errorMessage); + await saveMessage({ ...errorMessage, user }); handleError(res, errorMessage); } } diff --git a/api/server/routes/ask/google.js b/api/server/routes/ask/google.js index 5742120b0..ed6859ee6 100644 --- a/api/server/routes/ask/google.js +++ b/api/server/routes/ask/google.js @@ -55,6 +55,7 @@ const ask = async ({ text, endpointOption, parentMessageId = null, conversationI let responseMessageId; let lastSavedTimestamp = 0; const { overrideParentMessageId = null } = req.body; + const user = req.user.id; try { const getIds = (data) => { @@ -82,6 +83,7 @@ const ask = async ({ text, endpointOption, parentMessageId = null, conversationI unfinished: true, cancelled: false, error: false, + user, }); } }, @@ -97,7 +99,7 @@ const ask = async ({ text, endpointOption, parentMessageId = null, conversationI endpointOption.key, 'Your GOOGLE_TOKEN has expired. Please provide your token again.', ); - key = await getUserKey({ userId: req.user.id, name: 'google' }); + key = await getUserKey({ userId: user, name: 'google' }); key = JSON.parse(key); delete endpointOption.key; console.log('Using service account key provided by User for PaLM models'); @@ -120,7 +122,7 @@ const ask = async ({ text, endpointOption, parentMessageId = null, conversationI let response = await client.sendMessage(text, { getIds, - user: req.user.id, + user, conversationId, parentMessageId, overrideParentMessageId, @@ -136,18 +138,18 @@ const ask = async ({ text, endpointOption, parentMessageId = null, conversationI response.parentMessageId = overrideParentMessageId; } - await saveConvo(req.user.id, { + await saveConvo(user, { ...endpointOption, ...endpointOption.modelOptions, conversationId, endpoint: 'google', }); - await saveMessage(response); + await saveMessage({ ...response, user }); sendMessage(res, { - title: await getConvoTitle(req.user.id, conversationId), + title: await getConvoTitle(user, conversationId), final: true, - conversation: await getConvo(req.user.id, conversationId), + conversation: await getConvo(user, conversationId), requestMessage: userMessage, responseMessage: response, }); @@ -164,7 +166,7 @@ const ask = async ({ text, endpointOption, parentMessageId = null, conversationI error: true, text: error.message, }; - await saveMessage(errorMessage); + await saveMessage({ ...errorMessage, user }); handleError(res, errorMessage); } }; diff --git a/api/server/routes/ask/gptPlugins.js b/api/server/routes/ask/gptPlugins.js index 330f9404d..a71c13352 100644 --- a/api/server/routes/ask/gptPlugins.js +++ b/api/server/routes/ask/gptPlugins.js @@ -76,6 +76,7 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, cancelled: false, error: false, plugins, + user, }); } @@ -129,7 +130,7 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, }; const onChainEnd = () => { - saveMessage(userMessage); + saveMessage({ ...userMessage, user }); sendIntermediateMessage(res, { plugins }); }; @@ -182,12 +183,12 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, console.log('CLIENT RESPONSE'); console.dir(response, { depth: null }); response.plugins = plugins.map((p) => ({ ...p, loading: false })); - await saveMessage(response); + await saveMessage({ ...response, user }); sendMessage(res, { - title: await getConvoTitle(req.user.id, conversationId), + title: await getConvoTitle(user, conversationId), final: true, - conversation: await getConvo(req.user.id, conversationId), + conversation: await getConvo(user, conversationId), requestMessage: userMessage, responseMessage: response, }); diff --git a/api/server/routes/ask/openAI.js b/api/server/routes/ask/openAI.js index fb662809d..c822ddaf5 100644 --- a/api/server/routes/ask/openAI.js +++ b/api/server/routes/ask/openAI.js @@ -61,6 +61,7 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, unfinished: true, cancelled: false, error: false, + user, }); } @@ -113,12 +114,12 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, response.promptTokens, response.completionTokens, ); - await saveMessage(response); + await saveMessage({ ...response, user }); sendMessage(res, { - title: await getConvoTitle(req.user.id, conversationId), + title: await getConvoTitle(user, conversationId), final: true, - conversation: await getConvo(req.user.id, conversationId), + conversation: await getConvo(user, conversationId), requestMessage: userMessage, responseMessage: response, }); diff --git a/api/server/routes/edit/anthropic.js b/api/server/routes/edit/anthropic.js index 5695d67cc..b69e589ec 100644 --- a/api/server/routes/edit/anthropic.js +++ b/api/server/routes/edit/anthropic.js @@ -33,6 +33,7 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, let lastSavedTimestamp = 0; let saveDelay = 100; const userMessageId = parentMessageId; + const user = req.user.id; const addMetadata = (data) => (metadata = data); const getIds = (data) => { @@ -56,6 +57,7 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, cancelled: false, isEdited: true, error: false, + user, }); } @@ -79,7 +81,7 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, const { client } = await initializeClient(req, endpointOption); let response = await client.sendMessage(text, { - user: req.user.id, + user, generation, isContinued, isEdited: true, @@ -107,11 +109,11 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, response.parentMessageId = overrideParentMessageId; } - await saveMessage(response); + await saveMessage({ ...response, user }); sendMessage(res, { - title: await getConvoTitle(req.user.id, conversationId), + title: await getConvoTitle(user, conversationId), final: true, - conversation: await getConvo(req.user.id, conversationId), + conversation: await getConvo(user, conversationId), requestMessage: userMessage, responseMessage: response, }); diff --git a/api/server/routes/edit/gptPlugins.js b/api/server/routes/edit/gptPlugins.js index b180c844f..5745d8e0b 100644 --- a/api/server/routes/edit/gptPlugins.js +++ b/api/server/routes/edit/gptPlugins.js @@ -75,6 +75,7 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, cancelled: false, isEdited: true, error: false, + user, }); } @@ -89,7 +90,7 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, plugin.inputs.push(formattedAction); plugin.latest = formattedAction.plugin; if (!start) { - saveMessage(userMessage); + saveMessage({ ...userMessage, user }); } sendIntermediateMessage(res, { plugin }); // console.log('PLUGIN ACTION', formattedAction); @@ -99,7 +100,7 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, let { intermediateSteps: steps } = data; plugin.outputs = steps && steps[0].action ? formatSteps(steps) : 'An error occurred.'; plugin.loading = false; - saveMessage(userMessage); + saveMessage({ ...userMessage, user }); sendIntermediateMessage(res, { plugin }); // console.log('CHAIN END', plugin.outputs); }; @@ -154,12 +155,12 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, console.log('CLIENT RESPONSE'); console.dir(response, { depth: null }); response.plugin = { ...plugin, loading: false }; - await saveMessage(response); + await saveMessage({ ...response, user }); sendMessage(res, { - title: await getConvoTitle(req.user.id, conversationId), + title: await getConvoTitle(user, conversationId), final: true, - conversation: await getConvo(req.user.id, conversationId), + conversation: await getConvo(user, conversationId), requestMessage: userMessage, responseMessage: response, }); diff --git a/api/server/routes/edit/openAI.js b/api/server/routes/edit/openAI.js index 8af7ee206..f98c123ea 100644 --- a/api/server/routes/edit/openAI.js +++ b/api/server/routes/edit/openAI.js @@ -33,6 +33,7 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, let lastSavedTimestamp = 0; let saveDelay = 100; const userMessageId = parentMessageId; + const user = req.user.id; const addMetadata = (data) => (metadata = data); const getIds = (data) => { @@ -58,6 +59,7 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, cancelled: false, isEdited: true, error: false, + user, }); } @@ -82,7 +84,7 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, const { client } = await initializeClient(req, endpointOption); let response = await client.sendMessage(text, { - user: req.user.id, + user, generation, isContinued, isEdited: true, @@ -110,12 +112,12 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, response.promptTokens, response.completionTokens, ); - await saveMessage(response); + await saveMessage({ ...response, user }); sendMessage(res, { - title: await getConvoTitle(req.user.id, conversationId), + title: await getConvoTitle(user, conversationId), final: true, - conversation: await getConvo(req.user.id, conversationId), + conversation: await getConvo(user, conversationId), requestMessage: userMessage, responseMessage: response, }); diff --git a/api/server/routes/endpoints/schemas.js b/api/server/routes/endpoints/schemas.js index 7c948f295..99a603605 100644 --- a/api/server/routes/endpoints/schemas.js +++ b/api/server/routes/endpoints/schemas.js @@ -46,7 +46,7 @@ const tAgentOptionsSchema = z.object({ const tConversationSchema = z.object({ conversationId: z.string().nullable(), - title: z.string(), + title: z.string().nullable().or(z.literal('New Chat')).default('New Chat'), user: z.string().optional(), endpoint: eModelEndpointSchema.nullable(), suggestions: z.array(z.string()).optional(), diff --git a/api/server/routes/messages.js b/api/server/routes/messages.js index 7dd72fad1..597b2480a 100644 --- a/api/server/routes/messages.js +++ b/api/server/routes/messages.js @@ -17,7 +17,7 @@ router.get('/:conversationId', requireJwtAuth, validateMessageReq, async (req, r // CREATE router.post('/:conversationId', requireJwtAuth, validateMessageReq, async (req, res) => { const message = req.body; - const savedMessage = await saveMessage(message); + const savedMessage = await saveMessage({ ...message, user: req.user.id }); await saveConvo(req.user.id, savedMessage); res.status(201).send(savedMessage); }); diff --git a/api/server/utils/streamResponse.js b/api/server/utils/streamResponse.js index 26cb0c238..133f18c43 100644 --- a/api/server/utils/streamResponse.js +++ b/api/server/utils/streamResponse.js @@ -32,7 +32,8 @@ const sendMessage = (res, message, event = 'message') => { * @param {function} callback - [Optional] The callback function to be executed. */ const sendError = async (res, options, callback) => { - const { sender, conversationId, messageId, parentMessageId, text, shouldSaveMessage } = options; + const { user, sender, conversationId, messageId, parentMessageId, text, shouldSaveMessage } = + options; const errorMessage = { sender, messageId: messageId ?? crypto.randomUUID(), @@ -50,7 +51,7 @@ const sendError = async (res, options, callback) => { } if (shouldSaveMessage) { - await saveMessage(errorMessage); + await saveMessage({ ...errorMessage, user }); } handleError(res, errorMessage); diff --git a/client/src/components/Input/EndpointMenu/EndpointMenu.jsx b/client/src/components/Input/EndpointMenu/EndpointMenu.jsx index b620be91e..c59d905b4 100644 --- a/client/src/components/Input/EndpointMenu/EndpointMenu.jsx +++ b/client/src/components/Input/EndpointMenu/EndpointMenu.jsx @@ -79,11 +79,12 @@ export default function NewConversationMenu() { lastModelUpdate.secondaryModel = conversation.agentOptions.model; } setLastModel(lastModelUpdate); - setLastConvo(conversation); } else if (endpoint === 'bingAI') { const { jailbreak, toneStyle } = conversation; setLastBingSettings({ ...lastBingSettings, jailbreak, toneStyle }); } + + setLastConvo(conversation); }, [conversation]); // set the current model diff --git a/e2e/playwright.config.local.ts b/e2e/playwright.config.local.ts index b6ceee5e1..09f499d3e 100644 --- a/e2e/playwright.config.local.ts +++ b/e2e/playwright.config.local.ts @@ -9,11 +9,13 @@ const config: PlaywrightTestConfig = { ...mainConfig, retries: 0, globalSetup: require.resolve('./setup/global-setup.local'), + globalTeardown: require.resolve('./setup/global-teardown.local'), webServer: { ...mainConfig.webServer, command: `node ${absolutePath}`, env: { ...process.env, + SEARCH: 'false', NODE_ENV: 'development', SESSION_EXPIRY: '60000', REFRESH_TOKEN_EXPIRY: '300000', diff --git a/e2e/playwright.config.ts b/e2e/playwright.config.ts index 7c6fe7107..ea45f1665 100644 --- a/e2e/playwright.config.ts +++ b/e2e/playwright.config.ts @@ -6,6 +6,7 @@ dotenv.config(); export default defineConfig({ globalSetup: require.resolve('./setup/global-setup'), + globalTeardown: require.resolve('./setup/global-teardown'), testDir: 'specs/', outputDir: 'specs/.test-results', /* Run tests in files in parallel. @@ -61,8 +62,10 @@ export default defineConfig({ reuseExistingServer: true, env: { ...process.env, + SEARCH: 'false', NODE_ENV: 'development', SESSION_EXPIRY: '60000', + ALLOW_REGISTRATION: 'true', REFRESH_TOKEN_EXPIRY: '300000', }, }, diff --git a/e2e/setup/authenticate.ts b/e2e/setup/authenticate.ts index d0bb2f2ea..9e91efd9f 100644 --- a/e2e/setup/authenticate.ts +++ b/e2e/setup/authenticate.ts @@ -2,10 +2,31 @@ import { Page, FullConfig, chromium } from '@playwright/test'; import dotenv from 'dotenv'; dotenv.config(); -type User = { username: string; password: string }; +type User = { email: string; name: string; password: string }; + +async function register(page: Page, user: User) { + await page.getByRole('link', { name: 'Sign up' }).click(); + await page.getByLabel('Full name').click(); + await page.getByLabel('Full name').fill('test'); + await page.getByText('Username (optional)').click(); + await page.getByLabel('Username (optional)').fill('test'); + await page.getByLabel('Email').click(); + await page.getByLabel('Email').fill(user.email); + await page.getByLabel('Email').press('Tab'); + await page.getByTestId('password').click(); + await page.getByTestId('password').fill(user.password); + await page.getByTestId('confirm_password').click(); + await page.getByTestId('confirm_password').fill(user.password); + await page.getByLabel('Submit registration').click(); +} + +async function logout(page: Page, user: User) { + await page.getByRole('button', { name: user.name }).click(); + await page.getByRole('button', { name: 'Log out' }).click(); +} async function login(page: Page, user: User) { - await page.locator('input[name="email"]').fill(user.username); + await page.locator('input[name="email"]').fill(user.email); await page.locator('input[name="password"]').fill(user.password); await page.locator('input[name="password"]').press('Enter'); } @@ -15,22 +36,36 @@ async function authenticate(config: FullConfig, user: User) { const { baseURL, storageState } = config.projects[0].use; console.log('🤖: using baseURL', baseURL); console.dir(user, { depth: null }); - const browser = await chromium.launch(); + const browser = await chromium.launch({ + // headless: false, + }); const page = await browser.newPage(); - console.log('🤖: 🗝 authenticating user:', user.username); + console.log('🤖: 🗝 authenticating user:', user.email); if (!baseURL) { throw new Error('🤖: baseURL is not defined'); } - await page.goto(baseURL, { timeout: 5000 }); - await login(page, user); - await page.waitForURL(`${baseURL}/chat/new`); - console.log('🤖: ✔️ user successfully authenticated'); + // Set localStorage before navigating to the page await page.context().addInitScript(() => { localStorage.setItem('navVisible', 'true'); }); console.log('🤖: ✔️ localStorage: set Nav as Visible', storageState); + + await page.goto(baseURL, { timeout: 5000 }); + await register(page, user); + await page.waitForURL(`${baseURL}/chat/new`); + console.log('🤖: ✔️ user successfully registered'); + + // Logout + await logout(page, user); + await page.waitForURL(`${baseURL}/login`); + console.log('🤖: ✔️ user successfully logged out'); + + await login(page, user); + await page.waitForURL(`${baseURL}/chat/new`); + console.log('🤖: ✔️ user successfully authenticated'); + await page.context().storageState({ path: storageState as string }); console.log('🤖: ✔️ authentication state successfully saved in', storageState); await browser.close(); diff --git a/e2e/setup/cleanupUser.ts b/e2e/setup/cleanupUser.ts new file mode 100644 index 000000000..e22316c3d --- /dev/null +++ b/e2e/setup/cleanupUser.ts @@ -0,0 +1,43 @@ +import connectDb from '@librechat/backend/lib/db/connectDb'; +import User from '@librechat/backend/models/User'; +import Session from '@librechat/backend/models/Session'; +import { deleteMessages } from '@librechat/backend/models/Message'; +import { deleteConvos } from '@librechat/backend/models/Conversation'; +type TUser = { email: string; password: string }; + +export default async function cleanupUser(user: TUser) { + const { email } = user; + try { + console.log('🤖: global teardown has been started'); + const db = await connectDb(); + console.log('🤖: ✅ Connected to Database'); + + const { _id } = await User.findOne({ email }).lean(); + console.log('🤖: ✅ Found user in Database'); + + // Delete all conversations & associated messages + const { deletedCount, messages } = await deleteConvos(_id, {}); + + if (messages.deletedCount > 0 || deletedCount > 0) { + console.log(`🤖: ✅ Deleted ${deletedCount} convos & ${messages.deletedCount} messages`); + } + + // Ensure all user messages are deleted + const { deletedCount: deletedMessages } = await deleteMessages({ user: _id }); + if (deletedMessages > 0) { + console.log(`🤖: ✅ Deleted ${deletedMessages} remaining message(s)`); + } + + await Session.deleteAllUserSessions(_id); + + await User.deleteMany({ email }); + + console.log('🤖: ✅ Deleted user from Database'); + + await db.connection.close(); + } catch (error) { + console.error('Error:', error); + } +} + +process.on('uncaughtException', (err) => console.error('Uncaught Exception:', err)); diff --git a/e2e/setup/global-setup.ts b/e2e/setup/global-setup.ts index 73ff9839a..25c60e11a 100644 --- a/e2e/setup/global-setup.ts +++ b/e2e/setup/global-setup.ts @@ -3,7 +3,8 @@ import authenticate from './authenticate'; async function globalSetup(config: FullConfig) { const user = { - username: String(process.env.E2E_USER_EMAIL), + name: 'test', + email: String(process.env.E2E_USER_EMAIL), password: String(process.env.E2E_USER_PASSWORD), }; diff --git a/e2e/setup/global-teardown.local.ts b/e2e/setup/global-teardown.local.ts new file mode 100644 index 000000000..cef902cfc --- /dev/null +++ b/e2e/setup/global-teardown.local.ts @@ -0,0 +1,12 @@ +import localUser from '../config.local'; +import cleanupUser from './cleanupUser'; + +async function globalTeardown() { + try { + await cleanupUser(localUser); + } catch (error) { + console.error('Error:', error); + } +} + +export default globalTeardown; diff --git a/e2e/setup/global-teardown.ts b/e2e/setup/global-teardown.ts new file mode 100644 index 000000000..c71e4d56a --- /dev/null +++ b/e2e/setup/global-teardown.ts @@ -0,0 +1,16 @@ +import cleanupUser from './cleanupUser'; + +async function globalTeardown() { + const user = { + email: String(process.env.E2E_USER_EMAIL), + password: String(process.env.E2E_USER_PASSWORD), + }; + + try { + await cleanupUser(user); + } catch (error) { + console.error('Error:', error); + } +} + +export default globalTeardown; diff --git a/packages/data-provider/src/schemas.ts b/packages/data-provider/src/schemas.ts index 27543c4db..0b00a32a7 100644 --- a/packages/data-provider/src/schemas.ts +++ b/packages/data-provider/src/schemas.ts @@ -73,7 +73,7 @@ export const tMessageSchema = z.object({ overrideParentMessageId: z.string().nullable().optional(), bg: z.string().nullable().optional(), model: z.string().nullable().optional(), - title: z.string().nullable().optional(), + title: z.string().nullable().or(z.literal('New Chat')).default('New Chat'), sender: z.string(), text: z.string(), generation: z.string().nullable().optional(), @@ -103,7 +103,7 @@ export type TMessage = z.input & { export const tConversationSchema = z.object({ conversationId: z.string().nullable(), - title: z.string(), + title: z.string().nullable().or(z.literal('New Chat')).default('New Chat'), user: z.string().optional(), endpoint: eModelEndpointSchema.nullable(), suggestions: z.array(z.string()).optional(),