Spaces:
Runtime error
Runtime error
File size: 1,689 Bytes
b0b9e1f a3be375 b0b9e1f 9fbe234 b0b9e1f 9fbe234 b0b9e1f 9fbe234 b0b9e1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
import torch
from huggan.pytorch.lightweight_gan.lightweight_gan import LightweightGAN
from datasets import load_dataset
def get_train_data(dataset_name="huggan/smithsonian_butterflies_subset"):
dataset=load_dataset(dataset_name)
dataset=dataset.sort("sim_score")
return dataset["train"]
from transformers import BeitFeatureExtractor, BeitForImageClassification
feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224')
model = BeitForImageClassification.from_pretrained('microsoft/beit-base-patch16-224')
def embed(images):
inputs = feature_extractor(images=images, return_tensors="pt")
outputs = model(**inputs,output_hidden_states= True)
last_hidden=outputs.hidden_states[-1]
pooler=model.base_model.pooler
final_emb=pooler(last_hidden).detach().numpy()
return final_emb
def build_index():
dataset=get_train_data()
ds_with_embeddings = dataset.map(lambda x: {"beit_embeddings":embed(x["image"])},batched=True,batch_size=20)
ds_with_embeddings.add_faiss_index(column='beit_embeddings')
ds_with_embeddings.save_faiss_index('beit_embeddings', 'beit_index.faiss')
def get_dataset():
dataset=get_train_data()
dataset.load_faiss_index('beit_embeddings', 'beit_index.faiss')
return dataset
def load_model(model_name='ceyda/butterfly_cropped_uniq1K_512'):
gan = LightweightGAN.from_pretrained(model_name)
gan.eval()
return gan
def generate(gan,batch_size=1):
with torch.no_grad():
ims = gan.G(torch.randn(batch_size, gan.latent_dim)).clamp_(0., 1.)
ims = ims.permute(0,2,3,1).detach().cpu().numpy()
return ims
def interpolate():
pass |