freddyaboulton's picture
Add code
b45f262
import mimetypes
import os
from typing import Any
import uuid
import gradio as gr
from dotenv import load_dotenv
from google import genai
from google.genai import types
from gradio.components.multimodal_textbox import MultimodalValue
load_dotenv()
generate_content_config = types.GenerateContentConfig(
temperature=1,
top_p=0.95,
top_k=40,
max_output_tokens=8192,
response_modalities=[
"image",
"text",
],
response_mime_type="text/plain",
)
model = "gemini-2.0-flash-exp-image-generation"
def save_binary_file(file_name, data):
f = open(file_name, "wb")
f.write(data)
f.close()
def response(
message: MultimodalValue,
history: list[dict[str, Any]],
chat_history_gemini,
client: genai.Client | None,
):
if client is None:
raise gr.Error("Client not initialized. Please enter your API key.")
files = [client.files.upload(file=f) for f in message.get("files", "")]
file_parts = [
types.Part.from_uri(file_uri=f.uri, mime_type=f.mime_type) for f in files
]
user_msg = types.Content(
role="user",
parts=file_parts + [types.Part.from_text(text=message.get("text", ""))],
)
chat_history_gemini.append(user_msg)
msg = {"files": [], "text": ""}
for chunk in client.models.generate_content_stream(
model=model,
contents=chat_history_gemini,
config=generate_content_config,
):
if (
not chunk.candidates
or not chunk.candidates[0].content
or not chunk.candidates[0].content.parts
):
continue
if chunk.candidates[0].content.parts[0].inline_data:
file_name = str(uuid.uuid4())
inline_data = chunk.candidates[0].content.parts[0].inline_data
file_extension = mimetypes.guess_extension(inline_data.mime_type)
save_binary_file(f"{file_name}{file_extension}", inline_data.data)
msg["files"] = [f"{file_name}{file_extension}"]
else:
msg["text"] += chunk.text
yield msg
api_key = gr.Textbox(
label="Enter your API key", type="password", value=os.getenv("GOOGLE_API_KEY")
)
deep_link = gr.DeepLinkButton()
with gr.Blocks() as demo:
chat_history_gemini = gr.State([])
client = gr.State(value=None)
gr.ChatInterface(
response,
title="Gemini Image Generation",
description="Edit images and share chat with deep links",
additional_inputs=[chat_history_gemini, client],
type="messages",
multimodal=True,
)
gr.DeepLinkButton()
api_key.render()
api_key.submit(
lambda api_key: genai.Client(api_key=api_key),
inputs=api_key,
outputs=client,
)
demo.load(
lambda: genai.Client(api_key=os.getenv("GOOGLE_API_KEY"))
if os.getenv("GOOGLE_API_KEY")
else None,
inputs=[],
outputs=[client],
)
if __name__ == "__main__":
demo.launch()