miaoyibo commited on
Commit
5d6758b
·
1 Parent(s): f3a9564
Files changed (2) hide show
  1. app.py +2 -2
  2. kimi_vl/serve/inference.py +7 -7
app.py CHANGED
@@ -127,7 +127,7 @@ def predict(
127
  """
128
  print("running the prediction function")
129
  try:
130
- model, processor = fetch_model(args.model)
131
 
132
  if text == "":
133
  yield chatbot, history, "Empty context."
@@ -157,9 +157,9 @@ def predict(
157
  text,
158
  pil_images,
159
  history,
160
- processor,
161
  max_length=max_context_length_tokens,
162
  )
 
163
  all_conv, last_image = convert_conversation_to_prompts(conversation)
164
  stop_words = conversation.stop_str
165
  gradio_chatbot_output = to_gradio_chatbot(conversation)
 
127
  """
128
  print("running the prediction function")
129
  try:
130
+ model = fetch_model(args.model)
131
 
132
  if text == "":
133
  yield chatbot, history, "Empty context."
 
157
  text,
158
  pil_images,
159
  history,
 
160
  max_length=max_context_length_tokens,
161
  )
162
+ print(conversation)
163
  all_conv, last_image = convert_conversation_to_prompts(conversation)
164
  stop_words = conversation.stop_str
165
  gradio_chatbot_output = to_gradio_chatbot(conversation)
kimi_vl/serve/inference.py CHANGED
@@ -19,13 +19,13 @@ from .chat_utils import Conversation, get_conv_template
19
  logger = logging.getLogger(__name__)
20
 
21
 
22
- def load_model(model_path: str = "moonshotai/Kimi-VL-A3B-Thinking"):
23
  # hotfix the model to use flash attention 2
24
  config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
25
- config._attn_implementation = "flash_attention_2"
26
- config.vision_config._attn_implementation = "flash_attention_2"
27
- config.text_config._attn_implementation = "flash_attention_2"
28
- print("Successfully set the attn_implementation to flash_attention_2")
29
 
30
  model = AutoModelForCausalLM.from_pretrained(
31
  model_path,
@@ -34,9 +34,9 @@ def load_model(model_path: str = "moonshotai/Kimi-VL-A3B-Thinking"):
34
  device_map="auto",
35
  trust_remote_code=True,
36
  )
37
- processor = AutoProcessor.from_pretrained(model_path, config=config, trust_remote_code=True)
38
 
39
- return model, processor
40
 
41
 
42
  class StoppingCriteriaSub(StoppingCriteria):
 
19
  logger = logging.getLogger(__name__)
20
 
21
 
22
+ def load_model(model_path: str = "moonshotai/Kimi-Dev-72B"):
23
  # hotfix the model to use flash attention 2
24
  config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
25
+ # config._attn_implementation = "flash_attention_2"
26
+ # config.vision_config._attn_implementation = "flash_attention_2"
27
+ # config.text_config._attn_implementation = "flash_attention_2"
28
+ # print("Successfully set the attn_implementation to flash_attention_2")
29
 
30
  model = AutoModelForCausalLM.from_pretrained(
31
  model_path,
 
34
  device_map="auto",
35
  trust_remote_code=True,
36
  )
37
+ # processor = AutoProcessor.from_pretrained(model_path, config=config, trust_remote_code=True)
38
 
39
+ return model
40
 
41
 
42
  class StoppingCriteriaSub(StoppingCriteria):