helblazer811's picture
"Orphan branch commit with a readme"
55866f4
import torch
import numpy as np
from sklearn.metrics import average_precision_score
# Utils for concept encoding
def embed_concepts(
clip,
t5,
concepts: list[str],
batch_size=1
):
"""
Here the goal is to embed a bunch of concept vectors
into our text embedding space.
"""
# Code pulled from concept_attention.flux/sampling.py: prepare()
# Embed each concept separately
concept_embeddings = []
for concept in concepts:
concept_embedding = t5(concept)
# Pull out the first token
token_embedding = concept_embedding[0, 0, :] # First token of first prompt
concept_embeddings.append(token_embedding)
concept_embeddings = torch.stack(concept_embeddings).unsqueeze(0)
# Add filler tokens of zeros
concept_ids = torch.zeros(batch_size, concept_embeddings.shape[1], 3)
# Embed the concepts to a clip vector
prompt = " ".join(concepts)
vec = clip(prompt)
vec = torch.zeros_like(vec).to(vec.device)
return concept_embeddings, concept_ids, vec
def linear_normalization(x, dim):
# Subtract the minimum to shift all values to non-negative range
x_min = torch.min(x, dim=dim, keepdim=True)[0]
x_shifted = x - x_min
# Sum the values along the specified dimension
x_sum = torch.sum(x_shifted, dim=dim, keepdim=True)
# Avoid division by zero by setting sums of zero to one
x_sum = torch.where(x_sum == 0, torch.ones_like(x_sum), x_sum)
# Normalize by dividing by the sum
return x_shifted / x_sum
################################## Metrics ##################################
def get_ap_scores(predict, target, ignore_index=-1):
total = []
for pred, tgt in zip(predict, target):
target_expand = tgt.unsqueeze(0).expand_as(pred)
target_expand_numpy = target_expand.data.cpu().numpy().reshape(-1)
# Tensor process
x = torch.zeros_like(target_expand)
t = tgt.unsqueeze(0).clamp(min=0).long()
target_1hot = x.scatter_(0, t, 1)
predict_flat = pred.data.cpu().numpy().reshape(-1)
predict_flat = np.nan_to_num(predict_flat)
target_flat = target_1hot.data.cpu().numpy().reshape(-1)
p = predict_flat[target_expand_numpy != ignore_index]
t = target_flat[target_expand_numpy != ignore_index]
total.append(np.nan_to_num(average_precision_score(t, p)))
return total
def batch_pix_accuracy(predict, target):
"""Batch Pixel Accuracy
Args:
predict: input 3D tensor
target: label 3D tensor
"""
# _, predict = torch.max(predict, 0)
predict = predict.cpu().numpy() + 1
target = target.cpu().numpy() + 1
pixel_labeled = np.sum(target > 0)
pixel_correct = np.sum((predict == target) * (target > 0))
assert pixel_correct <= pixel_labeled, \
"Correct area should be smaller than Labeled"
return pixel_correct, pixel_labeled
def batch_intersection_union(predict, target, nclass):
"""Batch Intersection of Union
Args:
predict: input 3D tensor
target: label 3D tensor
nclass: number of categories (int)
"""
# _, predict = torch.max(predict, 0)
mini = 1
maxi = nclass
nbins = nclass
predict = predict.cpu().numpy() + 1
target = target.cpu().numpy() + 1
predict = predict * (target > 0).astype(predict.dtype)
intersection = predict * (predict == target)
# areas of intersection and union
area_inter, _ = np.histogram(intersection, bins=nbins, range=(mini, maxi))
area_pred, _ = np.histogram(predict, bins=nbins, range=(mini, maxi))
area_lab, _ = np.histogram(target, bins=nbins, range=(mini, maxi))
area_union = area_pred + area_lab - area_inter
assert (area_inter <= area_union).all(), \
"Intersection area should be smaller than Union area"
return area_inter, area_union