File size: 5,110 Bytes
6ad27b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from typing import Dict, Any
import torch
from PIL import Image
import base64
import io
import os
import sys

class EndpointHandler:
    def __init__(self, path="."):
        """
        Initialize the model and tokenizer for inference.
        Args:
            path (str): Path to the model directory
        """
        try:
            # Add the model's directory to the Python path
            if path not in sys.path:
                sys.path.append(path)

            # Import transformers
            from transformers import AutoModelForCausalLM, AutoTokenizer

            print(f"Loading model from {path}")

            # Load tokenizer
            self.tokenizer = AutoTokenizer.from_pretrained(
                path,
                trust_remote_code=True
            )

            # Load model
            self.model = AutoModelForCausalLM.from_pretrained(
                path,
                torch_dtype=torch.float16,
                device_map="auto",
                trust_remote_code=True
            )

            # Set model to evaluation mode
            self.model.eval()

            print("Model loaded successfully")

        except Exception as e:
            print(f"Error during initialization: {str(e)}")
            import traceback
            traceback.print_exc()
            raise

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Process the input data and return the model's output.
        """
        try:
            # Extract inputs from data
            prompt = data.get("prompt", "Describe this image.")
            image_data = data.get("image", None)
            max_new_tokens = data.get("max_new_tokens", 100)

            # Check if image is provided
            if not image_data:
                return {"error": "No image provided"}

            # Decode base64 image
            try:
                if isinstance(image_data, str):
                    if image_data.startswith("data:image"):
                        image_data = image_data.split(",")[1]
                    image = Image.open(io.BytesIO(base64.b64decode(image_data))).convert("RGB")
                elif isinstance(image_data, bytes):
                    image = Image.open(io.BytesIO(image_data)).convert("RGB")
                else:
                    return {"error": "Invalid image format"}
            except Exception as e:
                return {"error": f"Error processing image: {str(e)}"}

            try:
                # Prepare messages for the model
                messages = [
                    {"role": "user", "content": f"<|image|> {prompt}"},
                    {"role": "assistant", "content": ""}
                ]

                # For mPLUG-Owl3, the processor is directly in the model
                # Let's inspect the model structure to find the processor
                print("Model structure:", dir(self.model))

                # Try different ways to access the processor
                if hasattr(self.model, "init_processor"):
                    processor = self.model.init_processor(self.tokenizer)
                elif hasattr(self.model, "model") and hasattr(self.model.model, "init_processor"):
                    processor = self.model.model.init_processor(self.tokenizer)
                else:
                    # Let's try to find the processor in the model's attributes
                    for attr_name in dir(self.model):
                        if attr_name.startswith("_"):
                            continue
                        attr = getattr(self.model, attr_name)
                        if hasattr(attr, "init_processor"):
                            processor = attr.init_processor(self.tokenizer)
                            print(f"Found processor in {attr_name}")
                            break
                    else:
                        return {"error": "Could not find processor in model"}

                # Process inputs
                model_inputs = processor(messages, images=[image], videos=None)

                # Move inputs to the same device as the model
                device = next(self.model.parameters()).device
                model_inputs = model_inputs.to(device)

                # Add additional parameters
                model_inputs.update({
                    'tokenizer': self.tokenizer,
                    'max_new_tokens': max_new_tokens,
                    'decode_text': True,
                })

                # Generate output
                with torch.no_grad():
                    output = self.model.generate(**model_inputs)

                return {"generated_text": output}

            except Exception as e:
                print(f"Error during model inference: {str(e)}")
                import traceback
                traceback.print_exc()
                return {"error": f"Error during model inference: {str(e)}"}

        except Exception as e:
            print(f"General error: {str(e)}")
            import traceback
            traceback.print_exc()
            return {"error": f"General error: {str(e)}"}