Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from inference.utils import get_seg_color, load_model, preprocess_pcd, encode_text | |
| #import spaces | |
| DEVICE = "cuda:0" | |
| #if torch.cuda.is_available(): | |
| #DEVICE = "cuda:0" | |
| def pred_3d_upsample( | |
| pred, # n_subsampled_pts, feat_dim | |
| part_text_embeds, # n_parts, feat_dim | |
| temperature, | |
| xyz_sub, | |
| xyz_full, # n_pts, 3 | |
| N_CHUNKS=1 | |
| ): | |
| xyz_full = xyz_full.squeeze() | |
| logits = pred @ part_text_embeds.T # n_pts, n_mask | |
| logits_prepend0 = torch.cat([torch.zeros(logits.shape[0],1).to(DEVICE), logits],axis=1) | |
| pred_softmax = torch.nn.Softmax(dim=1)(logits_prepend0 * temperature) | |
| chunk_len = xyz_full.shape[0]//N_CHUNKS+1 | |
| closest_idx_list = [] | |
| for i in range(N_CHUNKS): | |
| cur_chunk = xyz_full[chunk_len*i:chunk_len*(i+1)] | |
| dist_all = (xyz_sub.unsqueeze(0) - cur_chunk.to(DEVICE).unsqueeze(1))**2 # 300k,5k,3 | |
| cur_dist = (dist_all.sum(dim=-1))**0.5 # 300k,5k | |
| min_idxs = torch.min(cur_dist, 1)[1] | |
| del cur_dist | |
| closest_idx_list.append(min_idxs) | |
| all_nn_idxs = torch.cat(closest_idx_list,axis=0) | |
| # just inversely weight all points | |
| all_probs = pred_softmax[all_nn_idxs] | |
| all_logits = logits[all_nn_idxs] | |
| pred_full = all_probs.argmax(dim=1).cpu()# here, 0 is unlabeled, 1,...n_part correspond to actual part assignment | |
| return all_logits, all_probs, pred_full | |
| def get_segmentation_rgb(model, data, N_CHUNKS=5): # evaluate loader can only have batch size=1 | |
| temperature = np.exp(model.ln_logit_scale.item()) | |
| with torch.no_grad(): | |
| for key in data.keys(): | |
| if isinstance(data[key], torch.Tensor) and "full" not in key: | |
| data[key] = data[key].to(DEVICE) | |
| net_out = model(x=data) | |
| text_embeds = data['label_embeds'] | |
| xyz_sub = data["coord"] | |
| xyz_full = data["xyz_full"] | |
| _, _, pred_full = pred_3d_upsample(net_out, # n_subsampled_pts, feat_dim | |
| text_embeds, # n_parts, feat_dim | |
| temperature, | |
| xyz_sub, | |
| xyz_full, # n_pts, 3 | |
| N_CHUNKS=N_CHUNKS) | |
| seg_rgb = get_seg_color(pred_full.cpu()) | |
| return seg_rgb | |
| def get_heatmap_rgb(model, data, N_CHUNKS=5): # evaluate loader can only have batch size=1 | |
| temperature = np.exp(model.ln_logit_scale.item()) | |
| with torch.no_grad(): | |
| for key in data.keys(): | |
| if isinstance(data[key], torch.Tensor) and "full" not in key: | |
| data[key] = data[key].to(DEVICE) | |
| net_out = model(x=data) | |
| text_embeds = data['label_embeds'] | |
| xyz_sub = data["coord"] | |
| xyz_full = data["xyz_full"] | |
| all_logits, _, _ = pred_3d_upsample(net_out, # n_subsampled_pts, feat_dim | |
| text_embeds, # n_parts, feat_dim | |
| temperature, | |
| xyz_sub, | |
| xyz_full, # n_pts, 3 | |
| N_CHUNKS=N_CHUNKS) | |
| scores = all_logits.squeeze().cpu() | |
| heatmap_rgb = torch.tensor(plt.cm.jet(scores.numpy())[:,:3]).squeeze() | |
| return heatmap_rgb | |
| def segment_obj(xyz, rgb, normal, queries): | |
| model = load_model() | |
| data_dict = preprocess_pcd(torch.tensor(xyz).float().to(DEVICE), torch.tensor(rgb).float().to(DEVICE), torch.tensor(normal).float().to(DEVICE)) | |
| data_dict["label_embeds"] = encode_text(queries) | |
| seg_rgb = get_segmentation_rgb(model, data_dict) | |
| return seg_rgb | |
| def get_heatmap(xyz, rgb, normal, query): | |
| model = load_model() | |
| data_dict = preprocess_pcd(torch.tensor(xyz).float().to(DEVICE), torch.tensor(rgb).float().to(DEVICE), torch.tensor(normal).float().to(DEVICE)) | |
| data_dict["label_embeds"] = encode_text([query]) | |
| heatmap_rgb = get_heatmap_rgb(model, data_dict) | |
| return heatmap_rgb |