vlm_clone_2 / recognize-anything /inference_ram.py
tuandunghcmut's picture
Add files using upload-large-folder tool
4617407 verified
raw
history blame
1.49 kB
'''
* The Recognize Anything Model (RAM)
* Written by Xinyu Huang
'''
import argparse
import numpy as np
import random
import torch
from PIL import Image
from ram.models import ram
from ram import inference_ram as inference
from ram import get_transform
parser = argparse.ArgumentParser(
description='Tag2Text inferece for tagging and captioning')
parser.add_argument('--image',
metavar='DIR',
help='path to dataset',
default='images/demo/demo1.jpg')
parser.add_argument('--pretrained',
metavar='DIR',
help='path to pretrained model',
default='pretrained/ram_swin_large_14m.pth')
parser.add_argument('--image-size',
default=384,
type=int,
metavar='N',
help='input image size (default: 384)')
if __name__ == "__main__":
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
transform = get_transform(image_size=args.image_size)
#######load model
model = ram(pretrained=args.pretrained,
image_size=args.image_size,
vit='swin_l')
model.eval()
model = model.to(device)
image = transform(Image.open(args.image)).unsqueeze(0).to(device)
res = inference(image, model)
print("Image Tags: ", res[0])
print("图像标签: ", res[1])