Spaces:
Running
on
L40S
Running
on
L40S
miaoyibo
commited on
Commit
Β·
46a0b0f
1
Parent(s):
8cf3ee6
kimi_dev
Browse files- .gitignore +3 -0
- app.py +156 -149
- {kimi_vl β kimi_dev}/__init__.py +0 -0
- {kimi_vl β kimi_dev}/serve/__init__.py +0 -0
- {kimi_vl β kimi_dev}/serve/assets/Kelpy-Codos.js +0 -0
- {kimi_vl β kimi_dev}/serve/assets/avatar.png +0 -0
- {kimi_vl β kimi_dev}/serve/assets/custom.css +0 -0
- {kimi_vl β kimi_dev}/serve/assets/custom.js +0 -0
- {kimi_vl β kimi_dev}/serve/assets/favicon.ico +0 -0
- kimi_dev/serve/examples.py +26 -0
- {kimi_vl β kimi_dev}/serve/frontend.py +0 -0
- {kimi_vl β kimi_dev}/serve/gradio_utils.py +0 -0
- kimi_dev/serve/inference.py +26 -0
- kimi_dev/serve/templates.py +337 -0
- {kimi_vl β kimi_dev}/serve/utils.py +0 -0
- kimi_vl/serve/chat_utils.py +0 -379
- kimi_vl/serve/examples.py +0 -54
- kimi_vl/serve/inference.py +0 -145
.gitignore
CHANGED
@@ -3,3 +3,6 @@
|
|
3 |
__pycache__
|
4 |
*.pyc
|
5 |
*.pyo
|
|
|
|
|
|
|
|
3 |
__pycache__
|
4 |
*.pyc
|
5 |
*.pyo
|
6 |
+
|
7 |
+
.gradio
|
8 |
+
local_path/
|
app.py
CHANGED
@@ -1,44 +1,38 @@
|
|
1 |
import argparse
|
2 |
import gradio as gr
|
3 |
import os
|
4 |
-
from PIL import Image
|
5 |
import spaces
|
6 |
import copy
|
7 |
import time
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
-
|
10 |
-
|
|
|
|
|
11 |
configure_logger,
|
12 |
-
pil_to_base64,
|
13 |
-
parse_ref_bbox,
|
14 |
-
strip_stop_words,
|
15 |
-
is_variable_assigned,
|
16 |
)
|
17 |
-
from
|
18 |
-
cancel_outputing,
|
19 |
-
delete_last_conversation,
|
20 |
reset_state,
|
21 |
reset_textbox,
|
22 |
transfer_input,
|
23 |
wrap_gen_fn,
|
24 |
)
|
25 |
-
from
|
26 |
-
|
27 |
-
|
28 |
-
to_gradio_chatbot,
|
29 |
-
to_gradio_history,
|
30 |
-
)
|
31 |
-
from kimi_vl.serve.inference import kimi_dev_generate, load_model
|
32 |
-
from kimi_vl.serve.examples import get_examples
|
33 |
|
34 |
-
TITLE = """<h1 align="left" style="min-width:200px; margin-top:0;">Chat with Kimi-Dev-72B
|
35 |
-
DESCRIPTION_TOP = """<a href="https://github.com/MoonshotAI/Kimi-VL" target="_blank">Kimi-Dev-72B</a> is a
|
36 |
-
|
37 |
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
38 |
DEPLOY_MODELS = dict()
|
39 |
logger = configure_logger()
|
40 |
|
41 |
-
|
42 |
def parse_args():
|
43 |
parser = argparse.ArgumentParser()
|
44 |
parser.add_argument("--model", type=str, default="Kimi-Dev-72B")
|
@@ -73,16 +67,6 @@ def fetch_model(model_name: str):
|
|
73 |
return model_info
|
74 |
|
75 |
|
76 |
-
def preview_images(files) -> list[str]:
|
77 |
-
if files is None:
|
78 |
-
return []
|
79 |
-
|
80 |
-
image_paths = []
|
81 |
-
for file in files:
|
82 |
-
image_paths.append(file.name)
|
83 |
-
return image_paths
|
84 |
-
|
85 |
-
|
86 |
def get_prompt(conversation) -> str:
|
87 |
"""
|
88 |
Get the prompt for the conversation.
|
@@ -103,30 +87,29 @@ def highlight_thinking(msg: str) -> str:
|
|
103 |
@spaces.GPU(duration=180)
|
104 |
def predict(
|
105 |
text,
|
106 |
-
|
107 |
chatbot,
|
108 |
history,
|
109 |
top_p,
|
110 |
temperature,
|
111 |
max_length_tokens,
|
112 |
-
max_context_length_tokens,
|
113 |
chunk_size: int = 512,
|
114 |
):
|
115 |
"""
|
116 |
-
Predict the response for the input text and
|
117 |
Args:
|
118 |
text (str): The input text.
|
119 |
-
|
120 |
chatbot (list): The chatbot.
|
121 |
history (list): The history.
|
122 |
top_p (float): The top-p value.
|
123 |
temperature (float): The temperature value.
|
124 |
repetition_penalty (float): The repetition penalty value.
|
125 |
max_length_tokens (int): The max length tokens.
|
126 |
-
max_context_length_tokens (int): The max context length tokens.
|
127 |
chunk_size (int): The chunk size.
|
128 |
"""
|
129 |
print("running the prediction function")
|
|
|
130 |
try:
|
131 |
model, tokenizer = fetch_model(args.model)
|
132 |
|
@@ -137,131 +120,161 @@ def predict(
|
|
137 |
yield [[text, "No Model Found"]], [], "No Model Found"
|
138 |
return
|
139 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
-
prompt = "Give me a short introduction to large language model."
|
142 |
messages = [
|
143 |
{"role": "system", "content": "You are a helpful assistant."},
|
144 |
-
{"role": "user", "content":
|
145 |
]
|
146 |
-
|
147 |
messages,
|
148 |
tokenize=False,
|
149 |
add_generation_prompt=True
|
150 |
)
|
151 |
-
model_inputs = tokenizer([
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
|
153 |
-
generated_ids = model.generate(
|
154 |
-
**model_inputs,
|
155 |
-
max_new_tokens=512
|
156 |
-
)
|
157 |
-
generated_ids = [
|
158 |
-
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
159 |
-
]
|
160 |
|
161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
|
|
|
|
|
163 |
print(response)
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
|
|
|
|
189 |
)
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
model=model,
|
199 |
-
tokneizer=tokenizer,
|
200 |
-
# processor=processor,
|
201 |
-
stop_words=stop_words,
|
202 |
-
max_length=max_length_tokens,
|
203 |
temperature=temperature,
|
204 |
top_p=top_p,
|
205 |
-
|
206 |
-
|
207 |
-
response = strip_stop_words(full_response, stop_words)
|
208 |
-
conversation.update_last_message(response)
|
209 |
-
gradio_chatbot_output[-1][1] = highlight_thinking(response)
|
210 |
-
|
211 |
-
yield gradio_chatbot_output, to_gradio_history(conversation), "Generating..."
|
212 |
-
|
213 |
-
if last_image is not None:
|
214 |
-
vg_image = parse_ref_bbox(response, last_image)
|
215 |
-
if vg_image is not None:
|
216 |
-
vg_base64 = pil_to_base64(vg_image, "vg", max_size=800, min_size=400)
|
217 |
-
gradio_chatbot_output[-1][1] += vg_base64
|
218 |
-
yield gradio_chatbot_output, to_gradio_history(conversation), "Generating..."
|
219 |
-
|
220 |
-
logger.info("flushed result to gradio")
|
221 |
-
|
222 |
-
if is_variable_assigned("x"):
|
223 |
-
print(
|
224 |
-
f"temperature: {temperature}, "
|
225 |
-
f"top_p: {top_p}, "
|
226 |
-
f"max_length_tokens: {max_length_tokens}"
|
227 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
|
229 |
-
|
|
|
|
|
|
|
230 |
|
231 |
|
232 |
def retry(
|
233 |
text,
|
234 |
-
|
235 |
chatbot,
|
236 |
history,
|
237 |
top_p,
|
238 |
temperature,
|
239 |
max_length_tokens,
|
240 |
-
max_context_length_tokens,
|
241 |
chunk_size: int = 512,
|
242 |
):
|
243 |
"""
|
244 |
-
Retry the response for the input text and
|
245 |
"""
|
246 |
if len(history) == 0:
|
247 |
yield (chatbot, history, "Empty context")
|
248 |
return
|
249 |
|
250 |
-
chatbot.pop()
|
251 |
-
history.pop()
|
252 |
-
text = history.pop()[-1]
|
253 |
if type(text) is tuple:
|
254 |
text, _ = text
|
255 |
|
256 |
yield from predict(
|
257 |
text,
|
258 |
-
|
259 |
chatbot,
|
260 |
history,
|
261 |
top_p,
|
262 |
temperature,
|
263 |
max_length_tokens,
|
264 |
-
max_context_length_tokens,
|
265 |
chunk_size,
|
266 |
)
|
267 |
|
@@ -270,12 +283,13 @@ def build_demo(args: argparse.Namespace) -> gr.Blocks:
|
|
270 |
with gr.Blocks(theme=gr.themes.Soft(), delete_cache=(1800, 1800)) as demo:
|
271 |
history = gr.State([])
|
272 |
input_text = gr.State()
|
273 |
-
|
274 |
|
275 |
with gr.Row():
|
276 |
gr.HTML(TITLE)
|
277 |
status_display = gr.Markdown("Success", elem_id="status_display")
|
278 |
gr.Markdown(DESCRIPTION_TOP)
|
|
|
279 |
|
280 |
with gr.Row(equal_height=True):
|
281 |
with gr.Column(scale=4):
|
@@ -284,63 +298,59 @@ def build_demo(args: argparse.Namespace) -> gr.Blocks:
|
|
284 |
elem_id="Kimi-Dev-72B",
|
285 |
show_share_button=True,
|
286 |
bubble_full_width=False,
|
287 |
-
height=
|
|
|
288 |
)
|
289 |
with gr.Row():
|
290 |
with gr.Column(scale=4):
|
291 |
-
text_box = gr.Textbox(
|
292 |
with gr.Column(min_width=70):
|
293 |
submit_btn = gr.Button("Send")
|
294 |
-
with gr.Column(min_width=70):
|
295 |
-
cancel_btn = gr.Button("Stop")
|
296 |
with gr.Row():
|
297 |
empty_btn = gr.Button("π§Ή New Conversation")
|
298 |
retry_btn = gr.Button("π Regenerate")
|
299 |
-
del_last_btn = gr.Button("ποΈ Remove Last Turn")
|
300 |
-
|
|
|
301 |
with gr.Column():
|
302 |
-
|
303 |
-
gr.
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
# Parameter Setting Tab for control the generation parameters
|
308 |
with gr.Tab(label="Parameter Setting"):
|
309 |
-
top_p = gr.Slider(minimum=-0, maximum=1.0, value=
|
310 |
temperature = gr.Slider(
|
311 |
-
minimum=0, maximum=1.0, value=0
|
312 |
)
|
313 |
max_length_tokens = gr.Slider(
|
314 |
-
minimum=512, maximum=
|
315 |
-
)
|
316 |
-
max_context_length_tokens = gr.Slider(
|
317 |
-
minimum=512, maximum=8192, value=2048, step=64, interactive=True, label="Max Context Length Tokens"
|
318 |
)
|
319 |
|
320 |
-
show_images = gr.HTML(visible=False)
|
321 |
-
|
322 |
gr.Examples(
|
323 |
examples=get_examples(ROOT_DIR),
|
324 |
-
inputs=[
|
325 |
)
|
326 |
-
gr.Markdown()
|
327 |
|
328 |
input_widgets = [
|
329 |
input_text,
|
330 |
-
|
331 |
chatbot,
|
332 |
history,
|
333 |
top_p,
|
334 |
temperature,
|
335 |
max_length_tokens,
|
336 |
-
max_context_length_tokens,
|
337 |
]
|
338 |
output_widgets = [chatbot, history, status_display]
|
339 |
|
340 |
transfer_input_args = dict(
|
341 |
fn=transfer_input,
|
342 |
-
inputs=[text_box,
|
343 |
-
outputs=[input_text,
|
344 |
show_progress=True,
|
345 |
)
|
346 |
|
@@ -356,8 +366,6 @@ def build_demo(args: argparse.Namespace) -> gr.Blocks:
|
|
356 |
empty_btn.click(reset_state, outputs=output_widgets, show_progress=True)
|
357 |
empty_btn.click(**reset_args)
|
358 |
retry_btn.click(**retry_args)
|
359 |
-
del_last_btn.click(delete_last_conversation, [chatbot, history], output_widgets, show_progress=True)
|
360 |
-
cancel_btn.click(cancel_outputing, [], [status_display], cancels=predict_events)
|
361 |
|
362 |
demo.title = "Kimi-Dev-72B"
|
363 |
return demo
|
@@ -367,8 +375,7 @@ def main(args: argparse.Namespace):
|
|
367 |
demo = build_demo(args)
|
368 |
reload_javascript()
|
369 |
|
370 |
-
|
371 |
-
favicon_path = os.path.join("kimi_vl/serve/assets/favicon.ico")
|
372 |
# demo.queue().launch(
|
373 |
# favicon_path=favicon_path,
|
374 |
# server_name=args.ip,
|
@@ -378,7 +385,7 @@ def main(args: argparse.Namespace):
|
|
378 |
favicon_path=favicon_path,
|
379 |
server_name=args.ip,
|
380 |
server_port=args.port,
|
381 |
-
share=True
|
382 |
)
|
383 |
|
384 |
if __name__ == "__main__":
|
|
|
1 |
import argparse
|
2 |
import gradio as gr
|
3 |
import os
|
|
|
4 |
import spaces
|
5 |
import copy
|
6 |
import time
|
7 |
+
import json
|
8 |
+
import subprocess
|
9 |
+
import ast
|
10 |
+
import pdb
|
11 |
+
from transformers import TextIteratorStreamer
|
12 |
|
13 |
+
import threading
|
14 |
+
|
15 |
+
from kimi_dev.serve.frontend import reload_javascript
|
16 |
+
from kimi_dev.serve.utils import (
|
17 |
configure_logger,
|
|
|
|
|
|
|
|
|
18 |
)
|
19 |
+
from kimi_dev.serve.gradio_utils import (
|
|
|
|
|
20 |
reset_state,
|
21 |
reset_textbox,
|
22 |
transfer_input,
|
23 |
wrap_gen_fn,
|
24 |
)
|
25 |
+
from kimi_dev.serve.inference import load_model
|
26 |
+
from kimi_dev.serve.examples import get_examples
|
27 |
+
from kimi_dev.serve.templates import post_process,get_loc_prompt, clone_github_repo, build_repo_structure, show_project_structure,get_repair_prompt,get_repo_files,get_full_file_paths_and_classes_and_functions,correct_file_path_in_structure
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
+
TITLE = """<h1 align="left" style="min-width:200px; margin-top:0;">Chat with Kimi-Dev-72Bπ₯ </h1>"""
|
30 |
+
DESCRIPTION_TOP = """<a href="https://github.com/MoonshotAI/Kimi-VL" target="_blank">Kimi-Dev-72B</a> is a strong and open-source coding LLM for software engineering tasks."""
|
31 |
+
USAGE_TOP = """Usage: 1. Input a Github url like "https://github.com/astropy/astropy" and submit it. \n2. Input your issue description and chat with Kimi-Dev-72B!"""
|
32 |
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
33 |
DEPLOY_MODELS = dict()
|
34 |
logger = configure_logger()
|
35 |
|
|
|
36 |
def parse_args():
|
37 |
parser = argparse.ArgumentParser()
|
38 |
parser.add_argument("--model", type=str, default="Kimi-Dev-72B")
|
|
|
67 |
return model_info
|
68 |
|
69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
def get_prompt(conversation) -> str:
|
71 |
"""
|
72 |
Get the prompt for the conversation.
|
|
|
87 |
@spaces.GPU(duration=180)
|
88 |
def predict(
|
89 |
text,
|
90 |
+
url,
|
91 |
chatbot,
|
92 |
history,
|
93 |
top_p,
|
94 |
temperature,
|
95 |
max_length_tokens,
|
|
|
96 |
chunk_size: int = 512,
|
97 |
):
|
98 |
"""
|
99 |
+
Predict the response for the input text and url.
|
100 |
Args:
|
101 |
text (str): The input text.
|
102 |
+
url (str): The input url.
|
103 |
chatbot (list): The chatbot.
|
104 |
history (list): The history.
|
105 |
top_p (float): The top-p value.
|
106 |
temperature (float): The temperature value.
|
107 |
repetition_penalty (float): The repetition penalty value.
|
108 |
max_length_tokens (int): The max length tokens.
|
|
|
109 |
chunk_size (int): The chunk size.
|
110 |
"""
|
111 |
print("running the prediction function")
|
112 |
+
|
113 |
try:
|
114 |
model, tokenizer = fetch_model(args.model)
|
115 |
|
|
|
120 |
yield [[text, "No Model Found"]], [], "No Model Found"
|
121 |
return
|
122 |
|
123 |
+
prompt = text
|
124 |
+
repo_name = url.split("/")[-1]
|
125 |
+
|
126 |
+
repo_path = './local_path/'+repo_name # Local clone path
|
127 |
+
|
128 |
+
clone_github_repo(url, repo_path)
|
129 |
+
structure = build_repo_structure(repo_path)
|
130 |
+
string_struture = show_project_structure(structure)
|
131 |
+
|
132 |
+
loc_prompt = get_loc_prompt(prompt,string_struture)
|
133 |
+
|
134 |
|
|
|
135 |
messages = [
|
136 |
{"role": "system", "content": "You are a helpful assistant."},
|
137 |
+
{"role": "user", "content": loc_prompt}
|
138 |
]
|
139 |
+
text_for_model = tokenizer.apply_chat_template(
|
140 |
messages,
|
141 |
tokenize=False,
|
142 |
add_generation_prompt=True
|
143 |
)
|
144 |
+
model_inputs = tokenizer([text_for_model], return_tensors="pt").to(model.device)
|
145 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
146 |
+
# print("start generating")
|
147 |
+
if temperature > 0:
|
148 |
+
generation_kwargs = dict(
|
149 |
+
**model_inputs,
|
150 |
+
do_sample=True,
|
151 |
+
temperature=temperature,
|
152 |
+
top_p=top_p,
|
153 |
+
max_new_tokens=max_length_tokens,
|
154 |
+
streamer=streamer
|
155 |
+
)
|
156 |
+
else:
|
157 |
+
generation_kwargs = dict(
|
158 |
+
**model_inputs,
|
159 |
+
do_sample=False,
|
160 |
+
max_new_tokens=max_length_tokens,
|
161 |
+
streamer=streamer
|
162 |
+
)
|
163 |
+
gen_thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
|
164 |
+
gen_thread.start()
|
165 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
|
167 |
+
partial_output = "Start Locating...\n"
|
168 |
+
|
169 |
+
for new_text in streamer:
|
170 |
+
partial_output += new_text
|
171 |
+
highlight_response = highlight_thinking(partial_output)
|
172 |
+
yield [[prompt, highlight_response]], [["null test", "null test2"]], "Generating file locations..."
|
173 |
+
|
174 |
+
gen_thread.join()
|
175 |
+
|
176 |
+
response = partial_output
|
177 |
|
178 |
+
raw_answer=post_process(response)
|
179 |
+
model_found_files = raw_answer.strip().split("\n")
|
180 |
print(response)
|
181 |
+
|
182 |
+
highlight_response = highlight_thinking(response)
|
183 |
+
yield [[prompt,highlight_response]], [["null test","null test2"]], "Generate: Success"
|
184 |
+
|
185 |
+
# reading file content
|
186 |
+
contents = ""
|
187 |
+
for file_path in model_found_files:
|
188 |
+
file_name = file_path.replace("```","")
|
189 |
+
print(file_name)
|
190 |
+
# pdb.set_trace()
|
191 |
+
to_open_path = repo_path + "/" + file_name
|
192 |
+
print("to_open_path,",to_open_path)
|
193 |
+
with open(to_open_path, "r", encoding="utf-8") as f:
|
194 |
+
content = f.read()
|
195 |
+
contents += f"{file_name}\n{content}\n\n"
|
196 |
+
|
197 |
+
|
198 |
+
repair_prompt = get_repair_prompt(prompt,contents)
|
199 |
+
|
200 |
+
messages = [
|
201 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
202 |
+
{"role": "user", "content": repair_prompt}
|
203 |
+
]
|
204 |
+
text = tokenizer.apply_chat_template(
|
205 |
+
messages,
|
206 |
+
tokenize=False,
|
207 |
+
add_generation_prompt=True
|
208 |
)
|
209 |
+
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
210 |
+
|
211 |
+
subprocess.run(["rm", "-rf", repo_path], check=True)
|
212 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
213 |
+
if temperature > 0:
|
214 |
+
generation_kwargs = dict(
|
215 |
+
**model_inputs,
|
216 |
+
do_sample=True,
|
|
|
|
|
|
|
|
|
|
|
217 |
temperature=temperature,
|
218 |
top_p=top_p,
|
219 |
+
max_new_tokens=max_length_tokens,
|
220 |
+
streamer=streamer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
)
|
222 |
+
else:
|
223 |
+
generation_kwargs = dict(
|
224 |
+
**model_inputs,
|
225 |
+
do_sample=False,
|
226 |
+
max_new_tokens=max_length_tokens,
|
227 |
+
streamer=streamer
|
228 |
+
)
|
229 |
+
gen_thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
|
230 |
+
gen_thread.start()
|
231 |
+
|
232 |
+
partial_output_repair = "Start Repairing...\n"
|
233 |
+
yield [[prompt,highlight_response],[repair_prompt,partial_output_repair]], [["null test","null test2"]], "Generate: Success"
|
234 |
+
time.sleep(5)
|
235 |
+
for new_text in streamer:
|
236 |
+
partial_output_repair += new_text
|
237 |
+
highlight_response = highlight_thinking(partial_output)
|
238 |
+
highlight_response_repair = highlight_thinking(partial_output_repair)
|
239 |
+
yield [[prompt, highlight_response], [repair_prompt, highlight_response_repair]], [["null test", "null test2"]], "Generating repair suggestion..."
|
240 |
|
241 |
+
gen_thread.join()
|
242 |
+
|
243 |
+
# yield response, "null test", "Generate: Success"
|
244 |
+
yield [[prompt,highlight_response],[repair_prompt,highlight_response_repair]], [["null test","null test2"]], "Generate: Success"
|
245 |
|
246 |
|
247 |
def retry(
|
248 |
text,
|
249 |
+
url,
|
250 |
chatbot,
|
251 |
history,
|
252 |
top_p,
|
253 |
temperature,
|
254 |
max_length_tokens,
|
|
|
255 |
chunk_size: int = 512,
|
256 |
):
|
257 |
"""
|
258 |
+
Retry the response for the input text and url.
|
259 |
"""
|
260 |
if len(history) == 0:
|
261 |
yield (chatbot, history, "Empty context")
|
262 |
return
|
263 |
|
264 |
+
# chatbot.pop()
|
265 |
+
# history.pop()
|
266 |
+
# text = history.pop()[-1]
|
267 |
if type(text) is tuple:
|
268 |
text, _ = text
|
269 |
|
270 |
yield from predict(
|
271 |
text,
|
272 |
+
url,
|
273 |
chatbot,
|
274 |
history,
|
275 |
top_p,
|
276 |
temperature,
|
277 |
max_length_tokens,
|
|
|
278 |
chunk_size,
|
279 |
)
|
280 |
|
|
|
283 |
with gr.Blocks(theme=gr.themes.Soft(), delete_cache=(1800, 1800)) as demo:
|
284 |
history = gr.State([])
|
285 |
input_text = gr.State()
|
286 |
+
upload_url = gr.State()
|
287 |
|
288 |
with gr.Row():
|
289 |
gr.HTML(TITLE)
|
290 |
status_display = gr.Markdown("Success", elem_id="status_display")
|
291 |
gr.Markdown(DESCRIPTION_TOP)
|
292 |
+
gr.Markdown(USAGE_TOP)
|
293 |
|
294 |
with gr.Row(equal_height=True):
|
295 |
with gr.Column(scale=4):
|
|
|
298 |
elem_id="Kimi-Dev-72B",
|
299 |
show_share_button=True,
|
300 |
bubble_full_width=False,
|
301 |
+
height=400,
|
302 |
+
# render_markdown=False
|
303 |
)
|
304 |
with gr.Row():
|
305 |
with gr.Column(scale=4):
|
306 |
+
text_box = gr.Textbox(label="Issue Description", placeholder="Enter issue description", container=False)
|
307 |
with gr.Column(min_width=70):
|
308 |
submit_btn = gr.Button("Send")
|
309 |
+
# with gr.Column(min_width=70):
|
310 |
+
# cancel_btn = gr.Button("Stop")
|
311 |
with gr.Row():
|
312 |
empty_btn = gr.Button("π§Ή New Conversation")
|
313 |
retry_btn = gr.Button("π Regenerate")
|
314 |
+
# del_last_btn = gr.Button("ποΈ Remove Last Turn")
|
315 |
+
def respond(message):
|
316 |
+
return f"Url submitted!"
|
317 |
with gr.Column():
|
318 |
+
url_box = gr.Textbox(label="Please input a Github url here",placeholder="Input your url", lines=1)
|
319 |
+
url_submit_btn = gr.Button("Submit")
|
320 |
+
output = gr.Textbox(label="Submitted url")
|
321 |
+
url_submit_btn.click(fn=respond, inputs=upload_url, outputs=output)
|
322 |
+
|
323 |
# Parameter Setting Tab for control the generation parameters
|
324 |
with gr.Tab(label="Parameter Setting"):
|
325 |
+
top_p = gr.Slider(minimum=-0, maximum=1.0, value=0.95, step=0.05, interactive=True, label="Top-p")
|
326 |
temperature = gr.Slider(
|
327 |
+
minimum=0, maximum=1.0, value=1.0, step=0.1, interactive=True, label="Temperature"
|
328 |
)
|
329 |
max_length_tokens = gr.Slider(
|
330 |
+
minimum=512, maximum=16384, value=8192, step=64, interactive=True, label="Max Length Tokens"
|
|
|
|
|
|
|
331 |
)
|
332 |
|
|
|
|
|
333 |
gr.Examples(
|
334 |
examples=get_examples(ROOT_DIR),
|
335 |
+
inputs=[url_box, text_box],
|
336 |
)
|
337 |
+
# gr.Markdown()
|
338 |
|
339 |
input_widgets = [
|
340 |
input_text,
|
341 |
+
upload_url,
|
342 |
chatbot,
|
343 |
history,
|
344 |
top_p,
|
345 |
temperature,
|
346 |
max_length_tokens,
|
|
|
347 |
]
|
348 |
output_widgets = [chatbot, history, status_display]
|
349 |
|
350 |
transfer_input_args = dict(
|
351 |
fn=transfer_input,
|
352 |
+
inputs=[text_box, url_box],
|
353 |
+
outputs=[input_text, upload_url, text_box, upload_url, submit_btn],
|
354 |
show_progress=True,
|
355 |
)
|
356 |
|
|
|
366 |
empty_btn.click(reset_state, outputs=output_widgets, show_progress=True)
|
367 |
empty_btn.click(**reset_args)
|
368 |
retry_btn.click(**retry_args)
|
|
|
|
|
369 |
|
370 |
demo.title = "Kimi-Dev-72B"
|
371 |
return demo
|
|
|
375 |
demo = build_demo(args)
|
376 |
reload_javascript()
|
377 |
|
378 |
+
favicon_path = os.path.join("kimi_dev/serve/assets/favicon.ico")
|
|
|
379 |
# demo.queue().launch(
|
380 |
# favicon_path=favicon_path,
|
381 |
# server_name=args.ip,
|
|
|
385 |
favicon_path=favicon_path,
|
386 |
server_name=args.ip,
|
387 |
server_port=args.port,
|
388 |
+
share=True
|
389 |
)
|
390 |
|
391 |
if __name__ == "__main__":
|
{kimi_vl β kimi_dev}/__init__.py
RENAMED
File without changes
|
{kimi_vl β kimi_dev}/serve/__init__.py
RENAMED
File without changes
|
{kimi_vl β kimi_dev}/serve/assets/Kelpy-Codos.js
RENAMED
File without changes
|
{kimi_vl β kimi_dev}/serve/assets/avatar.png
RENAMED
File without changes
|
{kimi_vl β kimi_dev}/serve/assets/custom.css
RENAMED
File without changes
|
{kimi_vl β kimi_dev}/serve/assets/custom.js
RENAMED
File without changes
|
{kimi_vl β kimi_dev}/serve/assets/favicon.ico
RENAMED
File without changes
|
kimi_dev/serve/examples.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import io
|
3 |
+
import base64
|
4 |
+
|
5 |
+
EXAMPLES_LIST = [
|
6 |
+
[
|
7 |
+
"https://github.com/astropy/astropy",
|
8 |
+
"units.quantity_input decorator fails for constructors with type hinted return value -> None\n### Summary\r\nI am using the `units.quantity_input` decorator with typing hints for constructors, however when I add the correct return value for the constructor (`None`) then I get an exception, because `None` has no attribute `to`.\r\n\r\n### Reproducer\r\nThe issue can be reproduced with the following file:\r\n``` Python\r\nimport astropy.units as u\r\n\r\n\r\nclass PoC(object):\r\n\r\n @u.quantity_input\r\n def __init__(self, voltage: u.V) -> None:\r\n pass\r\n\r\n\r\nif __name__ == '__main__':\r\n poc = PoC(1.*u.V)\r\n```\r\nwhich results in the following error:\r\n```\r\n$ python3 poc.py\r\nTraceback (most recent call last):\r\n File \"poc.py\", line 12, in <module>\r\n poc = PoC(1.*u.V)\r\n File \"/usr/lib64/python3.6/site-packages/astropy/utils/decorators.py\", line 868, in __init__\r\n func = make_function_with_signature(func, name=name, **wrapped_args)\r\n File \"/usr/lib64/python3.6/site-packages/astropy/units/decorators.py\", line 225, in wrapper\r\n return return_.to(wrapped_signature.return_annotation)\r\nAttributeError: 'NoneType' object has no attribute 'to'\r\n```\r\n\r\nThis has been tested on Fedora 27 with python 3.6.3, astropy 2.0.2 and numpy 1.13.3 all from Fedora's repository.\r\n\r\n### Workaround\r\nThe issue can be circumvented by not adding the return type typing hint. Unfortunately, then a static type checker cannot infer that this function returns nothing.\r\n\r\n### Possible fix\r\nMaybe the decorator could explicitly check whether None is returned and then omit the unit check.\n\n\n",
|
9 |
+
],
|
10 |
+
[
|
11 |
+
"https://github.com/sympy/sympy",
|
12 |
+
"evalf does not call _imp_ recursively\nExample from https://stackoverflow.com/questions/41818842/why-cant-i-evaluate-a-composition-of-implemented-functions-in-sympy-at-a-point:\r\n\r\n```\r\n>>> from sympy.utilities.lambdify import implemented_function\r\n>>> f = implemented_function('f', lambda x: x ** 2)\r\n>>> g = implemented_function('g', lambda x: 2 * x)\r\n>>> print(f( 2 ).evalf())\r\n4.00000000000000\r\n>>> print( g(2) .evalf())\r\n4.00000000000000\r\n>>> print(f(g(2)).evalf())\r\nf(g(2))\r\n```\r\n\r\nThe code for this is in `Function._eval_evalf`. It isn't calling evalf recursively on the return of `_imp_`. \n\n\n",
|
13 |
+
],
|
14 |
+
[
|
15 |
+
"https://github.com/matplotlib/matplotlib",
|
16 |
+
"[ENH]: ContourSet.set_paths\n### Problem\n\nTo get contour labelling working with its special transforms, Cartopy has a [workaround](https://github.com/SciTools/cartopy/blob/2ed668c17b4e52421f15c5be3761719c75c5311a/lib/cartopy/mpl/contour.py#L89-L108) where it replaces all the paths on the `ContourSet` with transformed versions. This currently looks like\r\n\r\n```python\r\npaths = cs.get_paths()\r\npaths[:] = transformed_paths\r\n``` \r\n\r\nwhich doesnβt smell very good.\n\n### Proposed solution\n\nThe above would smell better as \r\n\r\n```python\r\ncs.set_paths(transformed_paths)\r\n``` \n\n\n"
|
17 |
+
]
|
18 |
+
]
|
19 |
+
|
20 |
+
|
21 |
+
def get_examples(root_dir: str = None):
|
22 |
+
examples = []
|
23 |
+
for github_url, instance_id in EXAMPLES_LIST:
|
24 |
+
examples.append([github_url, instance_id])
|
25 |
+
|
26 |
+
return examples
|
{kimi_vl β kimi_dev}/serve/frontend.py
RENAMED
File without changes
|
{kimi_vl β kimi_dev}/serve/gradio_utils.py
RENAMED
File without changes
|
kimi_dev/serve/inference.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
from transformers import (
|
4 |
+
AutoModelForCausalLM,
|
5 |
+
AutoConfig,
|
6 |
+
AutoTokenizer
|
7 |
+
)
|
8 |
+
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
|
12 |
+
def load_model(model_path: str = "moonshotai/Kimi-Dev-72B"):
|
13 |
+
# hotfix the model to use flash attention 2
|
14 |
+
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
15 |
+
|
16 |
+
model = AutoModelForCausalLM.from_pretrained(
|
17 |
+
model_path,
|
18 |
+
config=config,
|
19 |
+
torch_dtype="auto",
|
20 |
+
device_map="auto",
|
21 |
+
trust_remote_code=True,
|
22 |
+
)
|
23 |
+
|
24 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
25 |
+
|
26 |
+
return model, tokenizer
|
kimi_dev/serve/templates.py
ADDED
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import json
|
4 |
+
import subprocess
|
5 |
+
import ast
|
6 |
+
|
7 |
+
def show_project_structure(structure, spacing=0) -> str:
|
8 |
+
"""pprint the project structure"""
|
9 |
+
|
10 |
+
pp_string = ''
|
11 |
+
|
12 |
+
for key, value in structure.items():
|
13 |
+
if '.' in key and '.py' not in key:
|
14 |
+
continue # skip none python files
|
15 |
+
|
16 |
+
# TODO: maybe we should skip the test files...
|
17 |
+
if key.startswith('test'):
|
18 |
+
continue # skip the test files as well...
|
19 |
+
|
20 |
+
if '.' in key:
|
21 |
+
pp_string += ' ' * spacing + str(key) + '\n'
|
22 |
+
else:
|
23 |
+
pp_string += ' ' * spacing + str(key) + '/' + '\n'
|
24 |
+
if 'classes' not in value:
|
25 |
+
pp_string += show_project_structure(value, spacing + 4)
|
26 |
+
|
27 |
+
return pp_string
|
28 |
+
|
29 |
+
import os
|
30 |
+
import json
|
31 |
+
import subprocess
|
32 |
+
import ast
|
33 |
+
def clone_github_repo(github_url, local_path):
|
34 |
+
"""Clone GitHub repository to local path"""
|
35 |
+
try:
|
36 |
+
subprocess.run(['git', 'clone', github_url, local_path], check=True)
|
37 |
+
print(f"Successfully cloned repository to: {local_path}")
|
38 |
+
except subprocess.CalledProcessError as e:
|
39 |
+
print(f"Warning: Repository cloning may have failed: {e}")
|
40 |
+
|
41 |
+
def parse_python_file(file_path, file_content=None):
|
42 |
+
"""Parse a Python file to extract class and function definitions with their line numbers.
|
43 |
+
:param file_path: Path to the Python file.
|
44 |
+
:return: Class names, function names, and file contents
|
45 |
+
"""
|
46 |
+
if file_content is None:
|
47 |
+
try:
|
48 |
+
with open(file_path, "r") as file:
|
49 |
+
file_content = file.read()
|
50 |
+
parsed_data = ast.parse(file_content)
|
51 |
+
except Exception as e: # Catch all types of exceptions
|
52 |
+
print(f"Error in file {file_path}: {e}")
|
53 |
+
return [], [], ""
|
54 |
+
else:
|
55 |
+
try:
|
56 |
+
parsed_data = ast.parse(file_content)
|
57 |
+
except Exception as e: # Catch all types of exceptions
|
58 |
+
print(f"Error in file {file_path}: {e}")
|
59 |
+
return [], [], ""
|
60 |
+
class_info = []
|
61 |
+
function_names = []
|
62 |
+
class_methods = set()
|
63 |
+
for node in ast.walk(parsed_data):
|
64 |
+
if isinstance(node, ast.ClassDef):
|
65 |
+
methods = []
|
66 |
+
for n in node.body:
|
67 |
+
if isinstance(n, ast.FunctionDef):
|
68 |
+
methods.append(
|
69 |
+
{
|
70 |
+
"name": n.name,
|
71 |
+
"start_line": n.lineno,
|
72 |
+
"end_line": n.end_lineno,
|
73 |
+
"text": file_content.splitlines()[
|
74 |
+
n.lineno - 1 : n.end_lineno
|
75 |
+
],
|
76 |
+
}
|
77 |
+
)
|
78 |
+
class_methods.add(n.name)
|
79 |
+
class_info.append(
|
80 |
+
{
|
81 |
+
"name": node.name,
|
82 |
+
"start_line": node.lineno,
|
83 |
+
"end_line": node.end_lineno,
|
84 |
+
"text": file_content.splitlines()[
|
85 |
+
node.lineno - 1 : node.end_lineno
|
86 |
+
],
|
87 |
+
"methods": methods,
|
88 |
+
}
|
89 |
+
)
|
90 |
+
elif isinstance(node, ast.FunctionDef) and not isinstance(
|
91 |
+
node, ast.AsyncFunctionDef
|
92 |
+
):
|
93 |
+
if node.name not in class_methods:
|
94 |
+
function_names.append(
|
95 |
+
{
|
96 |
+
"name": node.name,
|
97 |
+
"start_line": node.lineno,
|
98 |
+
"end_line": node.end_lineno,
|
99 |
+
"text": file_content.splitlines()[
|
100 |
+
node.lineno - 1 : node.end_lineno
|
101 |
+
],
|
102 |
+
}
|
103 |
+
)
|
104 |
+
return class_info, function_names, file_content.splitlines()
|
105 |
+
|
106 |
+
def create_structure(directory_path):
|
107 |
+
"""Create the structure of the repository directory by parsing Python files.
|
108 |
+
:param directory_path: Path to the repository directory.
|
109 |
+
:return: A dictionary representing the structure.
|
110 |
+
"""
|
111 |
+
structure = {}
|
112 |
+
for root, _, files in os.walk(directory_path):
|
113 |
+
repo_name = os.path.basename(directory_path)
|
114 |
+
relative_root = os.path.relpath(root, directory_path)
|
115 |
+
if relative_root == ".":
|
116 |
+
relative_root = repo_name
|
117 |
+
curr_struct = structure
|
118 |
+
for part in relative_root.split(os.sep):
|
119 |
+
if part not in curr_struct:
|
120 |
+
curr_struct[part] = {}
|
121 |
+
curr_struct = curr_struct[part]
|
122 |
+
for file_name in files:
|
123 |
+
if file_name.endswith(".py"):
|
124 |
+
file_path = os.path.join(root, file_name)
|
125 |
+
class_info, function_names, file_lines = parse_python_file(file_path)
|
126 |
+
curr_struct[file_name] = {
|
127 |
+
"classes": class_info,
|
128 |
+
"functions": function_names,
|
129 |
+
"text": file_lines,
|
130 |
+
}
|
131 |
+
else:
|
132 |
+
curr_struct[file_name] = {}
|
133 |
+
return structure
|
134 |
+
|
135 |
+
def build_repo_structure(root_path):
|
136 |
+
"""Build repository structure using improved parsing method"""
|
137 |
+
return create_structure(root_path)
|
138 |
+
|
139 |
+
|
140 |
+
|
141 |
+
def get_loc_prompt(issue_text,repo_structure):
|
142 |
+
obtain_relevant_files_prompt = """
|
143 |
+
Please look through the following GitHub problem description and Repository structure and provide a list of files that one would need to edit to fix the problem.
|
144 |
+
|
145 |
+
### GitHub Problem Description ###
|
146 |
+
{problem_statement}
|
147 |
+
|
148 |
+
###
|
149 |
+
|
150 |
+
### Repository Structure ###
|
151 |
+
{structure}
|
152 |
+
|
153 |
+
###
|
154 |
+
|
155 |
+
Please only provide the full path and return at most 5 files.
|
156 |
+
The returned files should be separated by new lines ordered by most to least important and wrapped with ```
|
157 |
+
For example:
|
158 |
+
```
|
159 |
+
file1.py
|
160 |
+
file2.py
|
161 |
+
```
|
162 |
+
"""
|
163 |
+
prompt_content = obtain_relevant_files_prompt.format(problem_statement=issue_text,structure=repo_structure)
|
164 |
+
return prompt_content
|
165 |
+
|
166 |
+
def get_repair_prompt(issue_text,file_content):
|
167 |
+
repair_prompt_combine_topn_cot_diff = """
|
168 |
+
We are currently solving the following issue within our repository. Here is the issue text:
|
169 |
+
--- BEGIN ISSUE ---
|
170 |
+
{problem_statement}
|
171 |
+
--- END ISSUE ---
|
172 |
+
|
173 |
+
Below are some code segments, each from a relevant file. One or more of these files may contain bugs.
|
174 |
+
--- BEGIN FILE ---
|
175 |
+
```
|
176 |
+
{content}
|
177 |
+
```
|
178 |
+
--- END FILE ---
|
179 |
+
|
180 |
+
Please first localize the bug based on the issue statement, and then generate *SEARCH/REPLACE* edits to fix the issue.
|
181 |
+
|
182 |
+
Every *SEARCH/REPLACE* edit must use this format:
|
183 |
+
1. The file path
|
184 |
+
2. The start of search block: <<<<<<< SEARCH
|
185 |
+
3. A contiguous chunk of lines to search for in the existing source code
|
186 |
+
4. The dividing line: =======
|
187 |
+
5. The lines to replace into the source code
|
188 |
+
6. The end of the replace block: >>>>>>> REPLACE
|
189 |
+
|
190 |
+
Here is an example:
|
191 |
+
|
192 |
+
```python
|
193 |
+
### mathweb/flask/app.py
|
194 |
+
<<<<<<< SEARCH
|
195 |
+
from flask import Flask
|
196 |
+
=======
|
197 |
+
import math
|
198 |
+
from flask import Flask
|
199 |
+
>>>>>>> REPLACE
|
200 |
+
```
|
201 |
+
|
202 |
+
Please note that the *SEARCH/REPLACE* edit REQUIRES PROPER INDENTATION. If you would like to add the line ' print(x)', you must fully write that out, with all those spaces before the code!
|
203 |
+
Wrap the *SEARCH/REPLACE* edit in blocks ```python...```.
|
204 |
+
"""
|
205 |
+
prompt_content = repair_prompt_combine_topn_cot_diff.format(problem_statement=issue_text,content=file_content.rstrip())
|
206 |
+
return prompt_content
|
207 |
+
|
208 |
+
def get_repo_files(structure, filepaths: list[str]):
|
209 |
+
files, classes, functions = get_full_file_paths_and_classes_and_functions(structure)
|
210 |
+
file_contents = dict()
|
211 |
+
for filepath in filepaths:
|
212 |
+
content = None
|
213 |
+
|
214 |
+
for file_content in files:
|
215 |
+
if file_content[0] == filepath:
|
216 |
+
content = '\n'.join(file_content[1])
|
217 |
+
file_contents[filepath] = content
|
218 |
+
break
|
219 |
+
|
220 |
+
# assert content is not None, "file not found"
|
221 |
+
return file_contents
|
222 |
+
|
223 |
+
def correct_file_path_in_structure(file_name, structure):
|
224 |
+
"""
|
225 |
+
Search for the correct file path in the structure, mainly checking first-level subdirectories
|
226 |
+
|
227 |
+
Args:
|
228 |
+
file_name (str): File name to search for
|
229 |
+
structure (dict): Repository structure
|
230 |
+
|
231 |
+
Returns:
|
232 |
+
str: Correct file path if found, otherwise returns original file_name
|
233 |
+
"""
|
234 |
+
# Search in current directory
|
235 |
+
file_contents = get_repo_files(structure, [file_name])
|
236 |
+
if file_contents != {}:
|
237 |
+
return file_name
|
238 |
+
|
239 |
+
# Only check first-level subdirectories
|
240 |
+
for sub_dir in structure.keys():
|
241 |
+
if isinstance(structure[sub_dir], dict):
|
242 |
+
file_contents = get_repo_files(structure[sub_dir], [file_name])
|
243 |
+
if file_contents != {}:
|
244 |
+
return f'{sub_dir}/{file_name}'
|
245 |
+
|
246 |
+
return file_name
|
247 |
+
|
248 |
+
def get_full_file_paths_and_classes_and_functions(structure, current_path=''):
|
249 |
+
"""
|
250 |
+
Recursively retrieve all file paths, classes, and functions within a directory structure.
|
251 |
+
|
252 |
+
Arguments:
|
253 |
+
structure -- a dictionary representing the directory structure
|
254 |
+
current_path -- the path accumulated so far, used during recursion (default="")
|
255 |
+
|
256 |
+
Returns:
|
257 |
+
A tuple containing:
|
258 |
+
- files: list of full file paths
|
259 |
+
- classes: list of class details with file paths
|
260 |
+
- functions: list of function details with file paths
|
261 |
+
"""
|
262 |
+
files = []
|
263 |
+
classes = []
|
264 |
+
functions = []
|
265 |
+
for name, content in structure.items():
|
266 |
+
if isinstance(content, dict):
|
267 |
+
if (
|
268 |
+
(
|
269 |
+
'functions' not in content.keys()
|
270 |
+
and 'classes' not in content.keys()
|
271 |
+
and 'text' not in content.keys()
|
272 |
+
)
|
273 |
+
or not len(content.keys()) == 3
|
274 |
+
or (
|
275 |
+
isinstance(content.get('text', []), dict)
|
276 |
+
or isinstance(content.get('functions', []), dict)
|
277 |
+
or isinstance(content.get('classes', []), dict)
|
278 |
+
)
|
279 |
+
):
|
280 |
+
# or guards against case where functions and classes are somehow part of the structure.
|
281 |
+
next_path = f'{current_path}/{name}' if current_path else name
|
282 |
+
(
|
283 |
+
sub_files,
|
284 |
+
sub_classes,
|
285 |
+
sub_functions,
|
286 |
+
) = get_full_file_paths_and_classes_and_functions(content, next_path)
|
287 |
+
files.extend(sub_files)
|
288 |
+
classes.extend(sub_classes)
|
289 |
+
functions.extend(sub_functions)
|
290 |
+
else:
|
291 |
+
next_path = f'{current_path}/{name}' if current_path else name
|
292 |
+
files.append((next_path, content.get('text', [])))
|
293 |
+
if content.get('text', []) == []:
|
294 |
+
continue
|
295 |
+
if 'classes' in content:
|
296 |
+
for clazz in content['classes']:
|
297 |
+
classes.append(
|
298 |
+
{
|
299 |
+
'file': next_path,
|
300 |
+
'name': clazz['name'],
|
301 |
+
'start_line': clazz['start_line'],
|
302 |
+
'end_line': clazz['end_line'],
|
303 |
+
'methods': [
|
304 |
+
{
|
305 |
+
'name': method['name'],
|
306 |
+
'start_line': method['start_line'],
|
307 |
+
'end_line': method['end_line'],
|
308 |
+
}
|
309 |
+
for method in clazz.get('methods', [])
|
310 |
+
],
|
311 |
+
},
|
312 |
+
)
|
313 |
+
if 'functions' in content:
|
314 |
+
for function in content['functions']:
|
315 |
+
try:
|
316 |
+
function['file'] = next_path
|
317 |
+
except TypeError:
|
318 |
+
continue
|
319 |
+
functions.append(function)
|
320 |
+
else:
|
321 |
+
next_path = f'{current_path}/{name}' if current_path else name
|
322 |
+
files.append(next_path)
|
323 |
+
return files, classes, functions
|
324 |
+
|
325 |
+
def post_process(response: str) -> str:
|
326 |
+
content = response
|
327 |
+
if "β/thinkβ·" in content:
|
328 |
+
content = content.replace("βthinkβ·", "")
|
329 |
+
parts = content.split("β/thinkβ·")
|
330 |
+
content = parts[-1]
|
331 |
+
# Extract content between triple backticks (```)
|
332 |
+
matches = re.findall(r"```.*?```", content, re.DOTALL)
|
333 |
+
|
334 |
+
if matches:
|
335 |
+
matches = [item.replace("```","") for item in matches]
|
336 |
+
return "\n".join(matches) # Return all matched code blocks joined by new lines
|
337 |
+
return content # If no match, return the full response
|
{kimi_vl β kimi_dev}/serve/utils.py
RENAMED
File without changes
|
kimi_vl/serve/chat_utils.py
DELETED
@@ -1,379 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
From https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
3 |
-
"""
|
4 |
-
|
5 |
-
import dataclasses
|
6 |
-
import logging
|
7 |
-
import copy
|
8 |
-
from enum import IntEnum, auto
|
9 |
-
from typing import Dict, List
|
10 |
-
import base64
|
11 |
-
|
12 |
-
import gradio as gr
|
13 |
-
import torch
|
14 |
-
|
15 |
-
from .utils import pil_to_base64
|
16 |
-
|
17 |
-
IMAGE_TOKEN = "<image>"
|
18 |
-
logger = logging.getLogger("gradio_logger")
|
19 |
-
|
20 |
-
|
21 |
-
class SeparatorStyle(IntEnum):
|
22 |
-
"""Separator styles."""
|
23 |
-
|
24 |
-
PLAIN = auto()
|
25 |
-
ALIGNMENT = auto()
|
26 |
-
KIMI_VL = auto()
|
27 |
-
|
28 |
-
|
29 |
-
@dataclasses.dataclass
|
30 |
-
class Conversation:
|
31 |
-
"""A class that manages prompt templates and keeps all conversation history."""
|
32 |
-
|
33 |
-
# The name of this template
|
34 |
-
name: str
|
35 |
-
# The template of the system prompt
|
36 |
-
system_template: str = "{system_message}"
|
37 |
-
# The system message
|
38 |
-
system_message: str = ""
|
39 |
-
# The names of two roles
|
40 |
-
roles: List[str] = (("USER", "ASSISTANT"),)
|
41 |
-
# All messages. Each item is (role, message).
|
42 |
-
messages: List[List[str]] = ()
|
43 |
-
# The number of few shot examples
|
44 |
-
offset: int = 0
|
45 |
-
# The separator style and configurations
|
46 |
-
sep_style: SeparatorStyle = SeparatorStyle.PLAIN
|
47 |
-
sep: str = "\n"
|
48 |
-
sep2: str = None
|
49 |
-
# Stop criteria (the default one is EOS token)
|
50 |
-
stop_str: str = None
|
51 |
-
# Stops generation if meeting any token in this list
|
52 |
-
stop_token_ids: List[int] = None
|
53 |
-
|
54 |
-
def get_prompt(self) -> str:
|
55 |
-
"""Get the prompt for generation."""
|
56 |
-
system_prompt = self.system_template.format(system_message=self.system_message)
|
57 |
-
if self.sep_style == SeparatorStyle.PLAIN:
|
58 |
-
seps = [self.sep, self.sep2]
|
59 |
-
ret = ""
|
60 |
-
for i, (role, message) in enumerate(self.messages):
|
61 |
-
if message:
|
62 |
-
if type(message) is tuple:
|
63 |
-
message = message[0]
|
64 |
-
if i % 2 == 0:
|
65 |
-
ret += message + seps[i % 2]
|
66 |
-
else:
|
67 |
-
ret += message + seps[i % 2]
|
68 |
-
else:
|
69 |
-
ret += ""
|
70 |
-
return ret
|
71 |
-
elif self.sep_style == SeparatorStyle.ALIGNMENT:
|
72 |
-
seps = [self.sep, self.sep2]
|
73 |
-
ret = ""
|
74 |
-
for i, (role, message) in enumerate(self.messages):
|
75 |
-
if message:
|
76 |
-
if type(message) is tuple:
|
77 |
-
message, _, _ = message
|
78 |
-
if i % 2 == 0:
|
79 |
-
ret += '<image>\n' + seps[i % 2]
|
80 |
-
else:
|
81 |
-
ret += message + seps[i % 2]
|
82 |
-
else:
|
83 |
-
ret += ""
|
84 |
-
return ret
|
85 |
-
elif self.sep_style == SeparatorStyle.KIMI_VL:
|
86 |
-
seps = [self.sep, self.sep2]
|
87 |
-
if system_prompt == "" or system_prompt is None:
|
88 |
-
ret = ""
|
89 |
-
else:
|
90 |
-
ret = system_prompt + seps[0]
|
91 |
-
for i, (role, message) in enumerate(self.messages):
|
92 |
-
if message:
|
93 |
-
if type(message) is tuple:
|
94 |
-
message = message[0]
|
95 |
-
|
96 |
-
if role == "user":
|
97 |
-
ret += message + self.sep
|
98 |
-
else:
|
99 |
-
if self.sep2 is not None:
|
100 |
-
ret += message + self.sep2
|
101 |
-
else:
|
102 |
-
ret += message
|
103 |
-
else:
|
104 |
-
ret = ret
|
105 |
-
return ret
|
106 |
-
else:
|
107 |
-
raise ValueError(f"Invalid style: {self.sep_style}")
|
108 |
-
|
109 |
-
def set_system_message(self, system_message: str):
|
110 |
-
"""Set the system message."""
|
111 |
-
self.system_message = system_message
|
112 |
-
|
113 |
-
def append_message(self, role: str, message: str):
|
114 |
-
"""Append a new message."""
|
115 |
-
self.messages.append([role, message])
|
116 |
-
|
117 |
-
def update_last_message(self, message: str):
|
118 |
-
"""Update the last output.
|
119 |
-
|
120 |
-
The last message is typically set to be None when constructing the prompt,
|
121 |
-
so we need to update it in-place after getting the response from a model.
|
122 |
-
"""
|
123 |
-
self.messages[-1][1] = message
|
124 |
-
|
125 |
-
def reset_message(self):
|
126 |
-
"""Reset a new message."""
|
127 |
-
self.messages = []
|
128 |
-
|
129 |
-
def to_gradio_chatbot(self):
|
130 |
-
"""Convert the conversation to gradio chatbot format."""
|
131 |
-
ret = []
|
132 |
-
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
133 |
-
if i % 2 == 0:
|
134 |
-
ret.append([msg, None])
|
135 |
-
else:
|
136 |
-
ret[-1][-1] = msg
|
137 |
-
return ret
|
138 |
-
|
139 |
-
def to_openai_api_messages(self):
|
140 |
-
"""Convert the conversation to OpenAI chat completion format."""
|
141 |
-
system_prompt = self.system_template.format(system_message=self.system_message)
|
142 |
-
ret = [{"role": "system", "content": system_prompt}]
|
143 |
-
|
144 |
-
for i, (_, msg) in enumerate(self.messages[self.offset :]):
|
145 |
-
if i % 2 == 0:
|
146 |
-
ret.append({"role": "user", "content": msg})
|
147 |
-
else:
|
148 |
-
if msg is not None:
|
149 |
-
ret.append({"role": "assistant", "content": msg})
|
150 |
-
return ret
|
151 |
-
|
152 |
-
def copy(self):
|
153 |
-
return Conversation(
|
154 |
-
name=self.name,
|
155 |
-
system_template=self.system_template,
|
156 |
-
system_message=self.system_message,
|
157 |
-
roles=self.roles,
|
158 |
-
messages=[[x, y] for x, y in self.messages],
|
159 |
-
offset=self.offset,
|
160 |
-
sep_style=self.sep_style,
|
161 |
-
sep=self.sep,
|
162 |
-
sep2=self.sep2,
|
163 |
-
stop_str=self.stop_str,
|
164 |
-
stop_token_ids=self.stop_token_ids,
|
165 |
-
)
|
166 |
-
|
167 |
-
def dict(self):
|
168 |
-
return {
|
169 |
-
"template_name": self.name,
|
170 |
-
"system_message": self.system_message,
|
171 |
-
"roles": self.roles,
|
172 |
-
"messages": self.messages,
|
173 |
-
"offset": self.offset,
|
174 |
-
}
|
175 |
-
|
176 |
-
|
177 |
-
# A global registry for all conversation templates
|
178 |
-
conv_templates: Dict[str, Conversation] = {}
|
179 |
-
|
180 |
-
|
181 |
-
def register_conv_template(template: Conversation, override: bool = False):
|
182 |
-
"""Register a new conversation template."""
|
183 |
-
if not override:
|
184 |
-
assert template.name not in conv_templates, f"{template.name} has been registered."
|
185 |
-
|
186 |
-
conv_templates[template.name] = template
|
187 |
-
|
188 |
-
|
189 |
-
def get_conv_template(name: str) -> Conversation:
|
190 |
-
"""Get a conversation template."""
|
191 |
-
return conv_templates[name].copy()
|
192 |
-
|
193 |
-
|
194 |
-
register_conv_template(
|
195 |
-
Conversation(
|
196 |
-
name="plain",
|
197 |
-
system_template="",
|
198 |
-
system_message="",
|
199 |
-
roles=("", ""),
|
200 |
-
messages=(),
|
201 |
-
offset=0,
|
202 |
-
sep_style=SeparatorStyle.PLAIN,
|
203 |
-
sep="",
|
204 |
-
sep2="",
|
205 |
-
stop_token_ids=[100001],
|
206 |
-
stop_str=['</s>'],
|
207 |
-
)
|
208 |
-
)
|
209 |
-
|
210 |
-
|
211 |
-
register_conv_template(
|
212 |
-
Conversation(
|
213 |
-
name="alignment",
|
214 |
-
system_template="",
|
215 |
-
system_message="",
|
216 |
-
roles=("", ""),
|
217 |
-
messages=(),
|
218 |
-
offset=0,
|
219 |
-
sep_style=SeparatorStyle.ALIGNMENT,
|
220 |
-
sep="",
|
221 |
-
sep2="",
|
222 |
-
stop_token_ids=[100001],
|
223 |
-
stop_str=['</s>'],
|
224 |
-
)
|
225 |
-
)
|
226 |
-
|
227 |
-
register_conv_template(
|
228 |
-
Conversation(
|
229 |
-
name="kimi-vl",
|
230 |
-
system_template="{system_message}",
|
231 |
-
system_message="You are a helpful assistant",
|
232 |
-
roles=("user", "assistant"),
|
233 |
-
messages=(),
|
234 |
-
offset=0,
|
235 |
-
sep_style=SeparatorStyle.KIMI_VL,
|
236 |
-
sep="<|im_end|>",
|
237 |
-
sep2=None,
|
238 |
-
stop_token_ids=None,
|
239 |
-
stop_str=["<|im_end|>"],
|
240 |
-
)
|
241 |
-
)
|
242 |
-
|
243 |
-
|
244 |
-
def new_chat_template(sft_format: str = "kimi-vl"):
|
245 |
-
return get_conv_template(sft_format)
|
246 |
-
|
247 |
-
|
248 |
-
def get_prompt(conv: Conversation) -> str:
|
249 |
-
"""Get the prompt for generation."""
|
250 |
-
return conv.get_prompt()
|
251 |
-
|
252 |
-
|
253 |
-
def generate_prompt_with_history(text, images, history, processor, max_length=2048):
|
254 |
-
"""
|
255 |
-
Generate a prompt with the chat history.
|
256 |
-
|
257 |
-
Args:
|
258 |
-
text (str): The text prompt.
|
259 |
-
images (list[PIL.Image.Image]): The image prompt.
|
260 |
-
history (list): List of previous conversation messages.
|
261 |
-
processor (KimiVLProcessor): The chat processor used for encoding the prompt.
|
262 |
-
max_length (int): The maximum length of the prompt.
|
263 |
-
"""
|
264 |
-
global IMAGE_TOKEN
|
265 |
-
|
266 |
-
user_role_ind = 0
|
267 |
-
bot_role_ind = 1
|
268 |
-
|
269 |
-
# Initialize conversation
|
270 |
-
conversation = new_chat_template(sft_format="plain")
|
271 |
-
|
272 |
-
if history:
|
273 |
-
conversation.messages = history
|
274 |
-
|
275 |
-
if images is not None and len(images) > 0:
|
276 |
-
# num_image_tags = text.count(IMAGE_TOKEN)
|
277 |
-
# num_images = len(images)
|
278 |
-
# if num_images > num_image_tags:
|
279 |
-
# pad_image_tags = num_images - num_image_tags
|
280 |
-
# image_tokens = "\n".join([IMAGE_TOKEN] * pad_image_tags)
|
281 |
-
|
282 |
-
# # append the <image> in a new line after the text prompt
|
283 |
-
# text = image_tokens + "\n" + text
|
284 |
-
# elif num_images < num_image_tags:
|
285 |
-
# remove_image_tags = num_image_tags - num_images
|
286 |
-
# text = text.replace(IMAGE_TOKEN, "", remove_image_tags)
|
287 |
-
|
288 |
-
print(f"prompt = {text}, len(images) = {len(images)}")
|
289 |
-
text = (text, images)
|
290 |
-
|
291 |
-
conversation.append_message(conversation.roles[user_role_ind], text)
|
292 |
-
conversation.append_message(conversation.roles[bot_role_ind], "")
|
293 |
-
|
294 |
-
# Create a copy of the conversation to avoid history truncation in the UI
|
295 |
-
conversation_copy = conversation.copy()
|
296 |
-
logger.info("=" * 80)
|
297 |
-
logger.info(get_prompt(conversation))
|
298 |
-
|
299 |
-
rounds = len(conversation.messages) // 2
|
300 |
-
|
301 |
-
for _ in range(rounds):
|
302 |
-
current_prompt = get_prompt(conversation)
|
303 |
-
assert isinstance(current_prompt, str) and len(current_prompt) > 0, f"current_prompt = {current_prompt}"
|
304 |
-
if torch.tensor(processor.tokenizer.encode(current_prompt)).size(-1) <= max_length:
|
305 |
-
return conversation_copy
|
306 |
-
|
307 |
-
if len(conversation.messages) % 2 != 0:
|
308 |
-
gr.Error("The messages between user and assistant are not paired.")
|
309 |
-
return
|
310 |
-
|
311 |
-
try:
|
312 |
-
for _ in range(2): # pop out two messages in a row
|
313 |
-
conversation.messages.pop(0)
|
314 |
-
except IndexError:
|
315 |
-
gr.Error("Input text processing failed, unable to respond in this round.")
|
316 |
-
return None
|
317 |
-
|
318 |
-
gr.Error("Prompt could not be generated within max_length limit.")
|
319 |
-
return None
|
320 |
-
|
321 |
-
|
322 |
-
def convert_conversation_to_prompts(conversation: Conversation):
|
323 |
-
"""
|
324 |
-
Convert the conversation to prompts.
|
325 |
-
"""
|
326 |
-
conv_prompts = []
|
327 |
-
last_image = None
|
328 |
-
|
329 |
-
messages = conversation.messages
|
330 |
-
for i in range(0, len(messages), 2):
|
331 |
-
if isinstance(messages[i][1], tuple):
|
332 |
-
text, images = messages[i][1]
|
333 |
-
last_image = images[-1]
|
334 |
-
else:
|
335 |
-
text, images = messages[i][1], []
|
336 |
-
|
337 |
-
prompt = {"role": messages[i][0], "content": text, "images": images}
|
338 |
-
response = {"role": messages[i + 1][0], "content": messages[i + 1][1]}
|
339 |
-
conv_prompts.extend([prompt, response])
|
340 |
-
|
341 |
-
return conv_prompts, last_image
|
342 |
-
|
343 |
-
|
344 |
-
def to_gradio_chatbot(conversation: Conversation) -> list:
|
345 |
-
"""Convert the conversation to gradio chatbot format."""
|
346 |
-
ret = []
|
347 |
-
for i, (_, msg) in enumerate(conversation.messages[conversation.offset :]):
|
348 |
-
if i % 2 == 0:
|
349 |
-
if type(msg) is tuple:
|
350 |
-
msg, images = copy.deepcopy(msg)
|
351 |
-
|
352 |
-
if isinstance(images, list):
|
353 |
-
img_str = ""
|
354 |
-
for j, image in enumerate(images):
|
355 |
-
if isinstance(image, str):
|
356 |
-
with open(image, "rb") as f:
|
357 |
-
data = f.read()
|
358 |
-
img_b64_str = base64.b64encode(data).decode()
|
359 |
-
image_str = (
|
360 |
-
f'<img src="data:image/png;base64,{img_b64_str}" '
|
361 |
-
f'alt="user upload image" style="max-width: 300px; height: auto;" />'
|
362 |
-
)
|
363 |
-
else:
|
364 |
-
image_str = pil_to_base64(image, f"user upload image_{j}", max_size=800, min_size=400)
|
365 |
-
|
366 |
-
img_str += image_str
|
367 |
-
msg = img_str + msg
|
368 |
-
else:
|
369 |
-
pass
|
370 |
-
|
371 |
-
ret.append([msg, None])
|
372 |
-
else:
|
373 |
-
ret[-1][-1] = msg
|
374 |
-
return ret
|
375 |
-
|
376 |
-
|
377 |
-
def to_gradio_history(conversation: Conversation):
|
378 |
-
"""Convert the conversation to gradio history format."""
|
379 |
-
return conversation.messages[conversation.offset :]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
kimi_vl/serve/examples.py
DELETED
@@ -1,54 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import io
|
3 |
-
import base64
|
4 |
-
from PIL import Image
|
5 |
-
|
6 |
-
EXAMPLES_LIST = [
|
7 |
-
[
|
8 |
-
["images/demo1.jpeg"],
|
9 |
-
"Where am I?",
|
10 |
-
],
|
11 |
-
[
|
12 |
-
["images/demo2.jpeg", "images/demo3.jpeg"],
|
13 |
-
"Based on the abstract and introduction above, write a concise and elegant Twitter post that highlights key points and figures without sounding overly promotional. Use English, include emojis and hashtags.",
|
14 |
-
],
|
15 |
-
[
|
16 |
-
["images/demo6.jpeg"],
|
17 |
-
"Create a role play modeled after this cat."
|
18 |
-
],
|
19 |
-
# mulit-frames example
|
20 |
-
[
|
21 |
-
["images/demo4.jpeg", "images/demo5.jpeg"],
|
22 |
-
"Please infer step by step who this manuscript belongs to and what it records."
|
23 |
-
]
|
24 |
-
]
|
25 |
-
|
26 |
-
|
27 |
-
def display_example(image_list, root_dir: str = None):
|
28 |
-
images_html = ""
|
29 |
-
for _, img_path in enumerate(image_list):
|
30 |
-
if root_dir is not None:
|
31 |
-
img_path = os.path.join(root_dir, img_path)
|
32 |
-
|
33 |
-
image = Image.open(img_path)
|
34 |
-
buffered = io.BytesIO()
|
35 |
-
image.save(buffered, format="PNG", quality=100)
|
36 |
-
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
37 |
-
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="{img_path}" style="height:80px; margin-right: 10px;" />'
|
38 |
-
images_html += img_str
|
39 |
-
|
40 |
-
result_html = f"""
|
41 |
-
<div style="display: flex; align-items: center; margin-bottom: 10px;">
|
42 |
-
<div style="flex: 1; margin-right: 10px;">{images_html}</div>
|
43 |
-
</div>
|
44 |
-
"""
|
45 |
-
|
46 |
-
return result_html
|
47 |
-
|
48 |
-
|
49 |
-
def get_examples(root_dir: str = None):
|
50 |
-
examples = []
|
51 |
-
for images, texts in EXAMPLES_LIST:
|
52 |
-
examples.append([images, display_example(images, root_dir), texts])
|
53 |
-
|
54 |
-
return examples
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
kimi_vl/serve/inference.py
DELETED
@@ -1,145 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
import re
|
3 |
-
from threading import Thread
|
4 |
-
from typing import List, Optional
|
5 |
-
|
6 |
-
import torch
|
7 |
-
import spaces
|
8 |
-
from transformers import (
|
9 |
-
AutoModelForCausalLM,
|
10 |
-
AutoProcessor,
|
11 |
-
AutoConfig,
|
12 |
-
StoppingCriteria,
|
13 |
-
StoppingCriteriaList,
|
14 |
-
TextIteratorStreamer,
|
15 |
-
AutoTokenizer
|
16 |
-
)
|
17 |
-
|
18 |
-
from .chat_utils import Conversation, get_conv_template
|
19 |
-
|
20 |
-
logger = logging.getLogger(__name__)
|
21 |
-
|
22 |
-
|
23 |
-
def load_model(model_path: str = "moonshotai/Kimi-Dev-72B"):
|
24 |
-
# hotfix the model to use flash attention 2
|
25 |
-
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
26 |
-
# config._attn_implementation = "flash_attention_2"
|
27 |
-
# config.vision_config._attn_implementation = "flash_attention_2"
|
28 |
-
# config.text_config._attn_implementation = "flash_attention_2"
|
29 |
-
# print("Successfully set the attn_implementation to flash_attention_2")
|
30 |
-
|
31 |
-
model = AutoModelForCausalLM.from_pretrained(
|
32 |
-
model_path,
|
33 |
-
config=config,
|
34 |
-
torch_dtype="auto",
|
35 |
-
device_map="auto",
|
36 |
-
trust_remote_code=True,
|
37 |
-
)
|
38 |
-
# processor = AutoProcessor.from_pretrained(model_path, config=config, trust_remote_code=True)
|
39 |
-
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
40 |
-
|
41 |
-
return model, tokenizer
|
42 |
-
|
43 |
-
|
44 |
-
class StoppingCriteriaSub(StoppingCriteria):
|
45 |
-
def __init__(self, stops=[], encounters=1):
|
46 |
-
super().__init__()
|
47 |
-
self.stops = [stop.to("cuda") for stop in stops]
|
48 |
-
|
49 |
-
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs):
|
50 |
-
for stop in self.stops:
|
51 |
-
if input_ids.shape[-1] < len(stop):
|
52 |
-
continue
|
53 |
-
if torch.all((stop == input_ids[0][-len(stop) :])).item():
|
54 |
-
return True
|
55 |
-
|
56 |
-
return False
|
57 |
-
|
58 |
-
|
59 |
-
def format_messages(
|
60 |
-
conversations: list[Conversation],
|
61 |
-
system_prompt: Optional[str] = "",
|
62 |
-
sft_format: Optional[str] = "kimi-vl",
|
63 |
-
):
|
64 |
-
"""
|
65 |
-
Format the conversations to the input format of the model.
|
66 |
-
"""
|
67 |
-
converstion = get_conv_template(sft_format)
|
68 |
-
converstion.set_system_message(system_prompt)
|
69 |
-
for message in conversations:
|
70 |
-
converstion.append_message(message["role"], message["content"])
|
71 |
-
return converstion
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
@torch.no_grad()
|
77 |
-
@torch.inference_mode()
|
78 |
-
def kimi_dev_generate(
|
79 |
-
model: torch.nn.Module,
|
80 |
-
tokenizer,
|
81 |
-
# processor: AutoProcessor,
|
82 |
-
conversations: list[Conversation],
|
83 |
-
stop_words: list,
|
84 |
-
max_length: int = 256,
|
85 |
-
temperature: float = 1.0,
|
86 |
-
top_p: float = 1.0,
|
87 |
-
chunk_size: int = -1,
|
88 |
-
):
|
89 |
-
# convert conversation to inputs
|
90 |
-
print(f"conversations = {conversations}")
|
91 |
-
# inputs = preprocess(conversations)
|
92 |
-
inputs = tokenizer.tokenize(conversations)
|
93 |
-
inputs = inputs.to(model.device)
|
94 |
-
|
95 |
-
return generate(
|
96 |
-
model,
|
97 |
-
tokenizer,
|
98 |
-
inputs,
|
99 |
-
max_gen_len=max_length,
|
100 |
-
temperature=temperature,
|
101 |
-
top_p=top_p,
|
102 |
-
stop_words=stop_words,
|
103 |
-
chunk_size=chunk_size,
|
104 |
-
)
|
105 |
-
|
106 |
-
|
107 |
-
def generate(
|
108 |
-
model,
|
109 |
-
tokenizer,
|
110 |
-
inputs,
|
111 |
-
max_gen_len: int = 256,
|
112 |
-
temperature: float = 0,
|
113 |
-
top_p: float = 0.95,
|
114 |
-
stop_words: List[str] = [],
|
115 |
-
chunk_size: int = -1,
|
116 |
-
):
|
117 |
-
"""Stream the text output from the multimodality model with prompt and image inputs."""
|
118 |
-
stop_words_ids = [torch.tensor(tokenizer.encode(stop_word)) for stop_word in stop_words]
|
119 |
-
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
|
120 |
-
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
|
121 |
-
|
122 |
-
kwargs = dict(
|
123 |
-
**inputs,
|
124 |
-
max_new_tokens=max_gen_len,
|
125 |
-
do_sample=True,
|
126 |
-
use_cache=True,
|
127 |
-
streamer=streamer,
|
128 |
-
stopping_criteria=stopping_criteria,
|
129 |
-
)
|
130 |
-
|
131 |
-
if temperature > 0:
|
132 |
-
kwargs.update(
|
133 |
-
{
|
134 |
-
"do_sample": True,
|
135 |
-
"top_p": top_p,
|
136 |
-
"temperature": temperature,
|
137 |
-
}
|
138 |
-
)
|
139 |
-
else:
|
140 |
-
kwargs["do_sample"] = False
|
141 |
-
|
142 |
-
thread = Thread(target=model.generate, kwargs=kwargs)
|
143 |
-
thread.start()
|
144 |
-
|
145 |
-
yield from streamer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|