from typing import Dict, Any import torch from PIL import Image import base64 import io import sys class EndpointHandler: def __init__(self, path="."): """ Initialize the model and tokenizer for inference. """ try: if path not in sys.path: sys.path.append(path) # Import from modelscope instead of transformers from modelscope import AutoConfig, AutoModel, AutoTokenizer print(f"Loading model from {path}") # Load config first self.config = AutoConfig.from_pretrained(path, trust_remote_code=True) # Load tokenizer self.tokenizer = AutoTokenizer.from_pretrained( path, trust_remote_code=True ) # Load model with correct parameters self.model = AutoModel.from_pretrained( path, attn_implementation='sdpa', # or 'flash_attention_2' torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True ) # Initialize processor self.processor = self.model.init_processor(self.tokenizer) # Set model to evaluation mode self.model.eval() print("Model loaded successfully") except Exception as e: print(f"Error during initialization: {str(e)}") import traceback traceback.print_exc() raise def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Process the input data and return the model's output. """ try: # Extract inputs from data prompt = data.get("prompt", "Describe this image.") image_data = data.get("image", None) max_new_tokens = data.get("max_new_tokens", 100) if not image_data: return {"error": "No image provided"} # Decode base64 image try: if isinstance(image_data, str): if image_data.startswith("data:image"): image_data = image_data.split(",")[1] image = Image.open(io.BytesIO(base64.b64decode(image_data))).convert("RGB") elif isinstance(image_data, bytes): image = Image.open(io.BytesIO(image_data)).convert("RGB") else: return {"error": "Invalid image format"} except Exception as e: return {"error": f"Error processing image: {str(e)}"} try: # Prepare messages following mPLUG-Owl3 format messages = [ {"role": "user", "content": f"<|image|>\n{prompt}"}, {"role": "assistant", "content": ""} ] # Process inputs using the processor model_inputs = self.processor(messages, images=[image], videos=None) # Move inputs to the correct device device = next(self.model.parameters()).device model_inputs = model_inputs.to(device) # Add required parameters model_inputs.update({ 'tokenizer': self.tokenizer, 'max_new_tokens': max_new_tokens, 'decode_text': True }) # Generate output with torch.no_grad(): output = self.model.generate(**model_inputs) return {"generated_text": output} except Exception as e: print(f"Error during model inference: {str(e)}") import traceback traceback.print_exc() return {"error": f"Error during model inference: {str(e)}"} except Exception as e: print(f"General error: {str(e)}") import traceback traceback.print_exc() return {"error": f"General error: {str(e)}"}