In [14]:
import open_clip
import torch
from tqdm.notebook import tqdm
import pandas as pd
import os

device = "cuda" if torch.cuda.is_available() else "cpu"

PROMPTS = [
    '{0}',
    'an image of {0}',
    'a photo of {0}',
    '{0} on a photo',
    'a photo of a person named {0}',
    'a person named {0}',
    'a man named {0}',
    'a woman named {0}',
    'the name of the person is {0}',
    'a photo of a person with the name {0}',
    '{0} at a gala',
    'a photo of the celebrity {0}',
    'actor {0}',
    'actress {0}',
    'a colored photo of {0}',
    'a black and white photo of {0}',
    'a cool photo of {0}',
    'a cropped photo of {0}',
    'a cropped image of {0}',
    '{0} in a suit',
    '{0} in a dress'
]
MODEL_NAMES = ['ViT-B-32', 'ViT-B-16', 'ViT-L-14']
SEED = 42

In [2]:
# init clip
models = {}
preprocessings = {}
tokenizers = {}
for model_name in MODEL_NAMES:
    model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained='laion400m_e32')
    preprocessings[model_name] = preprocess
    model = model.eval()
    models[model_name] = model
    tokenizers[model_name] = open_clip.get_tokenizer(model_name)

In [3]:
# define a function to get the predictions for an actor/actress
@torch.no_grad()
def get_text_embeddings(model, context, context_batchsize=1_000, use_tqdm=False):
    context_batchsize = context_batchsize * torch.cuda.device_count()
    # if there is not batches for the context unsqueeze it
    if context.dim() < 3:
        context = context.unsqueeze(0)

    # get the batch size, the number of labels and the sequence length
    seq_len = context.shape[-1]
    viewed_context = context.view(-1, seq_len)

    text_features = []
    for context_batch_idx in tqdm(range(0, len(viewed_context), context_batchsize), desc="Calculating Text Embeddings",
                                  disable=not use_tqdm):
        context_batch = viewed_context[context_batch_idx:context_batch_idx + context_batchsize]
        batch_text_features = model.encode_text(context_batch, normalize=True).cpu()

        text_features.append(batch_text_features)
    text_features = torch.cat(text_features).view(list(context.shape[:-1]) + [-1])

    return text_features

In [4]:
# load the possible names
possible_names = pd.read_csv('./full_names.csv', index_col=0)
possible_names
# possible_names_list = (possible_names['first_name'] + ' ' + possible_names['last_name']).tolist()
# possible_names_list[:5]

Unnamed: 0,first_name,sex,last_name
0,Eliana,f,Cardenas
1,Meghann,f,Daniels
2,Ada,f,Stevenson
3,Elsa,f,Leblanc
4,Avah,f,Lambert
...,...,...,...
9995,Kasen,m,Barker
9996,Camryn,m,Roberts
9997,Henry,m,Whitaker
9998,Adin,m,Richards


In [5]:
# populate the prompts with the possible names
prompts = []
for idx, row in possible_names.iterrows():
    df_dict = row.to_dict()
    name = f'{row["first_name"]} {row["last_name"]}'
    for prompt_idx, prompt in enumerate(PROMPTS):
        df_dict[f'prompt_{prompt_idx}'] = prompt.format(name)
    prompts.append(df_dict)
prompts = pd.DataFrame(prompts)
prompts

Unnamed: 0,first_name,sex,last_name,prompt_0,prompt_1,prompt_2,prompt_3,prompt_4,prompt_5,prompt_6,...,prompt_11,prompt_12,prompt_13,prompt_14,prompt_15,prompt_16,prompt_17,prompt_18,prompt_19,prompt_20
0,Eliana,f,Cardenas,Eliana Cardenas,an image of Eliana Cardenas,a photo of Eliana Cardenas,Eliana Cardenas on a photo,a photo of a person named Eliana Cardenas,a person named Eliana Cardenas,a man named Eliana Cardenas,...,a photo of the celebrity Eliana Cardenas,actor Eliana Cardenas,actress Eliana Cardenas,a colored photo of Eliana Cardenas,a black and white photo of Eliana Cardenas,a cool photo of Eliana Cardenas,a cropped photo of Eliana Cardenas,a cropped image of Eliana Cardenas,Eliana Cardenas in a suit,Eliana Cardenas in a dress
1,Meghann,f,Daniels,Meghann Daniels,an image of Meghann Daniels,a photo of Meghann Daniels,Meghann Daniels on a photo,a photo of a person named Meghann Daniels,a person named Meghann Daniels,a man named Meghann Daniels,...,a photo of the celebrity Meghann Daniels,actor Meghann Daniels,actress Meghann Daniels,a colored photo of Meghann Daniels,a black and white photo of Meghann Daniels,a cool photo of Meghann Daniels,a cropped photo of Meghann Daniels,a cropped image of Meghann Daniels,Meghann Daniels in a suit,Meghann Daniels in a dress
2,Ada,f,Stevenson,Ada Stevenson,an image of Ada Stevenson,a photo of Ada Stevenson,Ada Stevenson on a photo,a photo of a person named Ada Stevenson,a person named Ada Stevenson,a man named Ada Stevenson,...,a photo of the celebrity Ada Stevenson,actor Ada Stevenson,actress Ada Stevenson,a colored photo of Ada Stevenson,a black and white photo of Ada Stevenson,a cool photo of Ada Stevenson,a cropped photo of Ada Stevenson,a cropped image of Ada Stevenson,Ada Stevenson in a suit,Ada Stevenson in a dress
3,Elsa,f,Leblanc,Elsa Leblanc,an image of Elsa Leblanc,a photo of Elsa Leblanc,Elsa Leblanc on a photo,a photo of a person named Elsa Leblanc,a person named Elsa Leblanc,a man named Elsa Leblanc,...,a photo of the celebrity Elsa Leblanc,actor Elsa Leblanc,actress Elsa Leblanc,a colored photo of Elsa Leblanc,a black and white photo of Elsa Leblanc,a cool photo of Elsa Leblanc,a cropped photo of Elsa Leblanc,a cropped image of Elsa Leblanc,Elsa Leblanc in a suit,Elsa Leblanc in a dress
4,Avah,f,Lambert,Avah Lambert,an image of Avah Lambert,a photo of Avah Lambert,Avah Lambert on a photo,a photo of a person named Avah Lambert,a person named Avah Lambert,a man named Avah Lambert,...,a photo of the celebrity Avah Lambert,actor Avah Lambert,actress Avah Lambert,a colored photo of Avah Lambert,a black and white photo of Avah Lambert,a cool photo of Avah Lambert,a cropped photo of Avah Lambert,a cropped image of Avah Lambert,Avah Lambert in a suit,Avah Lambert in a dress
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9995,Kasen,m,Barker,Kasen Barker,an image of Kasen Barker,a photo of Kasen Barker,Kasen Barker on a photo,a photo of a person named Kasen Barker,a person named Kasen Barker,a man named Kasen Barker,...,a photo of the celebrity Kasen Barker,actor Kasen Barker,actress Kasen Barker,a colored photo of Kasen Barker,a black and white photo of Kasen Barker,a cool photo of Kasen Barker,a cropped photo of Kasen Barker,a cropped image of Kasen Barker,Kasen Barker in a suit,Kasen Barker in a dress
9996,Camryn,m,Roberts,Camryn Roberts,an image of Camryn Roberts,a photo of Camryn Roberts,Camryn Roberts on a photo,a photo of a person named Camryn Roberts,a person named Camryn Roberts,a man named Camryn Roberts,...,a photo of the celebrity Camryn Roberts,actor Camryn Roberts,actress Camryn Roberts,a colored photo of Camryn Roberts,a black and white photo of Camryn Roberts,a cool photo of Camryn Roberts,a cropped photo of Camryn Roberts,a cropped image of Camryn Roberts,Camryn Roberts in a suit,Camryn Roberts in a dress
9997,Henry,m,Whitaker,Henry Whitaker,an image of Henry Whitaker,a photo of Henry Whitaker,Henry Whitaker on a photo,a photo of a person named Henry Whitaker,a person named Henry Whitaker,a man named Henry Whitaker,...,a photo of the celebrity Henry Whitaker,actor Henry Whitaker,actress Henry Whitaker,a colored photo of Henry Whitaker,a black and white photo of Henry Whitaker,a cool photo of Henry Whitaker,a cropped photo of Henry Whitaker,a cropped image of Henry Whitaker,Henry Whitaker in a suit,Henry Whitaker in a dress
9998,Adin,m,Richards,Adin Richards,an image of Adin Richards,a photo of Adin Richards,Adin Richards on a photo,a photo of a person named Adin Richards,a person named Adin Richards,a man named Adin Richards,...,a photo of the celebrity Adin Richards,actor Adin Richards,actress Adin Richards,a colored photo of Adin Richards,a black and white photo of Adin Richards,a cool photo of Adin Richards,a cropped photo of Adin Richards,a cropped image of Adin Richards,Adin Richards in a suit,Adin Richards in a dress


In [7]:
label_context_vecs = []
for i in range(len(PROMPTS)):
    context = open_clip.tokenize(prompts[f'prompt_{i}'].to_numpy())
    label_context_vecs.append(context)
label_context_vecs = torch.stack(label_context_vecs)

In [8]:
label_context_vecs = label_context_vecs.to(device)

text_embeddings_per_model = {}
for model_name, model in models.items():
    model = model.to(device)
    text_embeddings = get_text_embeddings(model, label_context_vecs, use_tqdm=True, context_batchsize=1_000)
    text_embeddings_per_model[model_name] = text_embeddings
    model = model.cpu()

label_context_vecs = label_context_vecs.cpu()

Calculating Text Embeddings:   0%|          | 0/210 [00:00<?, ?it/s]

Calculating Text Embeddings:   0%|          | 0/210 [00:00<?, ?it/s]

Calculating Text Embeddings:   0%|          | 0/210 [00:00<?, ?it/s]

In [18]:
# save the calculated embeddings to a file
if not os.path.exists('./prompt_text_embeddings'):
    os.makedirs('./prompt_text_embeddings')

In [20]:
for model_name, _ in models.items():
    torch.save(
        text_embeddings_per_model[model_name],
        f'./prompt_text_embeddings/{model_name}_prompt_text_embeddings.pt'
    )