Kimi-Dev-72B / kimi_dev /serve /inference.py
miaoyibo
kimi_dev
46a0b0f
raw
history blame
605 Bytes
import logging
from transformers import (
AutoModelForCausalLM,
AutoConfig,
AutoTokenizer
)
logger = logging.getLogger(__name__)
def load_model(model_path: str = "moonshotai/Kimi-Dev-72B"):
# hotfix the model to use flash attention 2
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_path,
config=config,
torch_dtype="auto",
device_map="auto",
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_path)
return model, tokenizer