assentian1970's picture
Update handler.py
0d2d202 verified
from typing import Dict, Any
import torch
from PIL import Image
import base64
import io
import sys
class EndpointHandler:
def __init__(self, path="."):
"""
Initialize the model and tokenizer for inference.
"""
try:
if path not in sys.path:
sys.path.append(path)
# Import from modelscope instead of transformers
from modelscope import AutoConfig, AutoModel, AutoTokenizer
print(f"Loading model from {path}")
# Load config first
self.config = AutoConfig.from_pretrained(path, trust_remote_code=True)
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
path,
trust_remote_code=True
)
# Load model with correct parameters
self.model = AutoModel.from_pretrained(
path,
attn_implementation='sdpa', # or 'flash_attention_2'
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
# Initialize processor
self.processor = self.model.init_processor(self.tokenizer)
# Set model to evaluation mode
self.model.eval()
print("Model loaded successfully")
except Exception as e:
print(f"Error during initialization: {str(e)}")
import traceback
traceback.print_exc()
raise
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Process the input data and return the model's output.
"""
try:
# Extract inputs from data
prompt = data.get("prompt", "Describe this image.")
image_data = data.get("image", None)
max_new_tokens = data.get("max_new_tokens", 100)
if not image_data:
return {"error": "No image provided"}
# Decode base64 image
try:
if isinstance(image_data, str):
if image_data.startswith("data:image"):
image_data = image_data.split(",")[1]
image = Image.open(io.BytesIO(base64.b64decode(image_data))).convert("RGB")
elif isinstance(image_data, bytes):
image = Image.open(io.BytesIO(image_data)).convert("RGB")
else:
return {"error": "Invalid image format"}
except Exception as e:
return {"error": f"Error processing image: {str(e)}"}
try:
# Prepare messages following mPLUG-Owl3 format
messages = [
{"role": "user", "content": f"<|image|>\n{prompt}"},
{"role": "assistant", "content": ""}
]
# Process inputs using the processor
model_inputs = self.processor(messages, images=[image], videos=None)
# Move inputs to the correct device
device = next(self.model.parameters()).device
model_inputs = model_inputs.to(device)
# Add required parameters
model_inputs.update({
'tokenizer': self.tokenizer,
'max_new_tokens': max_new_tokens,
'decode_text': True
})
# Generate output
with torch.no_grad():
output = self.model.generate(**model_inputs)
return {"generated_text": output}
except Exception as e:
print(f"Error during model inference: {str(e)}")
import traceback
traceback.print_exc()
return {"error": f"Error during model inference: {str(e)}"}
except Exception as e:
print(f"General error: {str(e)}")
import traceback
traceback.print_exc()
return {"error": f"General error: {str(e)}"}