sreejith8100 commited on
Commit
7cfa330
·
verified ·
1 Parent(s): bb0622c

Upload 4 files

Browse files
Files changed (4) hide show
  1. Dockerfile +23 -0
  2. endpoint_handler.py +86 -0
  3. main.py +90 -0
  4. requirements.txt +21 -0
Dockerfile ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:2.3.1-cuda12.1-cudnn8-runtime
2
+
3
+ RUN apt-get update && apt-get install -y wget
4
+ RUN useradd -m -u 1000 user
5
+
6
+ USER user
7
+ WORKDIR /app
8
+
9
+ ENV PATH="/home/user/.local/bin:$PATH"
10
+ ENV TRANSFORMERS_CACHE=/home/user/.cache/huggingface
11
+ ENV TORCH_CUDA_ARCH_LIST="8.0+PTX"
12
+
13
+ RUN wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.0.4/flash_attn-2.7.3+cu121torch2.3-cp310-cp310-linux_x86_64.whl
14
+ RUN pip install ./flash_attn-2.7.3+cu121torch2.3-cp310-cp310-linux_x86_64.whl && rm flash_attn-2.7.3+cu121torch2.3-cp310-cp310-linux_x86_64.whl
15
+
16
+ COPY --chown=user requirements.txt .
17
+ RUN pip install --upgrade pip setuptools wheel
18
+ RUN pip install --no-cache-dir -r requirements.txt
19
+
20
+
21
+ COPY --chown=user . .
22
+
23
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
endpoint_handler.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from transformers import AutoModel, AutoTokenizer
4
+ from io import BytesIO
5
+ import base64
6
+ from huggingface_hub import login
7
+ import os
8
+
9
+ class EndpointHandler:
10
+ def __init__(self, model_dir=None):
11
+ print("[Init] Initializing EndpointHandler...")
12
+ self.load_model()
13
+
14
+ def load_model(self):
15
+ hf_token = os.getenv("HF_TOKEN")
16
+ # Replace with the quantized version if available
17
+ model_path = "openbmb/MiniCPM-V-2_6-int4" # Hypothetical quantized version
18
+
19
+ if hf_token:
20
+ print("[Auth] Logging into Hugging Face Hub with token...")
21
+ login(token=hf_token)
22
+
23
+ print(f"[Model Load] Loading quantized model from: {model_path}")
24
+ try:
25
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
26
+ self.model = AutoModel.from_pretrained(
27
+ model_path,
28
+ trust_remote_code=True,
29
+ ).eval()
30
+ print("[Model Load] Quantized model successfully loaded.")
31
+ except Exception as e:
32
+ print(f"[Model Load Error] {e}")
33
+ raise RuntimeError(f"Failed to load quantized model: {e}")
34
+
35
+ def load_image(self, image_base64):
36
+ try:
37
+ print("[Image Load] Decoding base64 image...")
38
+ image_bytes = base64.b64decode(image_base64)
39
+ image = Image.open(BytesIO(image_bytes)).convert("RGB")
40
+ print("[Image Load] Image successfully decoded and converted to RGB.")
41
+ return image
42
+ except Exception as e:
43
+ print(f"[Image Load Error] {e}")
44
+ raise ValueError(f"Failed to open image from base64 string: {e}")
45
+
46
+ def predict(self, request):
47
+ print(f"[Predict] Received request: {request}")
48
+
49
+ image_base64 = request.get("inputs", {}).get("image")
50
+ question = request.get("inputs", {}).get("question")
51
+ stream = request.get("inputs", {}).get("stream", False)
52
+
53
+ if not image_base64 or not question:
54
+ print("[Predict Error] Missing 'image' or 'question' in the request.")
55
+ return {"error": "Missing 'image' or 'question' in inputs."}
56
+
57
+ try:
58
+ image = self.load_image(image_base64)
59
+ msgs = [{"role": "user", "content": [image, question]}]
60
+
61
+ print(f"[Predict] Asking model with question: {question}")
62
+ print("[Predict] Starting chat inference...")
63
+
64
+ res = self.model.chat(
65
+ image=None,
66
+ msgs=msgs,
67
+ tokenizer=self.tokenizer,
68
+ sampling=True,
69
+ stream=stream
70
+ )
71
+
72
+ if stream:
73
+ for new_text in res:
74
+ yield {"output": new_text}
75
+ else:
76
+ generated_text = "".join(res)
77
+ print("[Predict] Inference complete.")
78
+ return {"output": generated_text}
79
+
80
+ except Exception as e:
81
+ print(f"[Predict Error] {e}")
82
+ return {"error": str(e)}
83
+
84
+ def __call__(self, data):
85
+ print("[__call__] Invoked handler with data.")
86
+ return self.predict(data)
main.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from fastapi.responses import JSONResponse, StreamingResponse
3
+ from pydantic import BaseModel
4
+ import types
5
+ import json
6
+ from pydantic import validator
7
+ from endpoint_handler import EndpointHandler # your handler file
8
+ import base64
9
+
10
+ app = FastAPI()
11
+
12
+ handler = None
13
+
14
+ @app.on_event("startup")
15
+ async def load_handler():
16
+ global handler
17
+ handler = EndpointHandler()
18
+
19
+ class PredictInput(BaseModel):
20
+ image: str # base64-encoded image string
21
+ question: str
22
+ stream: bool = False
23
+
24
+ @validator("question")
25
+ def question_not_empty(cls, v):
26
+ if not v.strip():
27
+ raise ValueError("Question must not be empty")
28
+ return v
29
+
30
+ @validator("image")
31
+ def valid_base64_and_size(cls, v):
32
+ try:
33
+ decoded = base64.b64decode(v, validate=True)
34
+ except Exception:
35
+ raise ValueError("`image` must be valid base64")
36
+ if len(decoded) > 10 * 1024 * 1024: # 10 MB limit
37
+ raise ValueError("Image exceeds 10 MB after decoding")
38
+ return v
39
+
40
+ class PredictRequest(BaseModel):
41
+ inputs: PredictInput
42
+
43
+ @app.get("/")
44
+ async def root():
45
+ return {"message": "FastAPI app is running on Hugging Face"}
46
+
47
+ @app.post("/predict")
48
+ async def predict_endpoint(payload: PredictRequest):
49
+ """
50
+ Handles prediction requests by processing the input payload and returning the prediction result.
51
+ Args:
52
+ payload (PredictRequest): The request payload containing the input data for prediction, including image, question, and stream flag.
53
+ Returns:
54
+ JSONResponse: If a ValueError occurs, returns a JSON response with an error message and status code 400.
55
+ JSONResponse: If any other exception occurs, returns a JSON response with a generic error message and status code 500.
56
+ StreamingResponse: If the prediction result is a generator (streaming), returns a streaming response with event-stream media type, yielding prediction chunks as JSON.
57
+ Notes:
58
+ - Logs the received question for debugging purposes.
59
+ - Handles both standard and streaming prediction results.
60
+ - Structured JSON messages are sent to indicate the end of the stream or errors during streaming.
61
+ """
62
+ print(f"[Request] Received question: {payload.inputs.question}")
63
+
64
+ data = {
65
+ "inputs": {
66
+ "image": payload.inputs.image,
67
+ "question": payload.inputs.question,
68
+ "stream": payload.inputs.stream
69
+ }
70
+ }
71
+
72
+ try:
73
+ result = handler.predict(data)
74
+ except ValueError as ve:
75
+ return JSONResponse({"error": str(ve)}, status_code=400)
76
+ except Exception as e:
77
+ return JSONResponse({"error": "Internal server error"}, status_code=500)
78
+
79
+ if isinstance(result, types.GeneratorType):
80
+ def event_stream():
81
+ try:
82
+ for chunk in result:
83
+ yield f"data: {json.dumps(chunk)}\n\n"
84
+ # Return structured JSON to indicate end of stream
85
+ yield f"data: {json.dumps({'end': True})}\n\n"
86
+ except Exception as e:
87
+ # Return structured JSON to indicate error
88
+ yield f"data: {json.dumps({'error': str(e)})}\n\n"
89
+ return StreamingResponse(event_stream(), media_type="text/event-stream")
90
+
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Pillow==10.1.0
2
+ torch==2.3.1
3
+ torchaudio==2.3.1
4
+ torchvision==0.18.1
5
+ transformers==4.44.2
6
+ librosa==0.9.0
7
+ soundfile==0.12.1
8
+ vector-quantize-pytorch==1.18.5
9
+ vocos==0.1.0
10
+ decord
11
+ moviepy
12
+ einops
13
+ accelerate
14
+ openbmb
15
+ fastapi
16
+ uvicorn[standard]
17
+ timm>=0.6.13
18
+ sentencepiece>=0.1.99
19
+ python-multipart
20
+ bitsandbytes==0.45.5
21
+ accelerate==0.30.0