File size: 1,689 Bytes
b0b9e1f
 
 
 
a3be375
b0b9e1f
 
 
9fbe234
 
 
 
 
 
 
 
 
 
 
b0b9e1f
9fbe234
 
 
 
 
 
 
 
 
 
b0b9e1f
 
 
9fbe234
b0b9e1f
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
from huggan.pytorch.lightweight_gan.lightweight_gan import LightweightGAN
from datasets import load_dataset

def get_train_data(dataset_name="huggan/smithsonian_butterflies_subset"):
    dataset=load_dataset(dataset_name)
    dataset=dataset.sort("sim_score")
    return dataset["train"]

from transformers import BeitFeatureExtractor, BeitForImageClassification
feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224')
model = BeitForImageClassification.from_pretrained('microsoft/beit-base-patch16-224')
def embed(images):
    inputs = feature_extractor(images=images, return_tensors="pt")
    outputs = model(**inputs,output_hidden_states= True)
    last_hidden=outputs.hidden_states[-1]
    pooler=model.base_model.pooler
    final_emb=pooler(last_hidden).detach().numpy()
    return final_emb    
    
def build_index():    
    dataset=get_train_data()
    ds_with_embeddings = dataset.map(lambda x: {"beit_embeddings":embed(x["image"])},batched=True,batch_size=20)
    ds_with_embeddings.add_faiss_index(column='beit_embeddings')
    ds_with_embeddings.save_faiss_index('beit_embeddings', 'beit_index.faiss')

def get_dataset():
    dataset=get_train_data()
    dataset.load_faiss_index('beit_embeddings', 'beit_index.faiss')
    return dataset

def load_model(model_name='ceyda/butterfly_cropped_uniq1K_512'):
    gan = LightweightGAN.from_pretrained(model_name)
    gan.eval()
    return gan
    
def generate(gan,batch_size=1):
    with torch.no_grad():
        ims = gan.G(torch.randn(batch_size, gan.latent_dim)).clamp_(0., 1.)
        ims = ims.permute(0,2,3,1).detach().cpu().numpy()
    return ims

def interpolate():
    pass