Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,874 Bytes
55866f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import torch
import numpy as np
from sklearn.metrics import average_precision_score
# Utils for concept encoding
def embed_concepts(
clip,
t5,
concepts: list[str],
batch_size=1
):
"""
Here the goal is to embed a bunch of concept vectors
into our text embedding space.
"""
# Code pulled from concept_attention.flux/sampling.py: prepare()
# Embed each concept separately
concept_embeddings = []
for concept in concepts:
concept_embedding = t5(concept)
# Pull out the first token
token_embedding = concept_embedding[0, 0, :] # First token of first prompt
concept_embeddings.append(token_embedding)
concept_embeddings = torch.stack(concept_embeddings).unsqueeze(0)
# Add filler tokens of zeros
concept_ids = torch.zeros(batch_size, concept_embeddings.shape[1], 3)
# Embed the concepts to a clip vector
prompt = " ".join(concepts)
vec = clip(prompt)
vec = torch.zeros_like(vec).to(vec.device)
return concept_embeddings, concept_ids, vec
def linear_normalization(x, dim):
# Subtract the minimum to shift all values to non-negative range
x_min = torch.min(x, dim=dim, keepdim=True)[0]
x_shifted = x - x_min
# Sum the values along the specified dimension
x_sum = torch.sum(x_shifted, dim=dim, keepdim=True)
# Avoid division by zero by setting sums of zero to one
x_sum = torch.where(x_sum == 0, torch.ones_like(x_sum), x_sum)
# Normalize by dividing by the sum
return x_shifted / x_sum
################################## Metrics ##################################
def get_ap_scores(predict, target, ignore_index=-1):
total = []
for pred, tgt in zip(predict, target):
target_expand = tgt.unsqueeze(0).expand_as(pred)
target_expand_numpy = target_expand.data.cpu().numpy().reshape(-1)
# Tensor process
x = torch.zeros_like(target_expand)
t = tgt.unsqueeze(0).clamp(min=0).long()
target_1hot = x.scatter_(0, t, 1)
predict_flat = pred.data.cpu().numpy().reshape(-1)
predict_flat = np.nan_to_num(predict_flat)
target_flat = target_1hot.data.cpu().numpy().reshape(-1)
p = predict_flat[target_expand_numpy != ignore_index]
t = target_flat[target_expand_numpy != ignore_index]
total.append(np.nan_to_num(average_precision_score(t, p)))
return total
def batch_pix_accuracy(predict, target):
"""Batch Pixel Accuracy
Args:
predict: input 3D tensor
target: label 3D tensor
"""
# _, predict = torch.max(predict, 0)
predict = predict.cpu().numpy() + 1
target = target.cpu().numpy() + 1
pixel_labeled = np.sum(target > 0)
pixel_correct = np.sum((predict == target) * (target > 0))
assert pixel_correct <= pixel_labeled, \
"Correct area should be smaller than Labeled"
return pixel_correct, pixel_labeled
def batch_intersection_union(predict, target, nclass):
"""Batch Intersection of Union
Args:
predict: input 3D tensor
target: label 3D tensor
nclass: number of categories (int)
"""
# _, predict = torch.max(predict, 0)
mini = 1
maxi = nclass
nbins = nclass
predict = predict.cpu().numpy() + 1
target = target.cpu().numpy() + 1
predict = predict * (target > 0).astype(predict.dtype)
intersection = predict * (predict == target)
# areas of intersection and union
area_inter, _ = np.histogram(intersection, bins=nbins, range=(mini, maxi))
area_pred, _ = np.histogram(predict, bins=nbins, range=(mini, maxi))
area_lab, _ = np.histogram(target, bins=nbins, range=(mini, maxi))
area_union = area_pred + area_lab - area_inter
assert (area_inter <= area_union).all(), \
"Intersection area should be smaller than Union area"
return area_inter, area_union
|