Spaces:
Runtime error
Runtime error
| import torch | |
| def distance_to_similarity(distances, temperature=1.0): | |
| """ | |
| Turns a distance matrix into a similarity matrix so it works with distribution-based metrics. | |
| """ | |
| similarities = torch.exp(-distances / temperature) | |
| similarities = torch.clamp(similarities, min=1e-8) | |
| return similarities | |
| ################################# | |
| ## "New Object" Detection ## | |
| ################################# | |
| def detect_newness_two_sided(distances, k=3, quantile=0.97): | |
| device = distances.device | |
| N_src, N_tgt = distances.shape | |
| topk_src_idx_t = torch.topk(distances, k, dim=0, largest=False).indices # [k, N_tgt] | |
| topk_tgt_idx_s = torch.topk(distances, k, dim=1, largest=False).indices # [N_src, k] | |
| src_to_tgt_mask = torch.zeros((N_src, N_tgt), device=device) | |
| tgt_to_src_mask = torch.zeros((N_src, N_tgt), device=device) | |
| row_indices = topk_src_idx_t # [k, N_tgt] | |
| col_indices = torch.arange(N_tgt, device=device).unsqueeze(0).repeat(k, 1) # [k, N_tgt] | |
| src_to_tgt_mask[row_indices, col_indices] = 1.0 # Assign 1.0 at the top-k positions | |
| row_indices = torch.arange(N_src, device=device).unsqueeze(1).repeat(1, k) # [N_src, k] | |
| col_indices = topk_tgt_idx_s # [N_src, k] | |
| tgt_to_src_mask[row_indices, col_indices] = 1.0 # Assign 1.0 at the top-k positions | |
| overlap_mask = (src_to_tgt_mask * tgt_to_src_mask).sum(dim=0) > 0 # [N_tgt] | |
| distances[:, overlap_mask] = 0.0 | |
| two_sided_mask = (~overlap_mask).float() | |
| min_distances, _ = distances.min(dim=0) | |
| threshold = torch.quantile(min_distances, quantile) | |
| threshold_mask = (min_distances > threshold).float() | |
| combined_mask = two_sided_mask * threshold_mask | |
| return combined_mask | |
| def detect_newness_distance(min_distances, quantile=0.97): | |
| """ | |
| Old approach: threshold on min distance at a chosen percentile. | |
| """ | |
| threshold = torch.quantile(min_distances, quantile) | |
| newness_mask = (min_distances > threshold).float() | |
| return newness_mask | |
| def detect_newness_topk_margin(distances, top_k=2, quantile=0.03): | |
| """ | |
| Top-k margin approach in distance space. | |
| distances: [N_src, N_tgt] | |
| Sort each column ascending => best match is index 0, second best is index 1, etc. | |
| A smaller margin => ambiguous => likely new. | |
| We threshold the margin at some percentile. | |
| """ | |
| sorted_dists, _ = torch.sort(distances, dim=0) | |
| best = sorted_dists[0] # [N_tgt] | |
| second_best = sorted_dists[1] if top_k >= 2 else sorted_dists[0] # [N_tgt] | |
| margin = second_best - best # [N_tgt] | |
| # If margin < threshold => ambiguous => "new" | |
| # We'll pick threshold as a quantile of margin | |
| threshold = torch.quantile(margin, quantile) | |
| newness_mask = (margin < threshold).float() | |
| return newness_mask | |
| def detect_newness_entropy(distances, temperature=1.0, quantile=0.97): | |
| """ | |
| Entropy-based approach. First convert distance->similarity with an exponential. | |
| Then normalize to get a distribution for each target patch, compute Shannon entropy. | |
| High entropy => new object (no strong match). | |
| """ | |
| similarities = distance_to_similarity(distances, temperature=temperature) | |
| probs = similarities / similarities.sum(dim=0, keepdim=True) # [N_src, N_tgt] | |
| # Shannon Entropy: -sum(p log p) | |
| entropy = -torch.sum(probs * torch.log(probs), dim=0) # [N_tgt] | |
| # threshold | |
| threshold = torch.quantile(entropy, quantile) | |
| newness_mask = (entropy > threshold).float() | |
| return newness_mask | |
| def detect_newness_gini(distances, temperature=1.0, quantile=0.97): | |
| """ | |
| Gini impurity-based approach. Convert distances to similarities, | |
| get a distribution, compute Gini. | |
| High Gini => wide distribution => new object. | |
| """ | |
| similarities = distance_to_similarity(distances, temperature=temperature) | |
| probs = similarities / similarities.sum(dim=0, keepdim=True) | |
| # Gini: sum(p_i*(1-p_i)) => high if spread out | |
| gini = torch.sum(probs * (1.0 - probs), dim=0) # [N_tgt] | |
| threshold = torch.quantile(gini, quantile) | |
| newness_mask = (gini > threshold).float() | |
| return newness_mask | |
| def detect_newness_kl(distances, temperature=1.0, quantile=0.97): | |
| """ | |
| KL-based approach. Compare distribution to uniform => if close to uniform => new object. | |
| 1) Convert distances -> similarities | |
| 2) p(x) = similarities / sum(similarities) | |
| 3) KL(p || uniform) => sum p(x) log (p(x)/(1/N_src)) | |
| 4) If p is near uniform => KL small => new object. | |
| We'll invert it => newness ~ 1/KL. | |
| """ | |
| similarities = distance_to_similarity(distances, temperature=temperature) | |
| N_src = distances.shape[0] | |
| probs = similarities / similarities.sum(dim=0, keepdim=True) | |
| uniform_val = 1.0 / float(N_src) | |
| kl_vals = torch.sum(probs * torch.log(probs / uniform_val), dim=0) # [N_tgt] | |
| inv_kl = 1.0 / (kl_vals + 1e-8) # big => distribution is near uniform => new | |
| threshold = torch.quantile(inv_kl, quantile) | |
| newness_mask = (inv_kl > threshold).float() | |
| return newness_mask | |
| def detect_newness_variation_ratio(distances, temperature=1.0, quantile=0.97): | |
| """ | |
| Variation Ratio: 1 - max(prob). | |
| 1) Convert distance->similarity | |
| 2) p(x) = sim(x) / sum_x'(sim(x')) | |
| 3) var_ratio = 1 - max(p) | |
| High var_ratio => new object. | |
| """ | |
| similarities = distance_to_similarity(distances, temperature=temperature) | |
| probs = similarities / similarities.sum(dim=0, keepdim=True) | |
| max_prob, _ = torch.max(probs, dim=0) # [N_tgt] | |
| var_ratio = 1.0 - max_prob | |
| threshold = torch.quantile(var_ratio, quantile) | |
| newness_mask = (var_ratio > threshold).float() | |
| return newness_mask | |
| def detect_newness_two_sided_ratio( | |
| distances, | |
| top_k_ratio_quantile=0.03, | |
| two_sided=True | |
| ): | |
| """ | |
| Two-sided matching + ratio test in distance space. | |
| Ratio test: For each t, let d0 = best distance, d1 = second best. | |
| ratio = d0 / (d1 + 1e-8). | |
| If ratio < ratio_threshold => ambiguous => new. | |
| (Typically a smaller ratio means a better match, but we invert logic: | |
| a patch can be "new" if the ratio is extremely small or ambiguous.) | |
| """ | |
| N_src, N_tgt = distances.shape | |
| # Target → Source: best match | |
| min_vals_t, best_s_for_t = torch.min(distances, dim=0) | |
| # Source → Target: best match | |
| min_vals_s, best_t_for_s = torch.min(distances, dim=1) | |
| # Two-sided consistency check | |
| twosided_mask = torch.zeros(N_tgt, device=distances.device) | |
| if two_sided: | |
| for t in range(N_tgt): | |
| s = best_s_for_t[t] | |
| if best_t_for_s[s] != t: | |
| twosided_mask[t] = 1.0 | |
| # Ratio test: ambiguous if best match is not clearly better than second best | |
| sorted_dists, _ = torch.sort(distances, dim=0) | |
| d0 = sorted_dists[0] | |
| d1 = sorted_dists[1] | |
| ratio = d0 / (d1 + 1e-8) | |
| ratio_threshold = torch.quantile(ratio, top_k_ratio_quantile) | |
| ratio_mask = (ratio < ratio_threshold).float() | |
| # Combine checks (currently using only two-sided result) | |
| newness_mask = twosided_mask | |
| return newness_mask | |