miaoyibo commited on
Commit
46a0b0f
Β·
1 Parent(s): 8cf3ee6
.gitignore CHANGED
@@ -3,3 +3,6 @@
3
  __pycache__
4
  *.pyc
5
  *.pyo
 
 
 
 
3
  __pycache__
4
  *.pyc
5
  *.pyo
6
+
7
+ .gradio
8
+ local_path/
app.py CHANGED
@@ -1,44 +1,38 @@
1
  import argparse
2
  import gradio as gr
3
  import os
4
- from PIL import Image
5
  import spaces
6
  import copy
7
  import time
 
 
 
 
 
8
 
9
- from kimi_vl.serve.frontend import reload_javascript
10
- from kimi_vl.serve.utils import (
 
 
11
  configure_logger,
12
- pil_to_base64,
13
- parse_ref_bbox,
14
- strip_stop_words,
15
- is_variable_assigned,
16
  )
17
- from kimi_vl.serve.gradio_utils import (
18
- cancel_outputing,
19
- delete_last_conversation,
20
  reset_state,
21
  reset_textbox,
22
  transfer_input,
23
  wrap_gen_fn,
24
  )
25
- from kimi_vl.serve.chat_utils import (
26
- generate_prompt_with_history,
27
- convert_conversation_to_prompts,
28
- to_gradio_chatbot,
29
- to_gradio_history,
30
- )
31
- from kimi_vl.serve.inference import kimi_dev_generate, load_model
32
- from kimi_vl.serve.examples import get_examples
33
 
34
- TITLE = """<h1 align="left" style="min-width:200px; margin-top:0;">Chat with Kimi-Dev-72BπŸ€” </h1>"""
35
- DESCRIPTION_TOP = """<a href="https://github.com/MoonshotAI/Kimi-VL" target="_blank">Kimi-Dev-72B</a> is a multi-modal LLM that can understand text and images, and generate text with thinking processes. For non-thinking version, please try [Kimi-VL-A3B](https://huggingface.co/spaces/moonshotai/Kimi-VL-A3B)."""
36
- DESCRIPTION = """"""
37
  ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
38
  DEPLOY_MODELS = dict()
39
  logger = configure_logger()
40
 
41
-
42
  def parse_args():
43
  parser = argparse.ArgumentParser()
44
  parser.add_argument("--model", type=str, default="Kimi-Dev-72B")
@@ -73,16 +67,6 @@ def fetch_model(model_name: str):
73
  return model_info
74
 
75
 
76
- def preview_images(files) -> list[str]:
77
- if files is None:
78
- return []
79
-
80
- image_paths = []
81
- for file in files:
82
- image_paths.append(file.name)
83
- return image_paths
84
-
85
-
86
  def get_prompt(conversation) -> str:
87
  """
88
  Get the prompt for the conversation.
@@ -103,30 +87,29 @@ def highlight_thinking(msg: str) -> str:
103
  @spaces.GPU(duration=180)
104
  def predict(
105
  text,
106
- images,
107
  chatbot,
108
  history,
109
  top_p,
110
  temperature,
111
  max_length_tokens,
112
- max_context_length_tokens,
113
  chunk_size: int = 512,
114
  ):
115
  """
116
- Predict the response for the input text and images.
117
  Args:
118
  text (str): The input text.
119
- images (list[PIL.Image.Image]): The input images.
120
  chatbot (list): The chatbot.
121
  history (list): The history.
122
  top_p (float): The top-p value.
123
  temperature (float): The temperature value.
124
  repetition_penalty (float): The repetition penalty value.
125
  max_length_tokens (int): The max length tokens.
126
- max_context_length_tokens (int): The max context length tokens.
127
  chunk_size (int): The chunk size.
128
  """
129
  print("running the prediction function")
 
130
  try:
131
  model, tokenizer = fetch_model(args.model)
132
 
@@ -137,131 +120,161 @@ def predict(
137
  yield [[text, "No Model Found"]], [], "No Model Found"
138
  return
139
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
- prompt = "Give me a short introduction to large language model."
142
  messages = [
143
  {"role": "system", "content": "You are a helpful assistant."},
144
- {"role": "user", "content": prompt}
145
  ]
146
- text = tokenizer.apply_chat_template(
147
  messages,
148
  tokenize=False,
149
  add_generation_prompt=True
150
  )
151
- model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
- generated_ids = model.generate(
154
- **model_inputs,
155
- max_new_tokens=512
156
- )
157
- generated_ids = [
158
- output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
159
- ]
160
 
161
- response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
 
 
 
 
 
 
 
 
162
 
 
 
163
  print(response)
164
- time.sleep(2600)
165
-
166
-
167
- if images is None:
168
- images = []
169
-
170
- # load images
171
- pil_images = []
172
- for img_or_file in images:
173
- try:
174
- # load as pil image
175
- if isinstance(images, Image.Image):
176
- pil_images.append(img_or_file)
177
- else:
178
- image = Image.open(img_or_file.name).convert("RGB")
179
- pil_images.append(image)
180
- except Exception as e:
181
- print(f"Error loading image: {e}")
182
-
183
- # generate prompt
184
- conversation = generate_prompt_with_history(
185
- text,
186
- pil_images,
187
- history,
188
- max_length=max_context_length_tokens,
 
 
189
  )
190
- print(conversation)
191
- all_conv, last_image = convert_conversation_to_prompts(conversation)
192
- stop_words = conversation.stop_str
193
- gradio_chatbot_output = to_gradio_chatbot(conversation)
194
-
195
- full_response = ""
196
- for x in kimi_dev_generate(
197
- conversations=all_conv,
198
- model=model,
199
- tokneizer=tokenizer,
200
- # processor=processor,
201
- stop_words=stop_words,
202
- max_length=max_length_tokens,
203
  temperature=temperature,
204
  top_p=top_p,
205
- ):
206
- full_response += x
207
- response = strip_stop_words(full_response, stop_words)
208
- conversation.update_last_message(response)
209
- gradio_chatbot_output[-1][1] = highlight_thinking(response)
210
-
211
- yield gradio_chatbot_output, to_gradio_history(conversation), "Generating..."
212
-
213
- if last_image is not None:
214
- vg_image = parse_ref_bbox(response, last_image)
215
- if vg_image is not None:
216
- vg_base64 = pil_to_base64(vg_image, "vg", max_size=800, min_size=400)
217
- gradio_chatbot_output[-1][1] += vg_base64
218
- yield gradio_chatbot_output, to_gradio_history(conversation), "Generating..."
219
-
220
- logger.info("flushed result to gradio")
221
-
222
- if is_variable_assigned("x"):
223
- print(
224
- f"temperature: {temperature}, "
225
- f"top_p: {top_p}, "
226
- f"max_length_tokens: {max_length_tokens}"
227
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
- yield gradio_chatbot_output, to_gradio_history(conversation), "Generate: Success"
 
 
 
230
 
231
 
232
  def retry(
233
  text,
234
- images,
235
  chatbot,
236
  history,
237
  top_p,
238
  temperature,
239
  max_length_tokens,
240
- max_context_length_tokens,
241
  chunk_size: int = 512,
242
  ):
243
  """
244
- Retry the response for the input text and images.
245
  """
246
  if len(history) == 0:
247
  yield (chatbot, history, "Empty context")
248
  return
249
 
250
- chatbot.pop()
251
- history.pop()
252
- text = history.pop()[-1]
253
  if type(text) is tuple:
254
  text, _ = text
255
 
256
  yield from predict(
257
  text,
258
- images,
259
  chatbot,
260
  history,
261
  top_p,
262
  temperature,
263
  max_length_tokens,
264
- max_context_length_tokens,
265
  chunk_size,
266
  )
267
 
@@ -270,12 +283,13 @@ def build_demo(args: argparse.Namespace) -> gr.Blocks:
270
  with gr.Blocks(theme=gr.themes.Soft(), delete_cache=(1800, 1800)) as demo:
271
  history = gr.State([])
272
  input_text = gr.State()
273
- input_images = gr.State()
274
 
275
  with gr.Row():
276
  gr.HTML(TITLE)
277
  status_display = gr.Markdown("Success", elem_id="status_display")
278
  gr.Markdown(DESCRIPTION_TOP)
 
279
 
280
  with gr.Row(equal_height=True):
281
  with gr.Column(scale=4):
@@ -284,63 +298,59 @@ def build_demo(args: argparse.Namespace) -> gr.Blocks:
284
  elem_id="Kimi-Dev-72B",
285
  show_share_button=True,
286
  bubble_full_width=False,
287
- height=600,
 
288
  )
289
  with gr.Row():
290
  with gr.Column(scale=4):
291
- text_box = gr.Textbox(show_label=False, placeholder="Enter text", container=False)
292
  with gr.Column(min_width=70):
293
  submit_btn = gr.Button("Send")
294
- with gr.Column(min_width=70):
295
- cancel_btn = gr.Button("Stop")
296
  with gr.Row():
297
  empty_btn = gr.Button("🧹 New Conversation")
298
  retry_btn = gr.Button("πŸ”„ Regenerate")
299
- del_last_btn = gr.Button("πŸ—‘οΈ Remove Last Turn")
300
-
 
301
  with gr.Column():
302
- # add note no more than 2 images once
303
- gr.Markdown("Note: you can upload no more than 2 images once")
304
- upload_images = gr.Files(file_types=["image"], show_label=True)
305
- gallery = gr.Gallery(columns=[3], height="200px", show_label=True)
306
- upload_images.change(preview_images, inputs=upload_images, outputs=gallery)
307
  # Parameter Setting Tab for control the generation parameters
308
  with gr.Tab(label="Parameter Setting"):
309
- top_p = gr.Slider(minimum=-0, maximum=1.0, value=1.0, step=0.05, interactive=True, label="Top-p")
310
  temperature = gr.Slider(
311
- minimum=0, maximum=1.0, value=0.6, step=0.1, interactive=True, label="Temperature"
312
  )
313
  max_length_tokens = gr.Slider(
314
- minimum=512, maximum=8192, value=2048, step=64, interactive=True, label="Max Length Tokens"
315
- )
316
- max_context_length_tokens = gr.Slider(
317
- minimum=512, maximum=8192, value=2048, step=64, interactive=True, label="Max Context Length Tokens"
318
  )
319
 
320
- show_images = gr.HTML(visible=False)
321
-
322
  gr.Examples(
323
  examples=get_examples(ROOT_DIR),
324
- inputs=[upload_images, show_images, text_box],
325
  )
326
- gr.Markdown()
327
 
328
  input_widgets = [
329
  input_text,
330
- input_images,
331
  chatbot,
332
  history,
333
  top_p,
334
  temperature,
335
  max_length_tokens,
336
- max_context_length_tokens,
337
  ]
338
  output_widgets = [chatbot, history, status_display]
339
 
340
  transfer_input_args = dict(
341
  fn=transfer_input,
342
- inputs=[text_box, upload_images],
343
- outputs=[input_text, input_images, text_box, upload_images, submit_btn],
344
  show_progress=True,
345
  )
346
 
@@ -356,8 +366,6 @@ def build_demo(args: argparse.Namespace) -> gr.Blocks:
356
  empty_btn.click(reset_state, outputs=output_widgets, show_progress=True)
357
  empty_btn.click(**reset_args)
358
  retry_btn.click(**retry_args)
359
- del_last_btn.click(delete_last_conversation, [chatbot, history], output_widgets, show_progress=True)
360
- cancel_btn.click(cancel_outputing, [], [status_display], cancels=predict_events)
361
 
362
  demo.title = "Kimi-Dev-72B"
363
  return demo
@@ -367,8 +375,7 @@ def main(args: argparse.Namespace):
367
  demo = build_demo(args)
368
  reload_javascript()
369
 
370
- # concurrency_count=CONCURRENT_COUNT, max_size=MAX_EVENTS
371
- favicon_path = os.path.join("kimi_vl/serve/assets/favicon.ico")
372
  # demo.queue().launch(
373
  # favicon_path=favicon_path,
374
  # server_name=args.ip,
@@ -378,7 +385,7 @@ def main(args: argparse.Namespace):
378
  favicon_path=favicon_path,
379
  server_name=args.ip,
380
  server_port=args.port,
381
- share=True # for ζœ¬εœ°θ°ƒθ―•
382
  )
383
 
384
  if __name__ == "__main__":
 
1
  import argparse
2
  import gradio as gr
3
  import os
 
4
  import spaces
5
  import copy
6
  import time
7
+ import json
8
+ import subprocess
9
+ import ast
10
+ import pdb
11
+ from transformers import TextIteratorStreamer
12
 
13
+ import threading
14
+
15
+ from kimi_dev.serve.frontend import reload_javascript
16
+ from kimi_dev.serve.utils import (
17
  configure_logger,
 
 
 
 
18
  )
19
+ from kimi_dev.serve.gradio_utils import (
 
 
20
  reset_state,
21
  reset_textbox,
22
  transfer_input,
23
  wrap_gen_fn,
24
  )
25
+ from kimi_dev.serve.inference import load_model
26
+ from kimi_dev.serve.examples import get_examples
27
+ from kimi_dev.serve.templates import post_process,get_loc_prompt, clone_github_repo, build_repo_structure, show_project_structure,get_repair_prompt,get_repo_files,get_full_file_paths_and_classes_and_functions,correct_file_path_in_structure
 
 
 
 
 
28
 
29
+ TITLE = """<h1 align="left" style="min-width:200px; margin-top:0;">Chat with Kimi-Dev-72BπŸ”₯ </h1>"""
30
+ DESCRIPTION_TOP = """<a href="https://github.com/MoonshotAI/Kimi-VL" target="_blank">Kimi-Dev-72B</a> is a strong and open-source coding LLM for software engineering tasks."""
31
+ USAGE_TOP = """Usage: 1. Input a Github url like "https://github.com/astropy/astropy" and submit it. \n2. Input your issue description and chat with Kimi-Dev-72B!"""
32
  ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
33
  DEPLOY_MODELS = dict()
34
  logger = configure_logger()
35
 
 
36
  def parse_args():
37
  parser = argparse.ArgumentParser()
38
  parser.add_argument("--model", type=str, default="Kimi-Dev-72B")
 
67
  return model_info
68
 
69
 
 
 
 
 
 
 
 
 
 
 
70
  def get_prompt(conversation) -> str:
71
  """
72
  Get the prompt for the conversation.
 
87
  @spaces.GPU(duration=180)
88
  def predict(
89
  text,
90
+ url,
91
  chatbot,
92
  history,
93
  top_p,
94
  temperature,
95
  max_length_tokens,
 
96
  chunk_size: int = 512,
97
  ):
98
  """
99
+ Predict the response for the input text and url.
100
  Args:
101
  text (str): The input text.
102
+ url (str): The input url.
103
  chatbot (list): The chatbot.
104
  history (list): The history.
105
  top_p (float): The top-p value.
106
  temperature (float): The temperature value.
107
  repetition_penalty (float): The repetition penalty value.
108
  max_length_tokens (int): The max length tokens.
 
109
  chunk_size (int): The chunk size.
110
  """
111
  print("running the prediction function")
112
+
113
  try:
114
  model, tokenizer = fetch_model(args.model)
115
 
 
120
  yield [[text, "No Model Found"]], [], "No Model Found"
121
  return
122
 
123
+ prompt = text
124
+ repo_name = url.split("/")[-1]
125
+
126
+ repo_path = './local_path/'+repo_name # Local clone path
127
+
128
+ clone_github_repo(url, repo_path)
129
+ structure = build_repo_structure(repo_path)
130
+ string_struture = show_project_structure(structure)
131
+
132
+ loc_prompt = get_loc_prompt(prompt,string_struture)
133
+
134
 
 
135
  messages = [
136
  {"role": "system", "content": "You are a helpful assistant."},
137
+ {"role": "user", "content": loc_prompt}
138
  ]
139
+ text_for_model = tokenizer.apply_chat_template(
140
  messages,
141
  tokenize=False,
142
  add_generation_prompt=True
143
  )
144
+ model_inputs = tokenizer([text_for_model], return_tensors="pt").to(model.device)
145
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
146
+ # print("start generating")
147
+ if temperature > 0:
148
+ generation_kwargs = dict(
149
+ **model_inputs,
150
+ do_sample=True,
151
+ temperature=temperature,
152
+ top_p=top_p,
153
+ max_new_tokens=max_length_tokens,
154
+ streamer=streamer
155
+ )
156
+ else:
157
+ generation_kwargs = dict(
158
+ **model_inputs,
159
+ do_sample=False,
160
+ max_new_tokens=max_length_tokens,
161
+ streamer=streamer
162
+ )
163
+ gen_thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
164
+ gen_thread.start()
165
 
 
 
 
 
 
 
 
166
 
167
+ partial_output = "Start Locating...\n"
168
+
169
+ for new_text in streamer:
170
+ partial_output += new_text
171
+ highlight_response = highlight_thinking(partial_output)
172
+ yield [[prompt, highlight_response]], [["null test", "null test2"]], "Generating file locations..."
173
+
174
+ gen_thread.join()
175
+
176
+ response = partial_output
177
 
178
+ raw_answer=post_process(response)
179
+ model_found_files = raw_answer.strip().split("\n")
180
  print(response)
181
+
182
+ highlight_response = highlight_thinking(response)
183
+ yield [[prompt,highlight_response]], [["null test","null test2"]], "Generate: Success"
184
+
185
+ # reading file content
186
+ contents = ""
187
+ for file_path in model_found_files:
188
+ file_name = file_path.replace("```","")
189
+ print(file_name)
190
+ # pdb.set_trace()
191
+ to_open_path = repo_path + "/" + file_name
192
+ print("to_open_path,",to_open_path)
193
+ with open(to_open_path, "r", encoding="utf-8") as f:
194
+ content = f.read()
195
+ contents += f"{file_name}\n{content}\n\n"
196
+
197
+
198
+ repair_prompt = get_repair_prompt(prompt,contents)
199
+
200
+ messages = [
201
+ {"role": "system", "content": "You are a helpful assistant."},
202
+ {"role": "user", "content": repair_prompt}
203
+ ]
204
+ text = tokenizer.apply_chat_template(
205
+ messages,
206
+ tokenize=False,
207
+ add_generation_prompt=True
208
  )
209
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
210
+
211
+ subprocess.run(["rm", "-rf", repo_path], check=True)
212
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
213
+ if temperature > 0:
214
+ generation_kwargs = dict(
215
+ **model_inputs,
216
+ do_sample=True,
 
 
 
 
 
217
  temperature=temperature,
218
  top_p=top_p,
219
+ max_new_tokens=max_length_tokens,
220
+ streamer=streamer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  )
222
+ else:
223
+ generation_kwargs = dict(
224
+ **model_inputs,
225
+ do_sample=False,
226
+ max_new_tokens=max_length_tokens,
227
+ streamer=streamer
228
+ )
229
+ gen_thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
230
+ gen_thread.start()
231
+
232
+ partial_output_repair = "Start Repairing...\n"
233
+ yield [[prompt,highlight_response],[repair_prompt,partial_output_repair]], [["null test","null test2"]], "Generate: Success"
234
+ time.sleep(5)
235
+ for new_text in streamer:
236
+ partial_output_repair += new_text
237
+ highlight_response = highlight_thinking(partial_output)
238
+ highlight_response_repair = highlight_thinking(partial_output_repair)
239
+ yield [[prompt, highlight_response], [repair_prompt, highlight_response_repair]], [["null test", "null test2"]], "Generating repair suggestion..."
240
 
241
+ gen_thread.join()
242
+
243
+ # yield response, "null test", "Generate: Success"
244
+ yield [[prompt,highlight_response],[repair_prompt,highlight_response_repair]], [["null test","null test2"]], "Generate: Success"
245
 
246
 
247
  def retry(
248
  text,
249
+ url,
250
  chatbot,
251
  history,
252
  top_p,
253
  temperature,
254
  max_length_tokens,
 
255
  chunk_size: int = 512,
256
  ):
257
  """
258
+ Retry the response for the input text and url.
259
  """
260
  if len(history) == 0:
261
  yield (chatbot, history, "Empty context")
262
  return
263
 
264
+ # chatbot.pop()
265
+ # history.pop()
266
+ # text = history.pop()[-1]
267
  if type(text) is tuple:
268
  text, _ = text
269
 
270
  yield from predict(
271
  text,
272
+ url,
273
  chatbot,
274
  history,
275
  top_p,
276
  temperature,
277
  max_length_tokens,
 
278
  chunk_size,
279
  )
280
 
 
283
  with gr.Blocks(theme=gr.themes.Soft(), delete_cache=(1800, 1800)) as demo:
284
  history = gr.State([])
285
  input_text = gr.State()
286
+ upload_url = gr.State()
287
 
288
  with gr.Row():
289
  gr.HTML(TITLE)
290
  status_display = gr.Markdown("Success", elem_id="status_display")
291
  gr.Markdown(DESCRIPTION_TOP)
292
+ gr.Markdown(USAGE_TOP)
293
 
294
  with gr.Row(equal_height=True):
295
  with gr.Column(scale=4):
 
298
  elem_id="Kimi-Dev-72B",
299
  show_share_button=True,
300
  bubble_full_width=False,
301
+ height=400,
302
+ # render_markdown=False
303
  )
304
  with gr.Row():
305
  with gr.Column(scale=4):
306
+ text_box = gr.Textbox(label="Issue Description", placeholder="Enter issue description", container=False)
307
  with gr.Column(min_width=70):
308
  submit_btn = gr.Button("Send")
309
+ # with gr.Column(min_width=70):
310
+ # cancel_btn = gr.Button("Stop")
311
  with gr.Row():
312
  empty_btn = gr.Button("🧹 New Conversation")
313
  retry_btn = gr.Button("πŸ”„ Regenerate")
314
+ # del_last_btn = gr.Button("πŸ—‘οΈ Remove Last Turn")
315
+ def respond(message):
316
+ return f"Url submitted!"
317
  with gr.Column():
318
+ url_box = gr.Textbox(label="Please input a Github url here",placeholder="Input your url", lines=1)
319
+ url_submit_btn = gr.Button("Submit")
320
+ output = gr.Textbox(label="Submitted url")
321
+ url_submit_btn.click(fn=respond, inputs=upload_url, outputs=output)
322
+
323
  # Parameter Setting Tab for control the generation parameters
324
  with gr.Tab(label="Parameter Setting"):
325
+ top_p = gr.Slider(minimum=-0, maximum=1.0, value=0.95, step=0.05, interactive=True, label="Top-p")
326
  temperature = gr.Slider(
327
+ minimum=0, maximum=1.0, value=1.0, step=0.1, interactive=True, label="Temperature"
328
  )
329
  max_length_tokens = gr.Slider(
330
+ minimum=512, maximum=16384, value=8192, step=64, interactive=True, label="Max Length Tokens"
 
 
 
331
  )
332
 
 
 
333
  gr.Examples(
334
  examples=get_examples(ROOT_DIR),
335
+ inputs=[url_box, text_box],
336
  )
337
+ # gr.Markdown()
338
 
339
  input_widgets = [
340
  input_text,
341
+ upload_url,
342
  chatbot,
343
  history,
344
  top_p,
345
  temperature,
346
  max_length_tokens,
 
347
  ]
348
  output_widgets = [chatbot, history, status_display]
349
 
350
  transfer_input_args = dict(
351
  fn=transfer_input,
352
+ inputs=[text_box, url_box],
353
+ outputs=[input_text, upload_url, text_box, upload_url, submit_btn],
354
  show_progress=True,
355
  )
356
 
 
366
  empty_btn.click(reset_state, outputs=output_widgets, show_progress=True)
367
  empty_btn.click(**reset_args)
368
  retry_btn.click(**retry_args)
 
 
369
 
370
  demo.title = "Kimi-Dev-72B"
371
  return demo
 
375
  demo = build_demo(args)
376
  reload_javascript()
377
 
378
+ favicon_path = os.path.join("kimi_dev/serve/assets/favicon.ico")
 
379
  # demo.queue().launch(
380
  # favicon_path=favicon_path,
381
  # server_name=args.ip,
 
385
  favicon_path=favicon_path,
386
  server_name=args.ip,
387
  server_port=args.port,
388
+ share=True
389
  )
390
 
391
  if __name__ == "__main__":
{kimi_vl β†’ kimi_dev}/__init__.py RENAMED
File without changes
{kimi_vl β†’ kimi_dev}/serve/__init__.py RENAMED
File without changes
{kimi_vl β†’ kimi_dev}/serve/assets/Kelpy-Codos.js RENAMED
File without changes
{kimi_vl β†’ kimi_dev}/serve/assets/avatar.png RENAMED
File without changes
{kimi_vl β†’ kimi_dev}/serve/assets/custom.css RENAMED
File without changes
{kimi_vl β†’ kimi_dev}/serve/assets/custom.js RENAMED
File without changes
{kimi_vl β†’ kimi_dev}/serve/assets/favicon.ico RENAMED
File without changes
kimi_dev/serve/examples.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import base64
4
+
5
+ EXAMPLES_LIST = [
6
+ [
7
+ "https://github.com/astropy/astropy",
8
+ "units.quantity_input decorator fails for constructors with type hinted return value -> None\n### Summary\r\nI am using the `units.quantity_input` decorator with typing hints for constructors, however when I add the correct return value for the constructor (`None`) then I get an exception, because `None` has no attribute `to`.\r\n\r\n### Reproducer\r\nThe issue can be reproduced with the following file:\r\n``` Python\r\nimport astropy.units as u\r\n\r\n\r\nclass PoC(object):\r\n\r\n @u.quantity_input\r\n def __init__(self, voltage: u.V) -> None:\r\n pass\r\n\r\n\r\nif __name__ == '__main__':\r\n poc = PoC(1.*u.V)\r\n```\r\nwhich results in the following error:\r\n```\r\n$ python3 poc.py\r\nTraceback (most recent call last):\r\n File \"poc.py\", line 12, in <module>\r\n poc = PoC(1.*u.V)\r\n File \"/usr/lib64/python3.6/site-packages/astropy/utils/decorators.py\", line 868, in __init__\r\n func = make_function_with_signature(func, name=name, **wrapped_args)\r\n File \"/usr/lib64/python3.6/site-packages/astropy/units/decorators.py\", line 225, in wrapper\r\n return return_.to(wrapped_signature.return_annotation)\r\nAttributeError: 'NoneType' object has no attribute 'to'\r\n```\r\n\r\nThis has been tested on Fedora 27 with python 3.6.3, astropy 2.0.2 and numpy 1.13.3 all from Fedora's repository.\r\n\r\n### Workaround\r\nThe issue can be circumvented by not adding the return type typing hint. Unfortunately, then a static type checker cannot infer that this function returns nothing.\r\n\r\n### Possible fix\r\nMaybe the decorator could explicitly check whether None is returned and then omit the unit check.\n\n\n",
9
+ ],
10
+ [
11
+ "https://github.com/sympy/sympy",
12
+ "evalf does not call _imp_ recursively\nExample from https://stackoverflow.com/questions/41818842/why-cant-i-evaluate-a-composition-of-implemented-functions-in-sympy-at-a-point:\r\n\r\n```\r\n>>> from sympy.utilities.lambdify import implemented_function\r\n>>> f = implemented_function('f', lambda x: x ** 2)\r\n>>> g = implemented_function('g', lambda x: 2 * x)\r\n>>> print(f( 2 ).evalf())\r\n4.00000000000000\r\n>>> print( g(2) .evalf())\r\n4.00000000000000\r\n>>> print(f(g(2)).evalf())\r\nf(g(2))\r\n```\r\n\r\nThe code for this is in `Function._eval_evalf`. It isn't calling evalf recursively on the return of `_imp_`. \n\n\n",
13
+ ],
14
+ [
15
+ "https://github.com/matplotlib/matplotlib",
16
+ "[ENH]: ContourSet.set_paths\n### Problem\n\nTo get contour labelling working with its special transforms, Cartopy has a [workaround](https://github.com/SciTools/cartopy/blob/2ed668c17b4e52421f15c5be3761719c75c5311a/lib/cartopy/mpl/contour.py#L89-L108) where it replaces all the paths on the `ContourSet` with transformed versions. This currently looks like\r\n\r\n```python\r\npaths = cs.get_paths()\r\npaths[:] = transformed_paths\r\n``` \r\n\r\nwhich doesn’t smell very good.\n\n### Proposed solution\n\nThe above would smell better as \r\n\r\n```python\r\ncs.set_paths(transformed_paths)\r\n``` \n\n\n"
17
+ ]
18
+ ]
19
+
20
+
21
+ def get_examples(root_dir: str = None):
22
+ examples = []
23
+ for github_url, instance_id in EXAMPLES_LIST:
24
+ examples.append([github_url, instance_id])
25
+
26
+ return examples
{kimi_vl β†’ kimi_dev}/serve/frontend.py RENAMED
File without changes
{kimi_vl β†’ kimi_dev}/serve/gradio_utils.py RENAMED
File without changes
kimi_dev/serve/inference.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ from transformers import (
4
+ AutoModelForCausalLM,
5
+ AutoConfig,
6
+ AutoTokenizer
7
+ )
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ def load_model(model_path: str = "moonshotai/Kimi-Dev-72B"):
13
+ # hotfix the model to use flash attention 2
14
+ config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
15
+
16
+ model = AutoModelForCausalLM.from_pretrained(
17
+ model_path,
18
+ config=config,
19
+ torch_dtype="auto",
20
+ device_map="auto",
21
+ trust_remote_code=True,
22
+ )
23
+
24
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
25
+
26
+ return model, tokenizer
kimi_dev/serve/templates.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ import subprocess
5
+ import ast
6
+
7
+ def show_project_structure(structure, spacing=0) -> str:
8
+ """pprint the project structure"""
9
+
10
+ pp_string = ''
11
+
12
+ for key, value in structure.items():
13
+ if '.' in key and '.py' not in key:
14
+ continue # skip none python files
15
+
16
+ # TODO: maybe we should skip the test files...
17
+ if key.startswith('test'):
18
+ continue # skip the test files as well...
19
+
20
+ if '.' in key:
21
+ pp_string += ' ' * spacing + str(key) + '\n'
22
+ else:
23
+ pp_string += ' ' * spacing + str(key) + '/' + '\n'
24
+ if 'classes' not in value:
25
+ pp_string += show_project_structure(value, spacing + 4)
26
+
27
+ return pp_string
28
+
29
+ import os
30
+ import json
31
+ import subprocess
32
+ import ast
33
+ def clone_github_repo(github_url, local_path):
34
+ """Clone GitHub repository to local path"""
35
+ try:
36
+ subprocess.run(['git', 'clone', github_url, local_path], check=True)
37
+ print(f"Successfully cloned repository to: {local_path}")
38
+ except subprocess.CalledProcessError as e:
39
+ print(f"Warning: Repository cloning may have failed: {e}")
40
+
41
+ def parse_python_file(file_path, file_content=None):
42
+ """Parse a Python file to extract class and function definitions with their line numbers.
43
+ :param file_path: Path to the Python file.
44
+ :return: Class names, function names, and file contents
45
+ """
46
+ if file_content is None:
47
+ try:
48
+ with open(file_path, "r") as file:
49
+ file_content = file.read()
50
+ parsed_data = ast.parse(file_content)
51
+ except Exception as e: # Catch all types of exceptions
52
+ print(f"Error in file {file_path}: {e}")
53
+ return [], [], ""
54
+ else:
55
+ try:
56
+ parsed_data = ast.parse(file_content)
57
+ except Exception as e: # Catch all types of exceptions
58
+ print(f"Error in file {file_path}: {e}")
59
+ return [], [], ""
60
+ class_info = []
61
+ function_names = []
62
+ class_methods = set()
63
+ for node in ast.walk(parsed_data):
64
+ if isinstance(node, ast.ClassDef):
65
+ methods = []
66
+ for n in node.body:
67
+ if isinstance(n, ast.FunctionDef):
68
+ methods.append(
69
+ {
70
+ "name": n.name,
71
+ "start_line": n.lineno,
72
+ "end_line": n.end_lineno,
73
+ "text": file_content.splitlines()[
74
+ n.lineno - 1 : n.end_lineno
75
+ ],
76
+ }
77
+ )
78
+ class_methods.add(n.name)
79
+ class_info.append(
80
+ {
81
+ "name": node.name,
82
+ "start_line": node.lineno,
83
+ "end_line": node.end_lineno,
84
+ "text": file_content.splitlines()[
85
+ node.lineno - 1 : node.end_lineno
86
+ ],
87
+ "methods": methods,
88
+ }
89
+ )
90
+ elif isinstance(node, ast.FunctionDef) and not isinstance(
91
+ node, ast.AsyncFunctionDef
92
+ ):
93
+ if node.name not in class_methods:
94
+ function_names.append(
95
+ {
96
+ "name": node.name,
97
+ "start_line": node.lineno,
98
+ "end_line": node.end_lineno,
99
+ "text": file_content.splitlines()[
100
+ node.lineno - 1 : node.end_lineno
101
+ ],
102
+ }
103
+ )
104
+ return class_info, function_names, file_content.splitlines()
105
+
106
+ def create_structure(directory_path):
107
+ """Create the structure of the repository directory by parsing Python files.
108
+ :param directory_path: Path to the repository directory.
109
+ :return: A dictionary representing the structure.
110
+ """
111
+ structure = {}
112
+ for root, _, files in os.walk(directory_path):
113
+ repo_name = os.path.basename(directory_path)
114
+ relative_root = os.path.relpath(root, directory_path)
115
+ if relative_root == ".":
116
+ relative_root = repo_name
117
+ curr_struct = structure
118
+ for part in relative_root.split(os.sep):
119
+ if part not in curr_struct:
120
+ curr_struct[part] = {}
121
+ curr_struct = curr_struct[part]
122
+ for file_name in files:
123
+ if file_name.endswith(".py"):
124
+ file_path = os.path.join(root, file_name)
125
+ class_info, function_names, file_lines = parse_python_file(file_path)
126
+ curr_struct[file_name] = {
127
+ "classes": class_info,
128
+ "functions": function_names,
129
+ "text": file_lines,
130
+ }
131
+ else:
132
+ curr_struct[file_name] = {}
133
+ return structure
134
+
135
+ def build_repo_structure(root_path):
136
+ """Build repository structure using improved parsing method"""
137
+ return create_structure(root_path)
138
+
139
+
140
+
141
+ def get_loc_prompt(issue_text,repo_structure):
142
+ obtain_relevant_files_prompt = """
143
+ Please look through the following GitHub problem description and Repository structure and provide a list of files that one would need to edit to fix the problem.
144
+
145
+ ### GitHub Problem Description ###
146
+ {problem_statement}
147
+
148
+ ###
149
+
150
+ ### Repository Structure ###
151
+ {structure}
152
+
153
+ ###
154
+
155
+ Please only provide the full path and return at most 5 files.
156
+ The returned files should be separated by new lines ordered by most to least important and wrapped with ```
157
+ For example:
158
+ ```
159
+ file1.py
160
+ file2.py
161
+ ```
162
+ """
163
+ prompt_content = obtain_relevant_files_prompt.format(problem_statement=issue_text,structure=repo_structure)
164
+ return prompt_content
165
+
166
+ def get_repair_prompt(issue_text,file_content):
167
+ repair_prompt_combine_topn_cot_diff = """
168
+ We are currently solving the following issue within our repository. Here is the issue text:
169
+ --- BEGIN ISSUE ---
170
+ {problem_statement}
171
+ --- END ISSUE ---
172
+
173
+ Below are some code segments, each from a relevant file. One or more of these files may contain bugs.
174
+ --- BEGIN FILE ---
175
+ ```
176
+ {content}
177
+ ```
178
+ --- END FILE ---
179
+
180
+ Please first localize the bug based on the issue statement, and then generate *SEARCH/REPLACE* edits to fix the issue.
181
+
182
+ Every *SEARCH/REPLACE* edit must use this format:
183
+ 1. The file path
184
+ 2. The start of search block: <<<<<<< SEARCH
185
+ 3. A contiguous chunk of lines to search for in the existing source code
186
+ 4. The dividing line: =======
187
+ 5. The lines to replace into the source code
188
+ 6. The end of the replace block: >>>>>>> REPLACE
189
+
190
+ Here is an example:
191
+
192
+ ```python
193
+ ### mathweb/flask/app.py
194
+ <<<<<<< SEARCH
195
+ from flask import Flask
196
+ =======
197
+ import math
198
+ from flask import Flask
199
+ >>>>>>> REPLACE
200
+ ```
201
+
202
+ Please note that the *SEARCH/REPLACE* edit REQUIRES PROPER INDENTATION. If you would like to add the line ' print(x)', you must fully write that out, with all those spaces before the code!
203
+ Wrap the *SEARCH/REPLACE* edit in blocks ```python...```.
204
+ """
205
+ prompt_content = repair_prompt_combine_topn_cot_diff.format(problem_statement=issue_text,content=file_content.rstrip())
206
+ return prompt_content
207
+
208
+ def get_repo_files(structure, filepaths: list[str]):
209
+ files, classes, functions = get_full_file_paths_and_classes_and_functions(structure)
210
+ file_contents = dict()
211
+ for filepath in filepaths:
212
+ content = None
213
+
214
+ for file_content in files:
215
+ if file_content[0] == filepath:
216
+ content = '\n'.join(file_content[1])
217
+ file_contents[filepath] = content
218
+ break
219
+
220
+ # assert content is not None, "file not found"
221
+ return file_contents
222
+
223
+ def correct_file_path_in_structure(file_name, structure):
224
+ """
225
+ Search for the correct file path in the structure, mainly checking first-level subdirectories
226
+
227
+ Args:
228
+ file_name (str): File name to search for
229
+ structure (dict): Repository structure
230
+
231
+ Returns:
232
+ str: Correct file path if found, otherwise returns original file_name
233
+ """
234
+ # Search in current directory
235
+ file_contents = get_repo_files(structure, [file_name])
236
+ if file_contents != {}:
237
+ return file_name
238
+
239
+ # Only check first-level subdirectories
240
+ for sub_dir in structure.keys():
241
+ if isinstance(structure[sub_dir], dict):
242
+ file_contents = get_repo_files(structure[sub_dir], [file_name])
243
+ if file_contents != {}:
244
+ return f'{sub_dir}/{file_name}'
245
+
246
+ return file_name
247
+
248
+ def get_full_file_paths_and_classes_and_functions(structure, current_path=''):
249
+ """
250
+ Recursively retrieve all file paths, classes, and functions within a directory structure.
251
+
252
+ Arguments:
253
+ structure -- a dictionary representing the directory structure
254
+ current_path -- the path accumulated so far, used during recursion (default="")
255
+
256
+ Returns:
257
+ A tuple containing:
258
+ - files: list of full file paths
259
+ - classes: list of class details with file paths
260
+ - functions: list of function details with file paths
261
+ """
262
+ files = []
263
+ classes = []
264
+ functions = []
265
+ for name, content in structure.items():
266
+ if isinstance(content, dict):
267
+ if (
268
+ (
269
+ 'functions' not in content.keys()
270
+ and 'classes' not in content.keys()
271
+ and 'text' not in content.keys()
272
+ )
273
+ or not len(content.keys()) == 3
274
+ or (
275
+ isinstance(content.get('text', []), dict)
276
+ or isinstance(content.get('functions', []), dict)
277
+ or isinstance(content.get('classes', []), dict)
278
+ )
279
+ ):
280
+ # or guards against case where functions and classes are somehow part of the structure.
281
+ next_path = f'{current_path}/{name}' if current_path else name
282
+ (
283
+ sub_files,
284
+ sub_classes,
285
+ sub_functions,
286
+ ) = get_full_file_paths_and_classes_and_functions(content, next_path)
287
+ files.extend(sub_files)
288
+ classes.extend(sub_classes)
289
+ functions.extend(sub_functions)
290
+ else:
291
+ next_path = f'{current_path}/{name}' if current_path else name
292
+ files.append((next_path, content.get('text', [])))
293
+ if content.get('text', []) == []:
294
+ continue
295
+ if 'classes' in content:
296
+ for clazz in content['classes']:
297
+ classes.append(
298
+ {
299
+ 'file': next_path,
300
+ 'name': clazz['name'],
301
+ 'start_line': clazz['start_line'],
302
+ 'end_line': clazz['end_line'],
303
+ 'methods': [
304
+ {
305
+ 'name': method['name'],
306
+ 'start_line': method['start_line'],
307
+ 'end_line': method['end_line'],
308
+ }
309
+ for method in clazz.get('methods', [])
310
+ ],
311
+ },
312
+ )
313
+ if 'functions' in content:
314
+ for function in content['functions']:
315
+ try:
316
+ function['file'] = next_path
317
+ except TypeError:
318
+ continue
319
+ functions.append(function)
320
+ else:
321
+ next_path = f'{current_path}/{name}' if current_path else name
322
+ files.append(next_path)
323
+ return files, classes, functions
324
+
325
+ def post_process(response: str) -> str:
326
+ content = response
327
+ if "◁/thinkβ–·" in content:
328
+ content = content.replace("◁thinkβ–·", "")
329
+ parts = content.split("◁/thinkβ–·")
330
+ content = parts[-1]
331
+ # Extract content between triple backticks (```)
332
+ matches = re.findall(r"```.*?```", content, re.DOTALL)
333
+
334
+ if matches:
335
+ matches = [item.replace("```","") for item in matches]
336
+ return "\n".join(matches) # Return all matched code blocks joined by new lines
337
+ return content # If no match, return the full response
{kimi_vl β†’ kimi_dev}/serve/utils.py RENAMED
File without changes
kimi_vl/serve/chat_utils.py DELETED
@@ -1,379 +0,0 @@
1
- """
2
- From https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
3
- """
4
-
5
- import dataclasses
6
- import logging
7
- import copy
8
- from enum import IntEnum, auto
9
- from typing import Dict, List
10
- import base64
11
-
12
- import gradio as gr
13
- import torch
14
-
15
- from .utils import pil_to_base64
16
-
17
- IMAGE_TOKEN = "<image>"
18
- logger = logging.getLogger("gradio_logger")
19
-
20
-
21
- class SeparatorStyle(IntEnum):
22
- """Separator styles."""
23
-
24
- PLAIN = auto()
25
- ALIGNMENT = auto()
26
- KIMI_VL = auto()
27
-
28
-
29
- @dataclasses.dataclass
30
- class Conversation:
31
- """A class that manages prompt templates and keeps all conversation history."""
32
-
33
- # The name of this template
34
- name: str
35
- # The template of the system prompt
36
- system_template: str = "{system_message}"
37
- # The system message
38
- system_message: str = ""
39
- # The names of two roles
40
- roles: List[str] = (("USER", "ASSISTANT"),)
41
- # All messages. Each item is (role, message).
42
- messages: List[List[str]] = ()
43
- # The number of few shot examples
44
- offset: int = 0
45
- # The separator style and configurations
46
- sep_style: SeparatorStyle = SeparatorStyle.PLAIN
47
- sep: str = "\n"
48
- sep2: str = None
49
- # Stop criteria (the default one is EOS token)
50
- stop_str: str = None
51
- # Stops generation if meeting any token in this list
52
- stop_token_ids: List[int] = None
53
-
54
- def get_prompt(self) -> str:
55
- """Get the prompt for generation."""
56
- system_prompt = self.system_template.format(system_message=self.system_message)
57
- if self.sep_style == SeparatorStyle.PLAIN:
58
- seps = [self.sep, self.sep2]
59
- ret = ""
60
- for i, (role, message) in enumerate(self.messages):
61
- if message:
62
- if type(message) is tuple:
63
- message = message[0]
64
- if i % 2 == 0:
65
- ret += message + seps[i % 2]
66
- else:
67
- ret += message + seps[i % 2]
68
- else:
69
- ret += ""
70
- return ret
71
- elif self.sep_style == SeparatorStyle.ALIGNMENT:
72
- seps = [self.sep, self.sep2]
73
- ret = ""
74
- for i, (role, message) in enumerate(self.messages):
75
- if message:
76
- if type(message) is tuple:
77
- message, _, _ = message
78
- if i % 2 == 0:
79
- ret += '<image>\n' + seps[i % 2]
80
- else:
81
- ret += message + seps[i % 2]
82
- else:
83
- ret += ""
84
- return ret
85
- elif self.sep_style == SeparatorStyle.KIMI_VL:
86
- seps = [self.sep, self.sep2]
87
- if system_prompt == "" or system_prompt is None:
88
- ret = ""
89
- else:
90
- ret = system_prompt + seps[0]
91
- for i, (role, message) in enumerate(self.messages):
92
- if message:
93
- if type(message) is tuple:
94
- message = message[0]
95
-
96
- if role == "user":
97
- ret += message + self.sep
98
- else:
99
- if self.sep2 is not None:
100
- ret += message + self.sep2
101
- else:
102
- ret += message
103
- else:
104
- ret = ret
105
- return ret
106
- else:
107
- raise ValueError(f"Invalid style: {self.sep_style}")
108
-
109
- def set_system_message(self, system_message: str):
110
- """Set the system message."""
111
- self.system_message = system_message
112
-
113
- def append_message(self, role: str, message: str):
114
- """Append a new message."""
115
- self.messages.append([role, message])
116
-
117
- def update_last_message(self, message: str):
118
- """Update the last output.
119
-
120
- The last message is typically set to be None when constructing the prompt,
121
- so we need to update it in-place after getting the response from a model.
122
- """
123
- self.messages[-1][1] = message
124
-
125
- def reset_message(self):
126
- """Reset a new message."""
127
- self.messages = []
128
-
129
- def to_gradio_chatbot(self):
130
- """Convert the conversation to gradio chatbot format."""
131
- ret = []
132
- for i, (role, msg) in enumerate(self.messages[self.offset :]):
133
- if i % 2 == 0:
134
- ret.append([msg, None])
135
- else:
136
- ret[-1][-1] = msg
137
- return ret
138
-
139
- def to_openai_api_messages(self):
140
- """Convert the conversation to OpenAI chat completion format."""
141
- system_prompt = self.system_template.format(system_message=self.system_message)
142
- ret = [{"role": "system", "content": system_prompt}]
143
-
144
- for i, (_, msg) in enumerate(self.messages[self.offset :]):
145
- if i % 2 == 0:
146
- ret.append({"role": "user", "content": msg})
147
- else:
148
- if msg is not None:
149
- ret.append({"role": "assistant", "content": msg})
150
- return ret
151
-
152
- def copy(self):
153
- return Conversation(
154
- name=self.name,
155
- system_template=self.system_template,
156
- system_message=self.system_message,
157
- roles=self.roles,
158
- messages=[[x, y] for x, y in self.messages],
159
- offset=self.offset,
160
- sep_style=self.sep_style,
161
- sep=self.sep,
162
- sep2=self.sep2,
163
- stop_str=self.stop_str,
164
- stop_token_ids=self.stop_token_ids,
165
- )
166
-
167
- def dict(self):
168
- return {
169
- "template_name": self.name,
170
- "system_message": self.system_message,
171
- "roles": self.roles,
172
- "messages": self.messages,
173
- "offset": self.offset,
174
- }
175
-
176
-
177
- # A global registry for all conversation templates
178
- conv_templates: Dict[str, Conversation] = {}
179
-
180
-
181
- def register_conv_template(template: Conversation, override: bool = False):
182
- """Register a new conversation template."""
183
- if not override:
184
- assert template.name not in conv_templates, f"{template.name} has been registered."
185
-
186
- conv_templates[template.name] = template
187
-
188
-
189
- def get_conv_template(name: str) -> Conversation:
190
- """Get a conversation template."""
191
- return conv_templates[name].copy()
192
-
193
-
194
- register_conv_template(
195
- Conversation(
196
- name="plain",
197
- system_template="",
198
- system_message="",
199
- roles=("", ""),
200
- messages=(),
201
- offset=0,
202
- sep_style=SeparatorStyle.PLAIN,
203
- sep="",
204
- sep2="",
205
- stop_token_ids=[100001],
206
- stop_str=['</s>'],
207
- )
208
- )
209
-
210
-
211
- register_conv_template(
212
- Conversation(
213
- name="alignment",
214
- system_template="",
215
- system_message="",
216
- roles=("", ""),
217
- messages=(),
218
- offset=0,
219
- sep_style=SeparatorStyle.ALIGNMENT,
220
- sep="",
221
- sep2="",
222
- stop_token_ids=[100001],
223
- stop_str=['</s>'],
224
- )
225
- )
226
-
227
- register_conv_template(
228
- Conversation(
229
- name="kimi-vl",
230
- system_template="{system_message}",
231
- system_message="You are a helpful assistant",
232
- roles=("user", "assistant"),
233
- messages=(),
234
- offset=0,
235
- sep_style=SeparatorStyle.KIMI_VL,
236
- sep="<|im_end|>",
237
- sep2=None,
238
- stop_token_ids=None,
239
- stop_str=["<|im_end|>"],
240
- )
241
- )
242
-
243
-
244
- def new_chat_template(sft_format: str = "kimi-vl"):
245
- return get_conv_template(sft_format)
246
-
247
-
248
- def get_prompt(conv: Conversation) -> str:
249
- """Get the prompt for generation."""
250
- return conv.get_prompt()
251
-
252
-
253
- def generate_prompt_with_history(text, images, history, processor, max_length=2048):
254
- """
255
- Generate a prompt with the chat history.
256
-
257
- Args:
258
- text (str): The text prompt.
259
- images (list[PIL.Image.Image]): The image prompt.
260
- history (list): List of previous conversation messages.
261
- processor (KimiVLProcessor): The chat processor used for encoding the prompt.
262
- max_length (int): The maximum length of the prompt.
263
- """
264
- global IMAGE_TOKEN
265
-
266
- user_role_ind = 0
267
- bot_role_ind = 1
268
-
269
- # Initialize conversation
270
- conversation = new_chat_template(sft_format="plain")
271
-
272
- if history:
273
- conversation.messages = history
274
-
275
- if images is not None and len(images) > 0:
276
- # num_image_tags = text.count(IMAGE_TOKEN)
277
- # num_images = len(images)
278
- # if num_images > num_image_tags:
279
- # pad_image_tags = num_images - num_image_tags
280
- # image_tokens = "\n".join([IMAGE_TOKEN] * pad_image_tags)
281
-
282
- # # append the <image> in a new line after the text prompt
283
- # text = image_tokens + "\n" + text
284
- # elif num_images < num_image_tags:
285
- # remove_image_tags = num_image_tags - num_images
286
- # text = text.replace(IMAGE_TOKEN, "", remove_image_tags)
287
-
288
- print(f"prompt = {text}, len(images) = {len(images)}")
289
- text = (text, images)
290
-
291
- conversation.append_message(conversation.roles[user_role_ind], text)
292
- conversation.append_message(conversation.roles[bot_role_ind], "")
293
-
294
- # Create a copy of the conversation to avoid history truncation in the UI
295
- conversation_copy = conversation.copy()
296
- logger.info("=" * 80)
297
- logger.info(get_prompt(conversation))
298
-
299
- rounds = len(conversation.messages) // 2
300
-
301
- for _ in range(rounds):
302
- current_prompt = get_prompt(conversation)
303
- assert isinstance(current_prompt, str) and len(current_prompt) > 0, f"current_prompt = {current_prompt}"
304
- if torch.tensor(processor.tokenizer.encode(current_prompt)).size(-1) <= max_length:
305
- return conversation_copy
306
-
307
- if len(conversation.messages) % 2 != 0:
308
- gr.Error("The messages between user and assistant are not paired.")
309
- return
310
-
311
- try:
312
- for _ in range(2): # pop out two messages in a row
313
- conversation.messages.pop(0)
314
- except IndexError:
315
- gr.Error("Input text processing failed, unable to respond in this round.")
316
- return None
317
-
318
- gr.Error("Prompt could not be generated within max_length limit.")
319
- return None
320
-
321
-
322
- def convert_conversation_to_prompts(conversation: Conversation):
323
- """
324
- Convert the conversation to prompts.
325
- """
326
- conv_prompts = []
327
- last_image = None
328
-
329
- messages = conversation.messages
330
- for i in range(0, len(messages), 2):
331
- if isinstance(messages[i][1], tuple):
332
- text, images = messages[i][1]
333
- last_image = images[-1]
334
- else:
335
- text, images = messages[i][1], []
336
-
337
- prompt = {"role": messages[i][0], "content": text, "images": images}
338
- response = {"role": messages[i + 1][0], "content": messages[i + 1][1]}
339
- conv_prompts.extend([prompt, response])
340
-
341
- return conv_prompts, last_image
342
-
343
-
344
- def to_gradio_chatbot(conversation: Conversation) -> list:
345
- """Convert the conversation to gradio chatbot format."""
346
- ret = []
347
- for i, (_, msg) in enumerate(conversation.messages[conversation.offset :]):
348
- if i % 2 == 0:
349
- if type(msg) is tuple:
350
- msg, images = copy.deepcopy(msg)
351
-
352
- if isinstance(images, list):
353
- img_str = ""
354
- for j, image in enumerate(images):
355
- if isinstance(image, str):
356
- with open(image, "rb") as f:
357
- data = f.read()
358
- img_b64_str = base64.b64encode(data).decode()
359
- image_str = (
360
- f'<img src="data:image/png;base64,{img_b64_str}" '
361
- f'alt="user upload image" style="max-width: 300px; height: auto;" />'
362
- )
363
- else:
364
- image_str = pil_to_base64(image, f"user upload image_{j}", max_size=800, min_size=400)
365
-
366
- img_str += image_str
367
- msg = img_str + msg
368
- else:
369
- pass
370
-
371
- ret.append([msg, None])
372
- else:
373
- ret[-1][-1] = msg
374
- return ret
375
-
376
-
377
- def to_gradio_history(conversation: Conversation):
378
- """Convert the conversation to gradio history format."""
379
- return conversation.messages[conversation.offset :]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
kimi_vl/serve/examples.py DELETED
@@ -1,54 +0,0 @@
1
- import os
2
- import io
3
- import base64
4
- from PIL import Image
5
-
6
- EXAMPLES_LIST = [
7
- [
8
- ["images/demo1.jpeg"],
9
- "Where am I?",
10
- ],
11
- [
12
- ["images/demo2.jpeg", "images/demo3.jpeg"],
13
- "Based on the abstract and introduction above, write a concise and elegant Twitter post that highlights key points and figures without sounding overly promotional. Use English, include emojis and hashtags.",
14
- ],
15
- [
16
- ["images/demo6.jpeg"],
17
- "Create a role play modeled after this cat."
18
- ],
19
- # mulit-frames example
20
- [
21
- ["images/demo4.jpeg", "images/demo5.jpeg"],
22
- "Please infer step by step who this manuscript belongs to and what it records."
23
- ]
24
- ]
25
-
26
-
27
- def display_example(image_list, root_dir: str = None):
28
- images_html = ""
29
- for _, img_path in enumerate(image_list):
30
- if root_dir is not None:
31
- img_path = os.path.join(root_dir, img_path)
32
-
33
- image = Image.open(img_path)
34
- buffered = io.BytesIO()
35
- image.save(buffered, format="PNG", quality=100)
36
- img_b64_str = base64.b64encode(buffered.getvalue()).decode()
37
- img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="{img_path}" style="height:80px; margin-right: 10px;" />'
38
- images_html += img_str
39
-
40
- result_html = f"""
41
- <div style="display: flex; align-items: center; margin-bottom: 10px;">
42
- <div style="flex: 1; margin-right: 10px;">{images_html}</div>
43
- </div>
44
- """
45
-
46
- return result_html
47
-
48
-
49
- def get_examples(root_dir: str = None):
50
- examples = []
51
- for images, texts in EXAMPLES_LIST:
52
- examples.append([images, display_example(images, root_dir), texts])
53
-
54
- return examples
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
kimi_vl/serve/inference.py DELETED
@@ -1,145 +0,0 @@
1
- import logging
2
- import re
3
- from threading import Thread
4
- from typing import List, Optional
5
-
6
- import torch
7
- import spaces
8
- from transformers import (
9
- AutoModelForCausalLM,
10
- AutoProcessor,
11
- AutoConfig,
12
- StoppingCriteria,
13
- StoppingCriteriaList,
14
- TextIteratorStreamer,
15
- AutoTokenizer
16
- )
17
-
18
- from .chat_utils import Conversation, get_conv_template
19
-
20
- logger = logging.getLogger(__name__)
21
-
22
-
23
- def load_model(model_path: str = "moonshotai/Kimi-Dev-72B"):
24
- # hotfix the model to use flash attention 2
25
- config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
26
- # config._attn_implementation = "flash_attention_2"
27
- # config.vision_config._attn_implementation = "flash_attention_2"
28
- # config.text_config._attn_implementation = "flash_attention_2"
29
- # print("Successfully set the attn_implementation to flash_attention_2")
30
-
31
- model = AutoModelForCausalLM.from_pretrained(
32
- model_path,
33
- config=config,
34
- torch_dtype="auto",
35
- device_map="auto",
36
- trust_remote_code=True,
37
- )
38
- # processor = AutoProcessor.from_pretrained(model_path, config=config, trust_remote_code=True)
39
- tokenizer = AutoTokenizer.from_pretrained(model_path)
40
-
41
- return model, tokenizer
42
-
43
-
44
- class StoppingCriteriaSub(StoppingCriteria):
45
- def __init__(self, stops=[], encounters=1):
46
- super().__init__()
47
- self.stops = [stop.to("cuda") for stop in stops]
48
-
49
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs):
50
- for stop in self.stops:
51
- if input_ids.shape[-1] < len(stop):
52
- continue
53
- if torch.all((stop == input_ids[0][-len(stop) :])).item():
54
- return True
55
-
56
- return False
57
-
58
-
59
- def format_messages(
60
- conversations: list[Conversation],
61
- system_prompt: Optional[str] = "",
62
- sft_format: Optional[str] = "kimi-vl",
63
- ):
64
- """
65
- Format the conversations to the input format of the model.
66
- """
67
- converstion = get_conv_template(sft_format)
68
- converstion.set_system_message(system_prompt)
69
- for message in conversations:
70
- converstion.append_message(message["role"], message["content"])
71
- return converstion
72
-
73
-
74
-
75
-
76
- @torch.no_grad()
77
- @torch.inference_mode()
78
- def kimi_dev_generate(
79
- model: torch.nn.Module,
80
- tokenizer,
81
- # processor: AutoProcessor,
82
- conversations: list[Conversation],
83
- stop_words: list,
84
- max_length: int = 256,
85
- temperature: float = 1.0,
86
- top_p: float = 1.0,
87
- chunk_size: int = -1,
88
- ):
89
- # convert conversation to inputs
90
- print(f"conversations = {conversations}")
91
- # inputs = preprocess(conversations)
92
- inputs = tokenizer.tokenize(conversations)
93
- inputs = inputs.to(model.device)
94
-
95
- return generate(
96
- model,
97
- tokenizer,
98
- inputs,
99
- max_gen_len=max_length,
100
- temperature=temperature,
101
- top_p=top_p,
102
- stop_words=stop_words,
103
- chunk_size=chunk_size,
104
- )
105
-
106
-
107
- def generate(
108
- model,
109
- tokenizer,
110
- inputs,
111
- max_gen_len: int = 256,
112
- temperature: float = 0,
113
- top_p: float = 0.95,
114
- stop_words: List[str] = [],
115
- chunk_size: int = -1,
116
- ):
117
- """Stream the text output from the multimodality model with prompt and image inputs."""
118
- stop_words_ids = [torch.tensor(tokenizer.encode(stop_word)) for stop_word in stop_words]
119
- stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
120
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
121
-
122
- kwargs = dict(
123
- **inputs,
124
- max_new_tokens=max_gen_len,
125
- do_sample=True,
126
- use_cache=True,
127
- streamer=streamer,
128
- stopping_criteria=stopping_criteria,
129
- )
130
-
131
- if temperature > 0:
132
- kwargs.update(
133
- {
134
- "do_sample": True,
135
- "top_p": top_p,
136
- "temperature": temperature,
137
- }
138
- )
139
- else:
140
- kwargs["do_sample"] = False
141
-
142
- thread = Thread(target=model.generate, kwargs=kwargs)
143
- thread.start()
144
-
145
- yield from streamer