freddyaboulton's picture
Add code
9cc847c
import asyncio
import base64
import os
from threading import Event, Thread
import gradio as gr
import numpy as np
import openai
from dotenv import load_dotenv
from gradio_webrtc import (
AdditionalOutputs,
StreamHandler,
WebRTC,
get_twilio_turn_credentials,
)
from openai.types.beta.realtime import ResponseAudioTranscriptDoneEvent
from pydub import AudioSegment
load_dotenv()
SAMPLE_RATE = 24000
def encode_audio(sample_rate, data):
segment = AudioSegment(
data.tobytes(),
frame_rate=sample_rate,
sample_width=data.dtype.itemsize,
channels=1,
)
pcm_audio = (
segment.set_frame_rate(SAMPLE_RATE).set_channels(1).set_sample_width(2).raw_data
)
return base64.b64encode(pcm_audio).decode("utf-8")
class OpenAIHandler(StreamHandler):
def __init__(
self,
expected_layout="mono",
output_sample_rate=SAMPLE_RATE,
output_frame_size=480,
) -> None:
super().__init__(
expected_layout,
output_sample_rate,
output_frame_size,
input_sample_rate=SAMPLE_RATE,
)
self.connection = None
self.all_output_data = None
self.args_set = Event()
self.quit = Event()
self.connected = Event()
self.thread = None
self._generator = None
def copy(self):
return OpenAIHandler(
expected_layout=self.expected_layout,
output_sample_rate=self.output_sample_rate,
output_frame_size=self.output_frame_size,
)
def _initialize_connection(self, api_key: str):
"""Connect to realtime API. Run forever in separate thread to keep connection open."""
self.client = openai.Client(api_key=api_key)
with self.client.beta.realtime.connect(
model="gpt-4o-mini-realtime-preview-2024-12-17"
) as conn:
conn.session.update(session={"turn_detection": {"type": "server_vad"}})
self.connection = conn
self.connected.set()
self.quit.wait()
async def fetch_args(
self,
):
if self.channel:
self.channel.send("tick")
def set_args(self, args):
super().set_args(args)
self.args_set.set()
def receive(self, frame: tuple[int, np.ndarray]) -> None:
if not self.channel:
return
if not self.connection:
asyncio.run_coroutine_threadsafe(self.fetch_args(), self.loop)
self.args_set.wait()
self.thread = Thread(
target=self._initialize_connection, args=(self.latest_args[-1],)
)
self.thread.start()
self.connected.wait()
try:
assert self.connection, "Connection not initialized"
sample_rate, array = frame
array = array.squeeze()
audio_message = encode_audio(sample_rate, array)
self.connection.input_audio_buffer.append(audio=audio_message)
except Exception as e:
# print traceback
print(f"Error in receive: {str(e)}")
import traceback
traceback.print_exc()
def generator(self):
while True:
if not self.connection:
yield None
continue
for event in self.connection:
if event.type == "response.audio_transcript.done":
yield AdditionalOutputs(event)
if event.type == "response.audio.delta":
yield (
self.output_sample_rate,
np.frombuffer(
base64.b64decode(event.delta), dtype=np.int16
).reshape(1, -1),
)
def emit(self) -> tuple[int, np.ndarray] | None:
if not self.connection:
return None
if not self._generator:
self._generator = self.generator()
try:
return next(self._generator)
except StopIteration:
self._generator = self.generator()
return None
def reset_state(self):
"""Reset connection state for new recording session"""
self.connection = None
self.args_set.clear()
self.quit.clear()
self.connected.clear()
self.thread = None
self._generator = None
self.current_session = None
def shutdown(self) -> None:
if self.connection:
self.connection.close()
self.quit.set()
if self.thread:
self.thread.join(timeout=5)
self.reset_state()
def update_chatbot(chatbot: list[dict], response: ResponseAudioTranscriptDoneEvent):
chatbot.append({"role": "assistant", "content": response.transcript})
return chatbot
with gr.Blocks() as demo:
gr.HTML("""
<div style='display: flex; align-items: center; justify-content: center; gap: 20px'>
<div style="background-color: var(--block-background-fill); border-radius: 8px">
<img src="/gradio_api/file=openai-logo.svg" style="width: 100px; height: 100px;">
</div>
<div>
<h1>OpenAI Realtime Voice Chat</h1>
<p>Speak with OpenAI's latest using real-time audio streaming api.</p>
<p>Powered by <a href="https://gradio.app/">Gradio</a> and <a href==https://freddyaboulton.github.io/gradio-webrtc/">WebRTC</a>⚡️</p>
<p>Get an API key from <a href="https://platform.openai.com/">OpenAI</a>.</p>
</div>
</div>
""")
with gr.Row(visible=True) as api_key_row:
api_key = gr.Textbox(
label="OpenAI API Key",
placeholder="Enter your OpenAI API Key",
value=os.getenv("OPENAI_API_KEY", ""),
type="password",
)
with gr.Row(visible=False) as row:
with gr.Column(scale=1):
webrtc = WebRTC(
label="Conversation",
modality="audio",
mode="send-receive",
rtc_configuration=get_twilio_turn_credentials(),
icon="openai-logo.svg",
)
with gr.Column(scale=5):
chatbot = gr.Chatbot(label="Conversation", value=[], type="messages")
webrtc.stream(
OpenAIHandler(),
inputs=[webrtc, api_key],
outputs=[webrtc],
time_limit=90,
concurrency_limit=2,
)
webrtc.on_additional_outputs(
update_chatbot,
inputs=[chatbot],
outputs=[chatbot],
show_progress="hidden",
queue=True,
)
api_key.submit(
lambda: (gr.update(visible=False), gr.update(visible=True)),
None,
[api_key_row, row],
)
if __name__ == "__main__":
demo.launch(allowed_paths=["openai-logo.svg"])