File size: 1,565 Bytes
bc2a0b5
a55dc79
 
ebb4c7f
a55dc79
 
ebb4c7f
bc2a0b5
 
a55dc79
bc2a0b5
ebb4c7f
c27e6f6
bc2a0b5
ebb4c7f
a55dc79
bc2a0b5
ebb4c7f
bc2a0b5
c27e6f6
a55dc79
bc2a0b5
 
 
 
77c95fb
bc2a0b5
 
c27e6f6
a55dc79
 
c27e6f6
bc2a0b5
 
a55dc79
bc2a0b5
c27e6f6
bc2a0b5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# handler.py
from typing import Dict, Any
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from accelerate import init_empty_weights, load_checkpoint_and_dispatch

class EndpointHandler:
    def __init__(self, model_dir: str, **kw):
        self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)

        # ① 空壳模型
        with init_empty_weights():
            base = AutoModelForCausalLM.from_pretrained(
                model_dir, torch_dtype=torch.float16, trust_remote_code=True
            )

        # ② 分片加载
        self.model = load_checkpoint_and_dispatch(
            base, checkpoint=model_dir, device_map="auto", dtype=torch.float16
        ).eval()

        # ③ 锁定“默认 GPU”= 词嵌入所在 GPU
        self.embed_device = self.model.get_input_embeddings().weight.device
        torch.cuda.set_device(self.embed_device)     # ← 关键 1
        print(">>> embedding on", self.embed_device)

        # 生成参数
        self.gen_kwargs = dict(max_new_tokens=512, temperature=0.7, top_p=0.9, do_sample=True)

    def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
        prompt = data["inputs"]

        # 把 *所有* 输入张量放到 embed_device
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.embed_device)  # ← 关键 2
        with torch.inference_mode():
            out_ids = self.model.generate(**inputs, **self.gen_kwargs)

        return {"generated_text": self.tokenizer.decode(out_ids[0], skip_special_tokens=True)}