import logging from transformers import ( AutoModelForCausalLM, AutoConfig, AutoTokenizer ) logger = logging.getLogger(__name__) def load_model(model_path: str = "moonshotai/Kimi-Dev-72B"): # hotfix the model to use flash attention 2 config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( model_path, config=config, torch_dtype="auto", device_map="auto", trust_remote_code=True, ) tokenizer = AutoTokenizer.from_pretrained(model_path) return model, tokenizer