liuhua liuhua Kevin Hu commited on
Commit
32ef5e5
·
1 Parent(s): 45b5929

OpenAITTS (#2493)

Browse files

### What problem does this PR solve?

OpenAITTS

### Type of change


- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: liuhua <[email protected]>
Co-authored-by: Kevin Hu <[email protected]>

conf/llm_factories.json CHANGED
@@ -77,6 +77,12 @@
77
  "tags": "LLM,CHAT,IMAGE2TEXT",
78
  "max_tokens": 765,
79
  "model_type": "image2text"
 
 
 
 
 
 
80
  }
81
  ]
82
  },
 
77
  "tags": "LLM,CHAT,IMAGE2TEXT",
78
  "max_tokens": 765,
79
  "model_type": "image2text"
80
+ },
81
+ {
82
+ "llm_name": "tts-1",
83
+ "tags": "TTS",
84
+ "max_tokens": 2048,
85
+ "model_type": "tts"
86
  }
87
  ]
88
  },
rag/llm/__init__.py CHANGED
@@ -138,5 +138,6 @@ Seq2txtModel = {
138
 
139
  TTSModel = {
140
  "Fish Audio": FishAudioTTS,
141
- "Tongyi-Qianwen": QwenTTS
 
142
  }
 
138
 
139
  TTSModel = {
140
  "Fish Audio": FishAudioTTS,
141
+ "Tongyi-Qianwen": QwenTTS,
142
+ "OpenAI":OpenAITTS
143
  }
rag/llm/tts_model.py CHANGED
@@ -14,6 +14,7 @@
14
  # limitations under the License.
15
  #
16
 
 
17
  from typing import Annotated, Literal
18
  from abc import ABC
19
  import httpx
@@ -23,6 +24,8 @@ from rag.utils import num_tokens_from_string
23
  import json
24
  import re
25
  import time
 
 
26
  class ServeReferenceAudio(BaseModel):
27
  audio: bytes
28
  text: str
@@ -52,7 +55,7 @@ class Base(ABC):
52
 
53
  def tts(self, audio):
54
  pass
55
-
56
  def normalize_text(self, text):
57
  return re.sub(r'(\*\*|##\d+\$\$|#)', '', text)
58
 
@@ -78,13 +81,13 @@ class FishAudioTTS(Base):
78
  with httpx.Client() as client:
79
  try:
80
  with client.stream(
81
- method="POST",
82
- url=self.base_url,
83
- content=ormsgpack.packb(
84
- request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC
85
- ),
86
- headers=self.headers,
87
- timeout=None,
88
  ) as response:
89
  if response.status_code == HTTPStatus.OK:
90
  for chunk in response.iter_bytes():
@@ -101,7 +104,7 @@ class FishAudioTTS(Base):
101
  class QwenTTS(Base):
102
  def __init__(self, key, model_name, base_url=""):
103
  import dashscope
104
-
105
  self.model_name = model_name
106
  dashscope.api_key = key
107
 
@@ -109,11 +112,11 @@ class QwenTTS(Base):
109
  from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
110
  from dashscope.audio.tts import ResultCallback, SpeechSynthesizer, SpeechSynthesisResult
111
  from collections import deque
112
-
113
  class Callback(ResultCallback):
114
  def __init__(self) -> None:
115
  self.dque = deque()
116
-
117
  def _run(self):
118
  while True:
119
  if not self.dque:
@@ -144,13 +147,40 @@ class QwenTTS(Base):
144
  text = self.normalize_text(text)
145
  callback = Callback()
146
  SpeechSynthesizer.call(model=self.model_name,
147
- text=text,
148
- callback=callback,
149
- format="mp3")
150
  try:
151
  for data in callback._run():
152
  yield data
153
  yield num_tokens_from_string(text)
154
-
155
  except Exception as e:
156
- raise RuntimeError(f"**ERROR**: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  # limitations under the License.
15
  #
16
 
17
+ import requests
18
  from typing import Annotated, Literal
19
  from abc import ABC
20
  import httpx
 
24
  import json
25
  import re
26
  import time
27
+
28
+
29
  class ServeReferenceAudio(BaseModel):
30
  audio: bytes
31
  text: str
 
55
 
56
  def tts(self, audio):
57
  pass
58
+
59
  def normalize_text(self, text):
60
  return re.sub(r'(\*\*|##\d+\$\$|#)', '', text)
61
 
 
81
  with httpx.Client() as client:
82
  try:
83
  with client.stream(
84
+ method="POST",
85
+ url=self.base_url,
86
+ content=ormsgpack.packb(
87
+ request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC
88
+ ),
89
+ headers=self.headers,
90
+ timeout=None,
91
  ) as response:
92
  if response.status_code == HTTPStatus.OK:
93
  for chunk in response.iter_bytes():
 
104
  class QwenTTS(Base):
105
  def __init__(self, key, model_name, base_url=""):
106
  import dashscope
107
+
108
  self.model_name = model_name
109
  dashscope.api_key = key
110
 
 
112
  from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
113
  from dashscope.audio.tts import ResultCallback, SpeechSynthesizer, SpeechSynthesisResult
114
  from collections import deque
115
+
116
  class Callback(ResultCallback):
117
  def __init__(self) -> None:
118
  self.dque = deque()
119
+
120
  def _run(self):
121
  while True:
122
  if not self.dque:
 
147
  text = self.normalize_text(text)
148
  callback = Callback()
149
  SpeechSynthesizer.call(model=self.model_name,
150
+ text=text,
151
+ callback=callback,
152
+ format="mp3")
153
  try:
154
  for data in callback._run():
155
  yield data
156
  yield num_tokens_from_string(text)
157
+
158
  except Exception as e:
159
+ raise RuntimeError(f"**ERROR**: {e}")
160
+
161
+
162
+ class OpenAITTS(Base):
163
+ def __init__(self, key, model_name="tts-1", base_url="https://api.openai.com/v1"):
164
+ self.api_key = key
165
+ self.model_name = model_name
166
+ self.base_url = base_url
167
+ self.headers = {
168
+ "Authorization": f"Bearer {self.api_key}",
169
+ "Content-Type": "application/json"
170
+ }
171
+
172
+ def tts(self, text, voice="alloy"):
173
+ text = self.normalize_text(text)
174
+ payload = {
175
+ "model": self.model_name,
176
+ "voice": voice,
177
+ "input": text
178
+ }
179
+
180
+ response = requests.post(f"{self.base_url}/audio/speech", headers=self.headers, json=payload, stream=True)
181
+
182
+ if response.status_code != 200:
183
+ raise Exception(f"**Error**: {response.status_code}, {response.text}")
184
+ for chunk in response.iter_content(chunk_size=1024):
185
+ if chunk:
186
+ yield chunk