Kimi-Dev-72B / app.py
miaoyibo
1
5ce5804
raw
history blame
11.4 kB
import argparse
import gradio as gr
import os
import spaces
import copy
import time
import json
import subprocess
import ast
import pdb
import openai
import threading
from kimi_dev.serve.frontend import reload_javascript
from kimi_dev.serve.utils import (
configure_logger,
)
from kimi_dev.serve.gradio_utils import (
reset_state,
reset_textbox,
transfer_input,
wrap_gen_fn,
)
from kimi_dev.serve.examples import get_examples
from kimi_dev.serve.templates import post_process,get_loc_prompt, clone_github_repo, build_repo_structure, show_project_structure,get_repair_prompt
TITLE = """<h1 align="left" style="min-width:200px; margin-top:0;">Chat with Kimi-Dev-72B🔥 </h1>"""
DESCRIPTION_TOP = """<a href="https://github.com/MoonshotAI/Kimi-Dev" target="_blank">Kimi-Dev-72B</a> is a strong and open-source coding LLM for software engineering tasks."""
USAGE_TOP = """Usage: 1. Input a Github url like "https://github.com/astropy/astropy" and a commit id and submit them. \n2. Input your issue description and chat with Kimi-Dev-72B!"""
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
DEPLOY_MODELS = dict()
logger = configure_logger()
client = openai.OpenAI(
base_url="http://localhost:8080/v1", # vLLM 服务地址
api_key="EMPTY" # 不验证,只要不是 None
)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="Kimi-Dev-72B")
parser.add_argument(
"--local-path",
type=str,
default="",
help="huggingface ckpt, optional",
)
parser.add_argument("--ip", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=7860)
return parser.parse_args()
def get_prompt(conversation) -> str:
"""
Get the prompt for the conversation.
"""
system_prompt = conversation.system_template.format(system_message=conversation.system_message)
return system_prompt
def highlight_thinking(msg: str) -> str:
msg = copy.deepcopy(msg)
if "◁think▷" in msg:
msg = msg.replace("◁think▷", "<b style='color:blue;'>🤔Thinking...</b>\n")
if "◁/think▷" in msg:
msg = msg.replace("◁/think▷", "\n<b style='color:purple;'>💡Summary</b>\n")
return msg
@wrap_gen_fn
@spaces.GPU(duration=180)
def predict(
text,
url,
commit_hash,
chatbot,
history,
top_p,
temperature,
max_length_tokens,
chunk_size: int = 512,
):
"""
Predict the response for the input text and url.
Args:
text (str): The input text.
url (str): The input url.
chatbot (list): The chatbot.
history (list): The history.
top_p (float): The top-p value.
temperature (float): The temperature value.
repetition_penalty (float): The repetition penalty value.
max_length_tokens (int): The max length tokens.
chunk_size (int): The chunk size.
"""
print("running the prediction function")
openai.api_key = "EMPTY"
openai.base_url = "http://localhost:8080/v1"
prompt = text
repo_name = url.split("/")[-1]
print(url)
# print(commit_hash)
repo_path = './local_path/'+repo_name # Local clone path
clone_github_repo(url, repo_path, commit_hash)
print("repo cloned")
structure = build_repo_structure(repo_path)
# print("type(structure)",type(structure))
string_struture = show_project_structure(structure)
# print("string_struturem,",string_struture)
loc_prompt = get_loc_prompt(prompt,string_struture)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": loc_prompt}
]
response = client.chat.completions.create(
model="kimi-dev", # 和vLLM启动时的一致
messages=messages,
stream=True,
temperature=temperature,
max_tokens=max_length_tokens,
)
partial_output = "Start Locating...\n"
for chunk in response:
delta = chunk.choices[0].delta
if delta and delta.content:
partial_output += delta.content
highlight_response = highlight_thinking(partial_output)
yield [[prompt, highlight_response]], [["null test", "null test2"]], "Generating file locations..."
response = partial_output
raw_answer=post_process(response)
model_found_files = raw_answer.strip().split("\n")
# print(response)
highlight_response = highlight_thinking(response)
yield [[prompt,highlight_response]], [["null test","null test2"]], "Generate: Success"
# reading file content
contents = ""
for file_path in model_found_files:
file_name = file_path.replace("```","")
print(file_name)
# pdb.set_trace()
to_open_path = repo_path + "/" + file_name
with open(to_open_path, "r", encoding="utf-8") as f:
content = f.read()
contents += f"{file_name}\n{content}\n\n"
repair_prompt = get_repair_prompt(prompt,contents)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": repair_prompt}
]
subprocess.run(["rm", "-rf", repo_path], check=True)
time.sleep(5)
response = client.chat.completions.create(
model="kimi-dev", # 和vLLM启动时的一致
messages=messages,
stream=True,
temperature=temperature,
max_tokens=max_length_tokens,
)
partial_output_repair = "Start Repairing...\n"
for chunk in response:
delta = chunk.choices[0].delta
if delta and delta.content:
partial_output_repair += delta.content
highlight_response_repair = highlight_thinking(partial_output_repair)
yield [[prompt,highlight_response],[repair_prompt,highlight_response_repair]], [["null test","null test2"]], "Generating file repairing..."
yield [[prompt,highlight_response],[repair_prompt,highlight_response_repair]], [["null test","null test2"]], "Generate: Success"
def retry(
text,
url,
commit_hash,
chatbot,
history,
top_p,
temperature,
max_length_tokens,
chunk_size: int = 512,
):
"""
Retry the response for the input text and url.
"""
if len(history) == 0:
yield (chatbot, history, "Empty context")
return
# chatbot.pop()
# history.pop()
# text = history.pop()[-1]
if type(text) is tuple:
text, _ = text
yield from predict(
text,
url,
commit_hash,
chatbot,
history,
top_p,
temperature,
max_length_tokens,
chunk_size,
)
def build_demo(args: argparse.Namespace) -> gr.Blocks:
with gr.Blocks(theme=gr.themes.Soft(), delete_cache=(1800, 1800)) as demo:
history = gr.State([])
input_text = gr.State()
upload_url = gr.State()
commit_hash = gr.State()
with gr.Row():
gr.HTML(TITLE)
status_display = gr.Markdown("Success", elem_id="status_display")
gr.Markdown(DESCRIPTION_TOP)
gr.Markdown(USAGE_TOP)
with gr.Row(equal_height=True):
with gr.Column(scale=4):
with gr.Row():
chatbot = gr.Chatbot(
elem_id="Kimi-Dev-72B",
show_share_button=True,
bubble_full_width=False,
height=400,
# render_markdown=False
)
with gr.Row():
with gr.Column(scale=4):
text_box = gr.Textbox(label="Issue Description", placeholder="Enter issue description", container=False)
with gr.Column(min_width=70):
submit_btn = gr.Button("Send")
# with gr.Column(min_width=70):
# cancel_btn = gr.Button("Stop")
with gr.Row():
empty_btn = gr.Button("🧹 New Conversation")
retry_btn = gr.Button("🔄 Regenerate")
# del_last_btn = gr.Button("🗑️ Remove Last Turn")
def respond(message):
return f"Url and commit hash submitted!"
with gr.Column():
url_box = gr.Textbox(label="Please input a Github url here",placeholder="Input your url", lines=1)
commit_hash_box = gr.Textbox(label="Please input a commit hash here",placeholder="Input your commit hash", lines=1)
# url_submit_btn = gr.Button("Submit")
url_submit_btn = gr.Button("Submit")
output = gr.Textbox(label="Submitted url and commit")
url_submit_btn.click(fn=respond, inputs=upload_url, outputs=output)
# Parameter Setting Tab for control the generation parameters
with gr.Tab(label="Parameter Setting"):
top_p = gr.Slider(minimum=-0, maximum=1.0, value=0.95, step=0.05, interactive=True, label="Top-p")
temperature = gr.Slider(
minimum=0, maximum=1.0, value=1.0, step=0.1, interactive=True, label="Temperature"
)
max_length_tokens = gr.Slider(
minimum=512, maximum=16384, value=8192, step=64, interactive=True, label="Max Length Tokens"
)
gr.Examples(
examples=get_examples(ROOT_DIR),
inputs=[url_box, text_box, commit_hash_box],
)
# gr.Markdown()
input_widgets = [
input_text,
upload_url,
commit_hash,
chatbot,
history,
top_p,
temperature,
max_length_tokens,
]
output_widgets = [chatbot, history, status_display]
transfer_input_args = dict(
fn=transfer_input,
inputs=[text_box, url_box,commit_hash_box],
outputs=[input_text, upload_url, text_box, commit_hash, submit_btn],
show_progress=True,
)
predict_args = dict(fn=predict, inputs=input_widgets, outputs=output_widgets, show_progress=True)
retry_args = dict(fn=retry, inputs=input_widgets, outputs=output_widgets, show_progress=True)
reset_args = dict(fn=reset_textbox, inputs=[], outputs=[text_box, status_display])
predict_events = [
text_box.submit(**transfer_input_args).then(**predict_args),
submit_btn.click(**transfer_input_args).then(**predict_args),
]
empty_btn.click(reset_state, outputs=output_widgets, show_progress=True)
empty_btn.click(**reset_args)
retry_btn.click(**retry_args)
demo.title = "Kimi-Dev-72B"
return demo
def main(args: argparse.Namespace):
demo = build_demo(args)
reload_javascript()
favicon_path = os.path.join("kimi_dev/serve/assets/favicon.ico")
# demo.queue().launch(
# favicon_path=favicon_path,
# server_name=args.ip,
# server_port=args.port,
# )
demo.queue().launch(
favicon_path=favicon_path,
server_name=args.ip,
server_port=args.port,
share=True
)
if __name__ == "__main__":
args = parse_args()
print(args)
main(args)