File size: 4,163 Bytes
44d8da2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial

class MaskExtractor(nn.Module):
    def __init__(self, config, mm_hidden_size, depth=2):
        super(MaskExtractor, self).__init__()
        self.mask_pooling = MaskPooling()
        modules = [nn.Linear(mm_hidden_size, config.hidden_size)]
        for _ in range(1, depth):
            modules.append(nn.GELU())
            modules.append(nn.Linear(config.hidden_size, config.hidden_size))
        self.feat_linear =  nn.Sequential(*modules)

    def forward(self, feats, masks):
        query_feats = []
        
        if masks is None: #infer
            return None
            # masks = torch.zeros((1, 1, 336, 336)).to(feats.device).float()

        num_imgs = len(masks)
        region_token_nums = []
        image_idx = 0
        for idx in range(num_imgs):
            if masks[idx]==None:
                continue
            for mask_idx in range(len(masks[idx])):
                mask = masks[idx][mask_idx].unsqueeze(0).unsqueeze(0).float()
                if len(mask[0])==0:
                    print('mask error')
                    mask = torch.zeros((1, 1, 336, 336)).to(feats.device).float()

                feat = feats[image_idx].unsqueeze(0)
                image_idx+=1
                
                # h, w = feat.shape[1:3]
                feat = feat.permute(0,3,1,2)

                raw_dtype = feat.dtype
                feat = feat.to(mask.dtype)
                
                mask_feat_raw = self.mask_pooling(feat, mask) # [n, 1024]

                query_feats.append(mask_feat_raw)
        if len(query_feats)==0:
            return None
        mask_feats = torch.cat(query_feats, dim=0)
        mask_feats = mask_feats.to(feats[0].dtype)
        mask_feats_linear = self.feat_linear(mask_feats)
        return mask_feats_linear

def kmeans_fast(tokens, num_clusters=10, num_iterations=20):
    # tokens: 输入的token数据,shape为[n, d]
    # num_clusters: 压缩后的组数
    # num_iterations: K-means算法的迭代次数
    
    # 初始化中心点
    n, d = tokens.shape
    centroids = tokens[torch.randperm(n)[:num_clusters]]

    for _ in range(num_iterations):
        # 扩展tokens和centroids维度以计算距离,避免显式循环
        tokens_expand = tokens.unsqueeze(1)  # [n, 1, d]
        centroids_expand = centroids.unsqueeze(0)  # [1, num_clusters, d]
        
        # 计算每个token到各个中心点的距离
        distances = torch.sum((tokens_expand - centroids_expand) ** 2, dim=2)  # [n, num_clusters]
        
        # 找到每个token最近的中心点
        labels = torch.argmin(distances, dim=1)  # [n]

        # 计算新的中心点
        new_centroids = torch.stack([tokens[labels == i].mean(dim=0) if tokens[labels == i].size(0) > 0 else centroids[i] for i in range(num_clusters)])
        
        # 检查是否收敛
        if torch.allclose(centroids, new_centroids, atol=1e-6):
            break
        
        centroids = new_centroids
    
    return centroids

class MaskPooling(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, mask):

        if not x.shape[-2:] == mask.shape[-2:]:
            # reshape mask to x
            x = F.interpolate(x, size=mask.shape[-2:], mode='bilinear', align_corners=False)
            # mask = F.interpolate(mask, size=x.shape[-2:], mode='bilinear', align_corners=False)
        if not x.device == mask.device:
            mask = mask.to(x.device)
        # b, c, h ,w = x.shape
        # b, q, h, w = mask.shape
        mask = (mask > 0).to(mask.dtype)
        mask = mask.permute(1,0,2,3)
        denorm = mask.sum(dim=(-1, -2), keepdim=True) + 1e-8
       
        mask_emb = x * mask
        mask = torch.any(mask_emb != 0, dim=(0, 1))
        mask_emb = mask_emb[:,:, mask]
        mask_embedding = mask_emb[0].permute(1,0)

        if len(mask_embedding)>10: #FIXME
            mask_embedding = kmeans_fast(mask_embedding)

        return mask_embedding


def build_region_encoder(config, mm_hidden_size):

    return MaskExtractor(config, mm_hidden_size)