miaoyibo commited on
Commit
5ce5804
·
1 Parent(s): 1d5f555
Files changed (4) hide show
  1. app.py +40 -114
  2. kimi_dev/serve/inference.py +0 -26
  3. requirements.txt +3 -2
  4. start.sh +42 -0
app.py CHANGED
@@ -8,7 +8,8 @@ import json
8
  import subprocess
9
  import ast
10
  import pdb
11
- from transformers import TextIteratorStreamer
 
12
 
13
  import threading
14
 
@@ -22,9 +23,8 @@ from kimi_dev.serve.gradio_utils import (
22
  transfer_input,
23
  wrap_gen_fn,
24
  )
25
- from kimi_dev.serve.inference import load_model
26
  from kimi_dev.serve.examples import get_examples
27
- from kimi_dev.serve.templates import post_process,get_loc_prompt, clone_github_repo, build_repo_structure, show_project_structure,get_repair_prompt,get_repo_files,get_full_file_paths_and_classes_and_functions,correct_file_path_in_structure
28
 
29
  TITLE = """<h1 align="left" style="min-width:200px; margin-top:0;">Chat with Kimi-Dev-72B🔥 </h1>"""
30
  DESCRIPTION_TOP = """<a href="https://github.com/MoonshotAI/Kimi-Dev" target="_blank">Kimi-Dev-72B</a> is a strong and open-source coding LLM for software engineering tasks."""
@@ -33,6 +33,12 @@ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
33
  DEPLOY_MODELS = dict()
34
  logger = configure_logger()
35
 
 
 
 
 
 
 
36
  def parse_args():
37
  parser = argparse.ArgumentParser()
38
  parser.add_argument("--model", type=str, default="Kimi-Dev-72B")
@@ -47,26 +53,6 @@ def parse_args():
47
  return parser.parse_args()
48
 
49
 
50
- def fetch_model(model_name: str):
51
- global args, DEPLOY_MODELS
52
-
53
- if args.local_path:
54
- model_path = args.local_path
55
- else:
56
- model_path = f"moonshotai/{args.model}"
57
-
58
- if model_name in DEPLOY_MODELS:
59
- model_info = DEPLOY_MODELS[model_name]
60
- print(f"{model_name} has been loaded.")
61
- else:
62
- print(f"{model_name} is loading...")
63
- DEPLOY_MODELS[model_name] = load_model(model_path)
64
- print(f"Load {model_name} successfully...")
65
- model_info = DEPLOY_MODELS[model_name]
66
-
67
- return model_info
68
-
69
-
70
  def get_prompt(conversation) -> str:
71
  """
72
  Get the prompt for the conversation.
@@ -111,20 +97,12 @@ def predict(
111
  """
112
  print("running the prediction function")
113
 
114
- try:
115
- model, tokenizer = fetch_model(args.model)
116
-
117
- if text == "":
118
- yield chatbot, history, "Empty context."
119
- return
120
- except KeyError:
121
- yield [[text, "No Model Found"]], [], "No Model Found"
122
- return
123
-
124
  prompt = text
125
  repo_name = url.split("/")[-1]
126
  print(url)
127
- print(commit_hash)
128
 
129
  repo_path = './local_path/'+repo_name # Local clone path
130
 
@@ -141,50 +119,22 @@ def predict(
141
  {"role": "system", "content": "You are a helpful assistant."},
142
  {"role": "user", "content": loc_prompt}
143
  ]
144
- text_for_model = tokenizer.apply_chat_template(
145
- messages,
146
- tokenize=False,
147
- add_generation_prompt=True
 
 
 
148
  )
149
- model_inputs = tokenizer([text_for_model], return_tensors="pt").to(model.device)
150
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
151
- # print("start generating")
152
-
153
- loc_start_time = time.time()
154
- if temperature > 0:
155
- generation_kwargs = dict(
156
- **model_inputs,
157
- do_sample=True,
158
- temperature=temperature,
159
- top_p=top_p,
160
- max_new_tokens=max_length_tokens,
161
- streamer=streamer
162
- )
163
- else:
164
- generation_kwargs = dict(
165
- **model_inputs,
166
- do_sample=False,
167
- max_new_tokens=max_length_tokens,
168
- streamer=streamer
169
- )
170
- gen_thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
171
- gen_thread.start()
172
-
173
 
174
  partial_output = "Start Locating...\n"
175
-
176
- for new_text in streamer:
177
- partial_output += new_text
178
- highlight_response = highlight_thinking(partial_output)
179
- yield [[prompt, highlight_response]], [["null test", "null test2"]], "Generating file locations..."
180
-
181
- gen_thread.join()
182
- loc_end_time = time.time()
183
- loc_time = loc_end_time - loc_start_time
184
-
185
- encoded_answer = tokenizer(partial_output, padding=True, truncation=True, return_tensors='pt')
186
- print("loc token/s:",len(encoded_answer['input_ids'][0])/loc_time)
187
-
188
  response = partial_output
189
 
190
  raw_answer=post_process(response)
@@ -213,53 +163,29 @@ def predict(
213
  {"role": "system", "content": "You are a helpful assistant."},
214
  {"role": "user", "content": repair_prompt}
215
  ]
216
- text = tokenizer.apply_chat_template(
217
- messages,
218
- tokenize=False,
219
- add_generation_prompt=True
220
- )
221
- model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
222
 
223
  subprocess.run(["rm", "-rf", repo_path], check=True)
224
- repair_start_time = time.time()
225
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
226
- if temperature > 0:
227
- generation_kwargs = dict(
228
- **model_inputs,
229
- do_sample=True,
230
- temperature=temperature,
231
- top_p=top_p,
232
- max_new_tokens=max_length_tokens,
233
- streamer=streamer
234
- )
235
- else:
236
- generation_kwargs = dict(
237
- **model_inputs,
238
- do_sample=False,
239
- max_new_tokens=max_length_tokens,
240
- streamer=streamer
241
- )
242
- gen_thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
243
- gen_thread.start()
244
 
245
- partial_output_repair = "Start Repairing...\n"
246
- yield [[prompt,highlight_response],[repair_prompt,partial_output_repair]], [["null test","null test2"]], "Generate: Success"
247
  time.sleep(5)
248
- for new_text in streamer:
249
- partial_output_repair += new_text
250
- highlight_response = highlight_thinking(partial_output)
251
- highlight_response_repair = highlight_thinking(partial_output_repair)
252
- yield [[prompt, highlight_response], [repair_prompt, highlight_response_repair]], [["null test", "null test2"]], "Generating repair suggestion..."
253
 
254
- gen_thread.join()
255
- repair_end_time = time.time()
256
 
257
- repair_time = repair_end_time - repair_start_time
 
 
 
 
 
 
258
 
259
- encoded_answer = tokenizer(partial_output_repair, padding=True, truncation=True, return_tensors='pt')
260
- print("repair token/s:",len(encoded_answer['input_ids'][0])/repair_time)
 
 
 
 
 
261
 
262
- # yield response, "null test", "Generate: Success"
263
  yield [[prompt,highlight_response],[repair_prompt,highlight_response_repair]], [["null test","null test2"]], "Generate: Success"
264
 
265
 
 
8
  import subprocess
9
  import ast
10
  import pdb
11
+
12
+ import openai
13
 
14
  import threading
15
 
 
23
  transfer_input,
24
  wrap_gen_fn,
25
  )
 
26
  from kimi_dev.serve.examples import get_examples
27
+ from kimi_dev.serve.templates import post_process,get_loc_prompt, clone_github_repo, build_repo_structure, show_project_structure,get_repair_prompt
28
 
29
  TITLE = """<h1 align="left" style="min-width:200px; margin-top:0;">Chat with Kimi-Dev-72B🔥 </h1>"""
30
  DESCRIPTION_TOP = """<a href="https://github.com/MoonshotAI/Kimi-Dev" target="_blank">Kimi-Dev-72B</a> is a strong and open-source coding LLM for software engineering tasks."""
 
33
  DEPLOY_MODELS = dict()
34
  logger = configure_logger()
35
 
36
+
37
+ client = openai.OpenAI(
38
+ base_url="http://localhost:8080/v1", # vLLM 服务地址
39
+ api_key="EMPTY" # 不验证,只要不是 None
40
+ )
41
+
42
  def parse_args():
43
  parser = argparse.ArgumentParser()
44
  parser.add_argument("--model", type=str, default="Kimi-Dev-72B")
 
53
  return parser.parse_args()
54
 
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  def get_prompt(conversation) -> str:
57
  """
58
  Get the prompt for the conversation.
 
97
  """
98
  print("running the prediction function")
99
 
100
+ openai.api_key = "EMPTY"
101
+ openai.base_url = "http://localhost:8080/v1"
 
 
 
 
 
 
 
 
102
  prompt = text
103
  repo_name = url.split("/")[-1]
104
  print(url)
105
+ # print(commit_hash)
106
 
107
  repo_path = './local_path/'+repo_name # Local clone path
108
 
 
119
  {"role": "system", "content": "You are a helpful assistant."},
120
  {"role": "user", "content": loc_prompt}
121
  ]
122
+
123
+ response = client.chat.completions.create(
124
+ model="kimi-dev", # 和vLLM启动时的一致
125
+ messages=messages,
126
+ stream=True,
127
+ temperature=temperature,
128
+ max_tokens=max_length_tokens,
129
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  partial_output = "Start Locating...\n"
132
+ for chunk in response:
133
+ delta = chunk.choices[0].delta
134
+ if delta and delta.content:
135
+ partial_output += delta.content
136
+ highlight_response = highlight_thinking(partial_output)
137
+ yield [[prompt, highlight_response]], [["null test", "null test2"]], "Generating file locations..."
 
 
 
 
 
 
 
138
  response = partial_output
139
 
140
  raw_answer=post_process(response)
 
163
  {"role": "system", "content": "You are a helpful assistant."},
164
  {"role": "user", "content": repair_prompt}
165
  ]
 
 
 
 
 
 
166
 
167
  subprocess.run(["rm", "-rf", repo_path], check=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
+
 
170
  time.sleep(5)
 
 
 
 
 
171
 
 
 
172
 
173
+ response = client.chat.completions.create(
174
+ model="kimi-dev", # 和vLLM启动时的一致
175
+ messages=messages,
176
+ stream=True,
177
+ temperature=temperature,
178
+ max_tokens=max_length_tokens,
179
+ )
180
 
181
+ partial_output_repair = "Start Repairing...\n"
182
+ for chunk in response:
183
+ delta = chunk.choices[0].delta
184
+ if delta and delta.content:
185
+ partial_output_repair += delta.content
186
+ highlight_response_repair = highlight_thinking(partial_output_repair)
187
+ yield [[prompt,highlight_response],[repair_prompt,highlight_response_repair]], [["null test","null test2"]], "Generating file repairing..."
188
 
 
189
  yield [[prompt,highlight_response],[repair_prompt,highlight_response_repair]], [["null test","null test2"]], "Generate: Success"
190
 
191
 
kimi_dev/serve/inference.py DELETED
@@ -1,26 +0,0 @@
1
- import logging
2
-
3
- from transformers import (
4
- AutoModelForCausalLM,
5
- AutoConfig,
6
- AutoTokenizer
7
- )
8
-
9
- logger = logging.getLogger(__name__)
10
-
11
-
12
- def load_model(model_path: str = "moonshotai/Kimi-Dev-72B"):
13
- # hotfix the model to use flash attention 2
14
- config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
15
-
16
- model = AutoModelForCausalLM.from_pretrained(
17
- model_path,
18
- config=config,
19
- torch_dtype="auto",
20
- device_map="auto",
21
- trust_remote_code=True,
22
- )
23
-
24
- tokenizer = AutoTokenizer.from_pretrained(model_path)
25
-
26
- return model, tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,4 +1,3 @@
1
- torchvision==0.20.0
2
  transformers==4.51.1
3
  accelerate
4
  sentencepiece
@@ -17,4 +16,6 @@ tqdm
17
  colorama
18
  Pygments
19
  markdown
20
- SentencePiece
 
 
 
 
1
  transformers==4.51.1
2
  accelerate
3
  sentencepiece
 
16
  colorama
17
  Pygments
18
  markdown
19
+ SentencePiece
20
+ vllm
21
+ openai
start.sh ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ python -m vllm.entrypoints.openai.api_server \
4
+ --model moonshotai/Kimi-Dev-72B \
5
+ --tensor-parallel-size 4 \
6
+ --max-num-seqs 8 \
7
+ --max-model-len 131072 \
8
+ --gpu-memory-utilization 0.9 \
9
+ --host localhost \
10
+ --served-model-name kimi-dev \
11
+ --port 8080
12
+
13
+ SERVICE_URL="http://localhost:8080/v1/models"
14
+ TIMEOUT=300 # 最大等待秒数
15
+ INTERVAL=5 # 检测间隔秒数
16
+ ELAPSED=0
17
+
18
+ echo "[*] 等待 vLLM 服务启动,最长等待 ${TIMEOUT}s ..."
19
+
20
+ while true; do
21
+ # 尝试请求模型列表接口,检查是否包含指定模型
22
+ if curl -s "$SERVICE_URL" | grep -q "moonshotai"; then
23
+ echo "✅ vLLM 服务已成功启动!"
24
+ break
25
+ fi
26
+
27
+ if [ $ELAPSED -ge $TIMEOUT ]; then
28
+ echo "❌ 等待超时,vLLM 服务未启动成功。"
29
+ exit 1
30
+ fi
31
+
32
+ echo "⏳ 服务尚未就绪,等待 ${INTERVAL}s 后重试..."
33
+ sleep $INTERVAL
34
+ ELAPSED=$((ELAPSED + INTERVAL))
35
+ done
36
+
37
+ # 这里写部署成功后要执行的命令
38
+ echo "[*] 现在执行后续操作..."
39
+
40
+ # 例如启动前端服务、运行测试脚本等
41
+ # ./start_frontend.sh
42
+ python app.py