File size: 5,110 Bytes
6ad27b0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
from typing import Dict, Any
import torch
from PIL import Image
import base64
import io
import os
import sys
class EndpointHandler:
def __init__(self, path="."):
"""
Initialize the model and tokenizer for inference.
Args:
path (str): Path to the model directory
"""
try:
# Add the model's directory to the Python path
if path not in sys.path:
sys.path.append(path)
# Import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
print(f"Loading model from {path}")
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
path,
trust_remote_code=True
)
# Load model
self.model = AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
# Set model to evaluation mode
self.model.eval()
print("Model loaded successfully")
except Exception as e:
print(f"Error during initialization: {str(e)}")
import traceback
traceback.print_exc()
raise
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Process the input data and return the model's output.
"""
try:
# Extract inputs from data
prompt = data.get("prompt", "Describe this image.")
image_data = data.get("image", None)
max_new_tokens = data.get("max_new_tokens", 100)
# Check if image is provided
if not image_data:
return {"error": "No image provided"}
# Decode base64 image
try:
if isinstance(image_data, str):
if image_data.startswith("data:image"):
image_data = image_data.split(",")[1]
image = Image.open(io.BytesIO(base64.b64decode(image_data))).convert("RGB")
elif isinstance(image_data, bytes):
image = Image.open(io.BytesIO(image_data)).convert("RGB")
else:
return {"error": "Invalid image format"}
except Exception as e:
return {"error": f"Error processing image: {str(e)}"}
try:
# Prepare messages for the model
messages = [
{"role": "user", "content": f"<|image|> {prompt}"},
{"role": "assistant", "content": ""}
]
# For mPLUG-Owl3, the processor is directly in the model
# Let's inspect the model structure to find the processor
print("Model structure:", dir(self.model))
# Try different ways to access the processor
if hasattr(self.model, "init_processor"):
processor = self.model.init_processor(self.tokenizer)
elif hasattr(self.model, "model") and hasattr(self.model.model, "init_processor"):
processor = self.model.model.init_processor(self.tokenizer)
else:
# Let's try to find the processor in the model's attributes
for attr_name in dir(self.model):
if attr_name.startswith("_"):
continue
attr = getattr(self.model, attr_name)
if hasattr(attr, "init_processor"):
processor = attr.init_processor(self.tokenizer)
print(f"Found processor in {attr_name}")
break
else:
return {"error": "Could not find processor in model"}
# Process inputs
model_inputs = processor(messages, images=[image], videos=None)
# Move inputs to the same device as the model
device = next(self.model.parameters()).device
model_inputs = model_inputs.to(device)
# Add additional parameters
model_inputs.update({
'tokenizer': self.tokenizer,
'max_new_tokens': max_new_tokens,
'decode_text': True,
})
# Generate output
with torch.no_grad():
output = self.model.generate(**model_inputs)
return {"generated_text": output}
except Exception as e:
print(f"Error during model inference: {str(e)}")
import traceback
traceback.print_exc()
return {"error": f"Error during model inference: {str(e)}"}
except Exception as e:
print(f"General error: {str(e)}")
import traceback
traceback.print_exc()
return {"error": f"General error: {str(e)}"} |