Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import numpy as np | |
from sklearn.metrics import average_precision_score | |
# Utils for concept encoding | |
def embed_concepts( | |
clip, | |
t5, | |
concepts: list[str], | |
batch_size=1 | |
): | |
""" | |
Here the goal is to embed a bunch of concept vectors | |
into our text embedding space. | |
""" | |
# Code pulled from concept_attention.flux/sampling.py: prepare() | |
# Embed each concept separately | |
concept_embeddings = [] | |
for concept in concepts: | |
concept_embedding = t5(concept) | |
# Pull out the first token | |
token_embedding = concept_embedding[0, 0, :] # First token of first prompt | |
concept_embeddings.append(token_embedding) | |
concept_embeddings = torch.stack(concept_embeddings).unsqueeze(0) | |
# Add filler tokens of zeros | |
concept_ids = torch.zeros(batch_size, concept_embeddings.shape[1], 3) | |
# Embed the concepts to a clip vector | |
prompt = " ".join(concepts) | |
vec = clip(prompt) | |
vec = torch.zeros_like(vec).to(vec.device) | |
return concept_embeddings, concept_ids, vec | |
def linear_normalization(x, dim): | |
# Subtract the minimum to shift all values to non-negative range | |
x_min = torch.min(x, dim=dim, keepdim=True)[0] | |
x_shifted = x - x_min | |
# Sum the values along the specified dimension | |
x_sum = torch.sum(x_shifted, dim=dim, keepdim=True) | |
# Avoid division by zero by setting sums of zero to one | |
x_sum = torch.where(x_sum == 0, torch.ones_like(x_sum), x_sum) | |
# Normalize by dividing by the sum | |
return x_shifted / x_sum | |
################################## Metrics ################################## | |
def get_ap_scores(predict, target, ignore_index=-1): | |
total = [] | |
for pred, tgt in zip(predict, target): | |
target_expand = tgt.unsqueeze(0).expand_as(pred) | |
target_expand_numpy = target_expand.data.cpu().numpy().reshape(-1) | |
# Tensor process | |
x = torch.zeros_like(target_expand) | |
t = tgt.unsqueeze(0).clamp(min=0).long() | |
target_1hot = x.scatter_(0, t, 1) | |
predict_flat = pred.data.cpu().numpy().reshape(-1) | |
predict_flat = np.nan_to_num(predict_flat) | |
target_flat = target_1hot.data.cpu().numpy().reshape(-1) | |
p = predict_flat[target_expand_numpy != ignore_index] | |
t = target_flat[target_expand_numpy != ignore_index] | |
total.append(np.nan_to_num(average_precision_score(t, p))) | |
return total | |
def batch_pix_accuracy(predict, target): | |
"""Batch Pixel Accuracy | |
Args: | |
predict: input 3D tensor | |
target: label 3D tensor | |
""" | |
# _, predict = torch.max(predict, 0) | |
predict = predict.cpu().numpy() + 1 | |
target = target.cpu().numpy() + 1 | |
pixel_labeled = np.sum(target > 0) | |
pixel_correct = np.sum((predict == target) * (target > 0)) | |
assert pixel_correct <= pixel_labeled, \ | |
"Correct area should be smaller than Labeled" | |
return pixel_correct, pixel_labeled | |
def batch_intersection_union(predict, target, nclass): | |
"""Batch Intersection of Union | |
Args: | |
predict: input 3D tensor | |
target: label 3D tensor | |
nclass: number of categories (int) | |
""" | |
# _, predict = torch.max(predict, 0) | |
mini = 1 | |
maxi = nclass | |
nbins = nclass | |
predict = predict.cpu().numpy() + 1 | |
target = target.cpu().numpy() + 1 | |
predict = predict * (target > 0).astype(predict.dtype) | |
intersection = predict * (predict == target) | |
# areas of intersection and union | |
area_inter, _ = np.histogram(intersection, bins=nbins, range=(mini, maxi)) | |
area_pred, _ = np.histogram(predict, bins=nbins, range=(mini, maxi)) | |
area_lab, _ = np.histogram(target, bins=nbins, range=(mini, maxi)) | |
area_union = area_pred + area_lab - area_inter | |
assert (area_inter <= area_union).all(), \ | |
"Intersection area should be smaller than Union area" | |
return area_inter, area_union | |