diff --git a/api/app/clients/BaseClient.js b/api/app/clients/BaseClient.js index 6a306726b..f49b431d0 100644 --- a/api/app/clients/BaseClient.js +++ b/api/app/clients/BaseClient.js @@ -26,6 +26,10 @@ class BaseClient { throw new Error("Method 'getCompletion' must be implemented."); } + sendCompletion() { + throw new Error("Method 'sendCompletion' must be implemented."); + } + getSaveOptions() { throw new Error('Subclasses must implement getSaveOptions'); } diff --git a/api/app/clients/GoogleClient.js b/api/app/clients/GoogleClient.js index e396581c7..8f0c415bb 100644 --- a/api/app/clients/GoogleClient.js +++ b/api/app/clients/GoogleClient.js @@ -18,10 +18,26 @@ class GoogleClient extends BaseClient { this.setOptions(options); } + /* Google/PaLM2 specific methods */ constructUrl() { return `https://us-central1-aiplatform.googleapis.com/v1/projects/${this.project_id}/locations/us-central1/publishers/google/models/${this.modelOptions.model}:predict`; } + async getClient() { + const scopes = ['https://www.googleapis.com/auth/cloud-platform']; + const jwtClient = new google.auth.JWT(this.client_email, null, this.private_key, scopes); + + jwtClient.authorize((err) => { + if (err) { + console.log(err); + throw err; + } + }); + + return jwtClient; + } + + /* Required Client methods */ setOptions(options) { if (this.options && !this.options.replaceOptions) { // nested options aren't spread properly, so we need to do this manually @@ -124,25 +140,19 @@ class GoogleClient extends BaseClient { return this; } - async getClient() { - const scopes = ['https://www.googleapis.com/auth/cloud-platform']; - const jwtClient = new google.auth.JWT(this.client_email, null, this.private_key, scopes); - - jwtClient.authorize((err) => { - if (err) { - console.log(err); - throw err; - } - }); - - return jwtClient; + getMessageMapMethod() { + return ((message) => ({ + author: message?.author ?? (message.isCreatedByUser ? this.userLabel : this.modelLabel), + content: message?.content ?? message.text + })).bind(this); } - buildMessages(input, { messages = [] }) { + buildMessages(messages = []) { + const formattedMessages = messages.map(this.getMessageMapMethod()); let payload = { instances: [ { - messages: [...messages, { author: this.userLabel, content: input }] + messages: formattedMessages, } ], parameters: this.options.modelOptions @@ -156,23 +166,24 @@ class GoogleClient extends BaseClient { payload.instances[0].examples = this.options.examples; } + /* TO-DO: text model needs more context since it can't process an array of messages */ if (this.isTextModel) { payload.instances = [ { - prompt: input + prompt: messages[messages.length -1].content } ]; } if (this.options.debug) { - console.debug('buildMessages'); + console.debug('GoogleClient buildMessages'); console.dir(payload, { depth: null }); } - return payload; + return { prompt: payload }; } - async getCompletion(input, messages = [], abortController = null) { + async getCompletion(payload, abortController = null) { if (!abortController) { abortController = new AbortController(); } @@ -198,19 +209,11 @@ class GoogleClient extends BaseClient { } const client = await this.getClient(); - const payload = this.buildMessages(input, { messages }); const res = await client.request({ url, method: 'POST', data: payload }); console.dir(res.data, { depth: null }); return res.data; } - getMessageMapMethod() { - return ((message) => ({ - author: message.isCreatedByUser ? this.userLabel : this.modelLabel, - content: message?.content ?? message.text - })).bind(this); - } - getSaveOptions() { return { ...this.modelOptions @@ -218,24 +221,15 @@ class GoogleClient extends BaseClient { } getBuildMessagesOptions() { - console.log('GoogleClient doesn\'t use getBuildMessagesOptions'); + // console.log('GoogleClient doesn\'t use getBuildMessagesOptions'); } - async sendMessage(message, opts = {}) { - console.log('GoogleClient: sendMessage', message, opts); - const { - user, - conversationId, - responseMessageId, - saveOptions, - userMessage, - } = await this.handleStartMethods(message, opts); - - await this.saveMessageToDatabase(userMessage, saveOptions, user); + async sendCompletion(payload, opts = {}) { + console.log('GoogleClient: sendcompletion', payload, opts); let reply = ''; let blocked = false; try { - const result = await this.getCompletion(message, this.currentMessages, opts.abortController); + const result = await this.getCompletion(payload, opts.abortController); blocked = result?.predictions?.[0]?.safetyAttributes?.blocked; reply = result?.predictions?.[0]?.candidates?.[0]?.content || @@ -254,29 +248,14 @@ class GoogleClient extends BaseClient { console.error(err); } - if (this.options.debug) { - console.debug('options'); - console.debug(this.options); - } - if (!blocked) { await this.generateTextStream(reply, opts.onProgress, { delay: 0.5 }); } - const responseMessage = { - messageId: responseMessageId, - conversationId, - parentMessageId: userMessage.messageId, - sender: this.sender, - text: reply, - error: blocked, - isCreatedByUser: false - }; - - await this.saveMessageToDatabase(responseMessage, saveOptions, user); - return responseMessage; + return reply.trim(); } + /* TO-DO: Handle tokens with Google tokenization NOTE: these are required */ static getTokenizer(encoding, isModelName = false, extendSpecialTokens = {}) { if (tokenizersCache[encoding]) { return tokenizersCache[encoding];