Commit
·
6a84ce9
1
Parent(s):
52f5c65
first draft
Browse files- app.py +36 -50
- requirements.txt +1 -0
app.py
CHANGED
@@ -9,8 +9,10 @@ import gradio as gr
|
|
9 |
import spaces
|
10 |
import torch
|
11 |
from gradio.utils import get_upload_folder
|
|
|
12 |
from transformers import AutoModelForImageTextToText, AutoProcessor
|
13 |
from transformers.generation.streamers import TextIteratorStreamer
|
|
|
14 |
|
15 |
model_id = "google/gemma-3n-E4B-it"
|
16 |
|
@@ -152,9 +154,8 @@ def process_history(history: list[dict]) -> list[dict]:
|
|
152 |
return messages
|
153 |
|
154 |
|
155 |
-
@spaces.GPU(duration=120)
|
156 |
@torch.inference_mode()
|
157 |
-
def
|
158 |
if not validate_media_constraints(message):
|
159 |
yield ""
|
160 |
return
|
@@ -199,54 +200,39 @@ def generate(message: dict, history: list[dict], system_prompt: str = "", max_ne
|
|
199 |
yield output
|
200 |
|
201 |
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
autofocus=True,
|
236 |
-
),
|
237 |
-
multimodal=True,
|
238 |
-
additional_inputs=[
|
239 |
-
gr.Textbox(label="System Prompt", value="You are a helpful assistant."),
|
240 |
-
gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700),
|
241 |
-
],
|
242 |
-
stop_btn=False,
|
243 |
-
title="Gemma 3n E4B it",
|
244 |
-
examples=examples,
|
245 |
-
run_examples_on_click=False,
|
246 |
-
cache_examples=False,
|
247 |
-
css_paths="style.css",
|
248 |
-
delete_cache=(1800, 1800),
|
249 |
-
)
|
250 |
|
251 |
if __name__ == "__main__":
|
252 |
demo.launch()
|
|
|
9 |
import spaces
|
10 |
import torch
|
11 |
from gradio.utils import get_upload_folder
|
12 |
+
from gradio.processing_utils import save_audio_to_cache
|
13 |
from transformers import AutoModelForImageTextToText, AutoProcessor
|
14 |
from transformers.generation.streamers import TextIteratorStreamer
|
15 |
+
from fastrtc import ReplyOnPause, WebRTCData, WebRTC, AdditionalOutputs, get_hf_turn_credentials
|
16 |
|
17 |
model_id = "google/gemma-3n-E4B-it"
|
18 |
|
|
|
154 |
return messages
|
155 |
|
156 |
|
|
|
157 |
@torch.inference_mode()
|
158 |
+
def _generate(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
|
159 |
if not validate_media_constraints(message):
|
160 |
yield ""
|
161 |
return
|
|
|
200 |
yield output
|
201 |
|
202 |
|
203 |
+
@spaces.GPU(time_limit=120)
|
204 |
+
def generate(data: WebRTCData, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512, image=None):
|
205 |
+
message = {"text": data.textbox, "files": [save_audio_to_cache(data.audio[1], data.audio[0], format="mp3", cache_dir=get_upload_folder())]}
|
206 |
+
new_message = {"role": "assistant", "content": ""}
|
207 |
+
for output in _generate(message, history, system_prompt, max_new_tokens):
|
208 |
+
new_message["content"] += output
|
209 |
+
yield AdditionalOutputs(history + [new_message])
|
210 |
+
|
211 |
+
|
212 |
+
|
213 |
+
with gr.Blocks() as demo:
|
214 |
+
chatbot = gr.Chatbot(type="messages")
|
215 |
+
webrtc = WebRTC(
|
216 |
+
modality="audio",
|
217 |
+
mode="send",
|
218 |
+
variant="textbox",
|
219 |
+
rtc_configuration=get_hf_turn_credentials,
|
220 |
+
server_rtc_configuration=get_hf_turn_credentials(ttl=3_600 * 24 * 30)
|
221 |
+
)
|
222 |
+
with gr.Accordion(label="Additional Inputs"):
|
223 |
+
sp = gr.Textbox(label="System Prompt", value="You are a helpful assistant."),
|
224 |
+
slider = gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700)
|
225 |
+
image = gr.Image()
|
226 |
+
|
227 |
+
webrtc.stream(
|
228 |
+
ReplyOnPause(response), # type: ignore
|
229 |
+
inputs=[webrtc, chatbot, sp, slider, image],
|
230 |
+
outputs=[chatbot],
|
231 |
+
concurrency_limit=100,
|
232 |
+
)
|
233 |
+
webrtc.on_additional_outputs(
|
234 |
+
lambda old, new: new, inputs=[chatbot], outputs=[chatbot], concurrency_limit=100
|
235 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
|
237 |
if __name__ == "__main__":
|
238 |
demo.launch()
|
requirements.txt
CHANGED
@@ -311,3 +311,4 @@ uvicorn==0.34.3
|
|
311 |
# via gradio
|
312 |
websockets==15.0.1
|
313 |
# via gradio-client
|
|
|
|
311 |
# via gradio
|
312 |
websockets==15.0.1
|
313 |
# via gradio-client
|
314 |
+
fastrtc[vad]
|