File size: 3,945 Bytes
0e8466f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import mimetypes
import os
import gradio as gr
import requests
import json, time
base_url = "http://127.0.0.1:8000"
def upload_image(file_path):
if file_path is None:
return None
# Gradio File component returns a tempfile-like object
# file_path = image.name
filename = os.path.basename(file_path)
# Guess MIME type
mime_type, _ = mimetypes.guess_type(filename)
mime_type = mime_type or 'application/octet-stream'
# Open file in binary mode for upload
with open(file_path, 'rb') as f:
file_bytes = f.read()
# Prepare multipart form data
files = {
'image': (filename, file_bytes, mime_type)
}
# Send to upload endpoint
resp = requests.post(
f'{base_url}/api/upload',
files=files
)
resp.raise_for_status()
data = resp.json()
return data.get('file_path')
def stop_generation():
try:
requests.get(f'{base_url}/api/stop')
except:
pass
def respond(prompt, image:gr.Image, temp, rep_penalty, tp, tk, history=None):
if history is None:
history = []
if not prompt.strip():
return history
# append empty response to history
if image is None:
file_path = None
else:
file_path = upload_image(image)
history.append((f'', None))
relative_path = os.path.relpath(file_path)
# html = f"<img src='{relative_path}' style='max-width:300px;'/>"
# history.append((html, None))
# print(relative_path)
history.append((prompt, ""))
yield history
# stream updates
payload = {
"prompt": prompt,
"temperature": temp,
"repetition_penalty": rep_penalty,
"top-p": tp,
"top-k": tk
}
if file_path:
payload["file_path"] = file_path
response = requests.post(
f'{base_url}/api/generate',
json=payload
)
response.raise_for_status()
while True:
time.sleep(0.01)
response = requests.get(
f'{base_url}/api/generate_provider'
)
data = response.json()
chunk:str = data.get("response", "")
done = data.get("done", False)
if done:
break
if chunk.strip() == "":
continue
history[-1] = (prompt, history[-1][1] + chunk)
yield history
print("end")
def chat_interface():
with gr.Blocks(theme=gr.themes.Soft(font="Consolas"), fill_width=True) as demo:
gr.Markdown("## Chat with LLM\nUpload an image and chat with the model!")
with gr.Row():
image = gr.Image(label="Upload Image", type="filepath")
with gr.Column(scale=3):
chatbot = gr.Chatbot(height=600)
prompt = gr.Textbox(placeholder="Type your message...", label="Prompt", value="描述一下这张图片")
with gr.Row():
btn_chat = gr.Button("Chat", variant="primary")
btn_stop = gr.Button("Stop", variant="stop")
with gr.Column(scale=1):
temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.7, label="Temperature")
repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, step=0.01, value=1.0, label="Repetition Penalty")
top_p = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.9, label="Top-p Sampling")
top_k = gr.Slider(minimum=0, maximum=100, step=1, value=40, label="Top-k Sampling")
btn_stop.click(fn=stop_generation, inputs=None, outputs=None)
btn_chat.click(
fn=respond,
inputs=[prompt, image, temperature, repetition_penalty, top_p, top_k, chatbot],
outputs=chatbot
)
demo.launch(server_name="0.0.0.0", server_port=7860)
if __name__ == "__main__":
chat_interface()
|