Hsquad commited on
Commit
a198bf3
·
verified ·
1 Parent(s): 5c3eb3f

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +56 -0
handler.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any
2
+ from transformers import JanusForConditionalGeneration, JanusProcessor
3
+ import torch, base64, io, PIL.Image as Image
4
+
5
+ class EndpointHandler:
6
+ """
7
+ Works for:
8
+ • text → text chat completions
9
+ • text → image generation (pass {"generation_mode":"image"})
10
+ """
11
+ def __init__(self, model_path: str):
12
+ self.processor = JanusProcessor.from_pretrained(
13
+ model_path, trust_remote_code=True
14
+ )
15
+ self.model = JanusForConditionalGeneration.from_pretrained(
16
+ model_path,
17
+ torch_dtype=torch.bfloat16, # fp16 also fine
18
+ device_map="auto",
19
+ load_in_4bit=True # comment out on bigger GPUs
20
+ )
21
+
22
+ # ---- each request lands here ----
23
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
24
+ prompt = data.get("prompt") or data.get("inputs")
25
+ gen_mode = data.get("generation_mode", "text") # "text" | "image"
26
+
27
+ templ = self.processor.apply_chat_template(
28
+ [{"role": "user",
29
+ "content": [{"type": "text", "text": prompt}]}],
30
+ add_generation_prompt=True,
31
+ )
32
+
33
+ inputs = self.processor(
34
+ text=templ,
35
+ generation_mode=gen_mode,
36
+ return_tensors="pt"
37
+ ).to(self.model.device)
38
+
39
+ out = self.model.generate(
40
+ **inputs,
41
+ generation_mode=gen_mode,
42
+ max_new_tokens=data.get("max_new_tokens", 128)
43
+ )
44
+
45
+ if gen_mode == "image":
46
+ img = self.processor.decode(out[0], output_type="pil")
47
+ return {"images": [self._pil_to_base64(img)]}
48
+ else:
49
+ return {"generated_text":
50
+ self.processor.decode(out[0], skip_special_tokens=True)}
51
+
52
+ @staticmethod
53
+ def _pil_to_base64(img: Image.Image) -> str:
54
+ buf = io.BytesIO()
55
+ img.save(buf, format="PNG")
56
+ return base64.b64encode(buf.getvalue()).decode()