Spaces:
Sleeping
Sleeping
| import os | |
| import jax | |
| from transformers import AutoTokenizer, CLIPProcessor | |
| from configuration_hybrid_clip import HybridCLIPConfig | |
| from modeling_hybrid_clip import FlaxHybridCLIP | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import torch | |
| import torchvision | |
| from torchvision.transforms.functional import InterpolationMode | |
| from torchvision.transforms import Resize, Normalize, ConvertImageDtype, ToTensor | |
| import numpy as np | |
| import pandas as pd | |
| def main(): | |
| model = FlaxHybridCLIP.from_pretrained("flax-community/medclip-roco") | |
| vision_model_name = "openai/clip-vit-base-patch32" | |
| img_dir = "/Users/kaumad/Documents/coding/hf-flax/demo/medclip-roco/images" | |
| processor = CLIPProcessor.from_pretrained(vision_model_name) | |
| img_list = os.listdir(img_dir) | |
| embeddings = [] | |
| for idx, img_path in enumerate(img_list): | |
| if idx % 10 == 0: | |
| print(f"{idx} images processed") | |
| img = Image.open(os.path.join(img_dir, img_path)).convert('RGB') | |
| inputs = processor(images=img, return_tensors="jax", padding=True) | |
| inputs['pixel_values'] = inputs['pixel_values'].transpose(0, 2, 3, 1) | |
| img_vec = model.get_image_features(**inputs) | |
| img_vec = np.array(img_vec).reshape(-1).tolist() | |
| embeddings.append(img_vec) | |
| if __name__=='__main__': | |
| main() |