|
import mimetypes |
|
import os |
|
from typing import Any |
|
import uuid |
|
|
|
import gradio as gr |
|
from dotenv import load_dotenv |
|
from google import genai |
|
from google.genai import types |
|
from gradio.components.multimodal_textbox import MultimodalValue |
|
|
|
load_dotenv() |
|
|
|
generate_content_config = types.GenerateContentConfig( |
|
temperature=1, |
|
top_p=0.95, |
|
top_k=40, |
|
max_output_tokens=8192, |
|
response_modalities=[ |
|
"image", |
|
"text", |
|
], |
|
response_mime_type="text/plain", |
|
) |
|
model = "gemini-2.0-flash-exp-image-generation" |
|
|
|
|
|
def save_binary_file(file_name, data): |
|
f = open(file_name, "wb") |
|
f.write(data) |
|
f.close() |
|
|
|
|
|
def response( |
|
message: MultimodalValue, |
|
history: list[dict[str, Any]], |
|
chat_history_gemini, |
|
client: genai.Client | None, |
|
): |
|
if client is None: |
|
raise gr.Error("Client not initialized. Please enter your API key.") |
|
|
|
files = [client.files.upload(file=f) for f in message.get("files", "")] |
|
file_parts = [ |
|
types.Part.from_uri(file_uri=f.uri, mime_type=f.mime_type) for f in files |
|
] |
|
user_msg = types.Content( |
|
role="user", |
|
parts=file_parts + [types.Part.from_text(text=message.get("text", ""))], |
|
) |
|
chat_history_gemini.append(user_msg) |
|
msg = {"files": [], "text": ""} |
|
for chunk in client.models.generate_content_stream( |
|
model=model, |
|
contents=chat_history_gemini, |
|
config=generate_content_config, |
|
): |
|
if ( |
|
not chunk.candidates |
|
or not chunk.candidates[0].content |
|
or not chunk.candidates[0].content.parts |
|
): |
|
continue |
|
if chunk.candidates[0].content.parts[0].inline_data: |
|
file_name = str(uuid.uuid4()) |
|
inline_data = chunk.candidates[0].content.parts[0].inline_data |
|
file_extension = mimetypes.guess_extension(inline_data.mime_type) |
|
save_binary_file(f"{file_name}{file_extension}", inline_data.data) |
|
msg["files"] = [f"{file_name}{file_extension}"] |
|
else: |
|
msg["text"] += chunk.text |
|
yield msg |
|
|
|
|
|
api_key = gr.Textbox( |
|
label="Enter your API key", type="password", value=os.getenv("GOOGLE_API_KEY") |
|
) |
|
deep_link = gr.DeepLinkButton() |
|
|
|
with gr.Blocks() as demo: |
|
chat_history_gemini = gr.State([]) |
|
client = gr.State(value=None) |
|
gr.ChatInterface( |
|
response, |
|
title="Gemini Image Generation", |
|
description="Edit images and share chat with deep links", |
|
additional_inputs=[chat_history_gemini, client], |
|
type="messages", |
|
multimodal=True, |
|
) |
|
gr.DeepLinkButton() |
|
api_key.render() |
|
api_key.submit( |
|
lambda api_key: genai.Client(api_key=api_key), |
|
inputs=api_key, |
|
outputs=client, |
|
) |
|
demo.load( |
|
lambda: genai.Client(api_key=os.getenv("GOOGLE_API_KEY")) |
|
if os.getenv("GOOGLE_API_KEY") |
|
else None, |
|
inputs=[], |
|
outputs=[client], |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|