|
|
import os
|
|
|
import glob
|
|
|
import math
|
|
|
from functools import partial
|
|
|
import torch
|
|
|
|
|
|
import ipywidgets as widgets
|
|
|
import io
|
|
|
from PIL import Image
|
|
|
from torchvision import transforms
|
|
|
import matplotlib.pyplot as plt
|
|
|
import numpy as np
|
|
|
import pandas as pd
|
|
|
from torch import nn
|
|
|
from thop import profile
|
|
|
is_flop_cal = False
|
|
|
|
|
|
import warnings
|
|
|
warnings.filterwarnings("ignore")
|
|
|
|
|
|
|
|
|
|
|
|
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
|
|
|
|
|
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
|
|
|
|
|
|
|
|
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
|
|
def norm_cdf(x):
|
|
|
|
|
|
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
|
|
|
|
|
|
|
|
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
|
|
if drop_prob == 0. or not training:
|
|
|
return x
|
|
|
keep_prob = 1 - drop_prob
|
|
|
|
|
|
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
|
|
|
random_tensor = keep_prob + \
|
|
|
torch.rand(shape, dtype=x.dtype, device=x.device)
|
|
|
random_tensor.floor_()
|
|
|
output = x.div(keep_prob) * random_tensor
|
|
|
return output
|
|
|
|
|
|
|
|
|
class DropPath(nn.Module):
|
|
|
"""
|
|
|
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
|
|
"""
|
|
|
|
|
|
def __init__(self, drop_prob=None):
|
|
|
super(DropPath, self).__init__()
|
|
|
self.drop_prob = drop_prob
|
|
|
|
|
|
def forward(self, x):
|
|
|
return drop_path(x, self.drop_prob, self.training)
|
|
|
|
|
|
|
|
|
class Mlp(nn.Module):
|
|
|
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
|
|
super().__init__()
|
|
|
out_features = out_features or in_features
|
|
|
hidden_features = hidden_features or in_features
|
|
|
self.fc1 = nn.Linear(in_features, hidden_features)
|
|
|
self.act = act_layer()
|
|
|
self.fc2 = nn.Linear(hidden_features, out_features)
|
|
|
self.drop = nn.Dropout(drop)
|
|
|
|
|
|
def forward(self, x):
|
|
|
x = self.fc1(x)
|
|
|
x = self.act(x)
|
|
|
x = self.drop(x)
|
|
|
x = self.fc2(x)
|
|
|
x = self.drop(x)
|
|
|
return x
|
|
|
|
|
|
|
|
|
class Attention(nn.Module):
|
|
|
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
|
|
super().__init__()
|
|
|
self.num_heads = num_heads
|
|
|
head_dim = dim // num_heads
|
|
|
self.scale = qk_scale or head_dim ** -0.5
|
|
|
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
|
self.attn_drop = nn.Dropout(attn_drop)
|
|
|
self.proj = nn.Linear(dim, dim)
|
|
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
|
|
|
|
def forward(self, x):
|
|
|
B, N, C = x.shape
|
|
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C //
|
|
|
self.num_heads).permute(2, 0, 3, 1, 4)
|
|
|
q, k, v = qkv[0], qkv[1], qkv[2]
|
|
|
|
|
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
|
|
attn = attn.softmax(dim=-1)
|
|
|
attn = self.attn_drop(attn)
|
|
|
|
|
|
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
|
|
x = self.proj(x)
|
|
|
x = self.proj_drop(x)
|
|
|
return x, attn
|
|
|
|
|
|
|
|
|
class Block(nn.Module):
|
|
|
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
|
|
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
|
|
super().__init__()
|
|
|
self.norm1 = norm_layer(dim)
|
|
|
self.attn = Attention(
|
|
|
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
|
|
self.drop_path = DropPath(
|
|
|
drop_path) if drop_path > 0. else nn.Identity()
|
|
|
self.norm2 = norm_layer(dim)
|
|
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
|
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
|
|
|
act_layer=act_layer, drop=drop)
|
|
|
|
|
|
def forward(self, x, return_attention=False):
|
|
|
y, attn = self.attn(self.norm1(x))
|
|
|
if return_attention:
|
|
|
return attn
|
|
|
x = x + self.drop_path(y)
|
|
|
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
|
|
return x
|
|
|
|
|
|
|
|
|
class PatchEmbed(nn.Module):
|
|
|
"""
|
|
|
Image to Patch Embedding
|
|
|
"""
|
|
|
|
|
|
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
|
|
super().__init__()
|
|
|
num_patches = (img_size // patch_size) * (img_size // patch_size)
|
|
|
self.img_size = img_size
|
|
|
self.patch_size = patch_size
|
|
|
self.num_patches = num_patches
|
|
|
self.proj = nn.Conv2d(in_chans, embed_dim,
|
|
|
kernel_size=patch_size, stride=patch_size)
|
|
|
|
|
|
def forward(self, x):
|
|
|
B, C, H, W = x.shape
|
|
|
x = self.proj(x).flatten(2).transpose(1, 2)
|
|
|
return x
|
|
|
|
|
|
|
|
|
class VisionTransformer(nn.Module):
|
|
|
"""
|
|
|
Vision Transformer
|
|
|
"""
|
|
|
def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
|
|
|
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
|
|
drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs):
|
|
|
super().__init__()
|
|
|
self.num_features = self.embed_dim = embed_dim
|
|
|
|
|
|
self.patch_embed = PatchEmbed(
|
|
|
img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
|
|
num_patches = self.patch_embed.num_patches
|
|
|
|
|
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
|
|
self.pos_embed = nn.Parameter(
|
|
|
torch.zeros(1, num_patches + 1, embed_dim))
|
|
|
self.pos_drop = nn.Dropout(p=drop_rate)
|
|
|
|
|
|
|
|
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
|
|
|
self.blocks = nn.ModuleList([
|
|
|
Block(
|
|
|
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
|
|
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
|
|
|
for i in range(depth)])
|
|
|
self.norm = norm_layer(embed_dim)
|
|
|
|
|
|
|
|
|
self.head = nn.Linear(
|
|
|
embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
|
|
|
|
|
trunc_normal_(self.pos_embed, std=.02)
|
|
|
trunc_normal_(self.cls_token, std=.02)
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
def _init_weights(self, m):
|
|
|
if isinstance(m, nn.Linear):
|
|
|
trunc_normal_(m.weight, std=.02)
|
|
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
elif isinstance(m, nn.LayerNorm):
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
nn.init.constant_(m.weight, 1.0)
|
|
|
|
|
|
def interpolate_pos_encoding(self, x, w, h):
|
|
|
npatch = x.shape[1] - 1
|
|
|
N = self.pos_embed.shape[1] - 1
|
|
|
if npatch == N and w == h:
|
|
|
return self.pos_embed
|
|
|
class_pos_embed = self.pos_embed[:, 0]
|
|
|
patch_pos_embed = self.pos_embed[:, 1:]
|
|
|
dim = x.shape[-1]
|
|
|
w0 = w // self.patch_embed.patch_size
|
|
|
h0 = h // self.patch_embed.patch_size
|
|
|
|
|
|
|
|
|
w0, h0 = w0 + 0.1, h0 + 0.1
|
|
|
patch_pos_embed = nn.functional.interpolate(
|
|
|
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(
|
|
|
math.sqrt(N)), dim).permute(0, 3, 1, 2),
|
|
|
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
|
|
|
mode='bicubic',
|
|
|
)
|
|
|
assert int(
|
|
|
w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
|
|
|
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
|
|
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
|
|
|
|
|
def prepare_tokens(self, x):
|
|
|
B, nc, w, h = x.shape
|
|
|
x = self.patch_embed(x)
|
|
|
|
|
|
|
|
|
cls_tokens = self.cls_token.expand(B, -1, -1)
|
|
|
x = torch.cat((cls_tokens, x), dim=1)
|
|
|
|
|
|
|
|
|
x = x + self.interpolate_pos_encoding(x, w, h)
|
|
|
|
|
|
return self.pos_drop(x)
|
|
|
|
|
|
def forward(self, x):
|
|
|
x = self.prepare_tokens(x)
|
|
|
for blk in self.blocks:
|
|
|
x = blk(x)
|
|
|
x = self.norm(x)
|
|
|
return x[:, 0], x[:, 1:]
|
|
|
|
|
|
def get_last_selfattention(self, x):
|
|
|
x = self.prepare_tokens(x)
|
|
|
for i, blk in enumerate(self.blocks):
|
|
|
if i < len(self.blocks) - 1:
|
|
|
x = blk(x)
|
|
|
else:
|
|
|
|
|
|
|
|
|
|
|
|
return blk(x, return_attention=True)
|
|
|
|
|
|
def get_intermediate_layers(self, x, n=1):
|
|
|
x = self.prepare_tokens(x)
|
|
|
|
|
|
output = []
|
|
|
for i, blk in enumerate(self.blocks):
|
|
|
x = blk(x)
|
|
|
if len(self.blocks) - i <= n:
|
|
|
output.append(self.norm(x))
|
|
|
return output
|
|
|
|
|
|
|
|
|
class VitGenerator(object):
|
|
|
def __init__(self, name_model, patch_size, device, evaluate=True, random=False, verbose=False):
|
|
|
self.name_model = name_model
|
|
|
self.patch_size = patch_size
|
|
|
self.evaluate = evaluate
|
|
|
self.device = device
|
|
|
self.verbose = verbose
|
|
|
self.model = self._getModel()
|
|
|
self._initializeModel()
|
|
|
if not random:
|
|
|
self._loadPretrainedWeights()
|
|
|
|
|
|
def _getModel(self):
|
|
|
if self.verbose:
|
|
|
pass
|
|
|
|
|
|
if self.name_model == 'vit_tiny':
|
|
|
model = VisionTransformer(patch_size=self.patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
|
|
|
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6))
|
|
|
|
|
|
elif self.name_model == 'vit_small':
|
|
|
model = VisionTransformer(patch_size=self.patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
|
|
|
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6))
|
|
|
|
|
|
elif self.name_model == 'vit_base':
|
|
|
model = VisionTransformer(patch_size=self.patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
|
|
|
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6))
|
|
|
else:
|
|
|
raise f"No model found with {self.name_model}"
|
|
|
|
|
|
return model
|
|
|
|
|
|
def _initializeModel(self):
|
|
|
if self.evaluate:
|
|
|
for p in self.model.parameters():
|
|
|
p.requires_grad = False
|
|
|
|
|
|
self.model.eval()
|
|
|
|
|
|
self.model.to(self.device)
|
|
|
|
|
|
def _loadPretrainedWeights(self):
|
|
|
if self.verbose:
|
|
|
pass
|
|
|
|
|
|
url = None
|
|
|
if self.name_model == 'vit_small' and self.patch_size == 16:
|
|
|
url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
|
|
|
|
|
|
elif self.name_model == 'vit_small' and self.patch_size == 8:
|
|
|
url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth"
|
|
|
|
|
|
elif self.name_model == 'vit_base' and self.patch_size == 16:
|
|
|
url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
|
|
|
|
|
|
elif self.name_model == 'vit_base' and self.patch_size == 8:
|
|
|
url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
|
|
|
|
|
|
if url is None:
|
|
|
pass
|
|
|
|
|
|
|
|
|
else:
|
|
|
state_dict = torch.hub.load_state_dict_from_url(
|
|
|
url="https://dl.fbaipublicfiles.com/dino/" + url)
|
|
|
self.model.load_state_dict(state_dict, strict=True)
|
|
|
|
|
|
|
|
|
def get_last_selfattention(self, img):
|
|
|
return self.model.get_last_selfattention(img.to(self.device))
|
|
|
|
|
|
def __call__(self, x):
|
|
|
return self.model(x)
|
|
|
|
|
|
|
|
|
def transform(img, img_size):
|
|
|
img = transforms.Resize(img_size)(img)
|
|
|
img = transforms.ToTensor()(img)
|
|
|
return img
|
|
|
|
|
|
def visualize_predict(model, img_tensor, patch_size, device, video_name, frame_number, fig_name, combined_name):
|
|
|
if img_tensor.dim() == 3:
|
|
|
img_tensor = img_tensor.unsqueeze(0)
|
|
|
attention = visualize_attention(model, img_tensor, patch_size, device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
activations_dict, frame_npy_path = get_activation_npy(video_name, frame_number, fig_name, combined_name, attention)
|
|
|
return activations_dict, frame_npy_path
|
|
|
|
|
|
def visualize_attention(model, img_tensor, patch_size, device):
|
|
|
|
|
|
|
|
|
w, h = img_tensor.shape[2] - img_tensor.shape[2] % patch_size, img_tensor.shape[3] - img_tensor.shape[3] % patch_size
|
|
|
img_tensor = img_tensor[:, :, :w, :h]
|
|
|
|
|
|
w_featmap = img_tensor.shape[-2] // patch_size
|
|
|
h_featmap = img_tensor.shape[-1] // patch_size
|
|
|
|
|
|
attentions = model.get_last_selfattention(img_tensor.to(device))
|
|
|
nh = attentions.shape[1]
|
|
|
|
|
|
|
|
|
attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
|
|
|
attentions = attentions.reshape(nh, w_featmap, h_featmap)
|
|
|
attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=patch_size, mode="nearest")[0].cpu().numpy()
|
|
|
|
|
|
return attentions
|
|
|
|
|
|
def get_activation_png(img, png_path, fig_name, attention):
|
|
|
n_heads = attention.shape[0]
|
|
|
|
|
|
|
|
|
for i in range(n_heads):
|
|
|
plt.imshow(attention[i], cmap='viridis')
|
|
|
plt.title(f"Head n: {i + 1}")
|
|
|
plt.axis('off')
|
|
|
|
|
|
|
|
|
fig_path = f'{png_path}{fig_name}_head_{i + 1}.png'
|
|
|
print(fig_path)
|
|
|
plt.savefig(fig_path)
|
|
|
plt.close()
|
|
|
|
|
|
|
|
|
plt.figure(figsize=(10, 10))
|
|
|
image_name = fig_name.replace('vit_feature_map_', '')
|
|
|
text = [f"{image_name}", "Head Mean"]
|
|
|
for i, fig in enumerate([img, np.mean(attention, 0)]):
|
|
|
plt.subplot(1, 2, i+1)
|
|
|
plt.imshow(fig, cmap='viridis')
|
|
|
plt.title(text[i])
|
|
|
plt.axis('off')
|
|
|
fig_path1 = f'{png_path}{fig_name}_head_mean.png'
|
|
|
print(fig_path1)
|
|
|
print("----------------" + '\n')
|
|
|
plt.savefig(fig_path1)
|
|
|
plt.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_activation_npy(video_name, frame_number, fig_name, combined_name, attention):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mean_attention = attention.mean(axis=0)
|
|
|
frame_npy_path = f'../features/vit/{video_name}/frame_{frame_number}_{combined_name}.npy'
|
|
|
|
|
|
return mean_attention, frame_npy_path
|
|
|
|
|
|
|
|
|
class Loader(object):
|
|
|
def __init__(self):
|
|
|
self.uploader = widgets.FileUpload(accept='image/*', multiple=False)
|
|
|
self._start()
|
|
|
|
|
|
def _start(self):
|
|
|
display(self.uploader)
|
|
|
|
|
|
def getLastImage(self):
|
|
|
try:
|
|
|
for uploaded_filename in self.uploader.value:
|
|
|
uploaded_filename = uploaded_filename
|
|
|
img = Image.open(io.BytesIO(
|
|
|
bytes(self.uploader.value[uploaded_filename]['content'])))
|
|
|
|
|
|
return img
|
|
|
except:
|
|
|
return None
|
|
|
|
|
|
def saveImage(self, path):
|
|
|
with open(path, 'wb') as output_file:
|
|
|
for uploaded_filename in self.uploader.value:
|
|
|
content = self.uploader.value[uploaded_filename]['content']
|
|
|
output_file.write(content)
|
|
|
|
|
|
def process_video_frame(video_name, frame, frame_number, model, patch_size, device):
|
|
|
|
|
|
if frame.dim() == 3:
|
|
|
frame = frame.unsqueeze(0)
|
|
|
if frame.shape[2:] != (224, 224):
|
|
|
frame_tensor = torch.nn.functional.interpolate(frame, size=(224, 224), mode='bicubic', align_corners=False)
|
|
|
else:
|
|
|
frame_tensor = frame
|
|
|
|
|
|
|
|
|
if is_flop_cal == True:
|
|
|
total_flops, total_params = profile(model.model, inputs=(frame_tensor,), verbose=False)
|
|
|
print(f"total FLOPs for ViT layerstack: {total_flops}, Params: {total_params}")
|
|
|
else:
|
|
|
total_flops, total_params = None, None
|
|
|
|
|
|
fig_name = f"vit_feature_map"
|
|
|
combined_name = f"vit_feature_map"
|
|
|
|
|
|
|
|
|
attention_features, frame_feature_npy_path = extract_features(model, frame_tensor, video_name, frame_number, combined_name)
|
|
|
return attention_features, frame_feature_npy_path, total_flops, total_params
|
|
|
|
|
|
def extract_features(model, img_tensor, video_name, frame_number, combined_name):
|
|
|
if img_tensor.dim() == 3:
|
|
|
img_tensor = img_tensor.unsqueeze(0)
|
|
|
cls_token, attention_features = model(img_tensor)
|
|
|
|
|
|
attention_features = attention_features.squeeze(0)
|
|
|
frame_feature_npy_path = f'../features/vit/{video_name}/frame_attention_{frame_number}_{combined_name}.npy'
|
|
|
return attention_features, frame_feature_npy_path
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
|
if device.type == "cuda":
|
|
|
torch.cuda.set_device(0)
|
|
|
|
|
|
name_model = 'vit_base'
|
|
|
patch_size = 16
|
|
|
|
|
|
model = VitGenerator(name_model, patch_size,
|
|
|
device, evaluate=True, random=False, verbose=True)
|
|
|
|
|
|
video_type = 'test'
|
|
|
|
|
|
if video_type == 'test':
|
|
|
metadata_path = "../../metadata/test_videos.csv"
|
|
|
|
|
|
elif video_type == 'resolution_ugc':
|
|
|
resolution = '360P'
|
|
|
metadata_path = f"../../metadata/YOUTUBE_UGC_{resolution}_metadata.csv"
|
|
|
else:
|
|
|
metadata_path = f'../../metadata/{video_type.upper()}_metadata.csv'
|
|
|
|
|
|
ugcdata = pd.read_csv(metadata_path)
|
|
|
for i in range(len(ugcdata)):
|
|
|
video_name = ugcdata['vid'][i]
|
|
|
sampled_frame_path = os.path.join('../..', 'video_sampled_frame', 'sampled_frame', f'{video_name}')
|
|
|
|
|
|
print(f"Processing video: {video_name}")
|
|
|
image_paths = glob.glob(os.path.join(sampled_frame_path, f'{video_name}_*.png'))
|
|
|
frame_number = 0
|
|
|
for image in image_paths:
|
|
|
print(f"{image}")
|
|
|
frame_number += 1
|
|
|
process_video_frame(video_name, image, frame_number, model, patch_size, device)
|
|
|
|