| import torch | |
| from PIL import Image | |
| from transformers import (AutoTokenizer, VisionEncoderDecoderModel, | |
| ViTFeatureExtractor) | |
| import gradio as gr | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| else: | |
| device = "cpu" | |
| encoder_checkpoint = "google/vit-base-patch16-224-in21k" | |
| decoder_checkpoint = "gpt2" | |
| model_checkpoint = "gagan3012/ViTGPT2I2A" | |
| feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint) | |
| tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint) | |
| model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device) | |
| def predict(image): | |
| clean_text = lambda x: x.replace("<|endoftext|>", "").split("\n")[0] | |
| sample = feature_extractor(image, return_tensors="pt").pixel_values.to(device) | |
| caption_ids = model.generate(sample, max_length=50)[0] | |
| caption_text = clean_text(tokenizer.decode(caption_ids)) | |
| return caption_text | |
| inputs = [ | |
| gr.inputs.Image(type="pil", label="Original Image") | |
| ] | |
| outputs = [ | |
| gr.outputs.Textbox(label = 'Caption') | |
| ] | |
| title = "Image Captioning using ViT + GPT2" | |
| description = "ViT and GPT2 are used to generate Image Caption for the uploaded images" | |
| article = " <a href='https://huggingface.co/gagan3012/ViTGPT2_vizwiz'>Model Repo on Hugging Face Model Hub</a>" | |
| examples = [ | |
| ["duck.jpg"], | |
| ["dice.jpg"], | |
| ["banana.jpg"], | |
| ["avacado.jpg"] | |
| ] | |
| gr.Interface( | |
| predict, | |
| inputs, | |
| outputs, | |
| title=title, | |
| description=description, | |
| article=article, | |
| examples=examples, | |
| theme="huggingface", | |
| ).launch(debug=True, enable_queue=True) | |