Spaces:
Running
on
L40S
Running
on
L40S
import argparse | |
import gradio as gr | |
import os | |
import spaces | |
import copy | |
import time | |
import json | |
import subprocess | |
import ast | |
import pdb | |
import openai | |
from kimi_dev.serve.frontend import reload_javascript | |
from kimi_dev.serve.utils import ( | |
configure_logger, | |
) | |
from kimi_dev.serve.gradio_utils import ( | |
reset_state, | |
reset_textbox, | |
transfer_input, | |
wrap_gen_fn, | |
) | |
from kimi_dev.serve.examples import get_examples | |
from kimi_dev.serve.templates import post_process,get_loc_prompt, clone_github_repo, build_repo_structure, show_project_structure,get_repair_prompt,get_full_file_paths_and_classes_and_functions,correct_file_path_in_structure,correct_file_paths | |
TITLE = """<h1 align="left" style="min-width:200px; margin-top:0;">Chat with Kimi-Dev-72B🔥 </h1>""" | |
DESCRIPTION_TOP = """<a href="https://github.com/MoonshotAI/Kimi-Dev" target="_blank">Kimi-Dev-72B</a> is a strong and open-source coding LLM for software engineering tasks.""" | |
USAGE_TOP = """Usage: 1. Input a Github url like "https://github.com/astropy/astropy" and a commit id and submit them. \n2. Input your issue description and chat with Kimi-Dev-72B!""" | |
ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
DEPLOY_MODELS = dict() | |
logger = configure_logger() | |
client = openai.OpenAI( | |
base_url="http://localhost:8080/v1", # vLLM 服务地址 | |
api_key="EMPTY" # 不验证,只要不是 None | |
) | |
def parse_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--model", type=str, default="Kimi-Dev-72B") | |
parser.add_argument( | |
"--local-path", | |
type=str, | |
default="", | |
help="huggingface ckpt, optional", | |
) | |
parser.add_argument("--ip", type=str, default="0.0.0.0") | |
parser.add_argument("--port", type=int, default=7860) | |
return parser.parse_args() | |
def get_prompt(conversation) -> str: | |
""" | |
Get the prompt for the conversation. | |
""" | |
system_prompt = conversation.system_template.format(system_message=conversation.system_message) | |
return system_prompt | |
def highlight_thinking(msg: str) -> str: | |
msg = copy.deepcopy(msg) | |
if "◁think▷" in msg: | |
msg = msg.replace("◁think▷", "<b style='color:blue;'>🤔Thinking...</b>\n") | |
if "◁/think▷" in msg: | |
msg = msg.replace("◁/think▷", "\n<b style='color:purple;'>💡Summary</b>\n") | |
return msg | |
def predict( | |
text, | |
url, | |
commit_hash, | |
chatbot, | |
history, | |
top_p, | |
temperature, | |
max_length_tokens, | |
chunk_size: int = 512, | |
): | |
""" | |
Predict the response for the input text and url. | |
Args: | |
text (str): The input text. | |
url (str): The input url. | |
chatbot (list): The chatbot. | |
history (list): The history. | |
top_p (float): The top-p value. | |
temperature (float): The temperature value. | |
repetition_penalty (float): The repetition penalty value. | |
max_length_tokens (int): The max length tokens. | |
chunk_size (int): The chunk size. | |
""" | |
print("running the prediction function") | |
openai.api_key = "EMPTY" | |
openai.base_url = "http://localhost:8080/v1" | |
prompt = text | |
repo_name = url.split("/")[-1] | |
print(url) | |
print(commit_hash) | |
repo_path = './local_path/'+repo_name # Local clone path | |
clone_github_repo(url, repo_path, commit_hash) | |
print("repo cloned") | |
structure = build_repo_structure(repo_path) | |
string_struture = show_project_structure(structure) | |
loc_prompt = get_loc_prompt(prompt,string_struture) | |
messages = [ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "user", "content": loc_prompt} | |
] | |
response = client.chat.completions.create( | |
model="kimi-dev", | |
messages=messages, | |
stream=True, | |
temperature=temperature, | |
max_tokens=max_length_tokens, | |
) | |
partial_output = "Start Locating...\n" | |
for chunk in response: | |
delta = chunk.choices[0].delta | |
if delta and delta.content: | |
partial_output += delta.content | |
highlight_response = highlight_thinking(partial_output) | |
yield [[prompt, highlight_response]], [["null test", "null test2"]], "Generating file locations..." | |
response = partial_output | |
raw_answer=post_process(response) | |
model_found_files = raw_answer.strip().split("\n") | |
files, _, _ = get_full_file_paths_and_classes_and_functions(structure) | |
model_found_files = [correct_file_path_in_structure(file, structure) for file in model_found_files] | |
found_files = correct_file_paths(model_found_files, files) | |
highlight_response = highlight_thinking(response) | |
yield [[prompt,highlight_response]], [["null test","null test2"]], "Generate: Success" | |
contents = "" | |
for file_path in found_files: | |
file_name = file_path.replace("```","") | |
print(file_name) | |
to_open_path = repo_path + "/" + file_name | |
with open(to_open_path, "r", encoding="utf-8") as f: | |
content = f.read() | |
contents += f"{file_name}\n{content}\n\n" | |
repair_prompt = get_repair_prompt(prompt,contents) | |
messages = [ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "user", "content": repair_prompt} | |
] | |
subprocess.run(["rm", "-rf", repo_path], check=True) | |
time.sleep(5) | |
response = client.chat.completions.create( | |
model="kimi-dev", | |
messages=messages, | |
stream=True, | |
temperature=temperature, | |
max_tokens=max_length_tokens, | |
) | |
partial_output_repair = "Start Repairing...\n" | |
for chunk in response: | |
delta = chunk.choices[0].delta | |
if delta and delta.content: | |
partial_output_repair += delta.content | |
highlight_response_repair = highlight_thinking(partial_output_repair) | |
yield [[prompt,highlight_response],[repair_prompt,highlight_response_repair]], [["null test","null test2"]], "Generating file repairing..." | |
yield [[prompt,highlight_response],[repair_prompt,highlight_response_repair]], [["null test","null test2"]], "Generate: Success" | |
def retry( | |
text, | |
url, | |
commit_hash, | |
chatbot, | |
history, | |
top_p, | |
temperature, | |
max_length_tokens, | |
chunk_size: int = 512, | |
): | |
""" | |
Retry the response for the input text and url. | |
""" | |
if len(history) == 0: | |
yield (chatbot, history, "Empty context") | |
return | |
if type(text) is tuple: | |
text, _ = text | |
yield from predict( | |
text, | |
url, | |
commit_hash, | |
chatbot, | |
history, | |
top_p, | |
temperature, | |
max_length_tokens, | |
chunk_size, | |
) | |
def build_demo(args: argparse.Namespace) -> gr.Blocks: | |
with gr.Blocks(theme=gr.themes.Soft(), delete_cache=(1800, 1800)) as demo: | |
history = gr.State([]) | |
input_text = gr.State() | |
upload_url = gr.State() | |
commit_hash = gr.State() | |
with gr.Row(): | |
gr.HTML(TITLE) | |
status_display = gr.Markdown("Success", elem_id="status_display") | |
gr.Markdown(DESCRIPTION_TOP) | |
gr.Markdown(USAGE_TOP) | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=4): | |
with gr.Row(): | |
chatbot = gr.Chatbot( | |
elem_id="Kimi-Dev-72B", | |
show_share_button=True, | |
bubble_full_width=False, | |
height=400, | |
# render_markdown=False | |
) | |
with gr.Row(): | |
with gr.Column(scale=4): | |
text_box = gr.Textbox(label="Issue Description", placeholder="Enter issue description", container=False) | |
with gr.Column(min_width=70): | |
submit_btn = gr.Button("Send") | |
with gr.Row(): | |
empty_btn = gr.Button("🧹 New Conversation") | |
retry_btn = gr.Button("🔄 Regenerate") | |
def respond(message): | |
return f"Url and commit hash submitted!" | |
with gr.Column(): | |
url_box = gr.Textbox(label="Please input a Github url here",placeholder="Input your url", lines=1) | |
commit_hash_box = gr.Textbox(label="Please input a commit hash here",placeholder="Input your commit hash", lines=1) | |
url_submit_btn = gr.Button("Submit") | |
output = gr.Textbox(label="Submitted url and commit") | |
url_submit_btn.click(fn=respond, inputs=upload_url, outputs=output) | |
# Parameter Setting Tab for control the generation parameters | |
with gr.Tab(label="Parameter Setting"): | |
top_p = gr.Slider(minimum=-0, maximum=1.0, value=0.95, step=0.05, interactive=True, label="Top-p") | |
temperature = gr.Slider( | |
minimum=0, maximum=1.0, value=1.0, step=0.1, interactive=True, label="Temperature" | |
) | |
max_length_tokens = gr.Slider( | |
minimum=512, maximum=32768, value=16384, step=64, interactive=True, label="Max Length Tokens" | |
) | |
gr.Examples( | |
examples=get_examples(ROOT_DIR), | |
inputs=[url_box, text_box, commit_hash_box], | |
) | |
# gr.Markdown() | |
input_widgets = [ | |
input_text, | |
upload_url, | |
commit_hash, | |
chatbot, | |
history, | |
top_p, | |
temperature, | |
max_length_tokens, | |
] | |
output_widgets = [chatbot, history, status_display] | |
transfer_input_args = dict( | |
fn=transfer_input, | |
inputs=[text_box, url_box,commit_hash_box], | |
outputs=[input_text, upload_url, text_box, commit_hash, submit_btn], | |
show_progress=True, | |
) | |
predict_args = dict(fn=predict, inputs=input_widgets, outputs=output_widgets, show_progress=True) | |
retry_args = dict(fn=retry, inputs=input_widgets, outputs=output_widgets, show_progress=True) | |
reset_args = dict(fn=reset_textbox, inputs=[], outputs=[text_box, status_display]) | |
predict_events = [ | |
text_box.submit(**transfer_input_args).then(**predict_args), | |
submit_btn.click(**transfer_input_args).then(**predict_args), | |
] | |
empty_btn.click(reset_state, outputs=output_widgets, show_progress=True) | |
empty_btn.click(**reset_args) | |
retry_btn.click(**retry_args) | |
demo.title = "Kimi-Dev-72B" | |
return demo | |
def main(args: argparse.Namespace): | |
demo = build_demo(args) | |
reload_javascript() | |
favicon_path = os.path.join("kimi_dev/serve/assets/favicon.ico") | |
demo.queue().launch( | |
favicon_path=favicon_path, | |
server_name=args.ip, | |
server_port=args.port, | |
share=True | |
) | |
if __name__ == "__main__": | |
print("Start serving vllm...") | |
script_path = os.path.join(os.path.dirname(__file__), "serve_vllm.sh") | |
subprocess.Popen(["bash", script_path]) | |
time.sleep(450) | |
print("Served vllm!") | |
args = parse_args() | |
print(args) | |
main(args) | |