Andrew Luo
image
4ac3bf9
raw
history blame
2.14 kB
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
import torch
from PIL import Image
from typing import Dict, List, Any
import requests
class EndpointHandler():
def __init__(self, path=""):
model = VisionEncoderDecoderModel.from_pretrained(
"nlpconnect/vit-gpt2-image-captioning")
feature_extractor = ViTImageProcessor.from_pretrained(
"nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained(
"nlpconnect/vit-gpt2-image-captioning")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
self.model = model
self.feature_extractor = feature_extractor
self.tokenizer = tokenizer
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str`)
date (:obj: `str`)
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
# get inputs
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
max_length = 128
num_beams = 4
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
image_paths = data.pop("image_paths", data)
images = []
for image_path in image_paths:
response = requests.get(image_path)
response.raise_for_status() # Raise an exception if the request failed
with open("temp", "wb") as f:
f.write(response.content)
i_image = Image.open("temp")
if i_image.mode != "RGB":
i_image = i_image.convert(mode="RGB")
images.append(i_image)
pixel_values = self.feature_extractor(
images=images, return_tensors="pt").pixel_values
pixel_values = pixel_values.to(device)
output_ids = self.model.generate(pixel_values, **gen_kwargs)
preds = self.tokenizer.batch_decode(
output_ids, skip_special_tokens=True)
preds = [pred.strip() for pred in preds]
return preds