Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
from PIL import Image | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import spaces | |
# Model configuration | |
MID = "apple/FastVLM-0.5B" | |
IMAGE_TOKEN_INDEX = -200 | |
# Load model and tokenizer (will be loaded on first GPU allocation) | |
tok = None | |
model = None | |
def load_model(): | |
global tok, model | |
if tok is None or model is None: | |
print("Loading model...") | |
tok = AutoTokenizer.from_pretrained(MID, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
MID, | |
torch_dtype=torch.float16, | |
device_map="cuda", | |
trust_remote_code=True, | |
) | |
print("Model loaded successfully!") | |
return tok, model | |
def caption_image(image, custom_prompt=None): | |
""" | |
Generate a caption for the input image. | |
Args: | |
image: PIL Image from Gradio | |
custom_prompt: Optional custom prompt to use instead of default | |
Returns: | |
Generated caption text | |
""" | |
if image is None: | |
return "Please upload an image first." | |
try: | |
# Load model if not already loaded | |
tok, model = load_model() | |
# Convert image to RGB if needed | |
if image.mode != "RGB": | |
image = image.convert("RGB") | |
# Use custom prompt or default | |
prompt = custom_prompt if custom_prompt else "Describe this image in detail." | |
# Build chat message | |
messages = [ | |
{"role": "user", "content": f"<image>\n{prompt}"} | |
] | |
# Render to string to place <image> token correctly | |
rendered = tok.apply_chat_template( | |
messages, add_generation_prompt=True, tokenize=False | |
) | |
# Split at image token | |
pre, post = rendered.split("<image>", 1) | |
# Tokenize text around the image token | |
pre_ids = tok(pre, return_tensors="pt", add_special_tokens=False).input_ids | |
post_ids = tok(post, return_tensors="pt", add_special_tokens=False).input_ids | |
# Insert IMAGE token id at placeholder position | |
img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype) | |
input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1).to(model.device) | |
attention_mask = torch.ones_like(input_ids, device=model.device) | |
# Preprocess image using model's vision tower | |
px = model.get_vision_tower().image_processor( | |
images=image, return_tensors="pt" | |
)["pixel_values"] | |
px = px.to(model.device, dtype=model.dtype) | |
# Generate caption | |
with torch.no_grad(): | |
out = model.generate( | |
inputs=input_ids, | |
attention_mask=attention_mask, | |
images=px, | |
max_new_tokens=128, | |
do_sample=False, # Deterministic generation | |
temperature=1.0, | |
) | |
# Decode and return the generated text | |
generated_text = tok.decode(out[0], skip_special_tokens=True) | |
# Extract only the assistant's response | |
if "assistant" in generated_text: | |
response = generated_text.split("assistant")[-1].strip() | |
else: | |
response = generated_text | |
return response | |
except Exception as e: | |
return f"Error generating caption: {str(e)}" | |
# Create Gradio interface | |
with gr.Blocks(title="FastVLM Image Captioning") as demo: | |
gr.Markdown( | |
""" | |
# 🖼️ FastVLM Image Captioning | |
Upload an image to generate a detailed caption using Apple's FastVLM-0.5B model. | |
You can use the default prompt or provide your own custom prompt. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image( | |
type="pil", | |
label="Upload Image", | |
elem_id="image-upload" | |
) | |
custom_prompt = gr.Textbox( | |
label="Custom Prompt (Optional)", | |
placeholder="Leave empty for default: 'Describe this image in detail.'", | |
lines=2 | |
) | |
with gr.Row(): | |
clear_btn = gr.ClearButton([image_input, custom_prompt]) | |
generate_btn = gr.Button("Generate Caption", variant="primary") | |
with gr.Column(): | |
output = gr.Textbox( | |
label="Generated Caption", | |
lines=8, | |
max_lines=15, | |
show_copy_button=True | |
) | |
# Event handlers | |
generate_btn.click( | |
fn=caption_image, | |
inputs=[image_input, custom_prompt], | |
outputs=output | |
) | |
# Also generate on image upload if no custom prompt | |
image_input.change( | |
fn=lambda img, prompt: caption_image(img, prompt) if img is not None and not prompt else None, | |
inputs=[image_input, custom_prompt], | |
outputs=output | |
) | |
gr.Markdown( | |
""" | |
--- | |
**Model:** [apple/FastVLM-0.5B](https://huggingface.co/apple/FastVLM-0.5B) | |
**Note:** This Space uses ZeroGPU for dynamic GPU allocation. | |
""" | |
) | |
if __name__ == "__main__": | |
demo.launch( | |
share=False, | |
show_error=True, | |
server_name="0.0.0.0", | |
server_port=7860 | |
) |