Spaces:
Runtime error
Runtime error
| import timm | |
| import torch | |
| from torch import nn | |
| from loguru import logger | |
| from torch.utils.checkpoint import checkpoint | |
| # from sbp.nn.model_paths import MODEL_PATHS | |
| class ImageEncoder(nn.Module): | |
| def __init__(self, output_dim, base_model='eva02_base_patch14_224.mim_in22k', layer_num=6, seq_len=3, device='cpu'): | |
| super().__init__() | |
| self.output_dim = output_dim | |
| if base_model == 'eva02_base_patch14_224.mim_in22k': | |
| self.img_seq = 257 | |
| elif base_model == 'eva02_large_patch14_448.mim_in22k_ft_in1k': | |
| self.img_seq = 1025 | |
| else: | |
| raise ValueError(f" unknown {base_model}, supported: {list(paths.keys())}") | |
| self.base_model = timm.create_model(base_model, pretrained=False) | |
| del self.base_model.norm, self.base_model.fc_norm, self.base_model.head, self.base_model.head_drop | |
| del self.base_model.blocks[layer_num:] | |
| self.project = nn.Linear(self.base_model.num_features, output_dim) | |
| self.final_norm = nn.LayerNorm(output_dim) | |
| self.seq_len = seq_len | |
| self.device = device | |
| def forward(self, image_list): | |
| splits = [len(lst) for lst in image_list] | |
| if sum(splits) == 0: | |
| return torch.zeros([len(splits), self.seq_len * self.img_seq, self.output_dim], device=self.device, dtype=torch.bfloat16) | |
| x = torch.concat(image_list, dim=0).to(device=self.device, dtype=torch.bfloat16) | |
| x = self.base_model.patch_embed(x) | |
| x, rot_pos_embed = self.base_model._pos_embed(x) | |
| for blk in self.base_model.blocks: | |
| x = blk(x, rope=rot_pos_embed) | |
| x = self.project(x) | |
| x = self.final_norm(x) | |
| b, seq_len, c= x.shape | |
| split_patches = torch.split(x, splits, dim=0) | |
| split_patches = [nn.functional.pad(sample, (0, 0, 0, 0, 0, self.seq_len - len(sample))) for sample in split_patches] | |
| x = torch.stack(split_patches, dim=0) | |
| x = x.reshape((len(splits), self.seq_len * seq_len, c)) | |
| return x | |