|
from typing import Dict, Any |
|
import torch |
|
from PIL import Image |
|
import base64 |
|
import io |
|
import sys |
|
|
|
class EndpointHandler: |
|
def __init__(self, path="."): |
|
""" |
|
Initialize the model and tokenizer for inference. |
|
""" |
|
try: |
|
if path not in sys.path: |
|
sys.path.append(path) |
|
|
|
|
|
from modelscope import AutoConfig, AutoModel, AutoTokenizer |
|
|
|
print(f"Loading model from {path}") |
|
|
|
|
|
self.config = AutoConfig.from_pretrained(path, trust_remote_code=True) |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
path, |
|
trust_remote_code=True |
|
) |
|
|
|
|
|
self.model = AutoModel.from_pretrained( |
|
path, |
|
attn_implementation='sdpa', |
|
torch_dtype=torch.bfloat16, |
|
device_map="auto", |
|
trust_remote_code=True |
|
) |
|
|
|
|
|
self.processor = self.model.init_processor(self.tokenizer) |
|
|
|
|
|
self.model.eval() |
|
|
|
print("Model loaded successfully") |
|
|
|
except Exception as e: |
|
print(f"Error during initialization: {str(e)}") |
|
import traceback |
|
traceback.print_exc() |
|
raise |
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
""" |
|
Process the input data and return the model's output. |
|
""" |
|
try: |
|
|
|
prompt = data.get("prompt", "Describe this image.") |
|
image_data = data.get("image", None) |
|
max_new_tokens = data.get("max_new_tokens", 100) |
|
|
|
if not image_data: |
|
return {"error": "No image provided"} |
|
|
|
|
|
try: |
|
if isinstance(image_data, str): |
|
if image_data.startswith("data:image"): |
|
image_data = image_data.split(",")[1] |
|
image = Image.open(io.BytesIO(base64.b64decode(image_data))).convert("RGB") |
|
elif isinstance(image_data, bytes): |
|
image = Image.open(io.BytesIO(image_data)).convert("RGB") |
|
else: |
|
return {"error": "Invalid image format"} |
|
except Exception as e: |
|
return {"error": f"Error processing image: {str(e)}"} |
|
|
|
try: |
|
|
|
messages = [ |
|
{"role": "user", "content": f"<|image|>\n{prompt}"}, |
|
{"role": "assistant", "content": ""} |
|
] |
|
|
|
|
|
model_inputs = self.processor(messages, images=[image], videos=None) |
|
|
|
|
|
device = next(self.model.parameters()).device |
|
model_inputs = model_inputs.to(device) |
|
|
|
|
|
model_inputs.update({ |
|
'tokenizer': self.tokenizer, |
|
'max_new_tokens': max_new_tokens, |
|
'decode_text': True |
|
}) |
|
|
|
|
|
with torch.no_grad(): |
|
output = self.model.generate(**model_inputs) |
|
|
|
return {"generated_text": output} |
|
|
|
except Exception as e: |
|
print(f"Error during model inference: {str(e)}") |
|
import traceback |
|
traceback.print_exc() |
|
return {"error": f"Error during model inference: {str(e)}"} |
|
|
|
except Exception as e: |
|
print(f"General error: {str(e)}") |
|
import traceback |
|
traceback.print_exc() |
|
return {"error": f"General error: {str(e)}"} |