Qwen-VL-7B-2 / handler.py
fredaddy's picture
Update handler.py
193dc22 verified
raw
history blame
2.51 kB
import torch
from PIL import Image
import requests
from io import BytesIO
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
class EndpointHandler:
def __init__(self, path=""):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model and processor
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
path, torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32, device_map="auto"
).to(self.device)
self.processor = AutoProcessor.from_pretrained(path)
def __call__(self, data):
# Extract image and text from the input data
image_url = data.get("inputs", {}).get("image", "")
text_prompt = data.get("inputs", {}).get("text", "")
if not image_url or not text_prompt:
return {"error": "Both 'image' and 'text' must be provided in the input data."}
# Download and process the image
try:
response = requests.get(image_url)
response.raise_for_status()
image = Image.open(BytesIO(response.content)).convert("RGB")
except Exception as e:
return {"error": f"Failed to load image from URL: {e}"}
# Prepare the input in the format expected by the model
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": text_prompt},
],
}
]
# Process the input
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
# Move inputs to the appropriate device
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Generate output
with torch.no_grad():
output_ids = self.model.generate(**inputs, max_new_tokens=128)
# Decode the output
output_text = self.processor.batch_decode(
output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
return {"generated_text": output_text}