ollamcivic / app.py
KRISH09bha's picture
Update app.py
62f1e4c verified
raw
history blame
1.85 kB
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
from transformers import AutoTokenizer, AutoProcessor, AutoModelForCausalLM
import torch
from PIL import Image
import io
app = FastAPI()
model_path = "lmms-lab/LLaVA-OneVision-1.5-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype="auto", device_map="auto", trust_remote_code=True)
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
def process_vision_info(messages):
# Dummy implementation, replace with actual from qwen_vl_utils
image_inputs = [msg['content'][0]['image'] for msg in messages]
video_inputs = None
return image_inputs, video_inputs
@app.post("/analyze-image")
async def analyze_image(file: UploadFile = File(...)):
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes))
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": "Describe this image."},
],
}
]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to("cuda")
generated_ids = model.generate(**inputs, max_new_tokens=1024)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return JSONResponse(content={"result": output_text})