Wondering how I can deploy this myself with an inference API?

#2
by Navkun - opened

Hey love your work, what's the best way to deploy this model? thanks!

Hi,

Thanks for the support! Please refer to the following code with fastapi and vllm for OpenAI-style deployment.

Find more details about the environment requirements on our GitHub https://github.com/AriaUI/Aria-UI. Kindly offer us a star if you haven't!

from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Union, Optional
from PIL import Image, ImageDraw
import numpy as np
import base64
import io
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
import ast

app = FastAPI()

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Initialize model and tokenizer
MODEL_PATH = "Aria-UI/Aria-UI-base"
llm = LLM(
    model=MODEL_PATH,
    tokenizer_mode="slow",
    dtype="bfloat16",
    trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_PATH, 
    trust_remote_code=True, 
    use_fast=False
)

class Message(BaseModel):
    role: str
    content: List[dict]

class ChatCompletionRequest(BaseModel):
    messages: List[Message]
    model: str
    max_tokens: Optional[int] = 512
    stop: Optional[List[str]] = ["<|im_end|>"]
    extra_body: Optional[dict] = {
        "split_image": True,
        "image_max_size": 980
    }

class Choice(BaseModel):
    message: Message
    finish_reason: str
    index: int

class ChatCompletionResponse(BaseModel):
    id: str
    object: str
    created: int
    model: str
    choices: List[Choice]

def base64_to_pil(base64_string: str) -> Image.Image:
    """Convert base64 image to PIL Image."""
    # Remove data URL prefix if present
    if 'base64,' in base64_string:
        base64_string = base64_string.split('base64,')[1]
    image_data = base64.b64decode(base64_string)
    return Image.open(io.BytesIO(image_data))



@app
	.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
    try:
        # Process the last message which should contain the image and prompt
        last_message = request.messages[-1]
        
        # Extract text prompt and image from the message content
        text_content = None
        image_content = None
        for content in last_message.content:
            if content["type"] == "text":
                text_content = content["text"]
            elif content["type"] == "image_url":
                image_url = content["image_url"]["url"]
                image_content = base64_to_pil(image_url)

        if not text_content or not image_content:
            raise HTTPException(status_code=400, message="Missing text prompt or image")

        # Format messages for the model
        messages = [{
            "role": "user",
            "content": [
                {"type": "image"},
                {
                    "type": "text",
                    "text": text_content,
                }
            ],
        }]

        # Generate message using tokenizer
        message = tokenizer.apply_chat_template(messages, add_generation_prompt=True)

        # Get model output
        outputs = llm.generate(
            {
                "prompt_token_ids": message,
                "multi_modal_data": {
                    "image": [image_content],
                    "max_image_size": request.extra_body.get("image_max_size", 980),
                    "split_image": request.extra_body.get("split_image", True),
                },
            },
            sampling_params=SamplingParams(
                max_tokens=request.max_tokens,
                top_k=1,
                stop=request.stop
            ),
        )

        # Process the response
        generated_tokens = outputs[0].outputs[0].token_ids
        response = tokenizer.decode(generated_tokens, skip_special_tokens=True)

        # Create response object
        return ChatCompletionResponse(
            id="chatcmpl-" + base64.b64encode(str(hash(response)).encode()).decode()[:10],
            object="chat.completion",
            created=int(time.time()),
            model=request.model,
            choices=[
                Choice(
                    message=Message(
                        role="assistant",
                        content=[{"type": "text", "text": response}]
                    ),
                    finish_reason="stop",
                    index=0
                )
            ]
        )

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)

Thank you! sorry I'm a bit new to this, I don't have enough compute to run this locally, what hosting provider would you recommend to host this as an inference API easily?

Sign up or log in to comment