File size: 3,945 Bytes
0e8466f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import mimetypes
import os
import gradio as gr
import requests
import json, time

base_url = "http://127.0.0.1:8000"

def upload_image(file_path):
    if file_path is None:
        return None
    # Gradio File component returns a tempfile-like object
    # file_path = image.name
    filename = os.path.basename(file_path)
    # Guess MIME type
    mime_type, _ = mimetypes.guess_type(filename)
    mime_type = mime_type or 'application/octet-stream'
    # Open file in binary mode for upload
    with open(file_path, 'rb') as f:
        file_bytes = f.read()
    # Prepare multipart form data
    files = {
        'image': (filename, file_bytes, mime_type)
    }
    # Send to upload endpoint
    resp = requests.post(
        f'{base_url}/api/upload',
        files=files
    )
    resp.raise_for_status()
    data = resp.json()
    return data.get('file_path')

def stop_generation():
    try:
        requests.get(f'{base_url}/api/stop')
    except:
        pass

def respond(prompt, image:gr.Image, temp, rep_penalty, tp, tk, history=None):
    if history is None:
        history = []
    if not prompt.strip():
        return history
    # append empty response to history
    if  image is None:
        file_path = None
    else:
        file_path = upload_image(image)
        history.append((f'![]({file_path})', None))
        relative_path = os.path.relpath(file_path)
        # html = f"<img src='{relative_path}' style='max-width:300px;'/>"
        # history.append((html, None))
        # print(relative_path)
        
    
    history.append((prompt, ""))
    yield history
    # stream updates
    
    
    payload = {
        "prompt": prompt,
        "temperature": temp,
        "repetition_penalty": rep_penalty,
        "top-p": tp,
        "top-k": tk
    }
    if file_path:
        payload["file_path"] = file_path

    response = requests.post(
        f'{base_url}/api/generate',
        json=payload
    )
    response.raise_for_status()
    

    while True:
        time.sleep(0.01)
        response = requests.get(
                f'{base_url}/api/generate_provider'
            )
        data = response.json()
        chunk:str = data.get("response", "") 
        done = data.get("done", False)
        if done:
            break
        if chunk.strip() == "":
            continue
        history[-1] = (prompt, history[-1][1] + chunk)
        yield history
       
    print("end")
    
     


def chat_interface():
    with gr.Blocks(theme=gr.themes.Soft(font="Consolas"), fill_width=True) as demo:
        gr.Markdown("## Chat with LLM\nUpload an image and chat with the model!")
        with gr.Row():
            image = gr.Image(label="Upload Image", type="filepath")
            with gr.Column(scale=3):
                chatbot = gr.Chatbot(height=600)
                prompt = gr.Textbox(placeholder="Type your message...", label="Prompt", value="描述一下这张图片")
                with gr.Row():
                    btn_chat = gr.Button("Chat", variant="primary")
                    btn_stop = gr.Button("Stop", variant="stop")

            with gr.Column(scale=1):
                temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.7, label="Temperature")
                repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, step=0.01, value=1.0, label="Repetition Penalty")
                top_p = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.9, label="Top-p Sampling")
                top_k = gr.Slider(minimum=0, maximum=100, step=1, value=40, label="Top-k Sampling")

            btn_stop.click(fn=stop_generation, inputs=None, outputs=None)
            btn_chat.click(
                fn=respond,
                inputs=[prompt, image, temperature, repetition_penalty, top_p, top_k, chatbot],
                outputs=chatbot
            )

        demo.launch(server_name="0.0.0.0", server_port=7860)

if __name__ == "__main__":
    chat_interface()