freddyaboulton HF Staff commited on
Commit
6a84ce9
·
1 Parent(s): 52f5c65

first draft

Browse files
Files changed (2) hide show
  1. app.py +36 -50
  2. requirements.txt +1 -0
app.py CHANGED
@@ -9,8 +9,10 @@ import gradio as gr
9
  import spaces
10
  import torch
11
  from gradio.utils import get_upload_folder
 
12
  from transformers import AutoModelForImageTextToText, AutoProcessor
13
  from transformers.generation.streamers import TextIteratorStreamer
 
14
 
15
  model_id = "google/gemma-3n-E4B-it"
16
 
@@ -152,9 +154,8 @@ def process_history(history: list[dict]) -> list[dict]:
152
  return messages
153
 
154
 
155
- @spaces.GPU(duration=120)
156
  @torch.inference_mode()
157
- def generate(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
158
  if not validate_media_constraints(message):
159
  yield ""
160
  return
@@ -199,54 +200,39 @@ def generate(message: dict, history: list[dict], system_prompt: str = "", max_ne
199
  yield output
200
 
201
 
202
- examples = [
203
- [
204
- {
205
- "text": "What is the capital of France?",
206
- "files": [],
207
- }
208
- ],
209
- [
210
- {
211
- "text": "Describe this image in detail.",
212
- "files": ["assets/cat.jpeg"],
213
- }
214
- ],
215
- [
216
- {
217
- "text": "Transcribe the following speech segment in English.",
218
- "files": ["assets/speech.wav"],
219
- }
220
- ],
221
- [
222
- {
223
- "text": "Transcribe the following speech segment in English.",
224
- "files": ["assets/speech2.wav"],
225
- }
226
- ],
227
- ]
228
-
229
- demo = gr.ChatInterface(
230
- fn=generate,
231
- type="messages",
232
- textbox=gr.MultimodalTextbox(
233
- file_types=list(IMAGE_FILE_TYPES + VIDEO_FILE_TYPES + AUDIO_FILE_TYPES),
234
- file_count="multiple",
235
- autofocus=True,
236
- ),
237
- multimodal=True,
238
- additional_inputs=[
239
- gr.Textbox(label="System Prompt", value="You are a helpful assistant."),
240
- gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700),
241
- ],
242
- stop_btn=False,
243
- title="Gemma 3n E4B it",
244
- examples=examples,
245
- run_examples_on_click=False,
246
- cache_examples=False,
247
- css_paths="style.css",
248
- delete_cache=(1800, 1800),
249
- )
250
 
251
  if __name__ == "__main__":
252
  demo.launch()
 
9
  import spaces
10
  import torch
11
  from gradio.utils import get_upload_folder
12
+ from gradio.processing_utils import save_audio_to_cache
13
  from transformers import AutoModelForImageTextToText, AutoProcessor
14
  from transformers.generation.streamers import TextIteratorStreamer
15
+ from fastrtc import ReplyOnPause, WebRTCData, WebRTC, AdditionalOutputs, get_hf_turn_credentials
16
 
17
  model_id = "google/gemma-3n-E4B-it"
18
 
 
154
  return messages
155
 
156
 
 
157
  @torch.inference_mode()
158
+ def _generate(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
159
  if not validate_media_constraints(message):
160
  yield ""
161
  return
 
200
  yield output
201
 
202
 
203
+ @spaces.GPU(time_limit=120)
204
+ def generate(data: WebRTCData, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512, image=None):
205
+ message = {"text": data.textbox, "files": [save_audio_to_cache(data.audio[1], data.audio[0], format="mp3", cache_dir=get_upload_folder())]}
206
+ new_message = {"role": "assistant", "content": ""}
207
+ for output in _generate(message, history, system_prompt, max_new_tokens):
208
+ new_message["content"] += output
209
+ yield AdditionalOutputs(history + [new_message])
210
+
211
+
212
+
213
+ with gr.Blocks() as demo:
214
+ chatbot = gr.Chatbot(type="messages")
215
+ webrtc = WebRTC(
216
+ modality="audio",
217
+ mode="send",
218
+ variant="textbox",
219
+ rtc_configuration=get_hf_turn_credentials,
220
+ server_rtc_configuration=get_hf_turn_credentials(ttl=3_600 * 24 * 30)
221
+ )
222
+ with gr.Accordion(label="Additional Inputs"):
223
+ sp = gr.Textbox(label="System Prompt", value="You are a helpful assistant."),
224
+ slider = gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700)
225
+ image = gr.Image()
226
+
227
+ webrtc.stream(
228
+ ReplyOnPause(response), # type: ignore
229
+ inputs=[webrtc, chatbot, sp, slider, image],
230
+ outputs=[chatbot],
231
+ concurrency_limit=100,
232
+ )
233
+ webrtc.on_additional_outputs(
234
+ lambda old, new: new, inputs=[chatbot], outputs=[chatbot], concurrency_limit=100
235
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
 
237
  if __name__ == "__main__":
238
  demo.launch()
requirements.txt CHANGED
@@ -311,3 +311,4 @@ uvicorn==0.34.3
311
  # via gradio
312
  websockets==15.0.1
313
  # via gradio-client
 
 
311
  # via gradio
312
  websockets==15.0.1
313
  # via gradio-client
314
+ fastrtc[vad]