File size: 3,874 Bytes
55866f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
import numpy as np
from sklearn.metrics import average_precision_score

# Utils for concept encoding
def embed_concepts(
    clip,
    t5,
    concepts: list[str],
    batch_size=1
):
    """
        Here the goal is to embed a bunch of concept vectors 
        into our text embedding space.  
    """
    # Code pulled from concept_attention.flux/sampling.py: prepare()
    # Embed each concept separately
    concept_embeddings = []
    for concept in concepts:
        concept_embedding = t5(concept)
        # Pull out the first token
        token_embedding = concept_embedding[0, 0, :] # First token of first prompt
        concept_embeddings.append(token_embedding)
    concept_embeddings = torch.stack(concept_embeddings).unsqueeze(0)
    # Add filler tokens of zeros
    concept_ids = torch.zeros(batch_size, concept_embeddings.shape[1], 3)

    # Embed the concepts to a clip vector
    prompt = " ".join(concepts)
    vec = clip(prompt)
    vec = torch.zeros_like(vec).to(vec.device)

    return concept_embeddings, concept_ids, vec

def linear_normalization(x, dim):
    # Subtract the minimum to shift all values to non-negative range
    x_min = torch.min(x, dim=dim, keepdim=True)[0]
    x_shifted = x - x_min
    # Sum the values along the specified dimension
    x_sum = torch.sum(x_shifted, dim=dim, keepdim=True)
    # Avoid division by zero by setting sums of zero to one
    x_sum = torch.where(x_sum == 0, torch.ones_like(x_sum), x_sum)
    # Normalize by dividing by the sum
    return x_shifted / x_sum

################################## Metrics ##################################

def get_ap_scores(predict, target, ignore_index=-1):
    total = []
    for pred, tgt in zip(predict, target):
        target_expand = tgt.unsqueeze(0).expand_as(pred)
        target_expand_numpy = target_expand.data.cpu().numpy().reshape(-1)
        # Tensor process
        x = torch.zeros_like(target_expand)
        t = tgt.unsqueeze(0).clamp(min=0).long()
        target_1hot = x.scatter_(0, t, 1)
        predict_flat = pred.data.cpu().numpy().reshape(-1)
        predict_flat = np.nan_to_num(predict_flat)
        target_flat = target_1hot.data.cpu().numpy().reshape(-1)

        p = predict_flat[target_expand_numpy != ignore_index]
        t = target_flat[target_expand_numpy != ignore_index]

        total.append(np.nan_to_num(average_precision_score(t, p)))

    return total

def batch_pix_accuracy(predict, target):
    """Batch Pixel Accuracy
    Args:
        predict: input 3D tensor
        target: label 3D tensor
    """
    # _, predict = torch.max(predict, 0)
    predict = predict.cpu().numpy() + 1
    target = target.cpu().numpy() + 1
    pixel_labeled = np.sum(target > 0)
    pixel_correct = np.sum((predict == target) * (target > 0))
    assert pixel_correct <= pixel_labeled, \
        "Correct area should be smaller than Labeled"

    return pixel_correct, pixel_labeled


def batch_intersection_union(predict, target, nclass):
    """Batch Intersection of Union
    Args:
        predict: input 3D tensor
        target: label 3D tensor
        nclass: number of categories (int)
    """
    # _, predict = torch.max(predict, 0)
    mini = 1
    maxi = nclass
    nbins = nclass
    predict = predict.cpu().numpy() + 1
    target = target.cpu().numpy() + 1

    predict = predict * (target > 0).astype(predict.dtype)
    intersection = predict * (predict == target)
    # areas of intersection and union
    area_inter, _ = np.histogram(intersection, bins=nbins, range=(mini, maxi))
    area_pred, _ = np.histogram(predict, bins=nbins, range=(mini, maxi))
    area_lab, _ = np.histogram(target, bins=nbins, range=(mini, maxi))
    area_union = area_pred + area_lab - area_inter
    assert (area_inter <= area_union).all(), \
        "Intersection area should be smaller than Union area"
    return area_inter, area_union