assentian1970's picture
Create handler.py
6ad27b0 verified
raw
history blame
5.11 kB
from typing import Dict, Any
import torch
from PIL import Image
import base64
import io
import os
import sys
class EndpointHandler:
def __init__(self, path="."):
"""
Initialize the model and tokenizer for inference.
Args:
path (str): Path to the model directory
"""
try:
# Add the model's directory to the Python path
if path not in sys.path:
sys.path.append(path)
# Import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
print(f"Loading model from {path}")
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
path,
trust_remote_code=True
)
# Load model
self.model = AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
# Set model to evaluation mode
self.model.eval()
print("Model loaded successfully")
except Exception as e:
print(f"Error during initialization: {str(e)}")
import traceback
traceback.print_exc()
raise
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Process the input data and return the model's output.
"""
try:
# Extract inputs from data
prompt = data.get("prompt", "Describe this image.")
image_data = data.get("image", None)
max_new_tokens = data.get("max_new_tokens", 100)
# Check if image is provided
if not image_data:
return {"error": "No image provided"}
# Decode base64 image
try:
if isinstance(image_data, str):
if image_data.startswith("data:image"):
image_data = image_data.split(",")[1]
image = Image.open(io.BytesIO(base64.b64decode(image_data))).convert("RGB")
elif isinstance(image_data, bytes):
image = Image.open(io.BytesIO(image_data)).convert("RGB")
else:
return {"error": "Invalid image format"}
except Exception as e:
return {"error": f"Error processing image: {str(e)}"}
try:
# Prepare messages for the model
messages = [
{"role": "user", "content": f"<|image|> {prompt}"},
{"role": "assistant", "content": ""}
]
# For mPLUG-Owl3, the processor is directly in the model
# Let's inspect the model structure to find the processor
print("Model structure:", dir(self.model))
# Try different ways to access the processor
if hasattr(self.model, "init_processor"):
processor = self.model.init_processor(self.tokenizer)
elif hasattr(self.model, "model") and hasattr(self.model.model, "init_processor"):
processor = self.model.model.init_processor(self.tokenizer)
else:
# Let's try to find the processor in the model's attributes
for attr_name in dir(self.model):
if attr_name.startswith("_"):
continue
attr = getattr(self.model, attr_name)
if hasattr(attr, "init_processor"):
processor = attr.init_processor(self.tokenizer)
print(f"Found processor in {attr_name}")
break
else:
return {"error": "Could not find processor in model"}
# Process inputs
model_inputs = processor(messages, images=[image], videos=None)
# Move inputs to the same device as the model
device = next(self.model.parameters()).device
model_inputs = model_inputs.to(device)
# Add additional parameters
model_inputs.update({
'tokenizer': self.tokenizer,
'max_new_tokens': max_new_tokens,
'decode_text': True,
})
# Generate output
with torch.no_grad():
output = self.model.generate(**model_inputs)
return {"generated_text": output}
except Exception as e:
print(f"Error during model inference: {str(e)}")
import traceback
traceback.print_exc()
return {"error": f"Error during model inference: {str(e)}"}
except Exception as e:
print(f"General error: {str(e)}")
import traceback
traceback.print_exc()
return {"error": f"General error: {str(e)}"}