saurabh-straive commited on
Commit
7c2538e
·
verified ·
1 Parent(s): 647f969

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +31 -0
handler.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
2
+ import torch
3
+ from PIL import Image
4
+
5
+ class EndpointHandler():
6
+ def __init__(self, path=""):
7
+ disable_torch_init()
8
+ device = torch.cuda_device
9
+ self.processor = LlavaNextProcessor.from_pretrained(path, use_fast=False)
10
+ self.model = LlavaNextForConditionalGeneration.from_pretrained(
11
+ path,
12
+ torch_dtype=torch.float16,
13
+ low_cpu_mem_usage=True,
14
+ load_in_4bit=True
15
+ )
16
+ self.model.to("cuda:0")
17
+
18
+ def __call__(self, data):
19
+ image_encoded = data.pop("inputs", data)
20
+ prompt = data["text"]
21
+
22
+ image = self.decode_base64_image(image_encoded)
23
+ if image.mode != "RGB":
24
+ image = image.convert("RGB")
25
+
26
+ inputs = self.processor(prompt, image, return_tensors="pt").to("cuda:0")
27
+
28
+ # autoregressively complete prompt
29
+ output = self.model.generate(**inputs, max_new_tokens=500)
30
+
31
+ return processor.decode(output[0], skip_special_tokens=True)