File size: 4,062 Bytes
6ad27b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d2d202
 
6ad27b0
 
 
0d2d202
 
 
6ad27b0
 
 
 
 
 
0d2d202
 
6ad27b0
0d2d202
 
6ad27b0
 
 
 
0d2d202
 
 
6ad27b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d2d202
6ad27b0
0d2d202
6ad27b0
 
 
0d2d202
 
6ad27b0
0d2d202
6ad27b0
 
 
0d2d202
6ad27b0
 
 
0d2d202
6ad27b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from typing import Dict, Any
import torch
from PIL import Image
import base64
import io
import sys

class EndpointHandler:
    def __init__(self, path="."):
        """
        Initialize the model and tokenizer for inference.
        """
        try:
            if path not in sys.path:
                sys.path.append(path)

            # Import from modelscope instead of transformers
            from modelscope import AutoConfig, AutoModel, AutoTokenizer

            print(f"Loading model from {path}")

            # Load config first
            self.config = AutoConfig.from_pretrained(path, trust_remote_code=True)

            # Load tokenizer
            self.tokenizer = AutoTokenizer.from_pretrained(
                path,
                trust_remote_code=True
            )

            # Load model with correct parameters
            self.model = AutoModel.from_pretrained(
                path,
                attn_implementation='sdpa',  # or 'flash_attention_2'
                torch_dtype=torch.bfloat16,
                device_map="auto",
                trust_remote_code=True
            )

            # Initialize processor
            self.processor = self.model.init_processor(self.tokenizer)

            # Set model to evaluation mode
            self.model.eval()

            print("Model loaded successfully")

        except Exception as e:
            print(f"Error during initialization: {str(e)}")
            import traceback
            traceback.print_exc()
            raise

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Process the input data and return the model's output.
        """
        try:
            # Extract inputs from data
            prompt = data.get("prompt", "Describe this image.")
            image_data = data.get("image", None)
            max_new_tokens = data.get("max_new_tokens", 100)

            if not image_data:
                return {"error": "No image provided"}

            # Decode base64 image
            try:
                if isinstance(image_data, str):
                    if image_data.startswith("data:image"):
                        image_data = image_data.split(",")[1]
                    image = Image.open(io.BytesIO(base64.b64decode(image_data))).convert("RGB")
                elif isinstance(image_data, bytes):
                    image = Image.open(io.BytesIO(image_data)).convert("RGB")
                else:
                    return {"error": "Invalid image format"}
            except Exception as e:
                return {"error": f"Error processing image: {str(e)}"}

            try:
                # Prepare messages following mPLUG-Owl3 format
                messages = [
                    {"role": "user", "content": f"<|image|>\n{prompt}"},
                    {"role": "assistant", "content": ""}
                ]

                # Process inputs using the processor
                model_inputs = self.processor(messages, images=[image], videos=None)

                # Move inputs to the correct device
                device = next(self.model.parameters()).device
                model_inputs = model_inputs.to(device)

                # Add required parameters
                model_inputs.update({
                    'tokenizer': self.tokenizer,
                    'max_new_tokens': max_new_tokens,
                    'decode_text': True
                })

                # Generate output
                with torch.no_grad():
                    output = self.model.generate(**model_inputs)

                return {"generated_text": output}

            except Exception as e:
                print(f"Error during model inference: {str(e)}")
                import traceback
                traceback.print_exc()
                return {"error": f"Error during model inference: {str(e)}"}

        except Exception as e:
            print(f"General error: {str(e)}")
            import traceback
            traceback.print_exc()
            return {"error": f"General error: {str(e)}"}