add support for TongyiQwen tts (#2311)
Browse files### What problem does this PR solve?
add support for TongyiQwen tts
#1853
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
---------
Co-authored-by: Zhedong Cen <[email protected]>
- conf/llm_factories.json +12 -6
- rag/llm/__init__.py +2 -1
- rag/llm/tts_model.py +59 -1
conf/llm_factories.json
CHANGED
@@ -104,18 +104,24 @@
|
|
104 |
"max_tokens": 2048,
|
105 |
"model_type": "embedding"
|
106 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
{
|
108 |
"llm_name": "text-embedding-v3",
|
109 |
"tags": "TEXT EMBEDDING,8K",
|
110 |
"max_tokens": 8192,
|
111 |
"model_type": "embedding"
|
112 |
},
|
113 |
-
{
|
114 |
-
"llm_name": "paraformer-realtime-8k-v1",
|
115 |
-
"tags": "SPEECH2TEXT",
|
116 |
-
"max_tokens": 26214400,
|
117 |
-
"model_type": "speech2text"
|
118 |
-
},
|
119 |
{
|
120 |
"llm_name": "qwen-vl-max",
|
121 |
"tags": "LLM,CHAT,IMAGE2TEXT",
|
|
|
104 |
"max_tokens": 2048,
|
105 |
"model_type": "embedding"
|
106 |
},
|
107 |
+
{
|
108 |
+
"llm_name": "sambert-zhide-v1",
|
109 |
+
"tags": "TTS",
|
110 |
+
"max_tokens": 2048,
|
111 |
+
"model_type": "tts"
|
112 |
+
},
|
113 |
+
{
|
114 |
+
"llm_name": "sambert-zhiru-v1",
|
115 |
+
"tags": "TTS",
|
116 |
+
"max_tokens": 2048,
|
117 |
+
"model_type": "tts"
|
118 |
+
},
|
119 |
{
|
120 |
"llm_name": "text-embedding-v3",
|
121 |
"tags": "TEXT EMBEDDING,8K",
|
122 |
"max_tokens": 8192,
|
123 |
"model_type": "embedding"
|
124 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
{
|
126 |
"llm_name": "qwen-vl-max",
|
127 |
"tags": "LLM,CHAT,IMAGE2TEXT",
|
rag/llm/__init__.py
CHANGED
@@ -137,5 +137,6 @@ Seq2txtModel = {
|
|
137 |
}
|
138 |
|
139 |
TTSModel = {
|
140 |
-
"Fish Audio": FishAudioTTS
|
|
|
141 |
}
|
|
|
137 |
}
|
138 |
|
139 |
TTSModel = {
|
140 |
+
"Fish Audio": FishAudioTTS,
|
141 |
+
"Tongyi-Qianwen": QwenTTS
|
142 |
}
|
rag/llm/tts_model.py
CHANGED
@@ -22,7 +22,7 @@ from pydantic import BaseModel, conint
|
|
22 |
from rag.utils import num_tokens_from_string
|
23 |
import json
|
24 |
import re
|
25 |
-
|
26 |
class ServeReferenceAudio(BaseModel):
|
27 |
audio: bytes
|
28 |
text: str
|
@@ -96,3 +96,61 @@ class FishAudioTTS(Base):
|
|
96 |
|
97 |
except httpx.HTTPStatusError as e:
|
98 |
raise RuntimeError(f"**ERROR**: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
from rag.utils import num_tokens_from_string
|
23 |
import json
|
24 |
import re
|
25 |
+
import time
|
26 |
class ServeReferenceAudio(BaseModel):
|
27 |
audio: bytes
|
28 |
text: str
|
|
|
96 |
|
97 |
except httpx.HTTPStatusError as e:
|
98 |
raise RuntimeError(f"**ERROR**: {e}")
|
99 |
+
|
100 |
+
|
101 |
+
class QwenTTS(Base):
|
102 |
+
def __init__(self, key, model_name, base_url=""):
|
103 |
+
import dashscope
|
104 |
+
|
105 |
+
self.model_name = model_name
|
106 |
+
dashscope.api_key = key
|
107 |
+
|
108 |
+
def tts(self, text):
|
109 |
+
from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
|
110 |
+
from dashscope.audio.tts import ResultCallback, SpeechSynthesizer, SpeechSynthesisResult
|
111 |
+
from collections import deque
|
112 |
+
|
113 |
+
class Callback(ResultCallback):
|
114 |
+
def __init__(self) -> None:
|
115 |
+
self.dque = deque()
|
116 |
+
|
117 |
+
def _run(self):
|
118 |
+
while True:
|
119 |
+
if not self.dque:
|
120 |
+
time.sleep(0)
|
121 |
+
continue
|
122 |
+
val = self.dque.popleft()
|
123 |
+
if val:
|
124 |
+
yield val
|
125 |
+
else:
|
126 |
+
break
|
127 |
+
|
128 |
+
def on_open(self):
|
129 |
+
pass
|
130 |
+
|
131 |
+
def on_complete(self):
|
132 |
+
self.dque.append(None)
|
133 |
+
|
134 |
+
def on_error(self, response: SpeechSynthesisResponse):
|
135 |
+
raise RuntimeError(str(response))
|
136 |
+
|
137 |
+
def on_close(self):
|
138 |
+
pass
|
139 |
+
|
140 |
+
def on_event(self, result: SpeechSynthesisResult):
|
141 |
+
if result.get_audio_frame() is not None:
|
142 |
+
self.dque.append(result.get_audio_frame())
|
143 |
+
|
144 |
+
text = self.normalize_text(text)
|
145 |
+
callback = Callback()
|
146 |
+
SpeechSynthesizer.call(model=self.model_name,
|
147 |
+
text=text,
|
148 |
+
callback=callback,
|
149 |
+
format="mp3")
|
150 |
+
try:
|
151 |
+
for data in callback._run():
|
152 |
+
yield data
|
153 |
+
yield num_tokens_from_string(text)
|
154 |
+
|
155 |
+
except Exception as e:
|
156 |
+
raise RuntimeError(f"**ERROR**: {e}")
|