Spaces:
Runtime error
Runtime error
import torch | |
from huggan.pytorch.lightweight_gan.lightweight_gan import LightweightGAN | |
from datasets import load_dataset | |
def get_train_data(dataset_name="huggan/smithsonian_butterflies_subset"): | |
dataset=load_dataset(dataset_name) | |
dataset=dataset.sort("sim_score") | |
return dataset["train"] | |
from transformers import BeitFeatureExtractor, BeitForImageClassification | |
feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224') | |
model = BeitForImageClassification.from_pretrained('microsoft/beit-base-patch16-224') | |
def embed(images): | |
inputs = feature_extractor(images=images, return_tensors="pt") | |
outputs = model(**inputs,output_hidden_states= True) | |
last_hidden=outputs.hidden_states[-1] | |
pooler=model.base_model.pooler | |
final_emb=pooler(last_hidden).detach().numpy() | |
return final_emb | |
def build_index(): | |
dataset=get_train_data() | |
ds_with_embeddings = dataset.map(lambda x: {"beit_embeddings":embed(x["image"])},batched=True,batch_size=20) | |
ds_with_embeddings.add_faiss_index(column='beit_embeddings') | |
ds_with_embeddings.save_faiss_index('beit_embeddings', 'beit_index.faiss') | |
def get_dataset(): | |
dataset=get_train_data() | |
dataset.load_faiss_index('beit_embeddings', 'beit_index.faiss') | |
return dataset | |
def load_model(model_name='ceyda/butterfly_cropped_uniq1K_512'): | |
gan = LightweightGAN.from_pretrained(model_name) | |
gan.eval() | |
return gan | |
def generate(gan,batch_size=1): | |
with torch.no_grad(): | |
ims = gan.G(torch.randn(batch_size, gan.latent_dim)).clamp_(0., 1.) | |
ims = ims.permute(0,2,3,1).detach().cpu().numpy() | |
return ims | |
def interpolate(): | |
pass |