Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse files- .gitignore +15 -0
- .pre-commit-config.yaml +28 -0
- .python-version +1 -0
- .ruff_cache/.gitignore +2 -0
- .ruff_cache/0.4.8/17181755630229836148 +0 -0
- .ruff_cache/0.4.8/2516455456322530856 +0 -0
- .ruff_cache/0.4.8/3664365949595148797 +0 -0
- .ruff_cache/0.9.6/12093191028265889985 +0 -0
- .ruff_cache/0.9.6/16582661031577879600 +0 -0
- .ruff_cache/0.9.6/6136549848780317009 +0 -0
- .ruff_cache/CACHEDIR.TAG +1 -0
- README.md +74 -8
- pyproject.toml +21 -0
- questions.json +27 -0
- src copy/app.py +506 -0
- src copy/app2.py +308 -0
- src copy/app3.py +0 -0
- src copy/helpers/loop.py +274 -0
- src copy/helpers/prompts.py +12 -0
- src copy/helpers/session.py +50 -0
- src copy/index.html +452 -0
- src copy/models.py +30 -0
- src copy/prompts/default_prompt.jinja2 +41 -0
- src copy/run.py +96 -0
- src copy/tools/__init__.py +14 -0
- src copy/tools/functions.py +148 -0
- src copy/tts.py +103 -0
- src/app.py +302 -0
- src/helpers/datastore.py +5 -0
- src/helpers/prompts.py +12 -0
- src/prompts/default_prompt.jinja2 +41 -0
- src/tools/__init__.py +17 -0
- src/tools/functions.py +103 -0
- uv.lock +0 -0
.gitignore
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python-generated files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[oc]
|
| 4 |
+
build/
|
| 5 |
+
dist/
|
| 6 |
+
wheels/
|
| 7 |
+
*.egg-info
|
| 8 |
+
|
| 9 |
+
# Virtual environments
|
| 10 |
+
.venv
|
| 11 |
+
|
| 12 |
+
# Environment variables
|
| 13 |
+
.env
|
| 14 |
+
|
| 15 |
+
.vscode/
|
.pre-commit-config.yaml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
repos:
|
| 2 |
+
- repo: https://github.com/PyCQA/bandit
|
| 3 |
+
rev: 1.7.4
|
| 4 |
+
hooks:
|
| 5 |
+
- id: bandit
|
| 6 |
+
name: bandit
|
| 7 |
+
types: [python]
|
| 8 |
+
|
| 9 |
+
- repo: https://github.com/astral-sh/ruff-pre-commit
|
| 10 |
+
# Ruff version.
|
| 11 |
+
rev: v0.4.8
|
| 12 |
+
hooks:
|
| 13 |
+
# Run the linter.
|
| 14 |
+
- id: ruff
|
| 15 |
+
# Run the formatter.
|
| 16 |
+
- id: ruff-format
|
| 17 |
+
|
| 18 |
+
- repo: https://github.com/psf/black
|
| 19 |
+
rev: 23.1.0
|
| 20 |
+
hooks:
|
| 21 |
+
- id: black
|
| 22 |
+
name: black
|
| 23 |
+
|
| 24 |
+
- repo: https://github.com/pre-commit/mirrors-isort
|
| 25 |
+
rev: v5.10.1
|
| 26 |
+
hooks:
|
| 27 |
+
- id: isort
|
| 28 |
+
args: ["--profile", "black"]
|
.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.11.9
|
.ruff_cache/.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Automatically created by ruff.
|
| 2 |
+
*
|
.ruff_cache/0.4.8/17181755630229836148
ADDED
|
Binary file (187 Bytes). View file
|
|
|
.ruff_cache/0.4.8/2516455456322530856
ADDED
|
Binary file (291 Bytes). View file
|
|
|
.ruff_cache/0.4.8/3664365949595148797
ADDED
|
Binary file (222 Bytes). View file
|
|
|
.ruff_cache/0.9.6/12093191028265889985
ADDED
|
Binary file (236 Bytes). View file
|
|
|
.ruff_cache/0.9.6/16582661031577879600
ADDED
|
Binary file (187 Bytes). View file
|
|
|
.ruff_cache/0.9.6/6136549848780317009
ADDED
|
Binary file (222 Bytes). View file
|
|
|
.ruff_cache/CACHEDIR.TAG
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Signature: 8a477f597d28d172789f06886806bc55
|
README.md
CHANGED
|
@@ -1,12 +1,78 @@
|
|
| 1 |
---
|
| 2 |
-
title: ML6
|
| 3 |
-
|
| 4 |
-
colorFrom: red
|
| 5 |
-
colorTo: gray
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: 5.23.
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: false
|
| 10 |
---
|
|
|
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: ML6-Gemini-Demo
|
| 3 |
+
app_file: src/app.py
|
|
|
|
|
|
|
| 4 |
sdk: gradio
|
| 5 |
+
sdk_version: 5.23.0
|
|
|
|
|
|
|
| 6 |
---
|
| 7 |
+
# Gemini Voice Agent Demo
|
| 8 |
|
| 9 |
+
This repo contains a demo using the Gemini MultiModal API to create a voice-based agent that can conduct professional technical screening interviews.
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
## Technical Overview
|
| 13 |
+
|
| 14 |
+
The system is based on FastRTC and Gradio to provide a real-time voice UI.
|
| 15 |
+
|
| 16 |
+
### About the modality
|
| 17 |
+
|
| 18 |
+
You can configure the output modality:
|
| 19 |
+
|
| 20 |
+
- If set to AUDIO
|
| 21 |
+
- The agent will respond with an audio response.
|
| 22 |
+
- There is no text output so no transcription
|
| 23 |
+
if set to TEXT
|
| 24 |
+
- The agent will respond with a text response.
|
| 25 |
+
- The text output will be transcribed to audio using the TTS API.
|
| 26 |
+
- Transcriptions are available.
|
| 27 |
+
|
| 28 |
+
### Function Calling
|
| 29 |
+
|
| 30 |
+
There are 2 functions that can be called:
|
| 31 |
+
- Answer validation
|
| 32 |
+
- will check the answer type vs the expected type
|
| 33 |
+
- will store the answer
|
| 34 |
+
- Log Input
|
| 35 |
+
- will log the user input
|
| 36 |
+
- this is a form of transcribing the incoming audio
|
| 37 |
+
|
| 38 |
+
## Getting Started
|
| 39 |
+
|
| 40 |
+
To run the application, follow these steps:
|
| 41 |
+
|
| 42 |
+
1. Install uv (if not already installed):
|
| 43 |
+
`curl -LsSf https://astral.sh/uv/install.sh | sh`
|
| 44 |
+
|
| 45 |
+
2. Install dependencies:
|
| 46 |
+
`uv sync`
|
| 47 |
+
|
| 48 |
+
3. Setup the environment variables for either GenAI or VertexAI (see below)
|
| 49 |
+
|
| 50 |
+
4. Run the application:
|
| 51 |
+
`python src/app.py`
|
| 52 |
+
|
| 53 |
+
5. Visit `http://127.0.0.1:7860` in your browser to interact with the voice agent.
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
### GenAI vs VertexAI
|
| 57 |
+
|
| 58 |
+
"gemini-2.0-flash-exp" can be used in both GenAI and VertexAI. [more info](https://github.com/heiko-hotz/gemini-multimodal-live-dev-guide?tab=readme-ov-file)
|
| 59 |
+
|
| 60 |
+
- GenAI requires just a GEMINI_API_KEY environment variable [link](https://ai.google.dev/gemini-api/docs/api-key)
|
| 61 |
+
- VertexAI requires a GCP project and the following environment variables:
|
| 62 |
+
```
|
| 63 |
+
export GOOGLE_CLOUD_PROJECT=YOUR_PROJECT_ID
|
| 64 |
+
export GOOGLE_CLOUD_LOCATION=europe-west4
|
| 65 |
+
export GOOGLE_GENAI_USE_VERTEXAI=True
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
Depending `GOOGLE_GENAI_USE_VERTEXAI` flag this demo will use either GenAI or VertexAI.
|
| 69 |
+
|
| 70 |
+
### Note
|
| 71 |
+
|
| 72 |
+
The gradio-webrtc install fails unless you have ffmpeg@6, on mac:
|
| 73 |
+
|
| 74 |
+
```
|
| 75 |
+
brew uninstall ffmpeg
|
| 76 |
+
brew install ffmpeg@6
|
| 77 |
+
brew link ffmpeg@6
|
| 78 |
+
```
|
pyproject.toml
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "gemini-voice-agents"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Add your description here"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.11.9"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"fastrtc>=0.0.17",
|
| 9 |
+
"google>=3.0.0",
|
| 10 |
+
"google-cloud>=0.34.0",
|
| 11 |
+
"google-cloud-texttospeech>=2.25.1",
|
| 12 |
+
"google-genai>=1.7.0",
|
| 13 |
+
"gradio>=5.23.0",
|
| 14 |
+
"numpy>=2.1.3",
|
| 15 |
+
]
|
| 16 |
+
|
| 17 |
+
[dependency-groups]
|
| 18 |
+
dev = [
|
| 19 |
+
"ruff>=0.9.6",
|
| 20 |
+
"pre-commit>=4.1",
|
| 21 |
+
]
|
questions.json
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"id": 1,
|
| 4 |
+
"question": "What is your full name?",
|
| 5 |
+
"answer_format": "str"
|
| 6 |
+
},
|
| 7 |
+
{
|
| 8 |
+
"id": 2,
|
| 9 |
+
"question": "What is your current job title?",
|
| 10 |
+
"answer_format": "str"
|
| 11 |
+
},
|
| 12 |
+
{
|
| 13 |
+
"id": 3,
|
| 14 |
+
"question": "How many years of relevant experience do you have?",
|
| 15 |
+
"answer_format": "int"
|
| 16 |
+
},
|
| 17 |
+
{
|
| 18 |
+
"id": 4,
|
| 19 |
+
"question": "Are you looking for a new job?",
|
| 20 |
+
"answer_format": "bool"
|
| 21 |
+
},
|
| 22 |
+
{
|
| 23 |
+
"id": 5,
|
| 24 |
+
"question": "List your three strongest technical skills.",
|
| 25 |
+
"answer_format": "list[str]"
|
| 26 |
+
}
|
| 27 |
+
]
|
src copy/app.py
ADDED
|
@@ -0,0 +1,506 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright 2025 Google LLC
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""
|
| 16 |
+
## Setup
|
| 17 |
+
|
| 18 |
+
The gradio-webrtc install fails unless you have ffmpeg@6, on mac:
|
| 19 |
+
|
| 20 |
+
```
|
| 21 |
+
brew uninstall ffmpeg
|
| 22 |
+
brew install ffmpeg@6
|
| 23 |
+
brew link ffmpeg@6
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
Create a virtual python environment, then install the dependencies for this script:
|
| 27 |
+
|
| 28 |
+
```
|
| 29 |
+
pip install websockets numpy gradio-webrtc "gradio>=5.9.1"
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
If installation fails it may be
|
| 33 |
+
|
| 34 |
+
Before running this script, ensure the `GOOGLE_API_KEY` environment
|
| 35 |
+
|
| 36 |
+
```
|
| 37 |
+
$ export GOOGLE_API_KEY ='add your key here'
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
You can get an api-key from Google AI Studio (https://aistudio.google.com/apikey)
|
| 41 |
+
|
| 42 |
+
## Run
|
| 43 |
+
|
| 44 |
+
To run the script:
|
| 45 |
+
|
| 46 |
+
```
|
| 47 |
+
python gemini_gradio_audio.py
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
On the gradio page (http://127.0.0.1:7860/) click record, and talk, gemini will reply. But note that interruptions
|
| 51 |
+
don't work.
|
| 52 |
+
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
import base64
|
| 56 |
+
import json
|
| 57 |
+
import os
|
| 58 |
+
import wave
|
| 59 |
+
import itertools
|
| 60 |
+
|
| 61 |
+
import gradio as gr
|
| 62 |
+
import numpy as np
|
| 63 |
+
import websockets.sync.client
|
| 64 |
+
from gradio_webrtc import StreamHandler, WebRTC
|
| 65 |
+
from jinja2 import Template
|
| 66 |
+
import threading
|
| 67 |
+
import queue
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
from tools import FUNCTION_MAP, TOOLS
|
| 71 |
+
from google.cloud import texttospeech
|
| 72 |
+
|
| 73 |
+
# logging.basicConfig(
|
| 74 |
+
# level=logging.INFO,
|
| 75 |
+
# format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
| 76 |
+
# )
|
| 77 |
+
# logger = logging.getLogger(__name__)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
with open("questions.json", "r") as f:
|
| 81 |
+
questions_dict = json.load(f)
|
| 82 |
+
|
| 83 |
+
with open("src/prompts/default_prompt.jinja2") as f:
|
| 84 |
+
template_str = f.read()
|
| 85 |
+
template = Template(template_str)
|
| 86 |
+
system_prompt = template.render(questions=json.dumps(questions_dict, indent=4))
|
| 87 |
+
|
| 88 |
+
print(system_prompt)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# TOOLS = types.GenerateContentConfig(tools=[validate_answer])
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
__version__ = "0.0.3"
|
| 95 |
+
|
| 96 |
+
KEY_NAME = "GOOGLE_API_KEY"
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
# Configuration and Utilities
|
| 100 |
+
class GeminiConfig:
|
| 101 |
+
"""Configuration settings for Gemini API."""
|
| 102 |
+
|
| 103 |
+
def __init__(self):
|
| 104 |
+
self.api_key = os.getenv(KEY_NAME)
|
| 105 |
+
self.host = "generativelanguage.googleapis.com"
|
| 106 |
+
self.model = "models/gemini-2.0-flash-exp"
|
| 107 |
+
self.ws_url = f"wss://{self.host}/ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent?key={self.api_key}"
|
| 108 |
+
|
| 109 |
+
class TTSStreamer:
|
| 110 |
+
def __init__(self):
|
| 111 |
+
self.client = texttospeech.TextToSpeechClient()
|
| 112 |
+
self.text_queue = queue.Queue()
|
| 113 |
+
self.audio_queue = queue.Queue()
|
| 114 |
+
|
| 115 |
+
def start_stream(self):
|
| 116 |
+
streaming_config = texttospeech.StreamingSynthesizeConfig(
|
| 117 |
+
voice=texttospeech.VoiceSelectionParams(
|
| 118 |
+
name="en-US-Journey-D",
|
| 119 |
+
language_code="en-US"
|
| 120 |
+
)
|
| 121 |
+
)
|
| 122 |
+
config_request = texttospeech.StreamingSynthesizeRequest(
|
| 123 |
+
streaming_config=streaming_config
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
def request_generator():
|
| 127 |
+
while True:
|
| 128 |
+
try:
|
| 129 |
+
text = self.text_queue.get()
|
| 130 |
+
if text is None: # Poison pill to stop
|
| 131 |
+
break
|
| 132 |
+
yield texttospeech.StreamingSynthesizeRequest(
|
| 133 |
+
input=texttospeech.StreamingSynthesisInput(text=text)
|
| 134 |
+
)
|
| 135 |
+
except queue.Empty:
|
| 136 |
+
continue
|
| 137 |
+
|
| 138 |
+
def audio_processor():
|
| 139 |
+
responses = self.client.streaming_synthesize(
|
| 140 |
+
itertools.chain([config_request], request_generator())
|
| 141 |
+
)
|
| 142 |
+
print(f"Responses: {responses}")
|
| 143 |
+
for response in responses:
|
| 144 |
+
self.audio_queue.put(response.audio_content)
|
| 145 |
+
|
| 146 |
+
self.processor_thread = threading.Thread(target=audio_processor)
|
| 147 |
+
self.processor_thread.start()
|
| 148 |
+
|
| 149 |
+
def send_text(self, text: str):
|
| 150 |
+
"""Send text to be synthesized."""
|
| 151 |
+
self.text_queue.put(text)
|
| 152 |
+
|
| 153 |
+
def get_audio(self):
|
| 154 |
+
"""Get the next chunk of audio bytes."""
|
| 155 |
+
try:
|
| 156 |
+
return self.audio_queue.get_nowait()
|
| 157 |
+
except queue.Empty:
|
| 158 |
+
return None
|
| 159 |
+
|
| 160 |
+
def stop(self):
|
| 161 |
+
"""Stop the streaming synthesis."""
|
| 162 |
+
self.text_queue.put(None) # Send poison pill
|
| 163 |
+
if self.processor_thread:
|
| 164 |
+
self.processor_thread.join()
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class AudioProcessor:
|
| 168 |
+
"""Handles encoding and decoding of audio data."""
|
| 169 |
+
|
| 170 |
+
@staticmethod
|
| 171 |
+
def encode_audio(data, sample_rate):
|
| 172 |
+
"""Encodes audio data to base64."""
|
| 173 |
+
encoded = base64.b64encode(data.tobytes()).decode("UTF-8")
|
| 174 |
+
return {
|
| 175 |
+
"realtimeInput": {
|
| 176 |
+
"mediaChunks": [
|
| 177 |
+
{
|
| 178 |
+
"mimeType": f"audio/pcm;rate={sample_rate}",
|
| 179 |
+
"data": encoded,
|
| 180 |
+
}
|
| 181 |
+
],
|
| 182 |
+
},
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
@staticmethod
|
| 186 |
+
def process_audio_response(data):
|
| 187 |
+
"""Decodes audio data from base64."""
|
| 188 |
+
audio_data = base64.b64decode(data)
|
| 189 |
+
return np.frombuffer(audio_data, dtype=np.int16)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
# Gemini Interaction Handler
|
| 193 |
+
class GeminiHandler(StreamHandler):
|
| 194 |
+
"""Handles streaming interactions with the Gemini API."""
|
| 195 |
+
|
| 196 |
+
def __init__(
|
| 197 |
+
self,
|
| 198 |
+
audio_file=None,
|
| 199 |
+
expected_layout="mono",
|
| 200 |
+
output_sample_rate=24000,
|
| 201 |
+
output_frame_size=480,
|
| 202 |
+
) -> None:
|
| 203 |
+
super().__init__(
|
| 204 |
+
expected_layout,
|
| 205 |
+
output_sample_rate,
|
| 206 |
+
output_frame_size,
|
| 207 |
+
input_sample_rate=24000,
|
| 208 |
+
)
|
| 209 |
+
self.config = GeminiConfig()
|
| 210 |
+
self.ws = None
|
| 211 |
+
self.all_output_data = None
|
| 212 |
+
self.audio_processor = AudioProcessor()
|
| 213 |
+
self.audio_file = audio_file
|
| 214 |
+
self.text_buffer = ""
|
| 215 |
+
self.tts_engine = None
|
| 216 |
+
|
| 217 |
+
def copy(self):
|
| 218 |
+
"""Creates a copy of the GeminiHandler instance."""
|
| 219 |
+
return GeminiHandler(
|
| 220 |
+
expected_layout=self.expected_layout,
|
| 221 |
+
output_sample_rate=self.output_sample_rate,
|
| 222 |
+
output_frame_size=self.output_frame_size,
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
def _initialize_websocket(self):
|
| 226 |
+
"""Initializes the WebSocket connection to the Gemini API."""
|
| 227 |
+
try:
|
| 228 |
+
self.ws = websockets.sync.client.connect(self.config.ws_url, timeout=3000)
|
| 229 |
+
setup_request = {
|
| 230 |
+
"setup": {
|
| 231 |
+
"model": self.config.model,
|
| 232 |
+
"tools": [{"functionDeclarations": TOOLS}],
|
| 233 |
+
"generationConfig": {"responseModalities": "TEXT"},
|
| 234 |
+
"systemInstruction": {
|
| 235 |
+
"parts": [{"text": system_prompt}],
|
| 236 |
+
"role": "user",
|
| 237 |
+
},
|
| 238 |
+
}
|
| 239 |
+
}
|
| 240 |
+
self.ws.send(json.dumps(setup_request))
|
| 241 |
+
setup_response = json.loads(self.ws.recv())
|
| 242 |
+
print(f"Setup response: {setup_response}")
|
| 243 |
+
|
| 244 |
+
if self.audio_file:
|
| 245 |
+
self.input_audio_file(self.audio_file)
|
| 246 |
+
print("Audio file sent")
|
| 247 |
+
|
| 248 |
+
except websockets.exceptions.WebSocketException as e:
|
| 249 |
+
print(f"WebSocket connection failed: {str(e)}")
|
| 250 |
+
self.ws = None
|
| 251 |
+
except Exception as e:
|
| 252 |
+
print(f"Setup failed: {str(e)}")
|
| 253 |
+
self.ws = None
|
| 254 |
+
|
| 255 |
+
def input_audio_file(self, audio_file):
|
| 256 |
+
"""Processes an audio file and sends it to the Gemini API."""
|
| 257 |
+
try:
|
| 258 |
+
with wave.open(audio_file, "rb") as wf:
|
| 259 |
+
data = wf.readframes(wf.getnframes())
|
| 260 |
+
self.receive((wf.getframerate(), np.frombuffer(data, dtype=np.int16)))
|
| 261 |
+
except Exception as e:
|
| 262 |
+
print(f"Error in input_audio_file: {str(e)}")
|
| 263 |
+
|
| 264 |
+
def receive(self, frame: tuple[int, np.ndarray]) -> None:
|
| 265 |
+
"""Receives audio/video data, encodes it, and sends it to the Gemini API."""
|
| 266 |
+
try:
|
| 267 |
+
if not self.ws:
|
| 268 |
+
self._initialize_websocket()
|
| 269 |
+
|
| 270 |
+
sample_rate, array = frame
|
| 271 |
+
message = {"realtimeInput": {"mediaChunks": []}}
|
| 272 |
+
|
| 273 |
+
if sample_rate > 0 and array is not None:
|
| 274 |
+
array = array.squeeze()
|
| 275 |
+
audio_data = self.audio_processor.encode_audio(
|
| 276 |
+
array, self.output_sample_rate
|
| 277 |
+
)
|
| 278 |
+
message["realtimeInput"]["mediaChunks"].append(
|
| 279 |
+
{
|
| 280 |
+
"mimeType": f"audio/pcm;rate={self.output_sample_rate}",
|
| 281 |
+
"data": audio_data["realtimeInput"]["mediaChunks"][0]["data"],
|
| 282 |
+
}
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
if message["realtimeInput"]["mediaChunks"]:
|
| 286 |
+
self.ws.send(json.dumps(message))
|
| 287 |
+
except Exception as e:
|
| 288 |
+
print(f"Error in receive: {str(e)}")
|
| 289 |
+
if self.ws:
|
| 290 |
+
self.ws.close()
|
| 291 |
+
self.ws = None
|
| 292 |
+
|
| 293 |
+
def handle_tool_call(self, tool_call):
|
| 294 |
+
print(" ", tool_call)
|
| 295 |
+
for fc in tool_call["functionCalls"]:
|
| 296 |
+
print(f"Function call: {fc}")
|
| 297 |
+
# Call the function
|
| 298 |
+
try:
|
| 299 |
+
result = {"output": FUNCTION_MAP[fc["name"]](**fc["args"])}
|
| 300 |
+
except Exception as e:
|
| 301 |
+
result = {"error": str(e)}
|
| 302 |
+
|
| 303 |
+
# Send the response back
|
| 304 |
+
msg = {
|
| 305 |
+
"tool_response": {
|
| 306 |
+
"function_responses": [
|
| 307 |
+
{"id": fc["id"], "name": fc["name"], "response": result}
|
| 308 |
+
]
|
| 309 |
+
}
|
| 310 |
+
}
|
| 311 |
+
print(f"function response: {msg}")
|
| 312 |
+
self.ws.send(json.dumps(msg))
|
| 313 |
+
|
| 314 |
+
def _output_data(self, audio_array):
|
| 315 |
+
"""Processes audio output data from the WebSocket response."""
|
| 316 |
+
if self.all_output_data is None:
|
| 317 |
+
self.all_output_data = audio_array
|
| 318 |
+
else:
|
| 319 |
+
self.all_output_data = np.concatenate((self.all_output_data, audio_array))
|
| 320 |
+
|
| 321 |
+
while self.all_output_data.shape[-1] >= self.output_frame_size:
|
| 322 |
+
yield (
|
| 323 |
+
self.output_sample_rate,
|
| 324 |
+
self.all_output_data[: self.output_frame_size].reshape(1, -1),
|
| 325 |
+
)
|
| 326 |
+
self.all_output_data = self.all_output_data[self.output_frame_size :]
|
| 327 |
+
|
| 328 |
+
def _process_server_content(self, content):
|
| 329 |
+
"""Processes audio output data from the WebSocket response."""
|
| 330 |
+
if respone := content.get("modelTurn", {}):
|
| 331 |
+
if parts:= respone.get("parts"):
|
| 332 |
+
for part in parts:
|
| 333 |
+
print(f"Part: {part}")
|
| 334 |
+
data = part.get("inlineData", {}).get("data", "")
|
| 335 |
+
if data:
|
| 336 |
+
audio_array = self.audio_processor.process_audio_response(data)
|
| 337 |
+
yield from self._output_data(audio_array)
|
| 338 |
+
|
| 339 |
+
text = part.get("text", "")
|
| 340 |
+
if text:
|
| 341 |
+
self.text_buffer += text
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
# audio_array = self._text_to_audio(text)
|
| 346 |
+
# yield from self._output_data(audio_array)
|
| 347 |
+
# # self.text_buffer += text
|
| 348 |
+
|
| 349 |
+
# Check if the turn is complete and process the text buffer into audio
|
| 350 |
+
if content.get("turnComplete"):
|
| 351 |
+
if self.text_buffer:
|
| 352 |
+
audio_array = self._text_to_audio(self.text_buffer)
|
| 353 |
+
yield from self._output_data(audio_array)
|
| 354 |
+
self.text_buffer = ""
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def _text_to_audio(self, text: str) -> np.ndarray:
|
| 358 |
+
"""Convert text to audio using Google Cloud TTS streaming."""
|
| 359 |
+
|
| 360 |
+
client = texttospeech.TextToSpeechClient()
|
| 361 |
+
|
| 362 |
+
# Configure synthesis
|
| 363 |
+
synthesis_input = texttospeech.SynthesisInput(text=text)
|
| 364 |
+
voice = texttospeech.VoiceSelectionParams(
|
| 365 |
+
name="en-IN-Chirp-HD-O",
|
| 366 |
+
language_code="en-IN"
|
| 367 |
+
)
|
| 368 |
+
audio_config = texttospeech.AudioConfig(
|
| 369 |
+
audio_encoding=texttospeech.AudioEncoding.LINEAR16
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
# Get response in a single request
|
| 373 |
+
try:
|
| 374 |
+
response = client.synthesize_speech(
|
| 375 |
+
input=synthesis_input,
|
| 376 |
+
voice=voice,
|
| 377 |
+
audio_config=audio_config
|
| 378 |
+
)
|
| 379 |
+
return np.frombuffer(response.audio_content, dtype=np.int16)
|
| 380 |
+
except Exception as e:
|
| 381 |
+
print(f"Error in speech synthesis: {e}")
|
| 382 |
+
return np.array([], dtype=np.int16)
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def generator(self):
|
| 386 |
+
"""Generates audio output from the WebSocket stream."""
|
| 387 |
+
while True:
|
| 388 |
+
if not self.ws:
|
| 389 |
+
print("WebSocket not connected")
|
| 390 |
+
yield None
|
| 391 |
+
continue
|
| 392 |
+
|
| 393 |
+
try:
|
| 394 |
+
message = self.ws.recv(timeout=30)
|
| 395 |
+
msg = json.loads(message)
|
| 396 |
+
|
| 397 |
+
# {'serverContent': {'modelTurn': {'parts': [{'text': 'Hello'}]}}}
|
| 398 |
+
# {'serverContent': {'modelTurn': {'parts': [{'text': ', good morning! Thank you for taking my call. My name is [Your'}]}}}
|
| 399 |
+
# {'serverContent': {'modelTurn': {'parts': [{'text': " Name] and I'm a technical recruiter. I'm conducting a quick"}]}}}
|
| 400 |
+
# {'serverContent': {'modelTurn': {'parts': [{'text': ' initial screening, is that okay with you?\n'}]}}}
|
| 401 |
+
# {'serverContent': {'turnComplete': True}}
|
| 402 |
+
|
| 403 |
+
if "serverContent" in msg:
|
| 404 |
+
content = msg["serverContent"]
|
| 405 |
+
yield from self._process_server_content(content)
|
| 406 |
+
elif "toolCall" in msg:
|
| 407 |
+
yield from self.handle_tool_call(msg["toolCall"])
|
| 408 |
+
|
| 409 |
+
except TimeoutError:
|
| 410 |
+
print("Timeout waiting for server response")
|
| 411 |
+
yield None
|
| 412 |
+
except Exception:
|
| 413 |
+
yield None
|
| 414 |
+
|
| 415 |
+
def emit(self) -> tuple[int, np.ndarray] | None:
|
| 416 |
+
"""Emits the next audio chunk from the generator."""
|
| 417 |
+
if not self.ws:
|
| 418 |
+
return None
|
| 419 |
+
if not hasattr(self, "_generator"):
|
| 420 |
+
self._generator = self.generator()
|
| 421 |
+
try:
|
| 422 |
+
return next(self._generator)
|
| 423 |
+
except StopIteration:
|
| 424 |
+
self.reset()
|
| 425 |
+
return None
|
| 426 |
+
|
| 427 |
+
def reset(self) -> None:
|
| 428 |
+
"""Resets the generator and output data."""
|
| 429 |
+
if hasattr(self, "_generator"):
|
| 430 |
+
delattr(self, "_generator")
|
| 431 |
+
self.all_output_data = None
|
| 432 |
+
|
| 433 |
+
def shutdown(self) -> None:
|
| 434 |
+
"""Closes the WebSocket connection."""
|
| 435 |
+
if self.ws:
|
| 436 |
+
self.ws.close()
|
| 437 |
+
|
| 438 |
+
def check_connection(self):
|
| 439 |
+
"""Checks if the WebSocket connection is active."""
|
| 440 |
+
try:
|
| 441 |
+
if not self.ws or self.ws.closed:
|
| 442 |
+
self._initialize_websocket()
|
| 443 |
+
return True
|
| 444 |
+
except Exception as e:
|
| 445 |
+
print(f"Connection check failed: {str(e)}")
|
| 446 |
+
return False
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
def update_answers():
|
| 450 |
+
with open("answers.json", "r") as f:
|
| 451 |
+
return json.load(f)
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
# Main Gradio Interface
|
| 455 |
+
def registry(name: str, token: str | None = None, **kwargs):
|
| 456 |
+
"""Sets up and returns the Gradio interface."""
|
| 457 |
+
api_key = token or os.environ.get(KEY_NAME)
|
| 458 |
+
if not api_key:
|
| 459 |
+
raise ValueError(f"{KEY_NAME} environment variable is not set.")
|
| 460 |
+
|
| 461 |
+
interface = gr.Blocks()
|
| 462 |
+
with interface:
|
| 463 |
+
with gr.Tabs():
|
| 464 |
+
with gr.TabItem("Voice Chat"):
|
| 465 |
+
gr.HTML(
|
| 466 |
+
"""
|
| 467 |
+
<div style='text-align: left'>
|
| 468 |
+
<h1>ML6 Voice Demo - Function Calling and Custom Output Voice</h1>
|
| 469 |
+
</div>
|
| 470 |
+
"""
|
| 471 |
+
)
|
| 472 |
+
gemini_handler = GeminiHandler()
|
| 473 |
+
# gemini_handler = ThreeStepHandler()
|
| 474 |
+
|
| 475 |
+
with gr.Row():
|
| 476 |
+
audio = WebRTC(
|
| 477 |
+
label="Voice Chat", modality="audio", mode="send-receive"
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
# Add display components for questions and answers
|
| 481 |
+
with gr.Row():
|
| 482 |
+
with gr.Column():
|
| 483 |
+
gr.JSON(
|
| 484 |
+
label="Questions",
|
| 485 |
+
value=questions_dict,
|
| 486 |
+
)
|
| 487 |
+
with gr.Column():
|
| 488 |
+
gr.JSON(update_answers, label="Collected Answers", every=1)
|
| 489 |
+
|
| 490 |
+
audio.stream(
|
| 491 |
+
gemini_handler,
|
| 492 |
+
inputs=[audio], # Add audio_file to inputs
|
| 493 |
+
outputs=[audio],
|
| 494 |
+
time_limit=600,
|
| 495 |
+
concurrency_limit=10,
|
| 496 |
+
)
|
| 497 |
+
|
| 498 |
+
return interface
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
# Launch the Gradio interface
|
| 502 |
+
gr.load(
|
| 503 |
+
name="gemini-2.0-flash-exp",
|
| 504 |
+
src=registry,
|
| 505 |
+
).launch()
|
| 506 |
+
|
src copy/app2.py
ADDED
|
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright 2025 Google LLC
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""
|
| 16 |
+
## Setup
|
| 17 |
+
|
| 18 |
+
The gradio-webrtc install fails unless you have ffmpeg@6, on mac:
|
| 19 |
+
|
| 20 |
+
```
|
| 21 |
+
brew uninstall ffmpeg
|
| 22 |
+
brew install ffmpeg@6
|
| 23 |
+
brew link ffmpeg@6
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
Create a virtual python environment, then install the dependencies for this script:
|
| 27 |
+
|
| 28 |
+
```
|
| 29 |
+
pip install websockets numpy gradio-webrtc "gradio>=5.9.1"
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
If installation fails it may be
|
| 33 |
+
|
| 34 |
+
Before running this script, ensure the `GOOGLE_API_KEY` environment
|
| 35 |
+
|
| 36 |
+
```
|
| 37 |
+
$ export GOOGLE_API_KEY ='add your key here'
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
You can get an api-key from Google AI Studio (https://aistudio.google.com/apikey)
|
| 41 |
+
|
| 42 |
+
## Run
|
| 43 |
+
|
| 44 |
+
To run the script:
|
| 45 |
+
|
| 46 |
+
```
|
| 47 |
+
python gemini_gradio_audio.py
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
On the gradio page (http://127.0.0.1:7860/) click record, and talk, gemini will reply. But note that interruptions
|
| 51 |
+
don't work.
|
| 52 |
+
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
import asyncio
|
| 56 |
+
import json
|
| 57 |
+
import os
|
| 58 |
+
from typing import Literal
|
| 59 |
+
import base64
|
| 60 |
+
|
| 61 |
+
import gradio as gr
|
| 62 |
+
import numpy as np
|
| 63 |
+
from fastrtc import (
|
| 64 |
+
AsyncStreamHandler,
|
| 65 |
+
WebRTC,
|
| 66 |
+
wait_for_item,
|
| 67 |
+
)
|
| 68 |
+
from jinja2 import Template
|
| 69 |
+
from google import genai
|
| 70 |
+
from google.genai.types import LiveConnectConfig, Tool, FunctionDeclaration
|
| 71 |
+
|
| 72 |
+
from google.cloud import texttospeech
|
| 73 |
+
|
| 74 |
+
from tools import FUNCTION_MAP, TOOLS
|
| 75 |
+
|
| 76 |
+
with open("questions.json", "r") as f:
|
| 77 |
+
questions_dict = json.load(f)
|
| 78 |
+
|
| 79 |
+
with open("src/prompts/default_prompt.jinja2") as f:
|
| 80 |
+
template_str = f.read()
|
| 81 |
+
template = Template(template_str)
|
| 82 |
+
system_prompt = template.render(questions=json.dumps(questions_dict, indent=4))
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class TTSConfig:
|
| 88 |
+
def __init__(self):
|
| 89 |
+
self.client = texttospeech.TextToSpeechClient()
|
| 90 |
+
self.voice = texttospeech.VoiceSelectionParams(
|
| 91 |
+
name="en-US-Chirp3-HD-Charon",
|
| 92 |
+
language_code="en-US"
|
| 93 |
+
)
|
| 94 |
+
self.audio_config = texttospeech.AudioConfig(
|
| 95 |
+
audio_encoding=texttospeech.AudioEncoding.LINEAR16
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class AsyncGeminiHandler(AsyncStreamHandler):
|
| 100 |
+
"""Simple Async Gemini Handler"""
|
| 101 |
+
|
| 102 |
+
def __init__(
|
| 103 |
+
self,
|
| 104 |
+
expected_layout: Literal["mono"] = "mono",
|
| 105 |
+
output_sample_rate: int = 24000,
|
| 106 |
+
output_frame_size: int = 480,
|
| 107 |
+
) -> None:
|
| 108 |
+
super().__init__(
|
| 109 |
+
expected_layout,
|
| 110 |
+
output_sample_rate,
|
| 111 |
+
output_frame_size,
|
| 112 |
+
input_sample_rate=16000,
|
| 113 |
+
)
|
| 114 |
+
self.input_queue: asyncio.Queue = asyncio.Queue()
|
| 115 |
+
self.output_queue: asyncio.Queue = asyncio.Queue()
|
| 116 |
+
self.text_queue: asyncio.Queue = asyncio.Queue()
|
| 117 |
+
self.quit: asyncio.Event = asyncio.Event()
|
| 118 |
+
self.chunk_size = 1024
|
| 119 |
+
|
| 120 |
+
self.tts_config: TTSConfig | None = TTSConfig()
|
| 121 |
+
self.text_buffer = ""
|
| 122 |
+
|
| 123 |
+
def copy(self) -> "AsyncGeminiHandler":
|
| 124 |
+
return AsyncGeminiHandler(
|
| 125 |
+
expected_layout="mono",
|
| 126 |
+
output_sample_rate=self.output_sample_rate,
|
| 127 |
+
output_frame_size=self.output_frame_size,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
def _encode_audio(self, data: np.ndarray) -> str:
|
| 131 |
+
"""Encode Audio data to send to the server"""
|
| 132 |
+
return base64.b64encode(data.tobytes()).decode("UTF-8")
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
|
| 136 |
+
_, array = frame
|
| 137 |
+
array = array.squeeze()
|
| 138 |
+
audio_message = self._encode_audio(array)
|
| 139 |
+
self.input_queue.put_nowait(audio_message)
|
| 140 |
+
|
| 141 |
+
async def emit(self) -> tuple[int, np.ndarray] | None:
|
| 142 |
+
return await wait_for_item(self.output_queue)
|
| 143 |
+
|
| 144 |
+
async def start_up(self) -> None:
|
| 145 |
+
client = genai.Client(
|
| 146 |
+
api_key=os.getenv("GOOGLE_API_KEY"),
|
| 147 |
+
http_options={"api_version": "v1alpha"},
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
config = LiveConnectConfig(
|
| 152 |
+
system_instruction={
|
| 153 |
+
"parts": [{"text": system_prompt}],
|
| 154 |
+
"role": "user",
|
| 155 |
+
},
|
| 156 |
+
tools=[Tool(function_declarations=[FunctionDeclaration(**tool) for tool in TOOLS])],
|
| 157 |
+
response_modalities=["AUDIO"],
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
async with (
|
| 161 |
+
client.aio.live.connect(model="gemini-2.0-flash-exp", config=config) as session,
|
| 162 |
+
asyncio.TaskGroup() as tg
|
| 163 |
+
):
|
| 164 |
+
self.session = session
|
| 165 |
+
|
| 166 |
+
tasks = [
|
| 167 |
+
tg.create_task(self.process()),
|
| 168 |
+
tg.create_task(self.send_realtime()),
|
| 169 |
+
tg.create_task(self.tts()),
|
| 170 |
+
]
|
| 171 |
+
|
| 172 |
+
async def process(self) -> None:
|
| 173 |
+
while True:
|
| 174 |
+
try:
|
| 175 |
+
turn = self.session.receive()
|
| 176 |
+
async for response in turn:
|
| 177 |
+
if data := response.data:
|
| 178 |
+
array = np.frombuffer(data, dtype=np.int16)
|
| 179 |
+
self.output_queue.put_nowait((self.output_sample_rate, array))
|
| 180 |
+
continue
|
| 181 |
+
|
| 182 |
+
if text := response.text:
|
| 183 |
+
print(f"Received text: {text}")
|
| 184 |
+
self.text_buffer += text
|
| 185 |
+
|
| 186 |
+
if response.tool_call is not None:
|
| 187 |
+
for tool in response.tool_call.function_calls:
|
| 188 |
+
tool_response = FUNCTION_MAP[tool.name](**tool.args)
|
| 189 |
+
print(f"Calling tool: {tool.name}")
|
| 190 |
+
print(f"Tool response: {tool_response}")
|
| 191 |
+
await self.session.send(
|
| 192 |
+
input=tool_response, end_of_turn=True
|
| 193 |
+
)
|
| 194 |
+
await asyncio.sleep(0.1)
|
| 195 |
+
|
| 196 |
+
if sc := response.server_content:
|
| 197 |
+
if sc.turn_complete and self.text_buffer:
|
| 198 |
+
self.text_queue.put_nowait(self.text_buffer)
|
| 199 |
+
FUNCTION_MAP["store_input"](
|
| 200 |
+
role="bot",
|
| 201 |
+
input=self.text_buffer
|
| 202 |
+
)
|
| 203 |
+
self.text_buffer = ""
|
| 204 |
+
|
| 205 |
+
except Exception as e:
|
| 206 |
+
print(f"Error in processing: {e}")
|
| 207 |
+
await asyncio.sleep(0.1)
|
| 208 |
+
|
| 209 |
+
async def send_realtime(self) -> None:
|
| 210 |
+
"""Send real-time audio data to model."""
|
| 211 |
+
while True:
|
| 212 |
+
try:
|
| 213 |
+
data = await self.input_queue.get()
|
| 214 |
+
msg = {"data": data, "mime_type": "audio/pcm"}
|
| 215 |
+
await self.session.send(input=msg)
|
| 216 |
+
except Exception as e:
|
| 217 |
+
print(f"Error in real-time sending: {e}")
|
| 218 |
+
await asyncio.sleep(0.1)
|
| 219 |
+
|
| 220 |
+
async def tts(self) -> None:
|
| 221 |
+
|
| 222 |
+
while True:
|
| 223 |
+
try:
|
| 224 |
+
text = await self.text_queue.get()
|
| 225 |
+
# Get response in a single request
|
| 226 |
+
if text:
|
| 227 |
+
response = self.tts_config.client.synthesize_speech(
|
| 228 |
+
input=texttospeech.SynthesisInput(text=text),
|
| 229 |
+
voice=self.tts_config.voice,
|
| 230 |
+
audio_config=self.tts_config.audio_config
|
| 231 |
+
)
|
| 232 |
+
array = np.frombuffer(response.audio_content, dtype=np.int16)
|
| 233 |
+
self.output_queue.put_nowait((self.output_sample_rate, array))
|
| 234 |
+
|
| 235 |
+
except Exception as e:
|
| 236 |
+
print(f"Error in TTS: {e}")
|
| 237 |
+
await asyncio.sleep(0.1)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def shutdown(self) -> None:
|
| 241 |
+
self.quit.set()
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def reload_json(path):
|
| 245 |
+
with open(path, "r") as f:
|
| 246 |
+
return json.load(f)
|
| 247 |
+
|
| 248 |
+
# Main Gradio Interface
|
| 249 |
+
def registry(name: str, token: str | None = None, **kwargs):
|
| 250 |
+
"""Sets up and returns the Gradio interface."""
|
| 251 |
+
|
| 252 |
+
interface = gr.Blocks()
|
| 253 |
+
with interface:
|
| 254 |
+
with gr.Tabs():
|
| 255 |
+
with gr.TabItem("Voice Chat"):
|
| 256 |
+
gr.HTML(
|
| 257 |
+
"""
|
| 258 |
+
<div style='text-align: left'>
|
| 259 |
+
<h1>ML6 Voice Demo - Function Calling and Custom Output Voice</h1>
|
| 260 |
+
</div>
|
| 261 |
+
"""
|
| 262 |
+
)
|
| 263 |
+
gemini_handler = AsyncGeminiHandler()
|
| 264 |
+
|
| 265 |
+
with gr.Row():
|
| 266 |
+
audio = WebRTC(
|
| 267 |
+
label="Voice Chat", modality="audio", mode="send-receive"
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
# Add display components for questions and answers
|
| 271 |
+
with gr.Row():
|
| 272 |
+
with gr.Column():
|
| 273 |
+
gr.JSON(
|
| 274 |
+
label="Questions",
|
| 275 |
+
value=questions_dict,
|
| 276 |
+
)
|
| 277 |
+
# with gr.Column():
|
| 278 |
+
# gr.JSON(reload_json, inputs=gr.Text(value="/Users/georgeslorre/ML6/internal/gemini-voice-agents/conversation.json", visible=False), label="Conversation", every=1)
|
| 279 |
+
with gr.Column():
|
| 280 |
+
gr.JSON(reload_json, inputs=gr.Text(value="/Users/georgeslorre/ML6/internal/gemini-voice-agents/answers.json", visible=False),label="Collected Answers", every=1)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
audio.stream(
|
| 284 |
+
gemini_handler,
|
| 285 |
+
inputs=[audio], # Add audio_file to inputs
|
| 286 |
+
outputs=[audio],
|
| 287 |
+
time_limit=600,
|
| 288 |
+
concurrency_limit=10,
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
return interface
|
| 292 |
+
|
| 293 |
+
# Function to clear JSON files
|
| 294 |
+
def clear_json_files():
|
| 295 |
+
with open("/Users/georgeslorre/ML6/internal/gemini-voice-agents/conversation.json", "w") as f:
|
| 296 |
+
json.dump([], f)
|
| 297 |
+
with open("/Users/georgeslorre/ML6/internal/gemini-voice-agents/answers.json", "w") as f:
|
| 298 |
+
json.dump({}, f)
|
| 299 |
+
|
| 300 |
+
# Clear files before launching
|
| 301 |
+
clear_json_files()
|
| 302 |
+
|
| 303 |
+
# Launch the Gradio interface
|
| 304 |
+
gr.load(
|
| 305 |
+
name="gemini-2.0-flash-exp",
|
| 306 |
+
src=registry,
|
| 307 |
+
).launch()
|
| 308 |
+
|
src copy/app3.py
ADDED
|
File without changes
|
src copy/helpers/loop.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Helper for audio loop."""
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
import logging
|
| 5 |
+
import traceback
|
| 6 |
+
import wave
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
import pyaudio
|
| 10 |
+
from google import genai
|
| 11 |
+
|
| 12 |
+
from models import AudioConfig, ModelConfig
|
| 13 |
+
from tools import FUNCTION_MAP
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class TextLoop:
|
| 19 |
+
def __init__(self, model_config: ModelConfig):
|
| 20 |
+
self.model_config = model_config
|
| 21 |
+
self.client = self._setup_client()
|
| 22 |
+
self.session = None
|
| 23 |
+
|
| 24 |
+
def _setup_client(self) -> genai.Client:
|
| 25 |
+
"""Initialize the Gemini client."""
|
| 26 |
+
return genai.Client(
|
| 27 |
+
api_key=self.model_config.api_key,
|
| 28 |
+
http_options={"api_version": "v1alpha"},
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
async def send_text(self) -> None:
|
| 32 |
+
"""Handle text input and send to model."""
|
| 33 |
+
while True:
|
| 34 |
+
try:
|
| 35 |
+
text = await asyncio.to_thread(input, "message > ")
|
| 36 |
+
if text.lower() == "q":
|
| 37 |
+
break
|
| 38 |
+
await self.session.send(input=text or ".", end_of_turn=True)
|
| 39 |
+
except Exception as e:
|
| 40 |
+
logger.error(f"Error sending text: {e}")
|
| 41 |
+
await asyncio.sleep(0.1)
|
| 42 |
+
|
| 43 |
+
async def receive_text(self) -> None:
|
| 44 |
+
"""Process and handle model responses."""
|
| 45 |
+
while True:
|
| 46 |
+
try:
|
| 47 |
+
turn = self.session.receive()
|
| 48 |
+
async for response in turn:
|
| 49 |
+
if text := response.text:
|
| 50 |
+
logger.info(text)
|
| 51 |
+
if response.tool_call is not None:
|
| 52 |
+
for tool in response.tool_call.function_calls:
|
| 53 |
+
tool_response = FUNCTION_MAP[tool.name](**tool.args)
|
| 54 |
+
logger.info(tool_response)
|
| 55 |
+
await self.session.send(
|
| 56 |
+
input=tool_response, end_of_turn=True
|
| 57 |
+
)
|
| 58 |
+
await asyncio.sleep(0.1)
|
| 59 |
+
except Exception as e:
|
| 60 |
+
logger.error(f"Error receiving text: {e}")
|
| 61 |
+
await asyncio.sleep(0.1)
|
| 62 |
+
|
| 63 |
+
async def run(self):
|
| 64 |
+
try:
|
| 65 |
+
async with (
|
| 66 |
+
self.client.aio.live.connect(
|
| 67 |
+
model=self.model_config.name,
|
| 68 |
+
config={
|
| 69 |
+
"system_instruction": self.model_config.system_instruction,
|
| 70 |
+
"tools": self.model_config.tools,
|
| 71 |
+
"generation_config": self.model_config.generation_config,
|
| 72 |
+
},
|
| 73 |
+
) as session,
|
| 74 |
+
asyncio.TaskGroup() as tg,
|
| 75 |
+
):
|
| 76 |
+
self.session = session
|
| 77 |
+
tasks = [
|
| 78 |
+
tg.create_task(self.send_text()),
|
| 79 |
+
tg.create_task(self.receive_text()),
|
| 80 |
+
]
|
| 81 |
+
|
| 82 |
+
await tasks[0] # Wait for send_text to complete
|
| 83 |
+
raise asyncio.CancelledError("User requested exit")
|
| 84 |
+
|
| 85 |
+
except asyncio.CancelledError:
|
| 86 |
+
logger.info("Shutting down...")
|
| 87 |
+
except Exception as e:
|
| 88 |
+
logger.error(f"Error in main loop: {e}")
|
| 89 |
+
logger.debug(traceback.format_exc())
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class AudioLoop:
|
| 93 |
+
"""Handles real-time audio streaming and processing."""
|
| 94 |
+
|
| 95 |
+
def __init__(
|
| 96 |
+
self,
|
| 97 |
+
audio_config: AudioConfig,
|
| 98 |
+
model_config: ModelConfig,
|
| 99 |
+
function_map: Optional[dict[str, callable]] = FUNCTION_MAP,
|
| 100 |
+
instruction_audio: Optional[str] = None,
|
| 101 |
+
):
|
| 102 |
+
"""Initialize the audio loop.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
audio_config (AudioConfig): Audio configuration settings
|
| 106 |
+
model_config (ModelConfig): Model configuration settings
|
| 107 |
+
function_map (Optional[dict[str, callable]]): Function map
|
| 108 |
+
"""
|
| 109 |
+
self.audio_config = audio_config
|
| 110 |
+
self.model_config = model_config
|
| 111 |
+
|
| 112 |
+
self.audio_in_queue: Optional[asyncio.Queue] = None
|
| 113 |
+
self.out_queue: Optional[asyncio.Queue] = None
|
| 114 |
+
self.session = None
|
| 115 |
+
self.audio_stream = None
|
| 116 |
+
self.client = self._setup_client()
|
| 117 |
+
self.instruction_audio = instruction_audio
|
| 118 |
+
|
| 119 |
+
self.function_map = function_map
|
| 120 |
+
|
| 121 |
+
def _setup_client(self) -> genai.Client:
|
| 122 |
+
"""Initialize the Gemini client."""
|
| 123 |
+
return genai.Client(
|
| 124 |
+
api_key=self.model_config.api_key,
|
| 125 |
+
http_options={"api_version": "v1alpha"},
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
async def send_text(self) -> None:
|
| 129 |
+
"""Handle text input and send to model."""
|
| 130 |
+
while True:
|
| 131 |
+
try:
|
| 132 |
+
text = await asyncio.to_thread(input, "message > ")
|
| 133 |
+
if text.lower() == "q":
|
| 134 |
+
break
|
| 135 |
+
await self.session.send(input=text or ".", end_of_turn=True)
|
| 136 |
+
except Exception as e:
|
| 137 |
+
logger.error(f"Error sending text: {e}")
|
| 138 |
+
await asyncio.sleep(0.1)
|
| 139 |
+
|
| 140 |
+
async def send_realtime(self) -> None:
|
| 141 |
+
"""Send real-time audio data to model."""
|
| 142 |
+
while True:
|
| 143 |
+
try:
|
| 144 |
+
msg = await self.out_queue.get()
|
| 145 |
+
await self.session.send(input=msg)
|
| 146 |
+
except Exception as e:
|
| 147 |
+
logger.error(f"Error in real-time sending: {e}")
|
| 148 |
+
await asyncio.sleep(0.1)
|
| 149 |
+
|
| 150 |
+
def input_audio_file(self, file_path: str):
|
| 151 |
+
"""Read audio file and stream to the model."""
|
| 152 |
+
try:
|
| 153 |
+
with wave.open(file_path, "rb") as wave_file:
|
| 154 |
+
data = wave_file.readframes(wave_file.getnframes())
|
| 155 |
+
self.out_queue.put_nowait({"data": data, "mime_type": "audio/pcm"})
|
| 156 |
+
except Exception as e:
|
| 157 |
+
logger.error(f"Error reading audio file: {e}")
|
| 158 |
+
|
| 159 |
+
async def listen_audio(self) -> None:
|
| 160 |
+
"""Capture and process audio input."""
|
| 161 |
+
try:
|
| 162 |
+
pya = pyaudio.PyAudio()
|
| 163 |
+
mic_info = pya.get_default_input_device_info()
|
| 164 |
+
self.audio_stream = await asyncio.to_thread(
|
| 165 |
+
pya.open,
|
| 166 |
+
format=self.audio_config.format,
|
| 167 |
+
channels=self.audio_config.channels,
|
| 168 |
+
rate=self.audio_config.send_sample_rate,
|
| 169 |
+
input=True,
|
| 170 |
+
input_device_index=mic_info["index"],
|
| 171 |
+
frames_per_buffer=self.audio_config.chunk_size,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
kwargs = {"exception_on_overflow": False} if __debug__ else {}
|
| 175 |
+
|
| 176 |
+
while True:
|
| 177 |
+
data = await asyncio.to_thread(
|
| 178 |
+
self.audio_stream.read,
|
| 179 |
+
self.audio_config.chunk_size,
|
| 180 |
+
**kwargs,
|
| 181 |
+
)
|
| 182 |
+
await self.out_queue.put({"data": data, "mime_type": "audio/pcm"})
|
| 183 |
+
except Exception as e:
|
| 184 |
+
logger.error(f"Error in audio listening: {e}")
|
| 185 |
+
if self.audio_stream:
|
| 186 |
+
self.audio_stream.close()
|
| 187 |
+
|
| 188 |
+
async def receive_audio(self) -> None:
|
| 189 |
+
"""Process and handle model responses."""
|
| 190 |
+
while True:
|
| 191 |
+
try:
|
| 192 |
+
turn = self.session.receive()
|
| 193 |
+
async for response in turn:
|
| 194 |
+
if data := response.data:
|
| 195 |
+
self.audio_in_queue.put_nowait(data)
|
| 196 |
+
continue
|
| 197 |
+
if text := response.text:
|
| 198 |
+
logger.info(text)
|
| 199 |
+
if response.tool_call is not None:
|
| 200 |
+
for tool in response.tool_call.function_calls:
|
| 201 |
+
tool_response = FUNCTION_MAP[tool.name](**tool.args)
|
| 202 |
+
logger.info(tool_response)
|
| 203 |
+
await self.session.send(
|
| 204 |
+
input=tool_response, end_of_turn=True
|
| 205 |
+
)
|
| 206 |
+
await asyncio.sleep(0.1)
|
| 207 |
+
|
| 208 |
+
# Clear queue on turn completion
|
| 209 |
+
while not self.audio_in_queue.empty():
|
| 210 |
+
self.audio_in_queue.get_nowait()
|
| 211 |
+
except Exception as e:
|
| 212 |
+
logger.error(f"Error receiving audio: {e}")
|
| 213 |
+
await asyncio.sleep(0.1)
|
| 214 |
+
|
| 215 |
+
async def play_audio(self) -> None:
|
| 216 |
+
"""Play received audio through output device."""
|
| 217 |
+
try:
|
| 218 |
+
pya = pyaudio.PyAudio()
|
| 219 |
+
stream = await asyncio.to_thread(
|
| 220 |
+
pya.open,
|
| 221 |
+
format=self.audio_config.format,
|
| 222 |
+
channels=self.audio_config.channels,
|
| 223 |
+
rate=self.audio_config.receive_sample_rate,
|
| 224 |
+
output=True,
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
while True:
|
| 228 |
+
bytestream = await self.audio_in_queue.get()
|
| 229 |
+
await asyncio.to_thread(stream.write, bytestream)
|
| 230 |
+
except Exception as e:
|
| 231 |
+
logger.error(f"Error playing audio: {e}")
|
| 232 |
+
if "stream" in locals():
|
| 233 |
+
stream.close()
|
| 234 |
+
|
| 235 |
+
async def run(self) -> None:
|
| 236 |
+
"""Main execution loop."""
|
| 237 |
+
try:
|
| 238 |
+
async with (
|
| 239 |
+
self.client.aio.live.connect(
|
| 240 |
+
model=self.model_config.name,
|
| 241 |
+
config={
|
| 242 |
+
"system_instruction": self.model_config.system_instruction,
|
| 243 |
+
"tools": self.model_config.tools,
|
| 244 |
+
"generation_config": self.model_config.generation_config,
|
| 245 |
+
},
|
| 246 |
+
) as session,
|
| 247 |
+
asyncio.TaskGroup() as tg,
|
| 248 |
+
):
|
| 249 |
+
self.session = session
|
| 250 |
+
self.audio_in_queue = asyncio.Queue()
|
| 251 |
+
self.out_queue = asyncio.Queue(maxsize=5)
|
| 252 |
+
|
| 253 |
+
if self.instruction_audio:
|
| 254 |
+
self.input_audio_file(file_path=self.instruction_audio)
|
| 255 |
+
|
| 256 |
+
tasks = [
|
| 257 |
+
tg.create_task(self.send_text()),
|
| 258 |
+
tg.create_task(self.send_realtime()),
|
| 259 |
+
tg.create_task(self.listen_audio()),
|
| 260 |
+
tg.create_task(self.receive_audio()),
|
| 261 |
+
tg.create_task(self.play_audio()),
|
| 262 |
+
]
|
| 263 |
+
|
| 264 |
+
await tasks[0] # Wait for send_text to complete
|
| 265 |
+
raise asyncio.CancelledError("User requested exit")
|
| 266 |
+
|
| 267 |
+
except asyncio.CancelledError:
|
| 268 |
+
logger.info("Shutting down...")
|
| 269 |
+
except Exception as e:
|
| 270 |
+
logger.error(f"Error in main loop: {e}")
|
| 271 |
+
logger.debug(traceback.format_exc())
|
| 272 |
+
finally:
|
| 273 |
+
if self.audio_stream:
|
| 274 |
+
self.audio_stream.close()
|
src copy/helpers/prompts.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""This module contains the prompts for the application."""
|
| 2 |
+
|
| 3 |
+
# import jinja2 template prompt
|
| 4 |
+
|
| 5 |
+
from jinja2 import Template
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def load_prompt(prompt_path: str) -> str:
|
| 9 |
+
"""Load the prompt from the given path."""
|
| 10 |
+
with open(prompt_path, "r", encoding="utf-8") as file:
|
| 11 |
+
prompt = Template(file.read())
|
| 12 |
+
return prompt.render()
|
src copy/helpers/session.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
|
| 5 |
+
from jinja2 import Template
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass
|
| 9 |
+
class Question:
|
| 10 |
+
id: int
|
| 11 |
+
text: str
|
| 12 |
+
answer_format: type
|
| 13 |
+
user_answer: any = None
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Session:
|
| 17 |
+
def __init__(self, questions):
|
| 18 |
+
self.session_id = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 19 |
+
self.questions = questions
|
| 20 |
+
# self.questions = self.process_questions(questions)
|
| 21 |
+
|
| 22 |
+
@staticmethod
|
| 23 |
+
def process_questions(questions):
|
| 24 |
+
qq = {}
|
| 25 |
+
for q in questions:
|
| 26 |
+
if q["answer_format"] == "number":
|
| 27 |
+
Q = Question(q["id"], q["text"], int, None)
|
| 28 |
+
elif q["answer_format"] == "text":
|
| 29 |
+
Q = Question(q["id"], q["text"], str, None)
|
| 30 |
+
elif q["answer_format"] == "list":
|
| 31 |
+
Q = Question(q["id"], q["text"], list, None)
|
| 32 |
+
else:
|
| 33 |
+
raise ValueError("Invalid answer format")
|
| 34 |
+
qq[q["id"]] = Q
|
| 35 |
+
return qq
|
| 36 |
+
|
| 37 |
+
def answer_question(self, question_id, user_answer):
|
| 38 |
+
self.questions[question_id].user_answer = user_answer
|
| 39 |
+
|
| 40 |
+
def get_next_question(self):
|
| 41 |
+
for q in self.questions:
|
| 42 |
+
if q.user_answer:
|
| 43 |
+
return q
|
| 44 |
+
return False
|
| 45 |
+
|
| 46 |
+
def zero_shot_prompt(self, prompt_template_path):
|
| 47 |
+
with open(prompt_template_path) as f:
|
| 48 |
+
template_str = f.read()
|
| 49 |
+
template = Template(template_str)
|
| 50 |
+
return template.render(questions=json.dumps(self.questions, indent=4))
|
src copy/index.html
ADDED
|
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
|
| 4 |
+
<head>
|
| 5 |
+
<meta charset="UTF-8">
|
| 6 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 7 |
+
<title>Gemini Voice Chat</title>
|
| 8 |
+
<style>
|
| 9 |
+
:root {
|
| 10 |
+
--color-accent: #6366f1;
|
| 11 |
+
--color-background: #0f172a;
|
| 12 |
+
--color-surface: #1e293b;
|
| 13 |
+
--color-text: #e2e8f0;
|
| 14 |
+
--boxSize: 8px;
|
| 15 |
+
--gutter: 4px;
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
body {
|
| 19 |
+
margin: 0;
|
| 20 |
+
padding: 0;
|
| 21 |
+
background-color: var(--color-background);
|
| 22 |
+
color: var(--color-text);
|
| 23 |
+
font-family: system-ui, -apple-system, sans-serif;
|
| 24 |
+
min-height: 100vh;
|
| 25 |
+
display: flex;
|
| 26 |
+
flex-direction: column;
|
| 27 |
+
align-items: center;
|
| 28 |
+
justify-content: center;
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
.container {
|
| 32 |
+
width: 90%;
|
| 33 |
+
max-width: 800px;
|
| 34 |
+
background-color: var(--color-surface);
|
| 35 |
+
padding: 2rem;
|
| 36 |
+
border-radius: 1rem;
|
| 37 |
+
box-shadow: 0 25px 50px -12px rgba(0, 0, 0, 0.25);
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
.wave-container {
|
| 41 |
+
position: relative;
|
| 42 |
+
display: flex;
|
| 43 |
+
min-height: 100px;
|
| 44 |
+
max-height: 128px;
|
| 45 |
+
justify-content: center;
|
| 46 |
+
align-items: center;
|
| 47 |
+
margin: 2rem 0;
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
.box-container {
|
| 51 |
+
display: flex;
|
| 52 |
+
justify-content: space-between;
|
| 53 |
+
height: 64px;
|
| 54 |
+
width: 100%;
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
.box {
|
| 58 |
+
height: 100%;
|
| 59 |
+
width: var(--boxSize);
|
| 60 |
+
background: var(--color-accent);
|
| 61 |
+
border-radius: 8px;
|
| 62 |
+
transition: transform 0.05s ease;
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
.controls {
|
| 66 |
+
display: grid;
|
| 67 |
+
gap: 1rem;
|
| 68 |
+
margin-bottom: 2rem;
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
.input-group {
|
| 72 |
+
display: flex;
|
| 73 |
+
flex-direction: column;
|
| 74 |
+
gap: 0.5rem;
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
label {
|
| 78 |
+
font-size: 0.875rem;
|
| 79 |
+
font-weight: 500;
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
input,
|
| 83 |
+
select {
|
| 84 |
+
padding: 0.75rem;
|
| 85 |
+
border-radius: 0.5rem;
|
| 86 |
+
border: 1px solid rgba(255, 255, 255, 0.1);
|
| 87 |
+
background-color: var(--color-background);
|
| 88 |
+
color: var(--color-text);
|
| 89 |
+
font-size: 1rem;
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
button {
|
| 93 |
+
padding: 1rem 2rem;
|
| 94 |
+
border-radius: 0.5rem;
|
| 95 |
+
border: none;
|
| 96 |
+
background-color: var(--color-accent);
|
| 97 |
+
color: white;
|
| 98 |
+
font-weight: 600;
|
| 99 |
+
cursor: pointer;
|
| 100 |
+
transition: all 0.2s ease;
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
button:hover {
|
| 104 |
+
opacity: 0.9;
|
| 105 |
+
transform: translateY(-1px);
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
.icon-with-spinner {
|
| 109 |
+
display: flex;
|
| 110 |
+
align-items: center;
|
| 111 |
+
justify-content: center;
|
| 112 |
+
gap: 12px;
|
| 113 |
+
min-width: 180px;
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
.spinner {
|
| 117 |
+
width: 20px;
|
| 118 |
+
height: 20px;
|
| 119 |
+
border: 2px solid white;
|
| 120 |
+
border-top-color: transparent;
|
| 121 |
+
border-radius: 50%;
|
| 122 |
+
animation: spin 1s linear infinite;
|
| 123 |
+
flex-shrink: 0;
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
@keyframes spin {
|
| 127 |
+
to {
|
| 128 |
+
transform: rotate(360deg);
|
| 129 |
+
}
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
.pulse-container {
|
| 133 |
+
display: flex;
|
| 134 |
+
align-items: center;
|
| 135 |
+
justify-content: center;
|
| 136 |
+
gap: 12px;
|
| 137 |
+
min-width: 180px;
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
.pulse-circle {
|
| 141 |
+
width: 20px;
|
| 142 |
+
height: 20px;
|
| 143 |
+
border-radius: 50%;
|
| 144 |
+
background-color: white;
|
| 145 |
+
opacity: 0.2;
|
| 146 |
+
flex-shrink: 0;
|
| 147 |
+
transform: translateX(-0%) scale(var(--audio-level, 1));
|
| 148 |
+
transition: transform 0.1s ease;
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
/* Add styles for toast notifications */
|
| 152 |
+
.toast {
|
| 153 |
+
position: fixed;
|
| 154 |
+
top: 20px;
|
| 155 |
+
left: 50%;
|
| 156 |
+
transform: translateX(-50%);
|
| 157 |
+
padding: 16px 24px;
|
| 158 |
+
border-radius: 4px;
|
| 159 |
+
font-size: 14px;
|
| 160 |
+
z-index: 1000;
|
| 161 |
+
display: none;
|
| 162 |
+
box-shadow: 0 2px 5px rgba(0, 0, 0, 0.2);
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
.toast.error {
|
| 166 |
+
background-color: #f44336;
|
| 167 |
+
color: white;
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
.toast.warning {
|
| 171 |
+
background-color: #ffd700;
|
| 172 |
+
color: black;
|
| 173 |
+
}
|
| 174 |
+
</style>
|
| 175 |
+
</head>
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
<body>
|
| 179 |
+
<!-- Add toast element after body opening tag -->
|
| 180 |
+
<div id="error-toast" class="toast"></div>
|
| 181 |
+
<div style="text-align: center">
|
| 182 |
+
<h1>Gemini Voice Chat</h1>
|
| 183 |
+
<p>Speak with Gemini using real-time audio streaming</p>
|
| 184 |
+
<p>
|
| 185 |
+
Get a Gemini API key
|
| 186 |
+
<a href="https://ai.google.dev/gemini-api/docs/api-key">here</a>
|
| 187 |
+
</p>
|
| 188 |
+
</div>
|
| 189 |
+
<div class="container">
|
| 190 |
+
<div class="controls">
|
| 191 |
+
<div class="input-group">
|
| 192 |
+
<label for="api-key">API Key</label>
|
| 193 |
+
<input type="password" id="api-key" placeholder="Enter your API key">
|
| 194 |
+
</div>
|
| 195 |
+
<div class="input-group">
|
| 196 |
+
<label for="voice">Voice</label>
|
| 197 |
+
<select id="voice">
|
| 198 |
+
<option value="Puck">Puck</option>
|
| 199 |
+
<option value="Charon">Charon</option>
|
| 200 |
+
<option value="Kore">Kore</option>
|
| 201 |
+
<option value="Fenrir">Fenrir</option>
|
| 202 |
+
<option value="Aoede">Aoede</option>
|
| 203 |
+
</select>
|
| 204 |
+
</div>
|
| 205 |
+
</div>
|
| 206 |
+
|
| 207 |
+
<div class="wave-container">
|
| 208 |
+
<div class="box-container">
|
| 209 |
+
<!-- Boxes will be dynamically added here -->
|
| 210 |
+
</div>
|
| 211 |
+
</div>
|
| 212 |
+
|
| 213 |
+
<button id="start-button">Start Recording</button>
|
| 214 |
+
</div>
|
| 215 |
+
|
| 216 |
+
<audio id="audio-output"></audio>
|
| 217 |
+
|
| 218 |
+
<script>
|
| 219 |
+
let peerConnection;
|
| 220 |
+
let audioContext;
|
| 221 |
+
let dataChannel;
|
| 222 |
+
let isRecording = false;
|
| 223 |
+
let webrtc_id;
|
| 224 |
+
|
| 225 |
+
const startButton = document.getElementById('start-button');
|
| 226 |
+
const apiKeyInput = document.getElementById('api-key');
|
| 227 |
+
const voiceSelect = document.getElementById('voice');
|
| 228 |
+
const audioOutput = document.getElementById('audio-output');
|
| 229 |
+
const boxContainer = document.querySelector('.box-container');
|
| 230 |
+
|
| 231 |
+
const numBars = 32;
|
| 232 |
+
for (let i = 0; i < numBars; i++) {
|
| 233 |
+
const box = document.createElement('div');
|
| 234 |
+
box.className = 'box';
|
| 235 |
+
boxContainer.appendChild(box);
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
function updateButtonState() {
|
| 239 |
+
if (peerConnection && (peerConnection.connectionState === 'connecting' || peerConnection.connectionState === 'new')) {
|
| 240 |
+
startButton.innerHTML = `
|
| 241 |
+
<div class="icon-with-spinner">
|
| 242 |
+
<div class="spinner"></div>
|
| 243 |
+
<span>Connecting...</span>
|
| 244 |
+
</div>
|
| 245 |
+
`;
|
| 246 |
+
} else if (peerConnection && peerConnection.connectionState === 'connected') {
|
| 247 |
+
startButton.innerHTML = `
|
| 248 |
+
<div class="pulse-container">
|
| 249 |
+
<div class="pulse-circle"></div>
|
| 250 |
+
<span>Stop Recording</span>
|
| 251 |
+
</div>
|
| 252 |
+
`;
|
| 253 |
+
} else {
|
| 254 |
+
startButton.innerHTML = 'Start Recording';
|
| 255 |
+
}
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
function showError(message) {
|
| 259 |
+
const toast = document.getElementById('error-toast');
|
| 260 |
+
toast.textContent = message;
|
| 261 |
+
toast.className = 'toast error';
|
| 262 |
+
toast.style.display = 'block';
|
| 263 |
+
|
| 264 |
+
// Hide toast after 5 seconds
|
| 265 |
+
setTimeout(() => {
|
| 266 |
+
toast.style.display = 'none';
|
| 267 |
+
}, 5000);
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
async function setupWebRTC() {
|
| 271 |
+
const config = __RTC_CONFIGURATION__;
|
| 272 |
+
peerConnection = new RTCPeerConnection(config);
|
| 273 |
+
webrtc_id = Math.random().toString(36).substring(7);
|
| 274 |
+
|
| 275 |
+
const timeoutId = setTimeout(() => {
|
| 276 |
+
const toast = document.getElementById('error-toast');
|
| 277 |
+
toast.textContent = "Connection is taking longer than usual. Are you on a VPN?";
|
| 278 |
+
toast.className = 'toast warning';
|
| 279 |
+
toast.style.display = 'block';
|
| 280 |
+
|
| 281 |
+
// Hide warning after 5 seconds
|
| 282 |
+
setTimeout(() => {
|
| 283 |
+
toast.style.display = 'none';
|
| 284 |
+
}, 5000);
|
| 285 |
+
}, 5000);
|
| 286 |
+
|
| 287 |
+
try {
|
| 288 |
+
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
|
| 289 |
+
stream.getTracks().forEach(track => peerConnection.addTrack(track, stream));
|
| 290 |
+
|
| 291 |
+
// Update audio visualization setup
|
| 292 |
+
audioContext = new AudioContext();
|
| 293 |
+
analyser_input = audioContext.createAnalyser();
|
| 294 |
+
const source = audioContext.createMediaStreamSource(stream);
|
| 295 |
+
source.connect(analyser_input);
|
| 296 |
+
analyser_input.fftSize = 64;
|
| 297 |
+
dataArray_input = new Uint8Array(analyser_input.frequencyBinCount);
|
| 298 |
+
|
| 299 |
+
function updateAudioLevel() {
|
| 300 |
+
analyser_input.getByteFrequencyData(dataArray_input);
|
| 301 |
+
const average = Array.from(dataArray_input).reduce((a, b) => a + b, 0) / dataArray_input.length;
|
| 302 |
+
const audioLevel = average / 255;
|
| 303 |
+
|
| 304 |
+
const pulseCircle = document.querySelector('.pulse-circle');
|
| 305 |
+
if (pulseCircle) {
|
| 306 |
+
console.log("audioLevel", audioLevel);
|
| 307 |
+
pulseCircle.style.setProperty('--audio-level', 1 + audioLevel);
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
animationId = requestAnimationFrame(updateAudioLevel);
|
| 311 |
+
}
|
| 312 |
+
updateAudioLevel();
|
| 313 |
+
|
| 314 |
+
// Add connection state change listener
|
| 315 |
+
peerConnection.addEventListener('connectionstatechange', () => {
|
| 316 |
+
console.log('connectionstatechange', peerConnection.connectionState);
|
| 317 |
+
if (peerConnection.connectionState === 'connected') {
|
| 318 |
+
clearTimeout(timeoutId);
|
| 319 |
+
const toast = document.getElementById('error-toast');
|
| 320 |
+
toast.style.display = 'none';
|
| 321 |
+
}
|
| 322 |
+
updateButtonState();
|
| 323 |
+
});
|
| 324 |
+
|
| 325 |
+
// Handle incoming audio
|
| 326 |
+
peerConnection.addEventListener('track', (evt) => {
|
| 327 |
+
if (audioOutput && audioOutput.srcObject !== evt.streams[0]) {
|
| 328 |
+
audioOutput.srcObject = evt.streams[0];
|
| 329 |
+
audioOutput.play();
|
| 330 |
+
|
| 331 |
+
// Set up audio visualization on the output stream
|
| 332 |
+
audioContext = new AudioContext();
|
| 333 |
+
analyser = audioContext.createAnalyser();
|
| 334 |
+
const source = audioContext.createMediaStreamSource(evt.streams[0]);
|
| 335 |
+
source.connect(analyser);
|
| 336 |
+
analyser.fftSize = 2048;
|
| 337 |
+
dataArray = new Uint8Array(analyser.frequencyBinCount);
|
| 338 |
+
updateVisualization();
|
| 339 |
+
}
|
| 340 |
+
});
|
| 341 |
+
|
| 342 |
+
// Create data channel for messages
|
| 343 |
+
dataChannel = peerConnection.createDataChannel('text');
|
| 344 |
+
dataChannel.onmessage = (event) => {
|
| 345 |
+
const eventJson = JSON.parse(event.data);
|
| 346 |
+
if (eventJson.type === "error") {
|
| 347 |
+
showError(eventJson.message);
|
| 348 |
+
} else if (eventJson.type === "send_input") {
|
| 349 |
+
fetch('/input_hook', {
|
| 350 |
+
method: 'POST',
|
| 351 |
+
headers: {
|
| 352 |
+
'Content-Type': 'application/json',
|
| 353 |
+
},
|
| 354 |
+
body: JSON.stringify({
|
| 355 |
+
webrtc_id: webrtc_id,
|
| 356 |
+
api_key: apiKeyInput.value,
|
| 357 |
+
voice_name: voiceSelect.value
|
| 358 |
+
})
|
| 359 |
+
});
|
| 360 |
+
}
|
| 361 |
+
};
|
| 362 |
+
|
| 363 |
+
// Create and send offer
|
| 364 |
+
const offer = await peerConnection.createOffer();
|
| 365 |
+
await peerConnection.setLocalDescription(offer);
|
| 366 |
+
|
| 367 |
+
await new Promise((resolve) => {
|
| 368 |
+
if (peerConnection.iceGatheringState === "complete") {
|
| 369 |
+
resolve();
|
| 370 |
+
} else {
|
| 371 |
+
const checkState = () => {
|
| 372 |
+
if (peerConnection.iceGatheringState === "complete") {
|
| 373 |
+
peerConnection.removeEventListener("icegatheringstatechange", checkState);
|
| 374 |
+
resolve();
|
| 375 |
+
}
|
| 376 |
+
};
|
| 377 |
+
peerConnection.addEventListener("icegatheringstatechange", checkState);
|
| 378 |
+
}
|
| 379 |
+
});
|
| 380 |
+
|
| 381 |
+
const response = await fetch('/webrtc/offer', {
|
| 382 |
+
method: 'POST',
|
| 383 |
+
headers: { 'Content-Type': 'application/json' },
|
| 384 |
+
body: JSON.stringify({
|
| 385 |
+
sdp: peerConnection.localDescription.sdp,
|
| 386 |
+
type: peerConnection.localDescription.type,
|
| 387 |
+
webrtc_id: webrtc_id,
|
| 388 |
+
})
|
| 389 |
+
});
|
| 390 |
+
|
| 391 |
+
const serverResponse = await response.json();
|
| 392 |
+
|
| 393 |
+
if (serverResponse.status === 'failed') {
|
| 394 |
+
showError(serverResponse.meta.error === 'concurrency_limit_reached'
|
| 395 |
+
? `Too many connections. Maximum limit is ${serverResponse.meta.limit}`
|
| 396 |
+
: serverResponse.meta.error);
|
| 397 |
+
stop();
|
| 398 |
+
startButton.textContent = 'Start Recording';
|
| 399 |
+
return;
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
await peerConnection.setRemoteDescription(serverResponse);
|
| 403 |
+
} catch (err) {
|
| 404 |
+
clearTimeout(timeoutId);
|
| 405 |
+
console.error('Error setting up WebRTC:', err);
|
| 406 |
+
showError('Failed to establish connection. Please try again.');
|
| 407 |
+
stop();
|
| 408 |
+
startButton.textContent = 'Start Recording';
|
| 409 |
+
}
|
| 410 |
+
}
|
| 411 |
+
|
| 412 |
+
function updateVisualization() {
|
| 413 |
+
if (!analyser) return;
|
| 414 |
+
|
| 415 |
+
analyser.getByteFrequencyData(dataArray);
|
| 416 |
+
const bars = document.querySelectorAll('.box');
|
| 417 |
+
|
| 418 |
+
for (let i = 0; i < bars.length; i++) {
|
| 419 |
+
const barHeight = (dataArray[i] / 255) * 2;
|
| 420 |
+
bars[i].style.transform = `scaleY(${Math.max(0.1, barHeight)})`;
|
| 421 |
+
}
|
| 422 |
+
|
| 423 |
+
animationId = requestAnimationFrame(updateVisualization);
|
| 424 |
+
}
|
| 425 |
+
|
| 426 |
+
function stopWebRTC() {
|
| 427 |
+
if (peerConnection) {
|
| 428 |
+
peerConnection.close();
|
| 429 |
+
}
|
| 430 |
+
if (animationId) {
|
| 431 |
+
cancelAnimationFrame(animationId);
|
| 432 |
+
}
|
| 433 |
+
if (audioContext) {
|
| 434 |
+
audioContext.close();
|
| 435 |
+
}
|
| 436 |
+
updateButtonState();
|
| 437 |
+
}
|
| 438 |
+
|
| 439 |
+
startButton.addEventListener('click', () => {
|
| 440 |
+
if (!isRecording) {
|
| 441 |
+
setupWebRTC();
|
| 442 |
+
startButton.classList.add('recording');
|
| 443 |
+
} else {
|
| 444 |
+
stopWebRTC();
|
| 445 |
+
startButton.classList.remove('recording');
|
| 446 |
+
}
|
| 447 |
+
isRecording = !isRecording;
|
| 448 |
+
});
|
| 449 |
+
</script>
|
| 450 |
+
</body>
|
| 451 |
+
|
| 452 |
+
</html>
|
src copy/models.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Data models for the application."""
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
import pyaudio
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
|
| 8 |
+
load_dotenv()
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class AudioConfig:
|
| 13 |
+
"""Audio configuration settings."""
|
| 14 |
+
|
| 15 |
+
format: int = pyaudio.paInt16
|
| 16 |
+
channels: int = 1
|
| 17 |
+
send_sample_rate: int = 16000
|
| 18 |
+
receive_sample_rate: int = 24000
|
| 19 |
+
chunk_size: int = 1024
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class ModelConfig:
|
| 24 |
+
"""Gemini model configuration."""
|
| 25 |
+
|
| 26 |
+
api_key: str
|
| 27 |
+
name: str
|
| 28 |
+
tools: dict
|
| 29 |
+
generation_config: dict
|
| 30 |
+
system_instruction: str
|
src copy/prompts/default_prompt.jinja2
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Personality and Tone
|
| 2 |
+
## Identity
|
| 3 |
+
You are a friendly recruiter who conducts initial screening calls with candidates. You speak clear, professional English.
|
| 4 |
+
|
| 5 |
+
YOU ARE THE RECRUITER AND THE USER IS THE CANDIDATE, THE USER MUST ANSWER THE QUESTIONS.
|
| 6 |
+
|
| 7 |
+
## Tone and Language
|
| 8 |
+
- You are polite and professional.
|
| 9 |
+
- Use complete sentences
|
| 10 |
+
- Maintain a formal but warm demeanor
|
| 11 |
+
- Avoid slang or casual language
|
| 12 |
+
|
| 13 |
+
## Task
|
| 14 |
+
Your sole responsibility is to conduct brief initial screenings with candidates by following these exact steps:
|
| 15 |
+
|
| 16 |
+
# Strict Interview Protocol
|
| 17 |
+
|
| 18 |
+
1. ANSWER PROCESSING AND VALIDATION:
|
| 19 |
+
- ESSENTIAL INFO: Extract only the key information from candidate's response
|
| 20 |
+
- you MUST store the extracted information using validate_answer_tool
|
| 21 |
+
- VALIDATION: Use validate_answer_tool with the distilled answer ONLY
|
| 22 |
+
- ACKNOWLEDGE: Briefly acknowledge the candidate's response
|
| 23 |
+
- IMPORTANT: Never reveal validation process to candidates
|
| 24 |
+
- If validation fails, repeat question
|
| 25 |
+
|
| 26 |
+
2. ANSWER VALIDATION PROTOCOL:
|
| 27 |
+
- If answer is VALID: Proceed to next question
|
| 28 |
+
- If answer is INVALID: Repeat the same question
|
| 29 |
+
- No exceptions to this rule
|
| 30 |
+
|
| 31 |
+
3. INTERVIEW CONCLUSION:
|
| 32 |
+
- Only conclude after ALL questions are asked and validated
|
| 33 |
+
- End with a professional thank you message
|
| 34 |
+
- No additional commentary or questions allowed
|
| 35 |
+
|
| 36 |
+
DO NOT deviate from these protocols under any circumstances.
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
QUESTIONS SEQUENCE:
|
| 40 |
+
- You MUST ask questions in the exact order provided in:
|
| 41 |
+
{{ questions }}
|
src copy/run.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Real-time Speech Interface
|
| 2 |
+
|
| 3 |
+
This module provides a real-time speech interface using Google's Gemini model.
|
| 4 |
+
It handles bidirectional audio streaming with automatic speech recognition and synthesis.
|
| 5 |
+
|
| 6 |
+
Important:
|
| 7 |
+
Use headphones to prevent audio feedback and echo issues.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import argparse
|
| 11 |
+
import asyncio
|
| 12 |
+
import json
|
| 13 |
+
import logging
|
| 14 |
+
import os
|
| 15 |
+
import traceback
|
| 16 |
+
|
| 17 |
+
from helpers.loop import AudioLoop, TextLoop
|
| 18 |
+
from helpers.session import Session
|
| 19 |
+
from models import AudioConfig, ModelConfig
|
| 20 |
+
from tools import TOOLS
|
| 21 |
+
|
| 22 |
+
# Configure logging
|
| 23 |
+
logging.basicConfig(
|
| 24 |
+
level=logging.INFO,
|
| 25 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
| 26 |
+
)
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def main(
|
| 31 |
+
modality: str = "text", system_prompt: str = None, instruction_audio: str = None
|
| 32 |
+
) -> None:
|
| 33 |
+
"""Entry point for the application."""
|
| 34 |
+
try:
|
| 35 |
+
model_config = ModelConfig(
|
| 36 |
+
api_key=os.environ.get("GOOGLE_API_KEY"),
|
| 37 |
+
name="models/gemini-2.0-flash-exp",
|
| 38 |
+
system_instruction=system_prompt,
|
| 39 |
+
tools=TOOLS,
|
| 40 |
+
generation_config={
|
| 41 |
+
"response_modalities": modality.upper(),
|
| 42 |
+
},
|
| 43 |
+
)
|
| 44 |
+
if modality == "audio":
|
| 45 |
+
loop_instance = AudioLoop(
|
| 46 |
+
audio_config=AudioConfig(),
|
| 47 |
+
model_config=model_config,
|
| 48 |
+
instruction_audio=instruction_audio,
|
| 49 |
+
)
|
| 50 |
+
elif modality == "text":
|
| 51 |
+
loop_instance = TextLoop(model_config=model_config)
|
| 52 |
+
else:
|
| 53 |
+
raise ValueError("Invalid modality")
|
| 54 |
+
asyncio.run(loop_instance.run(), debug=True)
|
| 55 |
+
except KeyboardInterrupt:
|
| 56 |
+
logger.info("Application terminated by user")
|
| 57 |
+
except Exception as e:
|
| 58 |
+
logger.error(f"Application error: {e}")
|
| 59 |
+
logger.debug(traceback.format_exc())
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
if __name__ == "__main__":
|
| 63 |
+
parser = argparse.ArgumentParser(description="Real-time Speech Interface")
|
| 64 |
+
parser.add_argument(
|
| 65 |
+
"-m",
|
| 66 |
+
"--modality",
|
| 67 |
+
choices=["text", "audio"],
|
| 68 |
+
help="Response modality",
|
| 69 |
+
required=True,
|
| 70 |
+
)
|
| 71 |
+
parser.add_argument(
|
| 72 |
+
"--instruction-audio",
|
| 73 |
+
type=str,
|
| 74 |
+
help="Path to audio instructions (.wav file)",
|
| 75 |
+
required=False,
|
| 76 |
+
)
|
| 77 |
+
parser.add_argument(
|
| 78 |
+
"-q",
|
| 79 |
+
"--questions",
|
| 80 |
+
type=str,
|
| 81 |
+
help="Path to JSON file containing questions",
|
| 82 |
+
required=True,
|
| 83 |
+
)
|
| 84 |
+
args = parser.parse_args()
|
| 85 |
+
with open(args.questions, "r") as f:
|
| 86 |
+
questions_dict = json.load(f)
|
| 87 |
+
|
| 88 |
+
session = Session(questions=questions_dict)
|
| 89 |
+
system_prompt = session.zero_shot_prompt("src/prompts/default_prompt.jinja2")
|
| 90 |
+
print(system_prompt)
|
| 91 |
+
|
| 92 |
+
main(
|
| 93 |
+
modality=args.modality,
|
| 94 |
+
system_prompt=system_prompt,
|
| 95 |
+
instruction_audio=args.instruction_audio,
|
| 96 |
+
)
|
src copy/tools/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tools package for API integrations."""
|
| 2 |
+
|
| 3 |
+
from .functions import validate_answer, validate_answer_tool, store_input, store_input_tool
|
| 4 |
+
|
| 5 |
+
# Map of function names to their implementations
|
| 6 |
+
FUNCTION_MAP = {
|
| 7 |
+
"validate_answer": validate_answer,
|
| 8 |
+
"store_input": store_input,
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
# List of all available tools
|
| 12 |
+
# TOOLS = [validate_answer_tool, store_input_tool]
|
| 13 |
+
TOOLS = [validate_answer_tool]
|
| 14 |
+
|
src copy/tools/functions.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
logger = logging.getLogger(__name__)
|
| 6 |
+
|
| 7 |
+
"""Schedule meeting integration function."""
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def fetch_next_question() -> str:
|
| 11 |
+
"""Fetch the next question.
|
| 12 |
+
|
| 13 |
+
Returns:
|
| 14 |
+
str: The next question.
|
| 15 |
+
"""
|
| 16 |
+
questions = [
|
| 17 |
+
"What is the capital of France?",
|
| 18 |
+
"What is 2 + 2?",
|
| 19 |
+
"Who wrote Romeo and Juliet?",
|
| 20 |
+
"What is the chemical symbol for gold?",
|
| 21 |
+
"Which planet is known as the Red Planet?",
|
| 22 |
+
]
|
| 23 |
+
question = questions[0]
|
| 24 |
+
|
| 25 |
+
return f"You need to ask the candidate following question: `{question}`. Allow the candidate some time to respond "
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
fetch_next_question_tool = {
|
| 29 |
+
"name": "fetch_next_question",
|
| 30 |
+
"description": "Fetch the next question",
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def validate_answer(
|
| 35 |
+
question_id: int, answer: str, answer_type: str | int | list
|
| 36 |
+
) -> str:
|
| 37 |
+
"""Validate the user's answer against an expected answer type.
|
| 38 |
+
|
| 39 |
+
question_id (int): The identifier of the question being validated
|
| 40 |
+
answer (str): The user's provided answer to validate
|
| 41 |
+
answer_type (type): The expected python type that the answer should match (e.g. str, int, list)
|
| 42 |
+
|
| 43 |
+
str: Returns "Answer is valid" if answer matches expected type, raises ValueError otherwise
|
| 44 |
+
|
| 45 |
+
Raises:
|
| 46 |
+
ValueError: If the answer's type does not match the expected answer_type
|
| 47 |
+
|
| 48 |
+
Example:
|
| 49 |
+
>>> validate_answer(1, "42", str)
|
| 50 |
+
True
|
| 51 |
+
>>> validate_answer(1, 42, str)
|
| 52 |
+
ValueError: Invalid answer type
|
| 53 |
+
"""
|
| 54 |
+
logging.info(
|
| 55 |
+
{
|
| 56 |
+
"question_id": question_id,
|
| 57 |
+
"answer": answer,
|
| 58 |
+
"answer_type": answer_type,
|
| 59 |
+
}
|
| 60 |
+
)
|
| 61 |
+
if type(answer) is answer_type:
|
| 62 |
+
raise ValueError("Invalid answer type")
|
| 63 |
+
|
| 64 |
+
# Create or load the answers file
|
| 65 |
+
answers_file = "/Users/georgeslorre/ML6/internal/gemini-voice-agents/answers.json"
|
| 66 |
+
answers = []
|
| 67 |
+
|
| 68 |
+
if os.path.exists(answers_file):
|
| 69 |
+
with open(answers_file, "r") as f:
|
| 70 |
+
answers = json.load(f)
|
| 71 |
+
|
| 72 |
+
# Append new answer
|
| 73 |
+
answers[question_id] = {"question_id": question_id, "answer": answer}
|
| 74 |
+
|
| 75 |
+
# Write back to file
|
| 76 |
+
with open(answers_file, "w") as f:
|
| 77 |
+
json.dump(answers, f, indent=2)
|
| 78 |
+
|
| 79 |
+
return "Answer is valid"
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
validate_answer_tool = {
|
| 83 |
+
"name": "validate_answer",
|
| 84 |
+
"description": "Validate the user's answer against an expected answer type",
|
| 85 |
+
"parameters": {
|
| 86 |
+
"type": "OBJECT",
|
| 87 |
+
"properties": {
|
| 88 |
+
"question_id": {
|
| 89 |
+
"type": "INTEGER",
|
| 90 |
+
"description": "The identifier of the question being validated"
|
| 91 |
+
},
|
| 92 |
+
"answer": {
|
| 93 |
+
"type": "STRING",
|
| 94 |
+
"description": "The user's provided answer to validate"
|
| 95 |
+
},
|
| 96 |
+
"answer_type": {
|
| 97 |
+
"type": "STRING",
|
| 98 |
+
"description": "The expected python type that the answer should match (e.g. str, int, list)"
|
| 99 |
+
}
|
| 100 |
+
},
|
| 101 |
+
"required": ["question_id", "answer", "answer_type"]
|
| 102 |
+
}
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def store_input(role: str, input: str) -> str:
|
| 107 |
+
"""Store conversation input in a JSON file.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
role (str): The role of the speaker (user or assistant)
|
| 111 |
+
input (str): The text input to store
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
str: Confirmation message
|
| 115 |
+
"""
|
| 116 |
+
conversation_file = "/Users/georgeslorre/ML6/internal/gemini-voice-agents/conversation.json"
|
| 117 |
+
conversation = []
|
| 118 |
+
|
| 119 |
+
if os.path.exists(conversation_file):
|
| 120 |
+
with open(conversation_file, "r") as f:
|
| 121 |
+
conversation = json.load(f)
|
| 122 |
+
|
| 123 |
+
conversation.append({"role": role, "content": input})
|
| 124 |
+
|
| 125 |
+
with open(conversation_file, "w") as f:
|
| 126 |
+
json.dump(conversation, f, indent=2)
|
| 127 |
+
|
| 128 |
+
return "Input stored successfully"
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
store_input_tool = {
|
| 133 |
+
"name": "store_input",
|
| 134 |
+
"description": "Store user input in conversation history",
|
| 135 |
+
"parameters": {
|
| 136 |
+
"type": "OBJECT",
|
| 137 |
+
"properties": {
|
| 138 |
+
"role": {
|
| 139 |
+
"type": "STRING",
|
| 140 |
+
"description": "The role of the speaker (user or assistant)"
|
| 141 |
+
},
|
| 142 |
+
"input": {
|
| 143 |
+
"type": "STRING",
|
| 144 |
+
"description": "The text input to store"
|
| 145 |
+
}
|
| 146 |
+
}
|
| 147 |
+
}
|
| 148 |
+
}
|
src copy/tts.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# Copyright 2024 Google LLC
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
#
|
| 16 |
+
|
| 17 |
+
"""Google Cloud Text-To-Speech API streaming sample with input/output streams."""
|
| 18 |
+
|
| 19 |
+
from google.cloud import texttospeech
|
| 20 |
+
import itertools
|
| 21 |
+
import queue
|
| 22 |
+
import threading
|
| 23 |
+
|
| 24 |
+
class TTSStreamer:
|
| 25 |
+
def __init__(self):
|
| 26 |
+
self.client = texttospeech.TextToSpeechClient()
|
| 27 |
+
self.text_queue = queue.Queue()
|
| 28 |
+
self.audio_queue = queue.Queue()
|
| 29 |
+
|
| 30 |
+
def start_stream(self):
|
| 31 |
+
streaming_config = texttospeech.StreamingSynthesizeConfig(
|
| 32 |
+
voice=texttospeech.VoiceSelectionParams(
|
| 33 |
+
name="en-US-Journey-D",
|
| 34 |
+
language_code="en-US"
|
| 35 |
+
)
|
| 36 |
+
)
|
| 37 |
+
config_request = texttospeech.StreamingSynthesizeRequest(
|
| 38 |
+
streaming_config=streaming_config
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
def request_generator():
|
| 42 |
+
while True:
|
| 43 |
+
try:
|
| 44 |
+
text = self.text_queue.get()
|
| 45 |
+
if text is None: # Poison pill to stop
|
| 46 |
+
break
|
| 47 |
+
yield texttospeech.StreamingSynthesizeRequest(
|
| 48 |
+
input=texttospeech.StreamingSynthesisInput(text=text)
|
| 49 |
+
)
|
| 50 |
+
except queue.Empty:
|
| 51 |
+
continue
|
| 52 |
+
|
| 53 |
+
def audio_processor():
|
| 54 |
+
responses = self.client.streaming_synthesize(
|
| 55 |
+
itertools.chain([config_request], request_generator())
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
for response in responses:
|
| 59 |
+
self.audio_queue.put(response.audio_content)
|
| 60 |
+
|
| 61 |
+
self.processor_thread = threading.Thread(target=audio_processor)
|
| 62 |
+
self.processor_thread.start()
|
| 63 |
+
|
| 64 |
+
def send_text(self, text: str):
|
| 65 |
+
"""Send text to be synthesized."""
|
| 66 |
+
self.text_queue.put(text)
|
| 67 |
+
|
| 68 |
+
def get_audio(self):
|
| 69 |
+
"""Get the next chunk of audio bytes."""
|
| 70 |
+
try:
|
| 71 |
+
return self.audio_queue.get_nowait()
|
| 72 |
+
except queue.Empty:
|
| 73 |
+
return None
|
| 74 |
+
|
| 75 |
+
def stop(self):
|
| 76 |
+
"""Stop the streaming synthesis."""
|
| 77 |
+
self.text_queue.put(None) # Send poison pill
|
| 78 |
+
if self.processor_thread:
|
| 79 |
+
self.processor_thread.join()
|
| 80 |
+
|
| 81 |
+
def main():
|
| 82 |
+
tts = TTSStreamer()
|
| 83 |
+
tts.start_stream()
|
| 84 |
+
|
| 85 |
+
# Example usage
|
| 86 |
+
try:
|
| 87 |
+
while True:
|
| 88 |
+
text = input("Enter text (or 'q' to quit): ")
|
| 89 |
+
if text.lower() == 'q':
|
| 90 |
+
break
|
| 91 |
+
tts.send_text(text)
|
| 92 |
+
|
| 93 |
+
# Get and print audio bytes
|
| 94 |
+
while True:
|
| 95 |
+
audio_chunk = tts.get_audio()
|
| 96 |
+
if audio_chunk is None:
|
| 97 |
+
break
|
| 98 |
+
print(f"Received audio chunk of {len(audio_chunk)} bytes")
|
| 99 |
+
finally:
|
| 100 |
+
tts.stop()
|
| 101 |
+
|
| 102 |
+
if __name__ == "__main__":
|
| 103 |
+
main()
|
src/app.py
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import base64
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
from typing import Literal
|
| 6 |
+
|
| 7 |
+
import gradio as gr
|
| 8 |
+
import numpy as np
|
| 9 |
+
from fastrtc import AsyncStreamHandler, WebRTC, wait_for_item
|
| 10 |
+
from google import genai
|
| 11 |
+
from google.cloud import texttospeech
|
| 12 |
+
from google.genai.types import FunctionDeclaration, LiveConnectConfig, Tool
|
| 13 |
+
|
| 14 |
+
import helpers.datastore as datastore
|
| 15 |
+
from helpers.prompts import load_prompt
|
| 16 |
+
from tools import FUNCTION_MAP, TOOLS
|
| 17 |
+
|
| 18 |
+
with open("questions.json", "r") as f:
|
| 19 |
+
questions_dict = json.load(f)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
datastore.DATA_STORE["questions"] = questions_dict
|
| 23 |
+
|
| 24 |
+
SYSTEM_PROMPT = load_prompt(
|
| 25 |
+
"src/prompts/default_prompt.jinja2", questions=questions_dict
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class TTSConfig:
|
| 30 |
+
def __init__(self):
|
| 31 |
+
self.client = texttospeech.TextToSpeechClient()
|
| 32 |
+
self.voice = texttospeech.VoiceSelectionParams(
|
| 33 |
+
name="en-US-Chirp3-HD-Charon", language_code="en-US"
|
| 34 |
+
)
|
| 35 |
+
self.audio_config = texttospeech.AudioConfig(
|
| 36 |
+
audio_encoding=texttospeech.AudioEncoding.LINEAR16
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class AsyncGeminiHandler(AsyncStreamHandler):
|
| 41 |
+
"""Simple Async Gemini Handler"""
|
| 42 |
+
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
expected_layout: Literal["mono"] = "mono",
|
| 46 |
+
output_sample_rate: int = 24000,
|
| 47 |
+
output_frame_size: int = 480,
|
| 48 |
+
) -> None:
|
| 49 |
+
super().__init__(
|
| 50 |
+
expected_layout,
|
| 51 |
+
output_sample_rate,
|
| 52 |
+
output_frame_size,
|
| 53 |
+
input_sample_rate=16000,
|
| 54 |
+
)
|
| 55 |
+
self.input_queue: asyncio.Queue = asyncio.Queue()
|
| 56 |
+
self.output_queue: asyncio.Queue = asyncio.Queue()
|
| 57 |
+
self.text_queue: asyncio.Queue = asyncio.Queue()
|
| 58 |
+
self.quit: asyncio.Event = asyncio.Event()
|
| 59 |
+
self.chunk_size = 1024
|
| 60 |
+
|
| 61 |
+
self.tts_config: TTSConfig | None = TTSConfig()
|
| 62 |
+
self.text_buffer = ""
|
| 63 |
+
|
| 64 |
+
def copy(self) -> "AsyncGeminiHandler":
|
| 65 |
+
return AsyncGeminiHandler(
|
| 66 |
+
expected_layout="mono",
|
| 67 |
+
output_sample_rate=self.output_sample_rate,
|
| 68 |
+
output_frame_size=self.output_frame_size,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
def _encode_audio(self, data: np.ndarray) -> str:
|
| 72 |
+
"""Encode Audio data to send to the server"""
|
| 73 |
+
return base64.b64encode(data.tobytes()).decode("UTF-8")
|
| 74 |
+
|
| 75 |
+
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
|
| 76 |
+
"""Receives and processes audio frames asynchronously."""
|
| 77 |
+
_, array = frame
|
| 78 |
+
array = array.squeeze()
|
| 79 |
+
audio_message = self._encode_audio(array)
|
| 80 |
+
self.input_queue.put_nowait(audio_message)
|
| 81 |
+
|
| 82 |
+
async def emit(self) -> tuple[int, np.ndarray] | None:
|
| 83 |
+
"""Asynchronously emits items from the output queue."""
|
| 84 |
+
return await wait_for_item(self.output_queue)
|
| 85 |
+
|
| 86 |
+
async def start_up(self) -> None:
|
| 87 |
+
"""Initialize and start the voice agent application.
|
| 88 |
+
|
| 89 |
+
This asynchronous method sets up the Gemini API client, configures the live connection,
|
| 90 |
+
and starts three concurrent tasks for receiving, processing and sending information.
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
None
|
| 94 |
+
|
| 95 |
+
Raises:
|
| 96 |
+
ValueError: If GEMINI_API_KEY is not provided when required.
|
| 97 |
+
|
| 98 |
+
"""
|
| 99 |
+
if not os.getenv("GOOGLE_GENAI_USE_VERTEXAI") == "True":
|
| 100 |
+
api_key = os.getenv("GEMINI_API_KEY")
|
| 101 |
+
if not api_key:
|
| 102 |
+
raise ValueError("API Key is required")
|
| 103 |
+
|
| 104 |
+
client = genai.Client(
|
| 105 |
+
api_key=api_key,
|
| 106 |
+
http_options={"api_version": "v1alpha"},
|
| 107 |
+
)
|
| 108 |
+
else:
|
| 109 |
+
client = genai.Client(http_options={"api_version": "v1beta1"})
|
| 110 |
+
|
| 111 |
+
config = LiveConnectConfig(
|
| 112 |
+
system_instruction={
|
| 113 |
+
"parts": [{"text": SYSTEM_PROMPT}],
|
| 114 |
+
"role": "user",
|
| 115 |
+
},
|
| 116 |
+
tools=[
|
| 117 |
+
Tool(
|
| 118 |
+
function_declarations=[
|
| 119 |
+
FunctionDeclaration(**tool) for tool in TOOLS
|
| 120 |
+
]
|
| 121 |
+
)
|
| 122 |
+
],
|
| 123 |
+
response_modalities=["AUDIO"],
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
async with (
|
| 127 |
+
client.aio.live.connect(
|
| 128 |
+
model="gemini-2.0-flash-exp", config=config
|
| 129 |
+
) as session, # setup the live connection session (websocket)
|
| 130 |
+
asyncio.TaskGroup() as tg, # create a task group to run multiple tasks concurrently
|
| 131 |
+
):
|
| 132 |
+
self.session = session
|
| 133 |
+
|
| 134 |
+
# these tasks will run concurrently and continuously
|
| 135 |
+
[
|
| 136 |
+
tg.create_task(self.process()),
|
| 137 |
+
tg.create_task(self.send_realtime()),
|
| 138 |
+
tg.create_task(self.tts()),
|
| 139 |
+
]
|
| 140 |
+
|
| 141 |
+
async def process(self) -> None:
|
| 142 |
+
"""Process responses from the session in a continuous loop.
|
| 143 |
+
|
| 144 |
+
This asynchronous method handles different types of responses from the session:
|
| 145 |
+
- Audio data: Processes and queues audio data with the specified sample rate
|
| 146 |
+
- Text data: Accumulates received text in a buffer
|
| 147 |
+
- Tool calls: Executes registered functions and sends their responses back
|
| 148 |
+
- Server content: Handles turn completion and stores conversation history
|
| 149 |
+
|
| 150 |
+
The method runs indefinitely until interrupted, handling any exceptions that occur
|
| 151 |
+
during processing by logging them and continuing after a brief delay.
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
None
|
| 155 |
+
|
| 156 |
+
Raises:
|
| 157 |
+
Exception: Any exceptions during processing are caught and logged
|
| 158 |
+
"""
|
| 159 |
+
while True:
|
| 160 |
+
try:
|
| 161 |
+
turn = self.session.receive()
|
| 162 |
+
async for response in turn:
|
| 163 |
+
if data := response.data:
|
| 164 |
+
# audio data
|
| 165 |
+
array = np.frombuffer(data, dtype=np.int16)
|
| 166 |
+
self.output_queue.put_nowait((self.output_sample_rate, array))
|
| 167 |
+
continue
|
| 168 |
+
|
| 169 |
+
if text := response.text:
|
| 170 |
+
# text data
|
| 171 |
+
print(f"Received text: {text}")
|
| 172 |
+
self.text_buffer += text
|
| 173 |
+
|
| 174 |
+
if response.tool_call is not None:
|
| 175 |
+
# function calling
|
| 176 |
+
for tool in response.tool_call.function_calls:
|
| 177 |
+
try:
|
| 178 |
+
tool_response = FUNCTION_MAP[tool.name](**tool.args)
|
| 179 |
+
print(f"Calling tool: {tool.name}")
|
| 180 |
+
print(f"Tool response: {tool_response}")
|
| 181 |
+
await self.session.send(
|
| 182 |
+
input=tool_response, end_of_turn=True
|
| 183 |
+
)
|
| 184 |
+
await asyncio.sleep(0.1)
|
| 185 |
+
except Exception as e:
|
| 186 |
+
print(f"Error in tool call: {e}")
|
| 187 |
+
await asyncio.sleep(0.1)
|
| 188 |
+
|
| 189 |
+
if sc := response.server_content:
|
| 190 |
+
# check if bot's turn is complete
|
| 191 |
+
if sc.turn_complete and self.text_buffer:
|
| 192 |
+
self.text_queue.put_nowait(self.text_buffer)
|
| 193 |
+
FUNCTION_MAP["store_input"](
|
| 194 |
+
role="bot", input=self.text_buffer
|
| 195 |
+
)
|
| 196 |
+
self.text_buffer = ""
|
| 197 |
+
|
| 198 |
+
except Exception as e:
|
| 199 |
+
print(f"Error in processing: {e}")
|
| 200 |
+
await asyncio.sleep(0.1)
|
| 201 |
+
|
| 202 |
+
async def send_realtime(self) -> None:
|
| 203 |
+
"""Send real-time audio data to model.
|
| 204 |
+
|
| 205 |
+
This method continuously reads audio data from an input queue and sends it to a model
|
| 206 |
+
session in real-time. It runs in an infinite loop until interrupted.
|
| 207 |
+
|
| 208 |
+
The audio data is sent with mime type 'audio/pcm'. If an error occurs during sending,
|
| 209 |
+
it will be printed and the method will sleep briefly before retrying.
|
| 210 |
+
|
| 211 |
+
Returns:
|
| 212 |
+
None
|
| 213 |
+
|
| 214 |
+
Raises:
|
| 215 |
+
Exception: Any exceptions during queue access or session sending will be caught and logged.
|
| 216 |
+
"""
|
| 217 |
+
while True:
|
| 218 |
+
try:
|
| 219 |
+
data = await self.input_queue.get()
|
| 220 |
+
msg = {"data": data, "mime_type": "audio/pcm"}
|
| 221 |
+
await self.session.send(input=msg)
|
| 222 |
+
except Exception as e:
|
| 223 |
+
print(f"Error in real-time sending: {e}")
|
| 224 |
+
await asyncio.sleep(0.1)
|
| 225 |
+
|
| 226 |
+
async def tts(self) -> None:
|
| 227 |
+
while True:
|
| 228 |
+
try:
|
| 229 |
+
text = await self.text_queue.get()
|
| 230 |
+
# Get response in a single request
|
| 231 |
+
if text:
|
| 232 |
+
response = self.tts_config.client.synthesize_speech(
|
| 233 |
+
input=texttospeech.SynthesisInput(text=text),
|
| 234 |
+
voice=self.tts_config.voice,
|
| 235 |
+
audio_config=self.tts_config.audio_config,
|
| 236 |
+
)
|
| 237 |
+
array = np.frombuffer(response.audio_content, dtype=np.int16)
|
| 238 |
+
self.output_queue.put_nowait((self.output_sample_rate, array))
|
| 239 |
+
|
| 240 |
+
except Exception as e:
|
| 241 |
+
print(f"Error in TTS: {e}")
|
| 242 |
+
await asyncio.sleep(0.1)
|
| 243 |
+
|
| 244 |
+
def shutdown(self) -> None:
|
| 245 |
+
self.quit.set()
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
# Main Gradio Interface
|
| 249 |
+
def registry(*args, **kwargs):
|
| 250 |
+
"""Sets up and returns the Gradio interface."""
|
| 251 |
+
|
| 252 |
+
interface = gr.Blocks()
|
| 253 |
+
with interface:
|
| 254 |
+
with gr.Tabs():
|
| 255 |
+
with gr.TabItem("Voice Chat"):
|
| 256 |
+
gr.HTML(
|
| 257 |
+
"""
|
| 258 |
+
<div style='text-align: left'>
|
| 259 |
+
<h1>ML6 Voice Demo</h1>
|
| 260 |
+
</div>
|
| 261 |
+
"""
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
gemini_handler = AsyncGeminiHandler()
|
| 265 |
+
|
| 266 |
+
with gr.Row():
|
| 267 |
+
audio = WebRTC(
|
| 268 |
+
label="Voice Chat",
|
| 269 |
+
modality="audio",
|
| 270 |
+
mode="send-receive",
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
# Add display components for questions and answers
|
| 274 |
+
with gr.Row():
|
| 275 |
+
with gr.Column():
|
| 276 |
+
gr.JSON(
|
| 277 |
+
label="Questions",
|
| 278 |
+
value=datastore.DATA_STORE["questions"],
|
| 279 |
+
)
|
| 280 |
+
with gr.Column():
|
| 281 |
+
gr.JSON(
|
| 282 |
+
label="Answers",
|
| 283 |
+
value=lambda: datastore.DATA_STORE["answers"],
|
| 284 |
+
every=1,
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
audio.stream(
|
| 288 |
+
gemini_handler,
|
| 289 |
+
inputs=[audio],
|
| 290 |
+
outputs=[audio],
|
| 291 |
+
time_limit=600,
|
| 292 |
+
concurrency_limit=10,
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
return interface
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
# Launch the Gradio interface
|
| 299 |
+
gr.load(
|
| 300 |
+
name="demo",
|
| 301 |
+
src=registry,
|
| 302 |
+
).launch()
|
src/helpers/datastore.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
DATA_STORE = {
|
| 2 |
+
"questions": [],
|
| 3 |
+
"answers": [],
|
| 4 |
+
"conversation:": [],
|
| 5 |
+
}
|
src/helpers/prompts.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""This module contains the prompts for the application."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
|
| 5 |
+
from jinja2 import Template
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def load_prompt(prompt_path: str, **kwargs) -> str:
|
| 9 |
+
"""Load the prompt from the given path."""
|
| 10 |
+
with open(prompt_path, "r", encoding="utf-8") as file:
|
| 11 |
+
prompt = Template(file.read())
|
| 12 |
+
return prompt.render(**{k: json.dumps(v) for k, v in kwargs.items()})
|
src/prompts/default_prompt.jinja2
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Personality and Tone
|
| 2 |
+
## Identity
|
| 3 |
+
You are a friendly recruiter who conducts initial screening calls with candidates. You speak clear, professional English.
|
| 4 |
+
|
| 5 |
+
YOU ARE THE RECRUITER AND THE USER IS THE CANDIDATE, THE USER MUST ANSWER THE QUESTIONS.
|
| 6 |
+
|
| 7 |
+
## Tone and Language
|
| 8 |
+
- You are polite and professional.
|
| 9 |
+
- Use complete sentences
|
| 10 |
+
- Maintain a formal but warm demeanor
|
| 11 |
+
- Avoid slang or casual language
|
| 12 |
+
|
| 13 |
+
## Task
|
| 14 |
+
Your sole responsibility is to conduct brief initial screenings with candidates by following these exact steps:
|
| 15 |
+
|
| 16 |
+
# Strict Interview Protocol
|
| 17 |
+
|
| 18 |
+
1. ANSWER PROCESSING AND VALIDATION:
|
| 19 |
+
- ESSENTIAL INFO: Extract only the key information from candidate's response
|
| 20 |
+
- you MUST store the extracted information using validate_answer_tool
|
| 21 |
+
- VALIDATION: Use validate_answer_tool with the distilled answer ONLY
|
| 22 |
+
- ACKNOWLEDGE: Briefly acknowledge the candidate's response
|
| 23 |
+
- IMPORTANT: Never reveal validation process to candidates
|
| 24 |
+
- If validation fails, repeat question
|
| 25 |
+
|
| 26 |
+
2. ANSWER VALIDATION PROTOCOL:
|
| 27 |
+
- If answer is VALID: Proceed to next question
|
| 28 |
+
- If answer is INVALID: Repeat the same question
|
| 29 |
+
- No exceptions to this rule
|
| 30 |
+
|
| 31 |
+
3. INTERVIEW CONCLUSION:
|
| 32 |
+
- Only conclude after ALL questions are asked and validated
|
| 33 |
+
- End with a professional thank you message
|
| 34 |
+
- No additional commentary or questions allowed
|
| 35 |
+
|
| 36 |
+
DO NOT deviate from these protocols under any circumstances.
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
QUESTIONS SEQUENCE:
|
| 40 |
+
- You MUST ask questions in the exact order provided in:
|
| 41 |
+
{{ questions }}
|
src/tools/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tools package for API integrations."""
|
| 2 |
+
|
| 3 |
+
from .functions import (
|
| 4 |
+
store_input,
|
| 5 |
+
store_input_tool,
|
| 6 |
+
validate_answer,
|
| 7 |
+
validate_answer_tool,
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
# Map of function names to their implementations
|
| 11 |
+
FUNCTION_MAP = {
|
| 12 |
+
"validate_answer": validate_answer,
|
| 13 |
+
"store_input": store_input,
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
# List of all available tools
|
| 17 |
+
TOOLS = [store_input_tool, validate_answer_tool]
|
src/tools/functions.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
import helpers.datastore as datastore
|
| 4 |
+
|
| 5 |
+
logger = logging.getLogger(__name__)
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def validate_answer(
|
| 9 |
+
question_id: int, answer: str, answer_type: str | int | list
|
| 10 |
+
) -> str:
|
| 11 |
+
"""Validate the user's answer against an expected answer type.
|
| 12 |
+
|
| 13 |
+
question_id (int): The identifier of the question being validated
|
| 14 |
+
answer (str): The user's provided answer to validate
|
| 15 |
+
answer_type (type): The expected python type that the answer should match (e.g. str, int, list)
|
| 16 |
+
|
| 17 |
+
str: Returns "Answer is valid" if answer matches expected type, raises ValueError otherwise
|
| 18 |
+
|
| 19 |
+
Raises:
|
| 20 |
+
ValueError: If the answer's type does not match the expected answer_type
|
| 21 |
+
|
| 22 |
+
Example:
|
| 23 |
+
>>> validate_answer(1, "42", str)
|
| 24 |
+
True
|
| 25 |
+
>>> validate_answer(1, 42, str)
|
| 26 |
+
ValueError: Invalid answer type
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
logging.info(
|
| 30 |
+
{
|
| 31 |
+
"question_id": question_id,
|
| 32 |
+
"answer": answer,
|
| 33 |
+
"answer_type": answer_type,
|
| 34 |
+
}
|
| 35 |
+
)
|
| 36 |
+
if type(answer) is answer_type:
|
| 37 |
+
raise ValueError("Invalid answer type")
|
| 38 |
+
|
| 39 |
+
datastore.DATA_STORE["answers"].append(
|
| 40 |
+
{"question_id": question_id, "answer": answer}
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
return "Answer is valid"
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
validate_answer_tool = {
|
| 47 |
+
"name": "validate_answer",
|
| 48 |
+
"description": "Validate the user's answer against an expected answer type",
|
| 49 |
+
"parameters": {
|
| 50 |
+
"type": "OBJECT",
|
| 51 |
+
"properties": {
|
| 52 |
+
"question_id": {
|
| 53 |
+
"type": "INTEGER",
|
| 54 |
+
"description": "The identifier of the question being validated",
|
| 55 |
+
},
|
| 56 |
+
"answer": {
|
| 57 |
+
"type": "STRING",
|
| 58 |
+
"description": "The user's provided answer to validate",
|
| 59 |
+
},
|
| 60 |
+
"answer_type": {
|
| 61 |
+
"type": "STRING",
|
| 62 |
+
"description": "The expected python type that the answer should match (e.g. str, int, list)",
|
| 63 |
+
},
|
| 64 |
+
},
|
| 65 |
+
"required": ["question_id", "answer", "answer_type"],
|
| 66 |
+
},
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def store_input(role: str, input: str) -> str:
|
| 71 |
+
"""Store conversation input in a JSON file.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
role (str): The role of the speaker (user or assistant)
|
| 75 |
+
input (str): The text input to store
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
str: Confirmation message
|
| 79 |
+
"""
|
| 80 |
+
print(datastore.DATA_STORE)
|
| 81 |
+
conversation = datastore.DATA_STORE.get("conversation")
|
| 82 |
+
if conversation is None:
|
| 83 |
+
datastore.DATA_STORE["conversation"] = [{"role": role, "input": input}]
|
| 84 |
+
else:
|
| 85 |
+
datastore.DATA_STORE["conversation"].append({"role": role, "input": input})
|
| 86 |
+
|
| 87 |
+
return "Input stored successfully"
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
store_input_tool = {
|
| 91 |
+
"name": "store_input",
|
| 92 |
+
"description": "Store user input in conversation history",
|
| 93 |
+
"parameters": {
|
| 94 |
+
"type": "OBJECT",
|
| 95 |
+
"properties": {
|
| 96 |
+
"role": {
|
| 97 |
+
"type": "STRING",
|
| 98 |
+
"description": "The role of the speaker (user or assistant)",
|
| 99 |
+
},
|
| 100 |
+
"input": {"type": "STRING", "description": "The text input to store"},
|
| 101 |
+
},
|
| 102 |
+
},
|
| 103 |
+
}
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|