miaoyibo commited on
Commit
bfa25fc
·
1 Parent(s): f563de6
Files changed (1) hide show
  1. app.py +342 -54
app.py CHANGED
@@ -1,64 +1,352 @@
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
 
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  temperature,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  top_p,
 
 
 
 
17
  ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
 
63
  if __name__ == "__main__":
64
- demo.launch()
 
 
1
+ import argparse
2
  import gradio as gr
3
+ import os
4
+ from PIL import Image
5
+ import spaces
6
+ import copy
7
 
8
+ from kimi_vl.serve.frontend import reload_javascript
9
+ from kimi_vl.serve.utils import (
10
+ configure_logger,
11
+ pil_to_base64,
12
+ parse_ref_bbox,
13
+ strip_stop_words,
14
+ is_variable_assigned,
15
+ )
16
+ from kimi_vl.serve.gradio_utils import (
17
+ cancel_outputing,
18
+ delete_last_conversation,
19
+ reset_state,
20
+ reset_textbox,
21
+ transfer_input,
22
+ wrap_gen_fn,
23
+ )
24
+ from kimi_vl.serve.chat_utils import (
25
+ generate_prompt_with_history,
26
+ convert_conversation_to_prompts,
27
+ to_gradio_chatbot,
28
+ to_gradio_history,
29
+ )
30
+ from kimi_vl.serve.inference import kimi_vl_generate, load_model
31
+ from kimi_vl.serve.examples import get_examples
32
+
33
+ TITLE = """<h1 align="left" style="min-width:200px; margin-top:0;">Chat with Kimi-VL-A3B-Thinking🤔 </h1>"""
34
+ DESCRIPTION_TOP = """<a href="https://github.com/MoonshotAI/Kimi-VL" target="_blank">Kimi-VL-A3B-Thinking</a> is a multi-modal LLM that can understand text and images, and generate text with thinking processes. For non-thinking version, please try [Kimi-VL-A3B](https://huggingface.co/spaces/moonshotai/Kimi-VL-A3B)."""
35
+ DESCRIPTION = """"""
36
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
37
+ DEPLOY_MODELS = dict()
38
+ logger = configure_logger()
39
+
40
+
41
+ def parse_args():
42
+ parser = argparse.ArgumentParser()
43
+ parser.add_argument("--model", type=str, default="Kimi-VL-A3B-Thinking")
44
+ parser.add_argument(
45
+ "--local-path",
46
+ type=str,
47
+ default="",
48
+ help="huggingface ckpt, optional",
49
+ )
50
+ parser.add_argument("--ip", type=str, default="0.0.0.0")
51
+ parser.add_argument("--port", type=int, default=7860)
52
+ return parser.parse_args()
53
+
54
+
55
+ def fetch_model(model_name: str):
56
+ global args, DEPLOY_MODELS
57
+
58
+ if args.local_path:
59
+ model_path = args.local_path
60
+ else:
61
+ model_path = f"moonshotai/{args.model}"
62
+
63
+ if model_name in DEPLOY_MODELS:
64
+ model_info = DEPLOY_MODELS[model_name]
65
+ print(f"{model_name} has been loaded.")
66
+ else:
67
+ print(f"{model_name} is loading...")
68
+ DEPLOY_MODELS[model_name] = load_model(model_path)
69
+ print(f"Load {model_name} successfully...")
70
+ model_info = DEPLOY_MODELS[model_name]
71
+
72
+ return model_info
73
+
74
+
75
+ def preview_images(files) -> list[str]:
76
+ if files is None:
77
+ return []
78
 
79
+ image_paths = []
80
+ for file in files:
81
+ image_paths.append(file.name)
82
+ return image_paths
83
 
84
+
85
+ def get_prompt(conversation) -> str:
86
+ """
87
+ Get the prompt for the conversation.
88
+ """
89
+ system_prompt = conversation.system_template.format(system_message=conversation.system_message)
90
+ return system_prompt
91
+
92
+ def highlight_thinking(msg: str) -> str:
93
+ msg = copy.deepcopy(msg)
94
+ if "◁think▷" in msg:
95
+ msg = msg.replace("◁think▷", "<b style='color:blue;'>🤔Thinking...</b>\n")
96
+ if "◁/think▷" in msg:
97
+ msg = msg.replace("◁/think▷", "\n<b style='color:purple;'>💡Summary</b>\n")
98
+
99
+ return msg
100
+
101
+ @wrap_gen_fn
102
+ @spaces.GPU(duration=180)
103
+ def predict(
104
+ text,
105
+ images,
106
+ chatbot,
107
+ history,
108
+ top_p,
109
  temperature,
110
+ max_length_tokens,
111
+ max_context_length_tokens,
112
+ chunk_size: int = 512,
113
+ ):
114
+ """
115
+ Predict the response for the input text and images.
116
+ Args:
117
+ text (str): The input text.
118
+ images (list[PIL.Image.Image]): The input images.
119
+ chatbot (list): The chatbot.
120
+ history (list): The history.
121
+ top_p (float): The top-p value.
122
+ temperature (float): The temperature value.
123
+ repetition_penalty (float): The repetition penalty value.
124
+ max_length_tokens (int): The max length tokens.
125
+ max_context_length_tokens (int): The max context length tokens.
126
+ chunk_size (int): The chunk size.
127
+ """
128
+ print("running the prediction function")
129
+ try:
130
+ model, processor = fetch_model(args.model)
131
+
132
+ if text == "":
133
+ yield chatbot, history, "Empty context."
134
+ return
135
+ except KeyError:
136
+ yield [[text, "No Model Found"]], [], "No Model Found"
137
+ return
138
+
139
+ if images is None:
140
+ images = []
141
+
142
+ # load images
143
+ pil_images = []
144
+ for img_or_file in images:
145
+ try:
146
+ # load as pil image
147
+ if isinstance(images, Image.Image):
148
+ pil_images.append(img_or_file)
149
+ else:
150
+ image = Image.open(img_or_file.name).convert("RGB")
151
+ pil_images.append(image)
152
+ except Exception as e:
153
+ print(f"Error loading image: {e}")
154
+
155
+ # generate prompt
156
+ conversation = generate_prompt_with_history(
157
+ text,
158
+ pil_images,
159
+ history,
160
+ processor,
161
+ max_length=max_context_length_tokens,
162
+ )
163
+ all_conv, last_image = convert_conversation_to_prompts(conversation)
164
+ stop_words = conversation.stop_str
165
+ gradio_chatbot_output = to_gradio_chatbot(conversation)
166
+
167
+ full_response = ""
168
+ for x in kimi_vl_generate(
169
+ conversations=all_conv,
170
+ model=model,
171
+ processor=processor,
172
+ stop_words=stop_words,
173
+ max_length=max_length_tokens,
174
+ temperature=temperature,
175
+ top_p=top_p,
176
+ ):
177
+ full_response += x
178
+ response = strip_stop_words(full_response, stop_words)
179
+ conversation.update_last_message(response)
180
+ gradio_chatbot_output[-1][1] = highlight_thinking(response)
181
+
182
+ yield gradio_chatbot_output, to_gradio_history(conversation), "Generating..."
183
+
184
+ if last_image is not None:
185
+ vg_image = parse_ref_bbox(response, last_image)
186
+ if vg_image is not None:
187
+ vg_base64 = pil_to_base64(vg_image, "vg", max_size=800, min_size=400)
188
+ gradio_chatbot_output[-1][1] += vg_base64
189
+ yield gradio_chatbot_output, to_gradio_history(conversation), "Generating..."
190
+
191
+ logger.info("flushed result to gradio")
192
+
193
+ if is_variable_assigned("x"):
194
+ print(
195
+ f"temperature: {temperature}, "
196
+ f"top_p: {top_p}, "
197
+ f"max_length_tokens: {max_length_tokens}"
198
+ )
199
+
200
+ yield gradio_chatbot_output, to_gradio_history(conversation), "Generate: Success"
201
+
202
+
203
+ def retry(
204
+ text,
205
+ images,
206
+ chatbot,
207
+ history,
208
  top_p,
209
+ temperature,
210
+ max_length_tokens,
211
+ max_context_length_tokens,
212
+ chunk_size: int = 512,
213
  ):
214
+ """
215
+ Retry the response for the input text and images.
216
+ """
217
+ if len(history) == 0:
218
+ yield (chatbot, history, "Empty context")
219
+ return
220
+
221
+ chatbot.pop()
222
+ history.pop()
223
+ text = history.pop()[-1]
224
+ if type(text) is tuple:
225
+ text, _ = text
226
+
227
+ yield from predict(
228
+ text,
229
+ images,
230
+ chatbot,
231
+ history,
232
+ top_p,
233
+ temperature,
234
+ max_length_tokens,
235
+ max_context_length_tokens,
236
+ chunk_size,
237
+ )
238
+
239
+
240
+ def build_demo(args: argparse.Namespace) -> gr.Blocks:
241
+ with gr.Blocks(theme=gr.themes.Soft(), delete_cache=(1800, 1800)) as demo:
242
+ history = gr.State([])
243
+ input_text = gr.State()
244
+ input_images = gr.State()
245
+
246
+ with gr.Row():
247
+ gr.HTML(TITLE)
248
+ status_display = gr.Markdown("Success", elem_id="status_display")
249
+ gr.Markdown(DESCRIPTION_TOP)
250
+
251
+ with gr.Row(equal_height=True):
252
+ with gr.Column(scale=4):
253
+ with gr.Row():
254
+ chatbot = gr.Chatbot(
255
+ elem_id="Kimi-VL-A3B-Thinking-chatbot",
256
+ show_share_button=True,
257
+ bubble_full_width=False,
258
+ height=600,
259
+ )
260
+ with gr.Row():
261
+ with gr.Column(scale=4):
262
+ text_box = gr.Textbox(show_label=False, placeholder="Enter text", container=False)
263
+ with gr.Column(min_width=70):
264
+ submit_btn = gr.Button("Send")
265
+ with gr.Column(min_width=70):
266
+ cancel_btn = gr.Button("Stop")
267
+ with gr.Row():
268
+ empty_btn = gr.Button("🧹 New Conversation")
269
+ retry_btn = gr.Button("🔄 Regenerate")
270
+ del_last_btn = gr.Button("🗑️ Remove Last Turn")
271
+
272
+ with gr.Column():
273
+ # add note no more than 2 images once
274
+ gr.Markdown("Note: you can upload no more than 2 images once")
275
+ upload_images = gr.Files(file_types=["image"], show_label=True)
276
+ gallery = gr.Gallery(columns=[3], height="200px", show_label=True)
277
+ upload_images.change(preview_images, inputs=upload_images, outputs=gallery)
278
+ # Parameter Setting Tab for control the generation parameters
279
+ with gr.Tab(label="Parameter Setting"):
280
+ top_p = gr.Slider(minimum=-0, maximum=1.0, value=1.0, step=0.05, interactive=True, label="Top-p")
281
+ temperature = gr.Slider(
282
+ minimum=0, maximum=1.0, value=0.6, step=0.1, interactive=True, label="Temperature"
283
+ )
284
+ max_length_tokens = gr.Slider(
285
+ minimum=512, maximum=8192, value=2048, step=64, interactive=True, label="Max Length Tokens"
286
+ )
287
+ max_context_length_tokens = gr.Slider(
288
+ minimum=512, maximum=8192, value=2048, step=64, interactive=True, label="Max Context Length Tokens"
289
+ )
290
+
291
+ show_images = gr.HTML(visible=False)
292
+
293
+ gr.Examples(
294
+ examples=get_examples(ROOT_DIR),
295
+ inputs=[upload_images, show_images, text_box],
296
+ )
297
+ gr.Markdown()
298
+
299
+ input_widgets = [
300
+ input_text,
301
+ input_images,
302
+ chatbot,
303
+ history,
304
+ top_p,
305
+ temperature,
306
+ max_length_tokens,
307
+ max_context_length_tokens,
308
+ ]
309
+ output_widgets = [chatbot, history, status_display]
310
+
311
+ transfer_input_args = dict(
312
+ fn=transfer_input,
313
+ inputs=[text_box, upload_images],
314
+ outputs=[input_text, input_images, text_box, upload_images, submit_btn],
315
+ show_progress=True,
316
+ )
317
+
318
+ predict_args = dict(fn=predict, inputs=input_widgets, outputs=output_widgets, show_progress=True)
319
+ retry_args = dict(fn=retry, inputs=input_widgets, outputs=output_widgets, show_progress=True)
320
+ reset_args = dict(fn=reset_textbox, inputs=[], outputs=[text_box, status_display])
321
+
322
+ predict_events = [
323
+ text_box.submit(**transfer_input_args).then(**predict_args),
324
+ submit_btn.click(**transfer_input_args).then(**predict_args),
325
+ ]
326
+
327
+ empty_btn.click(reset_state, outputs=output_widgets, show_progress=True)
328
+ empty_btn.click(**reset_args)
329
+ retry_btn.click(**retry_args)
330
+ del_last_btn.click(delete_last_conversation, [chatbot, history], output_widgets, show_progress=True)
331
+ cancel_btn.click(cancel_outputing, [], [status_display], cancels=predict_events)
332
+
333
+ demo.title = "Kimi-VL-A3B-Thinking Chatbot"
334
+ return demo
335
+
336
+
337
+ def main(args: argparse.Namespace):
338
+ demo = build_demo(args)
339
+ reload_javascript()
340
+
341
+ # concurrency_count=CONCURRENT_COUNT, max_size=MAX_EVENTS
342
+ favicon_path = os.path.join("kimi_vl/serve/assets/favicon.ico")
343
+ demo.queue().launch(
344
+ favicon_path=favicon_path,
345
+ server_name=args.ip,
346
+ server_port=args.port,
347
+ )
348
 
349
 
350
  if __name__ == "__main__":
351
+ args = parse_args()
352
+ main(args)