GLorr commited on
Commit
6c09f76
·
verified ·
1 Parent(s): b6b7427

Upload folder using huggingface_hub

Browse files
.gitignore ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ # Virtual environments
10
+ .venv
11
+
12
+ # Environment variables
13
+ .env
14
+
15
+ .vscode/
.pre-commit-config.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/PyCQA/bandit
3
+ rev: 1.7.4
4
+ hooks:
5
+ - id: bandit
6
+ name: bandit
7
+ types: [python]
8
+
9
+ - repo: https://github.com/astral-sh/ruff-pre-commit
10
+ # Ruff version.
11
+ rev: v0.4.8
12
+ hooks:
13
+ # Run the linter.
14
+ - id: ruff
15
+ # Run the formatter.
16
+ - id: ruff-format
17
+
18
+ - repo: https://github.com/psf/black
19
+ rev: 23.1.0
20
+ hooks:
21
+ - id: black
22
+ name: black
23
+
24
+ - repo: https://github.com/pre-commit/mirrors-isort
25
+ rev: v5.10.1
26
+ hooks:
27
+ - id: isort
28
+ args: ["--profile", "black"]
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.11.9
.ruff_cache/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Automatically created by ruff.
2
+ *
.ruff_cache/0.4.8/17181755630229836148 ADDED
Binary file (187 Bytes). View file
 
.ruff_cache/0.4.8/2516455456322530856 ADDED
Binary file (291 Bytes). View file
 
.ruff_cache/0.4.8/3664365949595148797 ADDED
Binary file (222 Bytes). View file
 
.ruff_cache/0.9.6/12093191028265889985 ADDED
Binary file (236 Bytes). View file
 
.ruff_cache/0.9.6/16582661031577879600 ADDED
Binary file (187 Bytes). View file
 
.ruff_cache/0.9.6/6136549848780317009 ADDED
Binary file (222 Bytes). View file
 
.ruff_cache/CACHEDIR.TAG ADDED
@@ -0,0 +1 @@
 
 
1
+ Signature: 8a477f597d28d172789f06886806bc55
README.md CHANGED
@@ -1,12 +1,78 @@
1
  ---
2
- title: ML6 Gemini Demo
3
- emoji: 🔥
4
- colorFrom: red
5
- colorTo: gray
6
  sdk: gradio
7
- sdk_version: 5.23.1
8
- app_file: app.py
9
- pinned: false
10
  ---
 
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: ML6-Gemini-Demo
3
+ app_file: src/app.py
 
 
4
  sdk: gradio
5
+ sdk_version: 5.23.0
 
 
6
  ---
7
+ # Gemini Voice Agent Demo
8
 
9
+ This repo contains a demo using the Gemini MultiModal API to create a voice-based agent that can conduct professional technical screening interviews.
10
+
11
+
12
+ ## Technical Overview
13
+
14
+ The system is based on FastRTC and Gradio to provide a real-time voice UI.
15
+
16
+ ### About the modality
17
+
18
+ You can configure the output modality:
19
+
20
+ - If set to AUDIO
21
+ - The agent will respond with an audio response.
22
+ - There is no text output so no transcription
23
+ if set to TEXT
24
+ - The agent will respond with a text response.
25
+ - The text output will be transcribed to audio using the TTS API.
26
+ - Transcriptions are available.
27
+
28
+ ### Function Calling
29
+
30
+ There are 2 functions that can be called:
31
+ - Answer validation
32
+ - will check the answer type vs the expected type
33
+ - will store the answer
34
+ - Log Input
35
+ - will log the user input
36
+ - this is a form of transcribing the incoming audio
37
+
38
+ ## Getting Started
39
+
40
+ To run the application, follow these steps:
41
+
42
+ 1. Install uv (if not already installed):
43
+ `curl -LsSf https://astral.sh/uv/install.sh | sh`
44
+
45
+ 2. Install dependencies:
46
+ `uv sync`
47
+
48
+ 3. Setup the environment variables for either GenAI or VertexAI (see below)
49
+
50
+ 4. Run the application:
51
+ `python src/app.py`
52
+
53
+ 5. Visit `http://127.0.0.1:7860` in your browser to interact with the voice agent.
54
+
55
+
56
+ ### GenAI vs VertexAI
57
+
58
+ "gemini-2.0-flash-exp" can be used in both GenAI and VertexAI. [more info](https://github.com/heiko-hotz/gemini-multimodal-live-dev-guide?tab=readme-ov-file)
59
+
60
+ - GenAI requires just a GEMINI_API_KEY environment variable [link](https://ai.google.dev/gemini-api/docs/api-key)
61
+ - VertexAI requires a GCP project and the following environment variables:
62
+ ```
63
+ export GOOGLE_CLOUD_PROJECT=YOUR_PROJECT_ID
64
+ export GOOGLE_CLOUD_LOCATION=europe-west4
65
+ export GOOGLE_GENAI_USE_VERTEXAI=True
66
+ ```
67
+
68
+ Depending `GOOGLE_GENAI_USE_VERTEXAI` flag this demo will use either GenAI or VertexAI.
69
+
70
+ ### Note
71
+
72
+ The gradio-webrtc install fails unless you have ffmpeg@6, on mac:
73
+
74
+ ```
75
+ brew uninstall ffmpeg
76
+ brew install ffmpeg@6
77
+ brew link ffmpeg@6
78
+ ```
pyproject.toml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "gemini-voice-agents"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.11.9"
7
+ dependencies = [
8
+ "fastrtc>=0.0.17",
9
+ "google>=3.0.0",
10
+ "google-cloud>=0.34.0",
11
+ "google-cloud-texttospeech>=2.25.1",
12
+ "google-genai>=1.7.0",
13
+ "gradio>=5.23.0",
14
+ "numpy>=2.1.3",
15
+ ]
16
+
17
+ [dependency-groups]
18
+ dev = [
19
+ "ruff>=0.9.6",
20
+ "pre-commit>=4.1",
21
+ ]
questions.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "id": 1,
4
+ "question": "What is your full name?",
5
+ "answer_format": "str"
6
+ },
7
+ {
8
+ "id": 2,
9
+ "question": "What is your current job title?",
10
+ "answer_format": "str"
11
+ },
12
+ {
13
+ "id": 3,
14
+ "question": "How many years of relevant experience do you have?",
15
+ "answer_format": "int"
16
+ },
17
+ {
18
+ "id": 4,
19
+ "question": "Are you looking for a new job?",
20
+ "answer_format": "bool"
21
+ },
22
+ {
23
+ "id": 5,
24
+ "question": "List your three strongest technical skills.",
25
+ "answer_format": "list[str]"
26
+ }
27
+ ]
src copy/app.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2025 Google LLC
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ ## Setup
17
+
18
+ The gradio-webrtc install fails unless you have ffmpeg@6, on mac:
19
+
20
+ ```
21
+ brew uninstall ffmpeg
22
+ brew install ffmpeg@6
23
+ brew link ffmpeg@6
24
+ ```
25
+
26
+ Create a virtual python environment, then install the dependencies for this script:
27
+
28
+ ```
29
+ pip install websockets numpy gradio-webrtc "gradio>=5.9.1"
30
+ ```
31
+
32
+ If installation fails it may be
33
+
34
+ Before running this script, ensure the `GOOGLE_API_KEY` environment
35
+
36
+ ```
37
+ $ export GOOGLE_API_KEY ='add your key here'
38
+ ```
39
+
40
+ You can get an api-key from Google AI Studio (https://aistudio.google.com/apikey)
41
+
42
+ ## Run
43
+
44
+ To run the script:
45
+
46
+ ```
47
+ python gemini_gradio_audio.py
48
+ ```
49
+
50
+ On the gradio page (http://127.0.0.1:7860/) click record, and talk, gemini will reply. But note that interruptions
51
+ don't work.
52
+
53
+ """
54
+
55
+ import base64
56
+ import json
57
+ import os
58
+ import wave
59
+ import itertools
60
+
61
+ import gradio as gr
62
+ import numpy as np
63
+ import websockets.sync.client
64
+ from gradio_webrtc import StreamHandler, WebRTC
65
+ from jinja2 import Template
66
+ import threading
67
+ import queue
68
+
69
+
70
+ from tools import FUNCTION_MAP, TOOLS
71
+ from google.cloud import texttospeech
72
+
73
+ # logging.basicConfig(
74
+ # level=logging.INFO,
75
+ # format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
76
+ # )
77
+ # logger = logging.getLogger(__name__)
78
+
79
+
80
+ with open("questions.json", "r") as f:
81
+ questions_dict = json.load(f)
82
+
83
+ with open("src/prompts/default_prompt.jinja2") as f:
84
+ template_str = f.read()
85
+ template = Template(template_str)
86
+ system_prompt = template.render(questions=json.dumps(questions_dict, indent=4))
87
+
88
+ print(system_prompt)
89
+
90
+
91
+ # TOOLS = types.GenerateContentConfig(tools=[validate_answer])
92
+
93
+
94
+ __version__ = "0.0.3"
95
+
96
+ KEY_NAME = "GOOGLE_API_KEY"
97
+
98
+
99
+ # Configuration and Utilities
100
+ class GeminiConfig:
101
+ """Configuration settings for Gemini API."""
102
+
103
+ def __init__(self):
104
+ self.api_key = os.getenv(KEY_NAME)
105
+ self.host = "generativelanguage.googleapis.com"
106
+ self.model = "models/gemini-2.0-flash-exp"
107
+ self.ws_url = f"wss://{self.host}/ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent?key={self.api_key}"
108
+
109
+ class TTSStreamer:
110
+ def __init__(self):
111
+ self.client = texttospeech.TextToSpeechClient()
112
+ self.text_queue = queue.Queue()
113
+ self.audio_queue = queue.Queue()
114
+
115
+ def start_stream(self):
116
+ streaming_config = texttospeech.StreamingSynthesizeConfig(
117
+ voice=texttospeech.VoiceSelectionParams(
118
+ name="en-US-Journey-D",
119
+ language_code="en-US"
120
+ )
121
+ )
122
+ config_request = texttospeech.StreamingSynthesizeRequest(
123
+ streaming_config=streaming_config
124
+ )
125
+
126
+ def request_generator():
127
+ while True:
128
+ try:
129
+ text = self.text_queue.get()
130
+ if text is None: # Poison pill to stop
131
+ break
132
+ yield texttospeech.StreamingSynthesizeRequest(
133
+ input=texttospeech.StreamingSynthesisInput(text=text)
134
+ )
135
+ except queue.Empty:
136
+ continue
137
+
138
+ def audio_processor():
139
+ responses = self.client.streaming_synthesize(
140
+ itertools.chain([config_request], request_generator())
141
+ )
142
+ print(f"Responses: {responses}")
143
+ for response in responses:
144
+ self.audio_queue.put(response.audio_content)
145
+
146
+ self.processor_thread = threading.Thread(target=audio_processor)
147
+ self.processor_thread.start()
148
+
149
+ def send_text(self, text: str):
150
+ """Send text to be synthesized."""
151
+ self.text_queue.put(text)
152
+
153
+ def get_audio(self):
154
+ """Get the next chunk of audio bytes."""
155
+ try:
156
+ return self.audio_queue.get_nowait()
157
+ except queue.Empty:
158
+ return None
159
+
160
+ def stop(self):
161
+ """Stop the streaming synthesis."""
162
+ self.text_queue.put(None) # Send poison pill
163
+ if self.processor_thread:
164
+ self.processor_thread.join()
165
+
166
+
167
+ class AudioProcessor:
168
+ """Handles encoding and decoding of audio data."""
169
+
170
+ @staticmethod
171
+ def encode_audio(data, sample_rate):
172
+ """Encodes audio data to base64."""
173
+ encoded = base64.b64encode(data.tobytes()).decode("UTF-8")
174
+ return {
175
+ "realtimeInput": {
176
+ "mediaChunks": [
177
+ {
178
+ "mimeType": f"audio/pcm;rate={sample_rate}",
179
+ "data": encoded,
180
+ }
181
+ ],
182
+ },
183
+ }
184
+
185
+ @staticmethod
186
+ def process_audio_response(data):
187
+ """Decodes audio data from base64."""
188
+ audio_data = base64.b64decode(data)
189
+ return np.frombuffer(audio_data, dtype=np.int16)
190
+
191
+
192
+ # Gemini Interaction Handler
193
+ class GeminiHandler(StreamHandler):
194
+ """Handles streaming interactions with the Gemini API."""
195
+
196
+ def __init__(
197
+ self,
198
+ audio_file=None,
199
+ expected_layout="mono",
200
+ output_sample_rate=24000,
201
+ output_frame_size=480,
202
+ ) -> None:
203
+ super().__init__(
204
+ expected_layout,
205
+ output_sample_rate,
206
+ output_frame_size,
207
+ input_sample_rate=24000,
208
+ )
209
+ self.config = GeminiConfig()
210
+ self.ws = None
211
+ self.all_output_data = None
212
+ self.audio_processor = AudioProcessor()
213
+ self.audio_file = audio_file
214
+ self.text_buffer = ""
215
+ self.tts_engine = None
216
+
217
+ def copy(self):
218
+ """Creates a copy of the GeminiHandler instance."""
219
+ return GeminiHandler(
220
+ expected_layout=self.expected_layout,
221
+ output_sample_rate=self.output_sample_rate,
222
+ output_frame_size=self.output_frame_size,
223
+ )
224
+
225
+ def _initialize_websocket(self):
226
+ """Initializes the WebSocket connection to the Gemini API."""
227
+ try:
228
+ self.ws = websockets.sync.client.connect(self.config.ws_url, timeout=3000)
229
+ setup_request = {
230
+ "setup": {
231
+ "model": self.config.model,
232
+ "tools": [{"functionDeclarations": TOOLS}],
233
+ "generationConfig": {"responseModalities": "TEXT"},
234
+ "systemInstruction": {
235
+ "parts": [{"text": system_prompt}],
236
+ "role": "user",
237
+ },
238
+ }
239
+ }
240
+ self.ws.send(json.dumps(setup_request))
241
+ setup_response = json.loads(self.ws.recv())
242
+ print(f"Setup response: {setup_response}")
243
+
244
+ if self.audio_file:
245
+ self.input_audio_file(self.audio_file)
246
+ print("Audio file sent")
247
+
248
+ except websockets.exceptions.WebSocketException as e:
249
+ print(f"WebSocket connection failed: {str(e)}")
250
+ self.ws = None
251
+ except Exception as e:
252
+ print(f"Setup failed: {str(e)}")
253
+ self.ws = None
254
+
255
+ def input_audio_file(self, audio_file):
256
+ """Processes an audio file and sends it to the Gemini API."""
257
+ try:
258
+ with wave.open(audio_file, "rb") as wf:
259
+ data = wf.readframes(wf.getnframes())
260
+ self.receive((wf.getframerate(), np.frombuffer(data, dtype=np.int16)))
261
+ except Exception as e:
262
+ print(f"Error in input_audio_file: {str(e)}")
263
+
264
+ def receive(self, frame: tuple[int, np.ndarray]) -> None:
265
+ """Receives audio/video data, encodes it, and sends it to the Gemini API."""
266
+ try:
267
+ if not self.ws:
268
+ self._initialize_websocket()
269
+
270
+ sample_rate, array = frame
271
+ message = {"realtimeInput": {"mediaChunks": []}}
272
+
273
+ if sample_rate > 0 and array is not None:
274
+ array = array.squeeze()
275
+ audio_data = self.audio_processor.encode_audio(
276
+ array, self.output_sample_rate
277
+ )
278
+ message["realtimeInput"]["mediaChunks"].append(
279
+ {
280
+ "mimeType": f"audio/pcm;rate={self.output_sample_rate}",
281
+ "data": audio_data["realtimeInput"]["mediaChunks"][0]["data"],
282
+ }
283
+ )
284
+
285
+ if message["realtimeInput"]["mediaChunks"]:
286
+ self.ws.send(json.dumps(message))
287
+ except Exception as e:
288
+ print(f"Error in receive: {str(e)}")
289
+ if self.ws:
290
+ self.ws.close()
291
+ self.ws = None
292
+
293
+ def handle_tool_call(self, tool_call):
294
+ print(" ", tool_call)
295
+ for fc in tool_call["functionCalls"]:
296
+ print(f"Function call: {fc}")
297
+ # Call the function
298
+ try:
299
+ result = {"output": FUNCTION_MAP[fc["name"]](**fc["args"])}
300
+ except Exception as e:
301
+ result = {"error": str(e)}
302
+
303
+ # Send the response back
304
+ msg = {
305
+ "tool_response": {
306
+ "function_responses": [
307
+ {"id": fc["id"], "name": fc["name"], "response": result}
308
+ ]
309
+ }
310
+ }
311
+ print(f"function response: {msg}")
312
+ self.ws.send(json.dumps(msg))
313
+
314
+ def _output_data(self, audio_array):
315
+ """Processes audio output data from the WebSocket response."""
316
+ if self.all_output_data is None:
317
+ self.all_output_data = audio_array
318
+ else:
319
+ self.all_output_data = np.concatenate((self.all_output_data, audio_array))
320
+
321
+ while self.all_output_data.shape[-1] >= self.output_frame_size:
322
+ yield (
323
+ self.output_sample_rate,
324
+ self.all_output_data[: self.output_frame_size].reshape(1, -1),
325
+ )
326
+ self.all_output_data = self.all_output_data[self.output_frame_size :]
327
+
328
+ def _process_server_content(self, content):
329
+ """Processes audio output data from the WebSocket response."""
330
+ if respone := content.get("modelTurn", {}):
331
+ if parts:= respone.get("parts"):
332
+ for part in parts:
333
+ print(f"Part: {part}")
334
+ data = part.get("inlineData", {}).get("data", "")
335
+ if data:
336
+ audio_array = self.audio_processor.process_audio_response(data)
337
+ yield from self._output_data(audio_array)
338
+
339
+ text = part.get("text", "")
340
+ if text:
341
+ self.text_buffer += text
342
+
343
+
344
+
345
+ # audio_array = self._text_to_audio(text)
346
+ # yield from self._output_data(audio_array)
347
+ # # self.text_buffer += text
348
+
349
+ # Check if the turn is complete and process the text buffer into audio
350
+ if content.get("turnComplete"):
351
+ if self.text_buffer:
352
+ audio_array = self._text_to_audio(self.text_buffer)
353
+ yield from self._output_data(audio_array)
354
+ self.text_buffer = ""
355
+
356
+
357
+ def _text_to_audio(self, text: str) -> np.ndarray:
358
+ """Convert text to audio using Google Cloud TTS streaming."""
359
+
360
+ client = texttospeech.TextToSpeechClient()
361
+
362
+ # Configure synthesis
363
+ synthesis_input = texttospeech.SynthesisInput(text=text)
364
+ voice = texttospeech.VoiceSelectionParams(
365
+ name="en-IN-Chirp-HD-O",
366
+ language_code="en-IN"
367
+ )
368
+ audio_config = texttospeech.AudioConfig(
369
+ audio_encoding=texttospeech.AudioEncoding.LINEAR16
370
+ )
371
+
372
+ # Get response in a single request
373
+ try:
374
+ response = client.synthesize_speech(
375
+ input=synthesis_input,
376
+ voice=voice,
377
+ audio_config=audio_config
378
+ )
379
+ return np.frombuffer(response.audio_content, dtype=np.int16)
380
+ except Exception as e:
381
+ print(f"Error in speech synthesis: {e}")
382
+ return np.array([], dtype=np.int16)
383
+
384
+
385
+ def generator(self):
386
+ """Generates audio output from the WebSocket stream."""
387
+ while True:
388
+ if not self.ws:
389
+ print("WebSocket not connected")
390
+ yield None
391
+ continue
392
+
393
+ try:
394
+ message = self.ws.recv(timeout=30)
395
+ msg = json.loads(message)
396
+
397
+ # {'serverContent': {'modelTurn': {'parts': [{'text': 'Hello'}]}}}
398
+ # {'serverContent': {'modelTurn': {'parts': [{'text': ', good morning! Thank you for taking my call. My name is [Your'}]}}}
399
+ # {'serverContent': {'modelTurn': {'parts': [{'text': " Name] and I'm a technical recruiter. I'm conducting a quick"}]}}}
400
+ # {'serverContent': {'modelTurn': {'parts': [{'text': ' initial screening, is that okay with you?\n'}]}}}
401
+ # {'serverContent': {'turnComplete': True}}
402
+
403
+ if "serverContent" in msg:
404
+ content = msg["serverContent"]
405
+ yield from self._process_server_content(content)
406
+ elif "toolCall" in msg:
407
+ yield from self.handle_tool_call(msg["toolCall"])
408
+
409
+ except TimeoutError:
410
+ print("Timeout waiting for server response")
411
+ yield None
412
+ except Exception:
413
+ yield None
414
+
415
+ def emit(self) -> tuple[int, np.ndarray] | None:
416
+ """Emits the next audio chunk from the generator."""
417
+ if not self.ws:
418
+ return None
419
+ if not hasattr(self, "_generator"):
420
+ self._generator = self.generator()
421
+ try:
422
+ return next(self._generator)
423
+ except StopIteration:
424
+ self.reset()
425
+ return None
426
+
427
+ def reset(self) -> None:
428
+ """Resets the generator and output data."""
429
+ if hasattr(self, "_generator"):
430
+ delattr(self, "_generator")
431
+ self.all_output_data = None
432
+
433
+ def shutdown(self) -> None:
434
+ """Closes the WebSocket connection."""
435
+ if self.ws:
436
+ self.ws.close()
437
+
438
+ def check_connection(self):
439
+ """Checks if the WebSocket connection is active."""
440
+ try:
441
+ if not self.ws or self.ws.closed:
442
+ self._initialize_websocket()
443
+ return True
444
+ except Exception as e:
445
+ print(f"Connection check failed: {str(e)}")
446
+ return False
447
+
448
+
449
+ def update_answers():
450
+ with open("answers.json", "r") as f:
451
+ return json.load(f)
452
+
453
+
454
+ # Main Gradio Interface
455
+ def registry(name: str, token: str | None = None, **kwargs):
456
+ """Sets up and returns the Gradio interface."""
457
+ api_key = token or os.environ.get(KEY_NAME)
458
+ if not api_key:
459
+ raise ValueError(f"{KEY_NAME} environment variable is not set.")
460
+
461
+ interface = gr.Blocks()
462
+ with interface:
463
+ with gr.Tabs():
464
+ with gr.TabItem("Voice Chat"):
465
+ gr.HTML(
466
+ """
467
+ <div style='text-align: left'>
468
+ <h1>ML6 Voice Demo - Function Calling and Custom Output Voice</h1>
469
+ </div>
470
+ """
471
+ )
472
+ gemini_handler = GeminiHandler()
473
+ # gemini_handler = ThreeStepHandler()
474
+
475
+ with gr.Row():
476
+ audio = WebRTC(
477
+ label="Voice Chat", modality="audio", mode="send-receive"
478
+ )
479
+
480
+ # Add display components for questions and answers
481
+ with gr.Row():
482
+ with gr.Column():
483
+ gr.JSON(
484
+ label="Questions",
485
+ value=questions_dict,
486
+ )
487
+ with gr.Column():
488
+ gr.JSON(update_answers, label="Collected Answers", every=1)
489
+
490
+ audio.stream(
491
+ gemini_handler,
492
+ inputs=[audio], # Add audio_file to inputs
493
+ outputs=[audio],
494
+ time_limit=600,
495
+ concurrency_limit=10,
496
+ )
497
+
498
+ return interface
499
+
500
+
501
+ # Launch the Gradio interface
502
+ gr.load(
503
+ name="gemini-2.0-flash-exp",
504
+ src=registry,
505
+ ).launch()
506
+
src copy/app2.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright 2025 Google LLC
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ ## Setup
17
+
18
+ The gradio-webrtc install fails unless you have ffmpeg@6, on mac:
19
+
20
+ ```
21
+ brew uninstall ffmpeg
22
+ brew install ffmpeg@6
23
+ brew link ffmpeg@6
24
+ ```
25
+
26
+ Create a virtual python environment, then install the dependencies for this script:
27
+
28
+ ```
29
+ pip install websockets numpy gradio-webrtc "gradio>=5.9.1"
30
+ ```
31
+
32
+ If installation fails it may be
33
+
34
+ Before running this script, ensure the `GOOGLE_API_KEY` environment
35
+
36
+ ```
37
+ $ export GOOGLE_API_KEY ='add your key here'
38
+ ```
39
+
40
+ You can get an api-key from Google AI Studio (https://aistudio.google.com/apikey)
41
+
42
+ ## Run
43
+
44
+ To run the script:
45
+
46
+ ```
47
+ python gemini_gradio_audio.py
48
+ ```
49
+
50
+ On the gradio page (http://127.0.0.1:7860/) click record, and talk, gemini will reply. But note that interruptions
51
+ don't work.
52
+
53
+ """
54
+
55
+ import asyncio
56
+ import json
57
+ import os
58
+ from typing import Literal
59
+ import base64
60
+
61
+ import gradio as gr
62
+ import numpy as np
63
+ from fastrtc import (
64
+ AsyncStreamHandler,
65
+ WebRTC,
66
+ wait_for_item,
67
+ )
68
+ from jinja2 import Template
69
+ from google import genai
70
+ from google.genai.types import LiveConnectConfig, Tool, FunctionDeclaration
71
+
72
+ from google.cloud import texttospeech
73
+
74
+ from tools import FUNCTION_MAP, TOOLS
75
+
76
+ with open("questions.json", "r") as f:
77
+ questions_dict = json.load(f)
78
+
79
+ with open("src/prompts/default_prompt.jinja2") as f:
80
+ template_str = f.read()
81
+ template = Template(template_str)
82
+ system_prompt = template.render(questions=json.dumps(questions_dict, indent=4))
83
+
84
+
85
+
86
+
87
+ class TTSConfig:
88
+ def __init__(self):
89
+ self.client = texttospeech.TextToSpeechClient()
90
+ self.voice = texttospeech.VoiceSelectionParams(
91
+ name="en-US-Chirp3-HD-Charon",
92
+ language_code="en-US"
93
+ )
94
+ self.audio_config = texttospeech.AudioConfig(
95
+ audio_encoding=texttospeech.AudioEncoding.LINEAR16
96
+ )
97
+
98
+
99
+ class AsyncGeminiHandler(AsyncStreamHandler):
100
+ """Simple Async Gemini Handler"""
101
+
102
+ def __init__(
103
+ self,
104
+ expected_layout: Literal["mono"] = "mono",
105
+ output_sample_rate: int = 24000,
106
+ output_frame_size: int = 480,
107
+ ) -> None:
108
+ super().__init__(
109
+ expected_layout,
110
+ output_sample_rate,
111
+ output_frame_size,
112
+ input_sample_rate=16000,
113
+ )
114
+ self.input_queue: asyncio.Queue = asyncio.Queue()
115
+ self.output_queue: asyncio.Queue = asyncio.Queue()
116
+ self.text_queue: asyncio.Queue = asyncio.Queue()
117
+ self.quit: asyncio.Event = asyncio.Event()
118
+ self.chunk_size = 1024
119
+
120
+ self.tts_config: TTSConfig | None = TTSConfig()
121
+ self.text_buffer = ""
122
+
123
+ def copy(self) -> "AsyncGeminiHandler":
124
+ return AsyncGeminiHandler(
125
+ expected_layout="mono",
126
+ output_sample_rate=self.output_sample_rate,
127
+ output_frame_size=self.output_frame_size,
128
+ )
129
+
130
+ def _encode_audio(self, data: np.ndarray) -> str:
131
+ """Encode Audio data to send to the server"""
132
+ return base64.b64encode(data.tobytes()).decode("UTF-8")
133
+
134
+
135
+ async def receive(self, frame: tuple[int, np.ndarray]) -> None:
136
+ _, array = frame
137
+ array = array.squeeze()
138
+ audio_message = self._encode_audio(array)
139
+ self.input_queue.put_nowait(audio_message)
140
+
141
+ async def emit(self) -> tuple[int, np.ndarray] | None:
142
+ return await wait_for_item(self.output_queue)
143
+
144
+ async def start_up(self) -> None:
145
+ client = genai.Client(
146
+ api_key=os.getenv("GOOGLE_API_KEY"),
147
+ http_options={"api_version": "v1alpha"},
148
+ )
149
+
150
+
151
+ config = LiveConnectConfig(
152
+ system_instruction={
153
+ "parts": [{"text": system_prompt}],
154
+ "role": "user",
155
+ },
156
+ tools=[Tool(function_declarations=[FunctionDeclaration(**tool) for tool in TOOLS])],
157
+ response_modalities=["AUDIO"],
158
+ )
159
+
160
+ async with (
161
+ client.aio.live.connect(model="gemini-2.0-flash-exp", config=config) as session,
162
+ asyncio.TaskGroup() as tg
163
+ ):
164
+ self.session = session
165
+
166
+ tasks = [
167
+ tg.create_task(self.process()),
168
+ tg.create_task(self.send_realtime()),
169
+ tg.create_task(self.tts()),
170
+ ]
171
+
172
+ async def process(self) -> None:
173
+ while True:
174
+ try:
175
+ turn = self.session.receive()
176
+ async for response in turn:
177
+ if data := response.data:
178
+ array = np.frombuffer(data, dtype=np.int16)
179
+ self.output_queue.put_nowait((self.output_sample_rate, array))
180
+ continue
181
+
182
+ if text := response.text:
183
+ print(f"Received text: {text}")
184
+ self.text_buffer += text
185
+
186
+ if response.tool_call is not None:
187
+ for tool in response.tool_call.function_calls:
188
+ tool_response = FUNCTION_MAP[tool.name](**tool.args)
189
+ print(f"Calling tool: {tool.name}")
190
+ print(f"Tool response: {tool_response}")
191
+ await self.session.send(
192
+ input=tool_response, end_of_turn=True
193
+ )
194
+ await asyncio.sleep(0.1)
195
+
196
+ if sc := response.server_content:
197
+ if sc.turn_complete and self.text_buffer:
198
+ self.text_queue.put_nowait(self.text_buffer)
199
+ FUNCTION_MAP["store_input"](
200
+ role="bot",
201
+ input=self.text_buffer
202
+ )
203
+ self.text_buffer = ""
204
+
205
+ except Exception as e:
206
+ print(f"Error in processing: {e}")
207
+ await asyncio.sleep(0.1)
208
+
209
+ async def send_realtime(self) -> None:
210
+ """Send real-time audio data to model."""
211
+ while True:
212
+ try:
213
+ data = await self.input_queue.get()
214
+ msg = {"data": data, "mime_type": "audio/pcm"}
215
+ await self.session.send(input=msg)
216
+ except Exception as e:
217
+ print(f"Error in real-time sending: {e}")
218
+ await asyncio.sleep(0.1)
219
+
220
+ async def tts(self) -> None:
221
+
222
+ while True:
223
+ try:
224
+ text = await self.text_queue.get()
225
+ # Get response in a single request
226
+ if text:
227
+ response = self.tts_config.client.synthesize_speech(
228
+ input=texttospeech.SynthesisInput(text=text),
229
+ voice=self.tts_config.voice,
230
+ audio_config=self.tts_config.audio_config
231
+ )
232
+ array = np.frombuffer(response.audio_content, dtype=np.int16)
233
+ self.output_queue.put_nowait((self.output_sample_rate, array))
234
+
235
+ except Exception as e:
236
+ print(f"Error in TTS: {e}")
237
+ await asyncio.sleep(0.1)
238
+
239
+
240
+ def shutdown(self) -> None:
241
+ self.quit.set()
242
+
243
+
244
+ def reload_json(path):
245
+ with open(path, "r") as f:
246
+ return json.load(f)
247
+
248
+ # Main Gradio Interface
249
+ def registry(name: str, token: str | None = None, **kwargs):
250
+ """Sets up and returns the Gradio interface."""
251
+
252
+ interface = gr.Blocks()
253
+ with interface:
254
+ with gr.Tabs():
255
+ with gr.TabItem("Voice Chat"):
256
+ gr.HTML(
257
+ """
258
+ <div style='text-align: left'>
259
+ <h1>ML6 Voice Demo - Function Calling and Custom Output Voice</h1>
260
+ </div>
261
+ """
262
+ )
263
+ gemini_handler = AsyncGeminiHandler()
264
+
265
+ with gr.Row():
266
+ audio = WebRTC(
267
+ label="Voice Chat", modality="audio", mode="send-receive"
268
+ )
269
+
270
+ # Add display components for questions and answers
271
+ with gr.Row():
272
+ with gr.Column():
273
+ gr.JSON(
274
+ label="Questions",
275
+ value=questions_dict,
276
+ )
277
+ # with gr.Column():
278
+ # gr.JSON(reload_json, inputs=gr.Text(value="/Users/georgeslorre/ML6/internal/gemini-voice-agents/conversation.json", visible=False), label="Conversation", every=1)
279
+ with gr.Column():
280
+ gr.JSON(reload_json, inputs=gr.Text(value="/Users/georgeslorre/ML6/internal/gemini-voice-agents/answers.json", visible=False),label="Collected Answers", every=1)
281
+
282
+
283
+ audio.stream(
284
+ gemini_handler,
285
+ inputs=[audio], # Add audio_file to inputs
286
+ outputs=[audio],
287
+ time_limit=600,
288
+ concurrency_limit=10,
289
+ )
290
+
291
+ return interface
292
+
293
+ # Function to clear JSON files
294
+ def clear_json_files():
295
+ with open("/Users/georgeslorre/ML6/internal/gemini-voice-agents/conversation.json", "w") as f:
296
+ json.dump([], f)
297
+ with open("/Users/georgeslorre/ML6/internal/gemini-voice-agents/answers.json", "w") as f:
298
+ json.dump({}, f)
299
+
300
+ # Clear files before launching
301
+ clear_json_files()
302
+
303
+ # Launch the Gradio interface
304
+ gr.load(
305
+ name="gemini-2.0-flash-exp",
306
+ src=registry,
307
+ ).launch()
308
+
src copy/app3.py ADDED
File without changes
src copy/helpers/loop.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Helper for audio loop."""
2
+
3
+ import asyncio
4
+ import logging
5
+ import traceback
6
+ import wave
7
+ from typing import Optional
8
+
9
+ import pyaudio
10
+ from google import genai
11
+
12
+ from models import AudioConfig, ModelConfig
13
+ from tools import FUNCTION_MAP
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class TextLoop:
19
+ def __init__(self, model_config: ModelConfig):
20
+ self.model_config = model_config
21
+ self.client = self._setup_client()
22
+ self.session = None
23
+
24
+ def _setup_client(self) -> genai.Client:
25
+ """Initialize the Gemini client."""
26
+ return genai.Client(
27
+ api_key=self.model_config.api_key,
28
+ http_options={"api_version": "v1alpha"},
29
+ )
30
+
31
+ async def send_text(self) -> None:
32
+ """Handle text input and send to model."""
33
+ while True:
34
+ try:
35
+ text = await asyncio.to_thread(input, "message > ")
36
+ if text.lower() == "q":
37
+ break
38
+ await self.session.send(input=text or ".", end_of_turn=True)
39
+ except Exception as e:
40
+ logger.error(f"Error sending text: {e}")
41
+ await asyncio.sleep(0.1)
42
+
43
+ async def receive_text(self) -> None:
44
+ """Process and handle model responses."""
45
+ while True:
46
+ try:
47
+ turn = self.session.receive()
48
+ async for response in turn:
49
+ if text := response.text:
50
+ logger.info(text)
51
+ if response.tool_call is not None:
52
+ for tool in response.tool_call.function_calls:
53
+ tool_response = FUNCTION_MAP[tool.name](**tool.args)
54
+ logger.info(tool_response)
55
+ await self.session.send(
56
+ input=tool_response, end_of_turn=True
57
+ )
58
+ await asyncio.sleep(0.1)
59
+ except Exception as e:
60
+ logger.error(f"Error receiving text: {e}")
61
+ await asyncio.sleep(0.1)
62
+
63
+ async def run(self):
64
+ try:
65
+ async with (
66
+ self.client.aio.live.connect(
67
+ model=self.model_config.name,
68
+ config={
69
+ "system_instruction": self.model_config.system_instruction,
70
+ "tools": self.model_config.tools,
71
+ "generation_config": self.model_config.generation_config,
72
+ },
73
+ ) as session,
74
+ asyncio.TaskGroup() as tg,
75
+ ):
76
+ self.session = session
77
+ tasks = [
78
+ tg.create_task(self.send_text()),
79
+ tg.create_task(self.receive_text()),
80
+ ]
81
+
82
+ await tasks[0] # Wait for send_text to complete
83
+ raise asyncio.CancelledError("User requested exit")
84
+
85
+ except asyncio.CancelledError:
86
+ logger.info("Shutting down...")
87
+ except Exception as e:
88
+ logger.error(f"Error in main loop: {e}")
89
+ logger.debug(traceback.format_exc())
90
+
91
+
92
+ class AudioLoop:
93
+ """Handles real-time audio streaming and processing."""
94
+
95
+ def __init__(
96
+ self,
97
+ audio_config: AudioConfig,
98
+ model_config: ModelConfig,
99
+ function_map: Optional[dict[str, callable]] = FUNCTION_MAP,
100
+ instruction_audio: Optional[str] = None,
101
+ ):
102
+ """Initialize the audio loop.
103
+
104
+ Args:
105
+ audio_config (AudioConfig): Audio configuration settings
106
+ model_config (ModelConfig): Model configuration settings
107
+ function_map (Optional[dict[str, callable]]): Function map
108
+ """
109
+ self.audio_config = audio_config
110
+ self.model_config = model_config
111
+
112
+ self.audio_in_queue: Optional[asyncio.Queue] = None
113
+ self.out_queue: Optional[asyncio.Queue] = None
114
+ self.session = None
115
+ self.audio_stream = None
116
+ self.client = self._setup_client()
117
+ self.instruction_audio = instruction_audio
118
+
119
+ self.function_map = function_map
120
+
121
+ def _setup_client(self) -> genai.Client:
122
+ """Initialize the Gemini client."""
123
+ return genai.Client(
124
+ api_key=self.model_config.api_key,
125
+ http_options={"api_version": "v1alpha"},
126
+ )
127
+
128
+ async def send_text(self) -> None:
129
+ """Handle text input and send to model."""
130
+ while True:
131
+ try:
132
+ text = await asyncio.to_thread(input, "message > ")
133
+ if text.lower() == "q":
134
+ break
135
+ await self.session.send(input=text or ".", end_of_turn=True)
136
+ except Exception as e:
137
+ logger.error(f"Error sending text: {e}")
138
+ await asyncio.sleep(0.1)
139
+
140
+ async def send_realtime(self) -> None:
141
+ """Send real-time audio data to model."""
142
+ while True:
143
+ try:
144
+ msg = await self.out_queue.get()
145
+ await self.session.send(input=msg)
146
+ except Exception as e:
147
+ logger.error(f"Error in real-time sending: {e}")
148
+ await asyncio.sleep(0.1)
149
+
150
+ def input_audio_file(self, file_path: str):
151
+ """Read audio file and stream to the model."""
152
+ try:
153
+ with wave.open(file_path, "rb") as wave_file:
154
+ data = wave_file.readframes(wave_file.getnframes())
155
+ self.out_queue.put_nowait({"data": data, "mime_type": "audio/pcm"})
156
+ except Exception as e:
157
+ logger.error(f"Error reading audio file: {e}")
158
+
159
+ async def listen_audio(self) -> None:
160
+ """Capture and process audio input."""
161
+ try:
162
+ pya = pyaudio.PyAudio()
163
+ mic_info = pya.get_default_input_device_info()
164
+ self.audio_stream = await asyncio.to_thread(
165
+ pya.open,
166
+ format=self.audio_config.format,
167
+ channels=self.audio_config.channels,
168
+ rate=self.audio_config.send_sample_rate,
169
+ input=True,
170
+ input_device_index=mic_info["index"],
171
+ frames_per_buffer=self.audio_config.chunk_size,
172
+ )
173
+
174
+ kwargs = {"exception_on_overflow": False} if __debug__ else {}
175
+
176
+ while True:
177
+ data = await asyncio.to_thread(
178
+ self.audio_stream.read,
179
+ self.audio_config.chunk_size,
180
+ **kwargs,
181
+ )
182
+ await self.out_queue.put({"data": data, "mime_type": "audio/pcm"})
183
+ except Exception as e:
184
+ logger.error(f"Error in audio listening: {e}")
185
+ if self.audio_stream:
186
+ self.audio_stream.close()
187
+
188
+ async def receive_audio(self) -> None:
189
+ """Process and handle model responses."""
190
+ while True:
191
+ try:
192
+ turn = self.session.receive()
193
+ async for response in turn:
194
+ if data := response.data:
195
+ self.audio_in_queue.put_nowait(data)
196
+ continue
197
+ if text := response.text:
198
+ logger.info(text)
199
+ if response.tool_call is not None:
200
+ for tool in response.tool_call.function_calls:
201
+ tool_response = FUNCTION_MAP[tool.name](**tool.args)
202
+ logger.info(tool_response)
203
+ await self.session.send(
204
+ input=tool_response, end_of_turn=True
205
+ )
206
+ await asyncio.sleep(0.1)
207
+
208
+ # Clear queue on turn completion
209
+ while not self.audio_in_queue.empty():
210
+ self.audio_in_queue.get_nowait()
211
+ except Exception as e:
212
+ logger.error(f"Error receiving audio: {e}")
213
+ await asyncio.sleep(0.1)
214
+
215
+ async def play_audio(self) -> None:
216
+ """Play received audio through output device."""
217
+ try:
218
+ pya = pyaudio.PyAudio()
219
+ stream = await asyncio.to_thread(
220
+ pya.open,
221
+ format=self.audio_config.format,
222
+ channels=self.audio_config.channels,
223
+ rate=self.audio_config.receive_sample_rate,
224
+ output=True,
225
+ )
226
+
227
+ while True:
228
+ bytestream = await self.audio_in_queue.get()
229
+ await asyncio.to_thread(stream.write, bytestream)
230
+ except Exception as e:
231
+ logger.error(f"Error playing audio: {e}")
232
+ if "stream" in locals():
233
+ stream.close()
234
+
235
+ async def run(self) -> None:
236
+ """Main execution loop."""
237
+ try:
238
+ async with (
239
+ self.client.aio.live.connect(
240
+ model=self.model_config.name,
241
+ config={
242
+ "system_instruction": self.model_config.system_instruction,
243
+ "tools": self.model_config.tools,
244
+ "generation_config": self.model_config.generation_config,
245
+ },
246
+ ) as session,
247
+ asyncio.TaskGroup() as tg,
248
+ ):
249
+ self.session = session
250
+ self.audio_in_queue = asyncio.Queue()
251
+ self.out_queue = asyncio.Queue(maxsize=5)
252
+
253
+ if self.instruction_audio:
254
+ self.input_audio_file(file_path=self.instruction_audio)
255
+
256
+ tasks = [
257
+ tg.create_task(self.send_text()),
258
+ tg.create_task(self.send_realtime()),
259
+ tg.create_task(self.listen_audio()),
260
+ tg.create_task(self.receive_audio()),
261
+ tg.create_task(self.play_audio()),
262
+ ]
263
+
264
+ await tasks[0] # Wait for send_text to complete
265
+ raise asyncio.CancelledError("User requested exit")
266
+
267
+ except asyncio.CancelledError:
268
+ logger.info("Shutting down...")
269
+ except Exception as e:
270
+ logger.error(f"Error in main loop: {e}")
271
+ logger.debug(traceback.format_exc())
272
+ finally:
273
+ if self.audio_stream:
274
+ self.audio_stream.close()
src copy/helpers/prompts.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This module contains the prompts for the application."""
2
+
3
+ # import jinja2 template prompt
4
+
5
+ from jinja2 import Template
6
+
7
+
8
+ def load_prompt(prompt_path: str) -> str:
9
+ """Load the prompt from the given path."""
10
+ with open(prompt_path, "r", encoding="utf-8") as file:
11
+ prompt = Template(file.read())
12
+ return prompt.render()
src copy/helpers/session.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from dataclasses import dataclass
3
+ from datetime import datetime
4
+
5
+ from jinja2 import Template
6
+
7
+
8
+ @dataclass
9
+ class Question:
10
+ id: int
11
+ text: str
12
+ answer_format: type
13
+ user_answer: any = None
14
+
15
+
16
+ class Session:
17
+ def __init__(self, questions):
18
+ self.session_id = datetime.now().strftime("%Y%m%d_%H%M%S")
19
+ self.questions = questions
20
+ # self.questions = self.process_questions(questions)
21
+
22
+ @staticmethod
23
+ def process_questions(questions):
24
+ qq = {}
25
+ for q in questions:
26
+ if q["answer_format"] == "number":
27
+ Q = Question(q["id"], q["text"], int, None)
28
+ elif q["answer_format"] == "text":
29
+ Q = Question(q["id"], q["text"], str, None)
30
+ elif q["answer_format"] == "list":
31
+ Q = Question(q["id"], q["text"], list, None)
32
+ else:
33
+ raise ValueError("Invalid answer format")
34
+ qq[q["id"]] = Q
35
+ return qq
36
+
37
+ def answer_question(self, question_id, user_answer):
38
+ self.questions[question_id].user_answer = user_answer
39
+
40
+ def get_next_question(self):
41
+ for q in self.questions:
42
+ if q.user_answer:
43
+ return q
44
+ return False
45
+
46
+ def zero_shot_prompt(self, prompt_template_path):
47
+ with open(prompt_template_path) as f:
48
+ template_str = f.read()
49
+ template = Template(template_str)
50
+ return template.render(questions=json.dumps(self.questions, indent=4))
src copy/index.html ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+
4
+ <head>
5
+ <meta charset="UTF-8">
6
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
7
+ <title>Gemini Voice Chat</title>
8
+ <style>
9
+ :root {
10
+ --color-accent: #6366f1;
11
+ --color-background: #0f172a;
12
+ --color-surface: #1e293b;
13
+ --color-text: #e2e8f0;
14
+ --boxSize: 8px;
15
+ --gutter: 4px;
16
+ }
17
+
18
+ body {
19
+ margin: 0;
20
+ padding: 0;
21
+ background-color: var(--color-background);
22
+ color: var(--color-text);
23
+ font-family: system-ui, -apple-system, sans-serif;
24
+ min-height: 100vh;
25
+ display: flex;
26
+ flex-direction: column;
27
+ align-items: center;
28
+ justify-content: center;
29
+ }
30
+
31
+ .container {
32
+ width: 90%;
33
+ max-width: 800px;
34
+ background-color: var(--color-surface);
35
+ padding: 2rem;
36
+ border-radius: 1rem;
37
+ box-shadow: 0 25px 50px -12px rgba(0, 0, 0, 0.25);
38
+ }
39
+
40
+ .wave-container {
41
+ position: relative;
42
+ display: flex;
43
+ min-height: 100px;
44
+ max-height: 128px;
45
+ justify-content: center;
46
+ align-items: center;
47
+ margin: 2rem 0;
48
+ }
49
+
50
+ .box-container {
51
+ display: flex;
52
+ justify-content: space-between;
53
+ height: 64px;
54
+ width: 100%;
55
+ }
56
+
57
+ .box {
58
+ height: 100%;
59
+ width: var(--boxSize);
60
+ background: var(--color-accent);
61
+ border-radius: 8px;
62
+ transition: transform 0.05s ease;
63
+ }
64
+
65
+ .controls {
66
+ display: grid;
67
+ gap: 1rem;
68
+ margin-bottom: 2rem;
69
+ }
70
+
71
+ .input-group {
72
+ display: flex;
73
+ flex-direction: column;
74
+ gap: 0.5rem;
75
+ }
76
+
77
+ label {
78
+ font-size: 0.875rem;
79
+ font-weight: 500;
80
+ }
81
+
82
+ input,
83
+ select {
84
+ padding: 0.75rem;
85
+ border-radius: 0.5rem;
86
+ border: 1px solid rgba(255, 255, 255, 0.1);
87
+ background-color: var(--color-background);
88
+ color: var(--color-text);
89
+ font-size: 1rem;
90
+ }
91
+
92
+ button {
93
+ padding: 1rem 2rem;
94
+ border-radius: 0.5rem;
95
+ border: none;
96
+ background-color: var(--color-accent);
97
+ color: white;
98
+ font-weight: 600;
99
+ cursor: pointer;
100
+ transition: all 0.2s ease;
101
+ }
102
+
103
+ button:hover {
104
+ opacity: 0.9;
105
+ transform: translateY(-1px);
106
+ }
107
+
108
+ .icon-with-spinner {
109
+ display: flex;
110
+ align-items: center;
111
+ justify-content: center;
112
+ gap: 12px;
113
+ min-width: 180px;
114
+ }
115
+
116
+ .spinner {
117
+ width: 20px;
118
+ height: 20px;
119
+ border: 2px solid white;
120
+ border-top-color: transparent;
121
+ border-radius: 50%;
122
+ animation: spin 1s linear infinite;
123
+ flex-shrink: 0;
124
+ }
125
+
126
+ @keyframes spin {
127
+ to {
128
+ transform: rotate(360deg);
129
+ }
130
+ }
131
+
132
+ .pulse-container {
133
+ display: flex;
134
+ align-items: center;
135
+ justify-content: center;
136
+ gap: 12px;
137
+ min-width: 180px;
138
+ }
139
+
140
+ .pulse-circle {
141
+ width: 20px;
142
+ height: 20px;
143
+ border-radius: 50%;
144
+ background-color: white;
145
+ opacity: 0.2;
146
+ flex-shrink: 0;
147
+ transform: translateX(-0%) scale(var(--audio-level, 1));
148
+ transition: transform 0.1s ease;
149
+ }
150
+
151
+ /* Add styles for toast notifications */
152
+ .toast {
153
+ position: fixed;
154
+ top: 20px;
155
+ left: 50%;
156
+ transform: translateX(-50%);
157
+ padding: 16px 24px;
158
+ border-radius: 4px;
159
+ font-size: 14px;
160
+ z-index: 1000;
161
+ display: none;
162
+ box-shadow: 0 2px 5px rgba(0, 0, 0, 0.2);
163
+ }
164
+
165
+ .toast.error {
166
+ background-color: #f44336;
167
+ color: white;
168
+ }
169
+
170
+ .toast.warning {
171
+ background-color: #ffd700;
172
+ color: black;
173
+ }
174
+ </style>
175
+ </head>
176
+
177
+
178
+ <body>
179
+ <!-- Add toast element after body opening tag -->
180
+ <div id="error-toast" class="toast"></div>
181
+ <div style="text-align: center">
182
+ <h1>Gemini Voice Chat</h1>
183
+ <p>Speak with Gemini using real-time audio streaming</p>
184
+ <p>
185
+ Get a Gemini API key
186
+ <a href="https://ai.google.dev/gemini-api/docs/api-key">here</a>
187
+ </p>
188
+ </div>
189
+ <div class="container">
190
+ <div class="controls">
191
+ <div class="input-group">
192
+ <label for="api-key">API Key</label>
193
+ <input type="password" id="api-key" placeholder="Enter your API key">
194
+ </div>
195
+ <div class="input-group">
196
+ <label for="voice">Voice</label>
197
+ <select id="voice">
198
+ <option value="Puck">Puck</option>
199
+ <option value="Charon">Charon</option>
200
+ <option value="Kore">Kore</option>
201
+ <option value="Fenrir">Fenrir</option>
202
+ <option value="Aoede">Aoede</option>
203
+ </select>
204
+ </div>
205
+ </div>
206
+
207
+ <div class="wave-container">
208
+ <div class="box-container">
209
+ <!-- Boxes will be dynamically added here -->
210
+ </div>
211
+ </div>
212
+
213
+ <button id="start-button">Start Recording</button>
214
+ </div>
215
+
216
+ <audio id="audio-output"></audio>
217
+
218
+ <script>
219
+ let peerConnection;
220
+ let audioContext;
221
+ let dataChannel;
222
+ let isRecording = false;
223
+ let webrtc_id;
224
+
225
+ const startButton = document.getElementById('start-button');
226
+ const apiKeyInput = document.getElementById('api-key');
227
+ const voiceSelect = document.getElementById('voice');
228
+ const audioOutput = document.getElementById('audio-output');
229
+ const boxContainer = document.querySelector('.box-container');
230
+
231
+ const numBars = 32;
232
+ for (let i = 0; i < numBars; i++) {
233
+ const box = document.createElement('div');
234
+ box.className = 'box';
235
+ boxContainer.appendChild(box);
236
+ }
237
+
238
+ function updateButtonState() {
239
+ if (peerConnection && (peerConnection.connectionState === 'connecting' || peerConnection.connectionState === 'new')) {
240
+ startButton.innerHTML = `
241
+ <div class="icon-with-spinner">
242
+ <div class="spinner"></div>
243
+ <span>Connecting...</span>
244
+ </div>
245
+ `;
246
+ } else if (peerConnection && peerConnection.connectionState === 'connected') {
247
+ startButton.innerHTML = `
248
+ <div class="pulse-container">
249
+ <div class="pulse-circle"></div>
250
+ <span>Stop Recording</span>
251
+ </div>
252
+ `;
253
+ } else {
254
+ startButton.innerHTML = 'Start Recording';
255
+ }
256
+ }
257
+
258
+ function showError(message) {
259
+ const toast = document.getElementById('error-toast');
260
+ toast.textContent = message;
261
+ toast.className = 'toast error';
262
+ toast.style.display = 'block';
263
+
264
+ // Hide toast after 5 seconds
265
+ setTimeout(() => {
266
+ toast.style.display = 'none';
267
+ }, 5000);
268
+ }
269
+
270
+ async function setupWebRTC() {
271
+ const config = __RTC_CONFIGURATION__;
272
+ peerConnection = new RTCPeerConnection(config);
273
+ webrtc_id = Math.random().toString(36).substring(7);
274
+
275
+ const timeoutId = setTimeout(() => {
276
+ const toast = document.getElementById('error-toast');
277
+ toast.textContent = "Connection is taking longer than usual. Are you on a VPN?";
278
+ toast.className = 'toast warning';
279
+ toast.style.display = 'block';
280
+
281
+ // Hide warning after 5 seconds
282
+ setTimeout(() => {
283
+ toast.style.display = 'none';
284
+ }, 5000);
285
+ }, 5000);
286
+
287
+ try {
288
+ const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
289
+ stream.getTracks().forEach(track => peerConnection.addTrack(track, stream));
290
+
291
+ // Update audio visualization setup
292
+ audioContext = new AudioContext();
293
+ analyser_input = audioContext.createAnalyser();
294
+ const source = audioContext.createMediaStreamSource(stream);
295
+ source.connect(analyser_input);
296
+ analyser_input.fftSize = 64;
297
+ dataArray_input = new Uint8Array(analyser_input.frequencyBinCount);
298
+
299
+ function updateAudioLevel() {
300
+ analyser_input.getByteFrequencyData(dataArray_input);
301
+ const average = Array.from(dataArray_input).reduce((a, b) => a + b, 0) / dataArray_input.length;
302
+ const audioLevel = average / 255;
303
+
304
+ const pulseCircle = document.querySelector('.pulse-circle');
305
+ if (pulseCircle) {
306
+ console.log("audioLevel", audioLevel);
307
+ pulseCircle.style.setProperty('--audio-level', 1 + audioLevel);
308
+ }
309
+
310
+ animationId = requestAnimationFrame(updateAudioLevel);
311
+ }
312
+ updateAudioLevel();
313
+
314
+ // Add connection state change listener
315
+ peerConnection.addEventListener('connectionstatechange', () => {
316
+ console.log('connectionstatechange', peerConnection.connectionState);
317
+ if (peerConnection.connectionState === 'connected') {
318
+ clearTimeout(timeoutId);
319
+ const toast = document.getElementById('error-toast');
320
+ toast.style.display = 'none';
321
+ }
322
+ updateButtonState();
323
+ });
324
+
325
+ // Handle incoming audio
326
+ peerConnection.addEventListener('track', (evt) => {
327
+ if (audioOutput && audioOutput.srcObject !== evt.streams[0]) {
328
+ audioOutput.srcObject = evt.streams[0];
329
+ audioOutput.play();
330
+
331
+ // Set up audio visualization on the output stream
332
+ audioContext = new AudioContext();
333
+ analyser = audioContext.createAnalyser();
334
+ const source = audioContext.createMediaStreamSource(evt.streams[0]);
335
+ source.connect(analyser);
336
+ analyser.fftSize = 2048;
337
+ dataArray = new Uint8Array(analyser.frequencyBinCount);
338
+ updateVisualization();
339
+ }
340
+ });
341
+
342
+ // Create data channel for messages
343
+ dataChannel = peerConnection.createDataChannel('text');
344
+ dataChannel.onmessage = (event) => {
345
+ const eventJson = JSON.parse(event.data);
346
+ if (eventJson.type === "error") {
347
+ showError(eventJson.message);
348
+ } else if (eventJson.type === "send_input") {
349
+ fetch('/input_hook', {
350
+ method: 'POST',
351
+ headers: {
352
+ 'Content-Type': 'application/json',
353
+ },
354
+ body: JSON.stringify({
355
+ webrtc_id: webrtc_id,
356
+ api_key: apiKeyInput.value,
357
+ voice_name: voiceSelect.value
358
+ })
359
+ });
360
+ }
361
+ };
362
+
363
+ // Create and send offer
364
+ const offer = await peerConnection.createOffer();
365
+ await peerConnection.setLocalDescription(offer);
366
+
367
+ await new Promise((resolve) => {
368
+ if (peerConnection.iceGatheringState === "complete") {
369
+ resolve();
370
+ } else {
371
+ const checkState = () => {
372
+ if (peerConnection.iceGatheringState === "complete") {
373
+ peerConnection.removeEventListener("icegatheringstatechange", checkState);
374
+ resolve();
375
+ }
376
+ };
377
+ peerConnection.addEventListener("icegatheringstatechange", checkState);
378
+ }
379
+ });
380
+
381
+ const response = await fetch('/webrtc/offer', {
382
+ method: 'POST',
383
+ headers: { 'Content-Type': 'application/json' },
384
+ body: JSON.stringify({
385
+ sdp: peerConnection.localDescription.sdp,
386
+ type: peerConnection.localDescription.type,
387
+ webrtc_id: webrtc_id,
388
+ })
389
+ });
390
+
391
+ const serverResponse = await response.json();
392
+
393
+ if (serverResponse.status === 'failed') {
394
+ showError(serverResponse.meta.error === 'concurrency_limit_reached'
395
+ ? `Too many connections. Maximum limit is ${serverResponse.meta.limit}`
396
+ : serverResponse.meta.error);
397
+ stop();
398
+ startButton.textContent = 'Start Recording';
399
+ return;
400
+ }
401
+
402
+ await peerConnection.setRemoteDescription(serverResponse);
403
+ } catch (err) {
404
+ clearTimeout(timeoutId);
405
+ console.error('Error setting up WebRTC:', err);
406
+ showError('Failed to establish connection. Please try again.');
407
+ stop();
408
+ startButton.textContent = 'Start Recording';
409
+ }
410
+ }
411
+
412
+ function updateVisualization() {
413
+ if (!analyser) return;
414
+
415
+ analyser.getByteFrequencyData(dataArray);
416
+ const bars = document.querySelectorAll('.box');
417
+
418
+ for (let i = 0; i < bars.length; i++) {
419
+ const barHeight = (dataArray[i] / 255) * 2;
420
+ bars[i].style.transform = `scaleY(${Math.max(0.1, barHeight)})`;
421
+ }
422
+
423
+ animationId = requestAnimationFrame(updateVisualization);
424
+ }
425
+
426
+ function stopWebRTC() {
427
+ if (peerConnection) {
428
+ peerConnection.close();
429
+ }
430
+ if (animationId) {
431
+ cancelAnimationFrame(animationId);
432
+ }
433
+ if (audioContext) {
434
+ audioContext.close();
435
+ }
436
+ updateButtonState();
437
+ }
438
+
439
+ startButton.addEventListener('click', () => {
440
+ if (!isRecording) {
441
+ setupWebRTC();
442
+ startButton.classList.add('recording');
443
+ } else {
444
+ stopWebRTC();
445
+ startButton.classList.remove('recording');
446
+ }
447
+ isRecording = !isRecording;
448
+ });
449
+ </script>
450
+ </body>
451
+
452
+ </html>
src copy/models.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Data models for the application."""
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import pyaudio
6
+ from dotenv import load_dotenv
7
+
8
+ load_dotenv()
9
+
10
+
11
+ @dataclass
12
+ class AudioConfig:
13
+ """Audio configuration settings."""
14
+
15
+ format: int = pyaudio.paInt16
16
+ channels: int = 1
17
+ send_sample_rate: int = 16000
18
+ receive_sample_rate: int = 24000
19
+ chunk_size: int = 1024
20
+
21
+
22
+ @dataclass
23
+ class ModelConfig:
24
+ """Gemini model configuration."""
25
+
26
+ api_key: str
27
+ name: str
28
+ tools: dict
29
+ generation_config: dict
30
+ system_instruction: str
src copy/prompts/default_prompt.jinja2 ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Personality and Tone
2
+ ## Identity
3
+ You are a friendly recruiter who conducts initial screening calls with candidates. You speak clear, professional English.
4
+
5
+ YOU ARE THE RECRUITER AND THE USER IS THE CANDIDATE, THE USER MUST ANSWER THE QUESTIONS.
6
+
7
+ ## Tone and Language
8
+ - You are polite and professional.
9
+ - Use complete sentences
10
+ - Maintain a formal but warm demeanor
11
+ - Avoid slang or casual language
12
+
13
+ ## Task
14
+ Your sole responsibility is to conduct brief initial screenings with candidates by following these exact steps:
15
+
16
+ # Strict Interview Protocol
17
+
18
+ 1. ANSWER PROCESSING AND VALIDATION:
19
+ - ESSENTIAL INFO: Extract only the key information from candidate's response
20
+ - you MUST store the extracted information using validate_answer_tool
21
+ - VALIDATION: Use validate_answer_tool with the distilled answer ONLY
22
+ - ACKNOWLEDGE: Briefly acknowledge the candidate's response
23
+ - IMPORTANT: Never reveal validation process to candidates
24
+ - If validation fails, repeat question
25
+
26
+ 2. ANSWER VALIDATION PROTOCOL:
27
+ - If answer is VALID: Proceed to next question
28
+ - If answer is INVALID: Repeat the same question
29
+ - No exceptions to this rule
30
+
31
+ 3. INTERVIEW CONCLUSION:
32
+ - Only conclude after ALL questions are asked and validated
33
+ - End with a professional thank you message
34
+ - No additional commentary or questions allowed
35
+
36
+ DO NOT deviate from these protocols under any circumstances.
37
+
38
+
39
+ QUESTIONS SEQUENCE:
40
+ - You MUST ask questions in the exact order provided in:
41
+ {{ questions }}
src copy/run.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Real-time Speech Interface
2
+
3
+ This module provides a real-time speech interface using Google's Gemini model.
4
+ It handles bidirectional audio streaming with automatic speech recognition and synthesis.
5
+
6
+ Important:
7
+ Use headphones to prevent audio feedback and echo issues.
8
+ """
9
+
10
+ import argparse
11
+ import asyncio
12
+ import json
13
+ import logging
14
+ import os
15
+ import traceback
16
+
17
+ from helpers.loop import AudioLoop, TextLoop
18
+ from helpers.session import Session
19
+ from models import AudioConfig, ModelConfig
20
+ from tools import TOOLS
21
+
22
+ # Configure logging
23
+ logging.basicConfig(
24
+ level=logging.INFO,
25
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
26
+ )
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ def main(
31
+ modality: str = "text", system_prompt: str = None, instruction_audio: str = None
32
+ ) -> None:
33
+ """Entry point for the application."""
34
+ try:
35
+ model_config = ModelConfig(
36
+ api_key=os.environ.get("GOOGLE_API_KEY"),
37
+ name="models/gemini-2.0-flash-exp",
38
+ system_instruction=system_prompt,
39
+ tools=TOOLS,
40
+ generation_config={
41
+ "response_modalities": modality.upper(),
42
+ },
43
+ )
44
+ if modality == "audio":
45
+ loop_instance = AudioLoop(
46
+ audio_config=AudioConfig(),
47
+ model_config=model_config,
48
+ instruction_audio=instruction_audio,
49
+ )
50
+ elif modality == "text":
51
+ loop_instance = TextLoop(model_config=model_config)
52
+ else:
53
+ raise ValueError("Invalid modality")
54
+ asyncio.run(loop_instance.run(), debug=True)
55
+ except KeyboardInterrupt:
56
+ logger.info("Application terminated by user")
57
+ except Exception as e:
58
+ logger.error(f"Application error: {e}")
59
+ logger.debug(traceback.format_exc())
60
+
61
+
62
+ if __name__ == "__main__":
63
+ parser = argparse.ArgumentParser(description="Real-time Speech Interface")
64
+ parser.add_argument(
65
+ "-m",
66
+ "--modality",
67
+ choices=["text", "audio"],
68
+ help="Response modality",
69
+ required=True,
70
+ )
71
+ parser.add_argument(
72
+ "--instruction-audio",
73
+ type=str,
74
+ help="Path to audio instructions (.wav file)",
75
+ required=False,
76
+ )
77
+ parser.add_argument(
78
+ "-q",
79
+ "--questions",
80
+ type=str,
81
+ help="Path to JSON file containing questions",
82
+ required=True,
83
+ )
84
+ args = parser.parse_args()
85
+ with open(args.questions, "r") as f:
86
+ questions_dict = json.load(f)
87
+
88
+ session = Session(questions=questions_dict)
89
+ system_prompt = session.zero_shot_prompt("src/prompts/default_prompt.jinja2")
90
+ print(system_prompt)
91
+
92
+ main(
93
+ modality=args.modality,
94
+ system_prompt=system_prompt,
95
+ instruction_audio=args.instruction_audio,
96
+ )
src copy/tools/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tools package for API integrations."""
2
+
3
+ from .functions import validate_answer, validate_answer_tool, store_input, store_input_tool
4
+
5
+ # Map of function names to their implementations
6
+ FUNCTION_MAP = {
7
+ "validate_answer": validate_answer,
8
+ "store_input": store_input,
9
+ }
10
+
11
+ # List of all available tools
12
+ # TOOLS = [validate_answer_tool, store_input_tool]
13
+ TOOLS = [validate_answer_tool]
14
+
src copy/tools/functions.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+ """Schedule meeting integration function."""
8
+
9
+
10
+ def fetch_next_question() -> str:
11
+ """Fetch the next question.
12
+
13
+ Returns:
14
+ str: The next question.
15
+ """
16
+ questions = [
17
+ "What is the capital of France?",
18
+ "What is 2 + 2?",
19
+ "Who wrote Romeo and Juliet?",
20
+ "What is the chemical symbol for gold?",
21
+ "Which planet is known as the Red Planet?",
22
+ ]
23
+ question = questions[0]
24
+
25
+ return f"You need to ask the candidate following question: `{question}`. Allow the candidate some time to respond "
26
+
27
+
28
+ fetch_next_question_tool = {
29
+ "name": "fetch_next_question",
30
+ "description": "Fetch the next question",
31
+ }
32
+
33
+
34
+ def validate_answer(
35
+ question_id: int, answer: str, answer_type: str | int | list
36
+ ) -> str:
37
+ """Validate the user's answer against an expected answer type.
38
+
39
+ question_id (int): The identifier of the question being validated
40
+ answer (str): The user's provided answer to validate
41
+ answer_type (type): The expected python type that the answer should match (e.g. str, int, list)
42
+
43
+ str: Returns "Answer is valid" if answer matches expected type, raises ValueError otherwise
44
+
45
+ Raises:
46
+ ValueError: If the answer's type does not match the expected answer_type
47
+
48
+ Example:
49
+ >>> validate_answer(1, "42", str)
50
+ True
51
+ >>> validate_answer(1, 42, str)
52
+ ValueError: Invalid answer type
53
+ """
54
+ logging.info(
55
+ {
56
+ "question_id": question_id,
57
+ "answer": answer,
58
+ "answer_type": answer_type,
59
+ }
60
+ )
61
+ if type(answer) is answer_type:
62
+ raise ValueError("Invalid answer type")
63
+
64
+ # Create or load the answers file
65
+ answers_file = "/Users/georgeslorre/ML6/internal/gemini-voice-agents/answers.json"
66
+ answers = []
67
+
68
+ if os.path.exists(answers_file):
69
+ with open(answers_file, "r") as f:
70
+ answers = json.load(f)
71
+
72
+ # Append new answer
73
+ answers[question_id] = {"question_id": question_id, "answer": answer}
74
+
75
+ # Write back to file
76
+ with open(answers_file, "w") as f:
77
+ json.dump(answers, f, indent=2)
78
+
79
+ return "Answer is valid"
80
+
81
+
82
+ validate_answer_tool = {
83
+ "name": "validate_answer",
84
+ "description": "Validate the user's answer against an expected answer type",
85
+ "parameters": {
86
+ "type": "OBJECT",
87
+ "properties": {
88
+ "question_id": {
89
+ "type": "INTEGER",
90
+ "description": "The identifier of the question being validated"
91
+ },
92
+ "answer": {
93
+ "type": "STRING",
94
+ "description": "The user's provided answer to validate"
95
+ },
96
+ "answer_type": {
97
+ "type": "STRING",
98
+ "description": "The expected python type that the answer should match (e.g. str, int, list)"
99
+ }
100
+ },
101
+ "required": ["question_id", "answer", "answer_type"]
102
+ }
103
+ }
104
+
105
+
106
+ def store_input(role: str, input: str) -> str:
107
+ """Store conversation input in a JSON file.
108
+
109
+ Args:
110
+ role (str): The role of the speaker (user or assistant)
111
+ input (str): The text input to store
112
+
113
+ Returns:
114
+ str: Confirmation message
115
+ """
116
+ conversation_file = "/Users/georgeslorre/ML6/internal/gemini-voice-agents/conversation.json"
117
+ conversation = []
118
+
119
+ if os.path.exists(conversation_file):
120
+ with open(conversation_file, "r") as f:
121
+ conversation = json.load(f)
122
+
123
+ conversation.append({"role": role, "content": input})
124
+
125
+ with open(conversation_file, "w") as f:
126
+ json.dump(conversation, f, indent=2)
127
+
128
+ return "Input stored successfully"
129
+
130
+
131
+
132
+ store_input_tool = {
133
+ "name": "store_input",
134
+ "description": "Store user input in conversation history",
135
+ "parameters": {
136
+ "type": "OBJECT",
137
+ "properties": {
138
+ "role": {
139
+ "type": "STRING",
140
+ "description": "The role of the speaker (user or assistant)"
141
+ },
142
+ "input": {
143
+ "type": "STRING",
144
+ "description": "The text input to store"
145
+ }
146
+ }
147
+ }
148
+ }
src copy/tts.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Copyright 2024 Google LLC
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+
17
+ """Google Cloud Text-To-Speech API streaming sample with input/output streams."""
18
+
19
+ from google.cloud import texttospeech
20
+ import itertools
21
+ import queue
22
+ import threading
23
+
24
+ class TTSStreamer:
25
+ def __init__(self):
26
+ self.client = texttospeech.TextToSpeechClient()
27
+ self.text_queue = queue.Queue()
28
+ self.audio_queue = queue.Queue()
29
+
30
+ def start_stream(self):
31
+ streaming_config = texttospeech.StreamingSynthesizeConfig(
32
+ voice=texttospeech.VoiceSelectionParams(
33
+ name="en-US-Journey-D",
34
+ language_code="en-US"
35
+ )
36
+ )
37
+ config_request = texttospeech.StreamingSynthesizeRequest(
38
+ streaming_config=streaming_config
39
+ )
40
+
41
+ def request_generator():
42
+ while True:
43
+ try:
44
+ text = self.text_queue.get()
45
+ if text is None: # Poison pill to stop
46
+ break
47
+ yield texttospeech.StreamingSynthesizeRequest(
48
+ input=texttospeech.StreamingSynthesisInput(text=text)
49
+ )
50
+ except queue.Empty:
51
+ continue
52
+
53
+ def audio_processor():
54
+ responses = self.client.streaming_synthesize(
55
+ itertools.chain([config_request], request_generator())
56
+ )
57
+
58
+ for response in responses:
59
+ self.audio_queue.put(response.audio_content)
60
+
61
+ self.processor_thread = threading.Thread(target=audio_processor)
62
+ self.processor_thread.start()
63
+
64
+ def send_text(self, text: str):
65
+ """Send text to be synthesized."""
66
+ self.text_queue.put(text)
67
+
68
+ def get_audio(self):
69
+ """Get the next chunk of audio bytes."""
70
+ try:
71
+ return self.audio_queue.get_nowait()
72
+ except queue.Empty:
73
+ return None
74
+
75
+ def stop(self):
76
+ """Stop the streaming synthesis."""
77
+ self.text_queue.put(None) # Send poison pill
78
+ if self.processor_thread:
79
+ self.processor_thread.join()
80
+
81
+ def main():
82
+ tts = TTSStreamer()
83
+ tts.start_stream()
84
+
85
+ # Example usage
86
+ try:
87
+ while True:
88
+ text = input("Enter text (or 'q' to quit): ")
89
+ if text.lower() == 'q':
90
+ break
91
+ tts.send_text(text)
92
+
93
+ # Get and print audio bytes
94
+ while True:
95
+ audio_chunk = tts.get_audio()
96
+ if audio_chunk is None:
97
+ break
98
+ print(f"Received audio chunk of {len(audio_chunk)} bytes")
99
+ finally:
100
+ tts.stop()
101
+
102
+ if __name__ == "__main__":
103
+ main()
src/app.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import base64
3
+ import json
4
+ import os
5
+ from typing import Literal
6
+
7
+ import gradio as gr
8
+ import numpy as np
9
+ from fastrtc import AsyncStreamHandler, WebRTC, wait_for_item
10
+ from google import genai
11
+ from google.cloud import texttospeech
12
+ from google.genai.types import FunctionDeclaration, LiveConnectConfig, Tool
13
+
14
+ import helpers.datastore as datastore
15
+ from helpers.prompts import load_prompt
16
+ from tools import FUNCTION_MAP, TOOLS
17
+
18
+ with open("questions.json", "r") as f:
19
+ questions_dict = json.load(f)
20
+
21
+
22
+ datastore.DATA_STORE["questions"] = questions_dict
23
+
24
+ SYSTEM_PROMPT = load_prompt(
25
+ "src/prompts/default_prompt.jinja2", questions=questions_dict
26
+ )
27
+
28
+
29
+ class TTSConfig:
30
+ def __init__(self):
31
+ self.client = texttospeech.TextToSpeechClient()
32
+ self.voice = texttospeech.VoiceSelectionParams(
33
+ name="en-US-Chirp3-HD-Charon", language_code="en-US"
34
+ )
35
+ self.audio_config = texttospeech.AudioConfig(
36
+ audio_encoding=texttospeech.AudioEncoding.LINEAR16
37
+ )
38
+
39
+
40
+ class AsyncGeminiHandler(AsyncStreamHandler):
41
+ """Simple Async Gemini Handler"""
42
+
43
+ def __init__(
44
+ self,
45
+ expected_layout: Literal["mono"] = "mono",
46
+ output_sample_rate: int = 24000,
47
+ output_frame_size: int = 480,
48
+ ) -> None:
49
+ super().__init__(
50
+ expected_layout,
51
+ output_sample_rate,
52
+ output_frame_size,
53
+ input_sample_rate=16000,
54
+ )
55
+ self.input_queue: asyncio.Queue = asyncio.Queue()
56
+ self.output_queue: asyncio.Queue = asyncio.Queue()
57
+ self.text_queue: asyncio.Queue = asyncio.Queue()
58
+ self.quit: asyncio.Event = asyncio.Event()
59
+ self.chunk_size = 1024
60
+
61
+ self.tts_config: TTSConfig | None = TTSConfig()
62
+ self.text_buffer = ""
63
+
64
+ def copy(self) -> "AsyncGeminiHandler":
65
+ return AsyncGeminiHandler(
66
+ expected_layout="mono",
67
+ output_sample_rate=self.output_sample_rate,
68
+ output_frame_size=self.output_frame_size,
69
+ )
70
+
71
+ def _encode_audio(self, data: np.ndarray) -> str:
72
+ """Encode Audio data to send to the server"""
73
+ return base64.b64encode(data.tobytes()).decode("UTF-8")
74
+
75
+ async def receive(self, frame: tuple[int, np.ndarray]) -> None:
76
+ """Receives and processes audio frames asynchronously."""
77
+ _, array = frame
78
+ array = array.squeeze()
79
+ audio_message = self._encode_audio(array)
80
+ self.input_queue.put_nowait(audio_message)
81
+
82
+ async def emit(self) -> tuple[int, np.ndarray] | None:
83
+ """Asynchronously emits items from the output queue."""
84
+ return await wait_for_item(self.output_queue)
85
+
86
+ async def start_up(self) -> None:
87
+ """Initialize and start the voice agent application.
88
+
89
+ This asynchronous method sets up the Gemini API client, configures the live connection,
90
+ and starts three concurrent tasks for receiving, processing and sending information.
91
+
92
+ Returns:
93
+ None
94
+
95
+ Raises:
96
+ ValueError: If GEMINI_API_KEY is not provided when required.
97
+
98
+ """
99
+ if not os.getenv("GOOGLE_GENAI_USE_VERTEXAI") == "True":
100
+ api_key = os.getenv("GEMINI_API_KEY")
101
+ if not api_key:
102
+ raise ValueError("API Key is required")
103
+
104
+ client = genai.Client(
105
+ api_key=api_key,
106
+ http_options={"api_version": "v1alpha"},
107
+ )
108
+ else:
109
+ client = genai.Client(http_options={"api_version": "v1beta1"})
110
+
111
+ config = LiveConnectConfig(
112
+ system_instruction={
113
+ "parts": [{"text": SYSTEM_PROMPT}],
114
+ "role": "user",
115
+ },
116
+ tools=[
117
+ Tool(
118
+ function_declarations=[
119
+ FunctionDeclaration(**tool) for tool in TOOLS
120
+ ]
121
+ )
122
+ ],
123
+ response_modalities=["AUDIO"],
124
+ )
125
+
126
+ async with (
127
+ client.aio.live.connect(
128
+ model="gemini-2.0-flash-exp", config=config
129
+ ) as session, # setup the live connection session (websocket)
130
+ asyncio.TaskGroup() as tg, # create a task group to run multiple tasks concurrently
131
+ ):
132
+ self.session = session
133
+
134
+ # these tasks will run concurrently and continuously
135
+ [
136
+ tg.create_task(self.process()),
137
+ tg.create_task(self.send_realtime()),
138
+ tg.create_task(self.tts()),
139
+ ]
140
+
141
+ async def process(self) -> None:
142
+ """Process responses from the session in a continuous loop.
143
+
144
+ This asynchronous method handles different types of responses from the session:
145
+ - Audio data: Processes and queues audio data with the specified sample rate
146
+ - Text data: Accumulates received text in a buffer
147
+ - Tool calls: Executes registered functions and sends their responses back
148
+ - Server content: Handles turn completion and stores conversation history
149
+
150
+ The method runs indefinitely until interrupted, handling any exceptions that occur
151
+ during processing by logging them and continuing after a brief delay.
152
+
153
+ Returns:
154
+ None
155
+
156
+ Raises:
157
+ Exception: Any exceptions during processing are caught and logged
158
+ """
159
+ while True:
160
+ try:
161
+ turn = self.session.receive()
162
+ async for response in turn:
163
+ if data := response.data:
164
+ # audio data
165
+ array = np.frombuffer(data, dtype=np.int16)
166
+ self.output_queue.put_nowait((self.output_sample_rate, array))
167
+ continue
168
+
169
+ if text := response.text:
170
+ # text data
171
+ print(f"Received text: {text}")
172
+ self.text_buffer += text
173
+
174
+ if response.tool_call is not None:
175
+ # function calling
176
+ for tool in response.tool_call.function_calls:
177
+ try:
178
+ tool_response = FUNCTION_MAP[tool.name](**tool.args)
179
+ print(f"Calling tool: {tool.name}")
180
+ print(f"Tool response: {tool_response}")
181
+ await self.session.send(
182
+ input=tool_response, end_of_turn=True
183
+ )
184
+ await asyncio.sleep(0.1)
185
+ except Exception as e:
186
+ print(f"Error in tool call: {e}")
187
+ await asyncio.sleep(0.1)
188
+
189
+ if sc := response.server_content:
190
+ # check if bot's turn is complete
191
+ if sc.turn_complete and self.text_buffer:
192
+ self.text_queue.put_nowait(self.text_buffer)
193
+ FUNCTION_MAP["store_input"](
194
+ role="bot", input=self.text_buffer
195
+ )
196
+ self.text_buffer = ""
197
+
198
+ except Exception as e:
199
+ print(f"Error in processing: {e}")
200
+ await asyncio.sleep(0.1)
201
+
202
+ async def send_realtime(self) -> None:
203
+ """Send real-time audio data to model.
204
+
205
+ This method continuously reads audio data from an input queue and sends it to a model
206
+ session in real-time. It runs in an infinite loop until interrupted.
207
+
208
+ The audio data is sent with mime type 'audio/pcm'. If an error occurs during sending,
209
+ it will be printed and the method will sleep briefly before retrying.
210
+
211
+ Returns:
212
+ None
213
+
214
+ Raises:
215
+ Exception: Any exceptions during queue access or session sending will be caught and logged.
216
+ """
217
+ while True:
218
+ try:
219
+ data = await self.input_queue.get()
220
+ msg = {"data": data, "mime_type": "audio/pcm"}
221
+ await self.session.send(input=msg)
222
+ except Exception as e:
223
+ print(f"Error in real-time sending: {e}")
224
+ await asyncio.sleep(0.1)
225
+
226
+ async def tts(self) -> None:
227
+ while True:
228
+ try:
229
+ text = await self.text_queue.get()
230
+ # Get response in a single request
231
+ if text:
232
+ response = self.tts_config.client.synthesize_speech(
233
+ input=texttospeech.SynthesisInput(text=text),
234
+ voice=self.tts_config.voice,
235
+ audio_config=self.tts_config.audio_config,
236
+ )
237
+ array = np.frombuffer(response.audio_content, dtype=np.int16)
238
+ self.output_queue.put_nowait((self.output_sample_rate, array))
239
+
240
+ except Exception as e:
241
+ print(f"Error in TTS: {e}")
242
+ await asyncio.sleep(0.1)
243
+
244
+ def shutdown(self) -> None:
245
+ self.quit.set()
246
+
247
+
248
+ # Main Gradio Interface
249
+ def registry(*args, **kwargs):
250
+ """Sets up and returns the Gradio interface."""
251
+
252
+ interface = gr.Blocks()
253
+ with interface:
254
+ with gr.Tabs():
255
+ with gr.TabItem("Voice Chat"):
256
+ gr.HTML(
257
+ """
258
+ <div style='text-align: left'>
259
+ <h1>ML6 Voice Demo</h1>
260
+ </div>
261
+ """
262
+ )
263
+
264
+ gemini_handler = AsyncGeminiHandler()
265
+
266
+ with gr.Row():
267
+ audio = WebRTC(
268
+ label="Voice Chat",
269
+ modality="audio",
270
+ mode="send-receive",
271
+ )
272
+
273
+ # Add display components for questions and answers
274
+ with gr.Row():
275
+ with gr.Column():
276
+ gr.JSON(
277
+ label="Questions",
278
+ value=datastore.DATA_STORE["questions"],
279
+ )
280
+ with gr.Column():
281
+ gr.JSON(
282
+ label="Answers",
283
+ value=lambda: datastore.DATA_STORE["answers"],
284
+ every=1,
285
+ )
286
+
287
+ audio.stream(
288
+ gemini_handler,
289
+ inputs=[audio],
290
+ outputs=[audio],
291
+ time_limit=600,
292
+ concurrency_limit=10,
293
+ )
294
+
295
+ return interface
296
+
297
+
298
+ # Launch the Gradio interface
299
+ gr.load(
300
+ name="demo",
301
+ src=registry,
302
+ ).launch()
src/helpers/datastore.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ DATA_STORE = {
2
+ "questions": [],
3
+ "answers": [],
4
+ "conversation:": [],
5
+ }
src/helpers/prompts.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This module contains the prompts for the application."""
2
+
3
+ import json
4
+
5
+ from jinja2 import Template
6
+
7
+
8
+ def load_prompt(prompt_path: str, **kwargs) -> str:
9
+ """Load the prompt from the given path."""
10
+ with open(prompt_path, "r", encoding="utf-8") as file:
11
+ prompt = Template(file.read())
12
+ return prompt.render(**{k: json.dumps(v) for k, v in kwargs.items()})
src/prompts/default_prompt.jinja2 ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Personality and Tone
2
+ ## Identity
3
+ You are a friendly recruiter who conducts initial screening calls with candidates. You speak clear, professional English.
4
+
5
+ YOU ARE THE RECRUITER AND THE USER IS THE CANDIDATE, THE USER MUST ANSWER THE QUESTIONS.
6
+
7
+ ## Tone and Language
8
+ - You are polite and professional.
9
+ - Use complete sentences
10
+ - Maintain a formal but warm demeanor
11
+ - Avoid slang or casual language
12
+
13
+ ## Task
14
+ Your sole responsibility is to conduct brief initial screenings with candidates by following these exact steps:
15
+
16
+ # Strict Interview Protocol
17
+
18
+ 1. ANSWER PROCESSING AND VALIDATION:
19
+ - ESSENTIAL INFO: Extract only the key information from candidate's response
20
+ - you MUST store the extracted information using validate_answer_tool
21
+ - VALIDATION: Use validate_answer_tool with the distilled answer ONLY
22
+ - ACKNOWLEDGE: Briefly acknowledge the candidate's response
23
+ - IMPORTANT: Never reveal validation process to candidates
24
+ - If validation fails, repeat question
25
+
26
+ 2. ANSWER VALIDATION PROTOCOL:
27
+ - If answer is VALID: Proceed to next question
28
+ - If answer is INVALID: Repeat the same question
29
+ - No exceptions to this rule
30
+
31
+ 3. INTERVIEW CONCLUSION:
32
+ - Only conclude after ALL questions are asked and validated
33
+ - End with a professional thank you message
34
+ - No additional commentary or questions allowed
35
+
36
+ DO NOT deviate from these protocols under any circumstances.
37
+
38
+
39
+ QUESTIONS SEQUENCE:
40
+ - You MUST ask questions in the exact order provided in:
41
+ {{ questions }}
src/tools/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tools package for API integrations."""
2
+
3
+ from .functions import (
4
+ store_input,
5
+ store_input_tool,
6
+ validate_answer,
7
+ validate_answer_tool,
8
+ )
9
+
10
+ # Map of function names to their implementations
11
+ FUNCTION_MAP = {
12
+ "validate_answer": validate_answer,
13
+ "store_input": store_input,
14
+ }
15
+
16
+ # List of all available tools
17
+ TOOLS = [store_input_tool, validate_answer_tool]
src/tools/functions.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import helpers.datastore as datastore
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+
8
+ def validate_answer(
9
+ question_id: int, answer: str, answer_type: str | int | list
10
+ ) -> str:
11
+ """Validate the user's answer against an expected answer type.
12
+
13
+ question_id (int): The identifier of the question being validated
14
+ answer (str): The user's provided answer to validate
15
+ answer_type (type): The expected python type that the answer should match (e.g. str, int, list)
16
+
17
+ str: Returns "Answer is valid" if answer matches expected type, raises ValueError otherwise
18
+
19
+ Raises:
20
+ ValueError: If the answer's type does not match the expected answer_type
21
+
22
+ Example:
23
+ >>> validate_answer(1, "42", str)
24
+ True
25
+ >>> validate_answer(1, 42, str)
26
+ ValueError: Invalid answer type
27
+ """
28
+
29
+ logging.info(
30
+ {
31
+ "question_id": question_id,
32
+ "answer": answer,
33
+ "answer_type": answer_type,
34
+ }
35
+ )
36
+ if type(answer) is answer_type:
37
+ raise ValueError("Invalid answer type")
38
+
39
+ datastore.DATA_STORE["answers"].append(
40
+ {"question_id": question_id, "answer": answer}
41
+ )
42
+
43
+ return "Answer is valid"
44
+
45
+
46
+ validate_answer_tool = {
47
+ "name": "validate_answer",
48
+ "description": "Validate the user's answer against an expected answer type",
49
+ "parameters": {
50
+ "type": "OBJECT",
51
+ "properties": {
52
+ "question_id": {
53
+ "type": "INTEGER",
54
+ "description": "The identifier of the question being validated",
55
+ },
56
+ "answer": {
57
+ "type": "STRING",
58
+ "description": "The user's provided answer to validate",
59
+ },
60
+ "answer_type": {
61
+ "type": "STRING",
62
+ "description": "The expected python type that the answer should match (e.g. str, int, list)",
63
+ },
64
+ },
65
+ "required": ["question_id", "answer", "answer_type"],
66
+ },
67
+ }
68
+
69
+
70
+ def store_input(role: str, input: str) -> str:
71
+ """Store conversation input in a JSON file.
72
+
73
+ Args:
74
+ role (str): The role of the speaker (user or assistant)
75
+ input (str): The text input to store
76
+
77
+ Returns:
78
+ str: Confirmation message
79
+ """
80
+ print(datastore.DATA_STORE)
81
+ conversation = datastore.DATA_STORE.get("conversation")
82
+ if conversation is None:
83
+ datastore.DATA_STORE["conversation"] = [{"role": role, "input": input}]
84
+ else:
85
+ datastore.DATA_STORE["conversation"].append({"role": role, "input": input})
86
+
87
+ return "Input stored successfully"
88
+
89
+
90
+ store_input_tool = {
91
+ "name": "store_input",
92
+ "description": "Store user input in conversation history",
93
+ "parameters": {
94
+ "type": "OBJECT",
95
+ "properties": {
96
+ "role": {
97
+ "type": "STRING",
98
+ "description": "The role of the speaker (user or assistant)",
99
+ },
100
+ "input": {"type": "STRING", "description": "The text input to store"},
101
+ },
102
+ },
103
+ }
uv.lock ADDED
The diff for this file is too large to render. See raw diff