🌊 feat: enhance TTSService with Deepgram SDK integration and refactor voice validation

This commit is contained in:
Marco Beretta
2024-11-23 16:49:56 +01:00
parent 5eabd2493c
commit b7f4903acd
3 changed files with 123 additions and 33 deletions

View File

@@ -199,7 +199,7 @@ class STTService {
topics: sttSchema.intelligence?.topics,
};
[configOptions].forEach(this.removeUndefined);
this.removeUndefined(configOptions);
const { result, error } = await deepgram.listen.prerecorded.transcribeFile(
audioReadStream,
@@ -213,6 +213,7 @@ class STTService {
return result.results?.channels[0]?.alternatives[0]?.transcript || '';
}
// TODO: Implement a better way to determine if the SDK should be used
shouldUseSDK(provider, sttSchema) {
if (provider !== STTProviders.OPENAI && provider !== STTProviders.AZURE_OPENAI) {
return true;

View File

@@ -1,9 +1,11 @@
const axios = require('axios');
const { createClient } = require('@deepgram/sdk');
const { extractEnvVariable, TTSProviders } = require('librechat-data-provider');
const { getRandomVoiceId, createChunkProcessor, splitTextIntoChunks } = require('./streamAudio');
const { getCustomConfig } = require('~/server/services/Config');
const { genAzureEndpoint } = require('~/utils');
const { logger } = require('~/config');
const { Readable } = require('stream');
/**
* Service class for handling Text-to-Speech (TTS) operations.
@@ -16,13 +18,17 @@ class TTSService {
*/
constructor(customConfig) {
this.customConfig = customConfig;
this.providerStrategies = {
this.apiStrategies = {
[TTSProviders.OPENAI]: this.openAIProvider.bind(this),
[TTSProviders.AZURE_OPENAI]: this.azureOpenAIProvider.bind(this),
[TTSProviders.ELEVENLABS]: this.elevenLabsProvider.bind(this),
[TTSProviders.LOCALAI]: this.localAIProvider.bind(this),
[TTSProviders.ELEVENLABS]: this.elevenLabsProvider.bind(this),
};
this.sdkStrategies = {
[TTSProviders.DEEPGRAM]: this.deepgramSDKProvider.bind(this),
};
}
/**
@@ -110,19 +116,14 @@ class TTSService {
openAIProvider(ttsSchema, input, voice) {
const url = ttsSchema?.url || 'https://api.openai.com/v1/audio/speech';
if (
ttsSchema?.voices &&
ttsSchema.voices.length > 0 &&
!ttsSchema.voices.includes(voice) &&
!ttsSchema.voices.includes('ALL')
) {
if (ttsSchema?.voices && ttsSchema.voices.length > 0 && !ttsSchema.voices.includes(voice)) {
throw new Error(`Voice ${voice} is not available.`);
}
const data = {
input,
model: ttsSchema?.model,
voice: ttsSchema?.voices && ttsSchema.voices.length > 0 ? voice : undefined,
voice: voice,
backend: ttsSchema?.backend,
};
@@ -148,19 +149,14 @@ class TTSService {
azureOpenAIApiDeploymentName: ttsSchema?.deploymentName,
})}/audio/speech?api-version=${ttsSchema?.apiVersion}`;
if (
ttsSchema?.voices &&
ttsSchema.voices.length > 0 &&
!ttsSchema.voices.includes(voice) &&
!ttsSchema.voices.includes('ALL')
) {
if (ttsSchema?.voices && ttsSchema.voices.length > 0 && !ttsSchema.voices.includes(voice)) {
throw new Error(`Voice ${voice} is not available.`);
}
const data = {
model: ttsSchema?.model,
input,
voice: ttsSchema?.voices && ttsSchema.voices.length > 0 ? voice : undefined,
voice: voice,
};
const headers = {
@@ -185,7 +181,7 @@ class TTSService {
ttsSchema?.url ||
`https://api.elevenlabs.io/v1/text-to-speech/${voice}${stream ? '/stream' : ''}`;
if (!ttsSchema?.voices.includes(voice) && !ttsSchema?.voices.includes('ALL')) {
if (!ttsSchema?.voices.includes(voice)) {
throw new Error(`Voice ${voice} is not available.`);
}
@@ -221,18 +217,13 @@ class TTSService {
localAIProvider(ttsSchema, input, voice) {
const url = ttsSchema?.url;
if (
ttsSchema?.voices &&
ttsSchema.voices.length > 0 &&
!ttsSchema.voices.includes(voice) &&
!ttsSchema.voices.includes('ALL')
) {
if (ttsSchema?.voices && ttsSchema.voices.length > 0 && !ttsSchema.voices.includes(voice)) {
throw new Error(`Voice ${voice} is not available.`);
}
const data = {
input,
model: ttsSchema?.voices && ttsSchema.voices.length > 0 ? voice : undefined,
model: voice,
backend: ttsSchema?.backend,
};
@@ -248,6 +239,89 @@ class TTSService {
return [url, data, headers];
}
/**
* Converts a ReadableStream to a Node.js stream (used in Deepgram SDK).
* @async
* @param {ReadableStream} readableStream - The ReadableStream to convert.
* @returns {Promise<Readable>} The Node.js stream.
* @throws {Error} If the conversion fails.
*/
async streamToNodeStream(readableStream) {
const reader = readableStream.getReader();
const nodeStream = new Readable({
async read() {
try {
const { value, done } = await reader.read();
if (done) {
this.push(null);
} else {
this.push(Buffer.from(value));
}
} catch (err) {
this.destroy(err);
}
},
});
return nodeStream;
}
/**
* Prepares the request for Deepgram SDK TTS provider.
* @async
* @param {Object} ttsSchema - The TTS schema for Deepgram SDK.
* @param {string} input - The input text.
* @param {string} voice - The selected voice.
* @returns {Promise<Object>} The response object.
* @throws {Error} If the selected voice is not available or the request fails.
*/
async deepgramSDKProvider(ttsSchema, input, voice) {
const apiKey = extractEnvVariable(ttsSchema.apiKey) || '';
const deepgram = createClient(apiKey);
if (ttsSchema?.voices && ttsSchema.voices.length > 0 && !ttsSchema.voices.includes(voice)) {
throw new Error(`Voice ${voice} is not available.`);
}
const modelParts = [ttsSchema.model, voice, ttsSchema.language].filter(Boolean);
const configOptions = {
model: modelParts.join('-'),
encoding: 'linear16',
container: 'wav',
bit_rate: ttsSchema.media_settings?.bit_rate,
sample_rate: ttsSchema.media_settings?.sample_rate,
};
this.removeUndefined(configOptions);
try {
const response = await deepgram.speak.request({ text: input }, configOptions);
const audioStream = await response.getStream();
const headers = await response.getHeaders();
// Convert ReadableStream to Node.js stream
const nodeStream = await this.streamToNodeStream(audioStream);
return {
data: nodeStream,
headers,
status: 200,
};
} catch (error) {
logger.error('Deepgram TTS request failed:', error);
throw error;
}
}
// TODO: Implement a better way to determine if the SDK should be used
shouldUseSDK(provider, sttSchema) {
if (provider == TTSProviders.DEEPGRAM) {
return true;
}
return typeof sttSchema.url === 'string' && sttSchema.url.trim().length > 0;
}
/**
* Sends a TTS request to the specified provider.
* @async
@@ -261,22 +335,34 @@ class TTSService {
* @throws {Error} If the provider is invalid or the request fails.
*/
async ttsRequest(provider, ttsSchema, { input, voice, stream = true }) {
const strategy = this.providerStrategies[provider];
const useSDK = this.shouldUseSDK(provider, ttsSchema);
const strategy = useSDK ? this.sdkStrategies[provider] : this.apiStrategies[provider];
if (!strategy) {
throw new Error('Invalid provider');
}
const [url, data, headers] = strategy.call(this, ttsSchema, input, voice, stream);
if (useSDK) {
const response = await strategy.call(this, ttsSchema, input, voice, stream);
[data, headers].forEach(this.removeUndefined.bind(this));
return {
data: response.data,
headers: response.headers,
status: response.status,
};
} else {
const [url, data, headers] = strategy.call(this, ttsSchema, input, voice, stream);
const options = { headers, responseType: stream ? 'stream' : 'arraybuffer' };
[data, headers].forEach(this.removeUndefined.bind(this));
try {
return await axios.post(url, data, options);
} catch (error) {
logger.error(`TTS request failed for provider ${provider}:`, error);
throw error;
const options = { headers, responseType: stream ? 'stream' : 'arraybuffer' };
try {
return await axios.post(url, data, options);
} catch (error) {
logger.error(`TTS request failed for provider ${provider}:`, error);
throw error;
}
}
}

View File

@@ -37,6 +37,9 @@ async function getVoices(req, res) {
case TTSProviders.LOCALAI:
voices = ttsSchema.localai?.voices;
break;
case TTSProviders.DEEPGRAM:
voices = ttsSchema.deepgram?.voices;
break;
default:
throw new Error('Invalid provider');
}