Kimi-Dev-72B / app.py
miaoyibo
1
450e847
raw
history blame
13 kB
import argparse
import gradio as gr
import os
import spaces
import copy
import time
import json
import subprocess
import ast
import pdb
from transformers import TextIteratorStreamer
import threading
from kimi_dev.serve.frontend import reload_javascript
from kimi_dev.serve.utils import (
configure_logger,
)
from kimi_dev.serve.gradio_utils import (
reset_state,
reset_textbox,
transfer_input,
wrap_gen_fn,
)
from kimi_dev.serve.inference import load_model
from kimi_dev.serve.examples import get_examples
from kimi_dev.serve.templates import post_process,get_loc_prompt, clone_github_repo, build_repo_structure, show_project_structure,get_repair_prompt,get_repo_files,get_full_file_paths_and_classes_and_functions,correct_file_path_in_structure
TITLE = """<h1 align="left" style="min-width:200px; margin-top:0;">Chat with Kimi-Dev-72B🔥 </h1>"""
DESCRIPTION_TOP = """<a href="https://github.com/MoonshotAI/Kimi-Dev" target="_blank">Kimi-Dev-72B</a> is a strong and open-source coding LLM for software engineering tasks."""
USAGE_TOP = """Usage: 1. Input a Github url like "https://github.com/astropy/astropy" and submit it. \n2. Input your issue description and chat with Kimi-Dev-72B!"""
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
DEPLOY_MODELS = dict()
logger = configure_logger()
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="Kimi-Dev-72B")
parser.add_argument(
"--local-path",
type=str,
default="",
help="huggingface ckpt, optional",
)
parser.add_argument("--ip", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=7860)
return parser.parse_args()
def fetch_model(model_name: str):
global args, DEPLOY_MODELS
if args.local_path:
model_path = args.local_path
else:
model_path = f"moonshotai/{args.model}"
if model_name in DEPLOY_MODELS:
model_info = DEPLOY_MODELS[model_name]
print(f"{model_name} has been loaded.")
else:
print(f"{model_name} is loading...")
DEPLOY_MODELS[model_name] = load_model(model_path)
print(f"Load {model_name} successfully...")
model_info = DEPLOY_MODELS[model_name]
return model_info
def get_prompt(conversation) -> str:
"""
Get the prompt for the conversation.
"""
system_prompt = conversation.system_template.format(system_message=conversation.system_message)
return system_prompt
def highlight_thinking(msg: str) -> str:
msg = copy.deepcopy(msg)
if "◁think▷" in msg:
msg = msg.replace("◁think▷", "<b style='color:blue;'>🤔Thinking...</b>\n")
if "◁/think▷" in msg:
msg = msg.replace("◁/think▷", "\n<b style='color:purple;'>💡Summary</b>\n")
return msg
@wrap_gen_fn
@spaces.GPU(duration=180)
def predict(
text,
url,
chatbot,
history,
top_p,
temperature,
max_length_tokens,
chunk_size: int = 512,
):
"""
Predict the response for the input text and url.
Args:
text (str): The input text.
url (str): The input url.
chatbot (list): The chatbot.
history (list): The history.
top_p (float): The top-p value.
temperature (float): The temperature value.
repetition_penalty (float): The repetition penalty value.
max_length_tokens (int): The max length tokens.
chunk_size (int): The chunk size.
"""
print("running the prediction function")
try:
model, tokenizer = fetch_model(args.model)
if text == "":
yield chatbot, history, "Empty context."
return
except KeyError:
yield [[text, "No Model Found"]], [], "No Model Found"
return
prompt = text
repo_name = url.split("/")[-1]
repo_path = './local_path/'+repo_name # Local clone path
clone_github_repo(url, repo_path)
structure = build_repo_structure(repo_path)
string_struture = show_project_structure(structure)
loc_prompt = get_loc_prompt(prompt,string_struture)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": loc_prompt}
]
text_for_model = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text_for_model], return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# print("start generating")
if temperature > 0:
generation_kwargs = dict(
**model_inputs,
do_sample=True,
temperature=temperature,
top_p=top_p,
max_new_tokens=max_length_tokens,
streamer=streamer
)
else:
generation_kwargs = dict(
**model_inputs,
do_sample=False,
max_new_tokens=max_length_tokens,
streamer=streamer
)
gen_thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
gen_thread.start()
partial_output = "Start Locating...\n"
for new_text in streamer:
partial_output += new_text
highlight_response = highlight_thinking(partial_output)
yield [[prompt, highlight_response]], [["null test", "null test2"]], "Generating file locations..."
gen_thread.join()
response = partial_output
raw_answer=post_process(response)
model_found_files = raw_answer.strip().split("\n")
print(response)
highlight_response = highlight_thinking(response)
yield [[prompt,highlight_response]], [["null test","null test2"]], "Generate: Success"
# reading file content
contents = ""
for file_path in model_found_files:
file_name = file_path.replace("```","")
print(file_name)
# pdb.set_trace()
to_open_path = repo_path + "/" + file_name
print("to_open_path,",to_open_path)
with open(to_open_path, "r", encoding="utf-8") as f:
content = f.read()
contents += f"{file_name}\n{content}\n\n"
repair_prompt = get_repair_prompt(prompt,contents)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": repair_prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
subprocess.run(["rm", "-rf", repo_path], check=True)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
if temperature > 0:
generation_kwargs = dict(
**model_inputs,
do_sample=True,
temperature=temperature,
top_p=top_p,
max_new_tokens=max_length_tokens,
streamer=streamer
)
else:
generation_kwargs = dict(
**model_inputs,
do_sample=False,
max_new_tokens=max_length_tokens,
streamer=streamer
)
gen_thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
gen_thread.start()
partial_output_repair = "Start Repairing...\n"
yield [[prompt,highlight_response],[repair_prompt,partial_output_repair]], [["null test","null test2"]], "Generate: Success"
time.sleep(5)
for new_text in streamer:
partial_output_repair += new_text
highlight_response = highlight_thinking(partial_output)
highlight_response_repair = highlight_thinking(partial_output_repair)
yield [[prompt, highlight_response], [repair_prompt, highlight_response_repair]], [["null test", "null test2"]], "Generating repair suggestion..."
gen_thread.join()
# yield response, "null test", "Generate: Success"
yield [[prompt,highlight_response],[repair_prompt,highlight_response_repair]], [["null test","null test2"]], "Generate: Success"
def retry(
text,
url,
chatbot,
history,
top_p,
temperature,
max_length_tokens,
chunk_size: int = 512,
):
"""
Retry the response for the input text and url.
"""
if len(history) == 0:
yield (chatbot, history, "Empty context")
return
# chatbot.pop()
# history.pop()
# text = history.pop()[-1]
if type(text) is tuple:
text, _ = text
yield from predict(
text,
url,
chatbot,
history,
top_p,
temperature,
max_length_tokens,
chunk_size,
)
def build_demo(args: argparse.Namespace) -> gr.Blocks:
with gr.Blocks(theme=gr.themes.Soft(), delete_cache=(1800, 1800)) as demo:
history = gr.State([])
input_text = gr.State()
upload_url = gr.State()
with gr.Row():
gr.HTML(TITLE)
status_display = gr.Markdown("Success", elem_id="status_display")
gr.Markdown(DESCRIPTION_TOP)
gr.Markdown(USAGE_TOP)
with gr.Row(equal_height=True):
with gr.Column(scale=4):
with gr.Row():
chatbot = gr.Chatbot(
elem_id="Kimi-Dev-72B",
show_share_button=True,
bubble_full_width=False,
height=400,
# render_markdown=False
)
with gr.Row():
with gr.Column(scale=4):
text_box = gr.Textbox(label="Issue Description", placeholder="Enter issue description", container=False)
with gr.Column(min_width=70):
submit_btn = gr.Button("Send")
# with gr.Column(min_width=70):
# cancel_btn = gr.Button("Stop")
with gr.Row():
empty_btn = gr.Button("🧹 New Conversation")
retry_btn = gr.Button("🔄 Regenerate")
# del_last_btn = gr.Button("🗑️ Remove Last Turn")
def respond(message):
return f"Url submitted!"
with gr.Column():
url_box = gr.Textbox(label="Please input a Github url here",placeholder="Input your url", lines=1)
url_submit_btn = gr.Button("Submit")
output = gr.Textbox(label="Submitted url")
url_submit_btn.click(fn=respond, inputs=upload_url, outputs=output)
# Parameter Setting Tab for control the generation parameters
with gr.Tab(label="Parameter Setting"):
top_p = gr.Slider(minimum=-0, maximum=1.0, value=0.95, step=0.05, interactive=True, label="Top-p")
temperature = gr.Slider(
minimum=0, maximum=1.0, value=1.0, step=0.1, interactive=True, label="Temperature"
)
max_length_tokens = gr.Slider(
minimum=512, maximum=16384, value=8192, step=64, interactive=True, label="Max Length Tokens"
)
gr.Examples(
examples=get_examples(ROOT_DIR),
inputs=[url_box, text_box],
)
# gr.Markdown()
input_widgets = [
input_text,
upload_url,
chatbot,
history,
top_p,
temperature,
max_length_tokens,
]
output_widgets = [chatbot, history, status_display]
transfer_input_args = dict(
fn=transfer_input,
inputs=[text_box, url_box],
outputs=[input_text, upload_url, text_box, upload_url, submit_btn],
show_progress=True,
)
predict_args = dict(fn=predict, inputs=input_widgets, outputs=output_widgets, show_progress=True)
retry_args = dict(fn=retry, inputs=input_widgets, outputs=output_widgets, show_progress=True)
reset_args = dict(fn=reset_textbox, inputs=[], outputs=[text_box, status_display])
predict_events = [
text_box.submit(**transfer_input_args).then(**predict_args),
submit_btn.click(**transfer_input_args).then(**predict_args),
]
empty_btn.click(reset_state, outputs=output_widgets, show_progress=True)
empty_btn.click(**reset_args)
retry_btn.click(**retry_args)
demo.title = "Kimi-Dev-72B"
return demo
def main(args: argparse.Namespace):
demo = build_demo(args)
reload_javascript()
favicon_path = os.path.join("kimi_dev/serve/assets/favicon.ico")
demo.queue().launch(
favicon_path=favicon_path,
server_name=args.ip,
server_port=args.port,
)
# demo.queue().launch(
# favicon_path=favicon_path,
# server_name=args.ip,
# server_port=args.port,
# share=True
# )
if __name__ == "__main__":
args = parse_args()
print(args)
main(args)