File size: 3,452 Bytes
77fc260
9b164d1
f6721ff
9b164d1
 
 
77fc260
9b164d1
77fc260
9b164d1
 
 
 
 
 
 
 
 
f6721ff
 
9b164d1
 
 
 
 
77fc260
 
 
 
5249c43
 
77fc260
 
5249c43
0a36398
77fc260
0a36398
48a1000
 
77fc260
 
 
 
 
 
 
 
 
7020761
9b164d1
77fc260
 
 
 
 
 
 
9b164d1
 
77fc260
9b164d1
 
 
 
 
 
 
77fc260
 
9b164d1
7020761
77fc260
7020761
48a1000
7020761
77fc260
 
7020761
 
77fc260
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from typing import Dict, Any
import torch
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration, BitsAndBytesConfig
from PIL import Image
import requests
from io import BytesIO
import base64

class EndpointHandler:
    def __init__(self, path=""):
        # Configuraci贸 de la quantitzaci贸
        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16,
        )

        # Carrega el processador i model de forma global
        self.processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
        self.model = LlavaNextForConditionalGeneration.from_pretrained(
            "llava-hf/llava-v1.6-mistral-7b-hf",
            quantization_config=quantization_config,
            device_map="auto"
        )

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        logs = []
        logs.append("Iniciant processament de la petici贸.")
        
        inputs = data.get("inputs")
        if not inputs:
            logs.append("Format d'entrada inv脿lid. Manca la clau 'inputs'.")
            return {"error": "Invalid input format. 'inputs' key is missing.", "logs": logs}

        image_url = inputs.get("url")
        image_data = inputs.get("image_data")
        prompt = inputs.get("prompt")
        max_tokens = inputs.get("max_tokens", 100)
        
        if not prompt:
            logs.append("S'ha de proporcionar 'prompt' en 'inputs'.")
            return {"error": "The 'prompt' must be provided in 'inputs'.", "logs": logs}

        if not image_url and not image_data:
            logs.append("S'ha de proporcionar 'url' o 'image_data' en 'inputs'.")
            return {"error": "Either 'url' or 'image_data' must be provided in 'inputs'.", "logs": logs}

        logs.append(f"Processant entrada: url={image_url}, image_data={'present' if image_data else 'absent'}, prompt={prompt}")

        try:
            if image_url:
                logs.append(f"Carregant imatge des de URL: {image_url}")
                response = requests.get(image_url, stream=True)
                image = Image.open(response.raw)
            elif image_data:
                logs.append("Carregant imatge des de dades d'imatge en brut.")
                image = Image.open(BytesIO(base64.b64decode(image_data)))
            
            if image.format == 'PNG':
                logs.append("Convertint imatge PNG a JPG.")
                image = image.convert('RGB')
                buffer = BytesIO()
                image.save(buffer, format="JPEG")
                buffer.seek(0)
                image = Image.open(buffer)

        except Exception as e:
            logs.append(f"Error carregant imatge: {str(e)}")
            return {"error": str(e), "logs": logs}

        try:
            logs.append("Processant imatge amb el model.")
            inputs = self.processor(prompt, image, return_tensors="pt").to("cuda")
            output = self.model.generate(**inputs, max_new_tokens=max_tokens)
            result = self.processor.decode(output[0], skip_special_tokens=True)
            logs.append("Processament complet.")
            return {"input_prompt": prompt, "model_output": result, "logs": logs}

        except Exception as e:
            logs.append(f"Error processant el model: {str(e)}")
            return {"error": str(e), "logs": logs}