Spaces:
Build error
Build error
| """ | |
| Locality aware nms. | |
| This code is refered from: https://github.com/songdejia/EAST/blob/master/locality_aware_nms.py | |
| """ | |
| import numpy as np | |
| from shapely.geometry import Polygon | |
| def intersection(g, p): | |
| """ | |
| Intersection. | |
| """ | |
| g = Polygon(g[:8].reshape((4, 2))) | |
| p = Polygon(p[:8].reshape((4, 2))) | |
| g = g.buffer(0) | |
| p = p.buffer(0) | |
| if not g.is_valid or not p.is_valid: | |
| return 0 | |
| inter = Polygon(g).intersection(Polygon(p)).area | |
| union = g.area + p.area - inter | |
| if union == 0: | |
| return 0 | |
| else: | |
| return inter / union | |
| def intersection_iog(g, p): | |
| """ | |
| Intersection_iog. | |
| """ | |
| g = Polygon(g[:8].reshape((4, 2))) | |
| p = Polygon(p[:8].reshape((4, 2))) | |
| if not g.is_valid or not p.is_valid: | |
| return 0 | |
| inter = Polygon(g).intersection(Polygon(p)).area | |
| #union = g.area + p.area - inter | |
| union = p.area | |
| if union == 0: | |
| print("p_area is very small") | |
| return 0 | |
| else: | |
| return inter / union | |
| def weighted_merge(g, p): | |
| """ | |
| Weighted merge. | |
| """ | |
| g[:8] = (g[8] * g[:8] + p[8] * p[:8]) / (g[8] + p[8]) | |
| g[8] = (g[8] + p[8]) | |
| return g | |
| def standard_nms(S, thres): | |
| """ | |
| Standard nms. | |
| """ | |
| order = np.argsort(S[:, 8])[::-1] | |
| keep = [] | |
| while order.size > 0: | |
| i = order[0] | |
| keep.append(i) | |
| ovr = np.array([intersection(S[i], S[t]) for t in order[1:]]) | |
| inds = np.where(ovr <= thres)[0] | |
| order = order[inds + 1] | |
| return S[keep] | |
| def standard_nms_inds(S, thres): | |
| """ | |
| Standard nms, retun inds. | |
| """ | |
| order = np.argsort(S[:, 8])[::-1] | |
| keep = [] | |
| while order.size > 0: | |
| i = order[0] | |
| keep.append(i) | |
| ovr = np.array([intersection(S[i], S[t]) for t in order[1:]]) | |
| inds = np.where(ovr <= thres)[0] | |
| order = order[inds + 1] | |
| return keep | |
| def nms(S, thres): | |
| """ | |
| nms. | |
| """ | |
| order = np.argsort(S[:, 8])[::-1] | |
| keep = [] | |
| while order.size > 0: | |
| i = order[0] | |
| keep.append(i) | |
| ovr = np.array([intersection(S[i], S[t]) for t in order[1:]]) | |
| inds = np.where(ovr <= thres)[0] | |
| order = order[inds + 1] | |
| return keep | |
| def soft_nms(boxes_in, Nt_thres=0.3, threshold=0.8, sigma=0.5, method=2): | |
| """ | |
| soft_nms | |
| :para boxes_in, N x 9 (coords + score) | |
| :para threshould, eliminate cases min score(0.001) | |
| :para Nt_thres, iou_threshi | |
| :para sigma, gaussian weght | |
| :method, linear or gaussian | |
| """ | |
| boxes = boxes_in.copy() | |
| N = boxes.shape[0] | |
| if N is None or N < 1: | |
| return np.array([]) | |
| pos, maxpos = 0, 0 | |
| weight = 0.0 | |
| inds = np.arange(N) | |
| tbox, sbox = boxes[0].copy(), boxes[0].copy() | |
| for i in range(N): | |
| maxscore = boxes[i, 8] | |
| maxpos = i | |
| tbox = boxes[i].copy() | |
| ti = inds[i] | |
| pos = i + 1 | |
| #get max box | |
| while pos < N: | |
| if maxscore < boxes[pos, 8]: | |
| maxscore = boxes[pos, 8] | |
| maxpos = pos | |
| pos = pos + 1 | |
| #add max box as a detection | |
| boxes[i, :] = boxes[maxpos, :] | |
| inds[i] = inds[maxpos] | |
| #swap | |
| boxes[maxpos, :] = tbox | |
| inds[maxpos] = ti | |
| tbox = boxes[i].copy() | |
| pos = i + 1 | |
| #NMS iteration | |
| while pos < N: | |
| sbox = boxes[pos].copy() | |
| ts_iou_val = intersection(tbox, sbox) | |
| if ts_iou_val > 0: | |
| if method == 1: | |
| if ts_iou_val > Nt_thres: | |
| weight = 1 - ts_iou_val | |
| else: | |
| weight = 1 | |
| elif method == 2: | |
| weight = np.exp(-1.0 * ts_iou_val**2 / sigma) | |
| else: | |
| if ts_iou_val > Nt_thres: | |
| weight = 0 | |
| else: | |
| weight = 1 | |
| boxes[pos, 8] = weight * boxes[pos, 8] | |
| #if box score falls below thresold, discard the box by | |
| #swaping last box update N | |
| if boxes[pos, 8] < threshold: | |
| boxes[pos, :] = boxes[N - 1, :] | |
| inds[pos] = inds[N - 1] | |
| N = N - 1 | |
| pos = pos - 1 | |
| pos = pos + 1 | |
| return boxes[:N] | |
| def nms_locality(polys, thres=0.3): | |
| """ | |
| locality aware nms of EAST | |
| :param polys: a N*9 numpy array. first 8 coordinates, then prob | |
| :return: boxes after nms | |
| """ | |
| S = [] | |
| p = None | |
| for g in polys: | |
| if p is not None and intersection(g, p) > thres: | |
| p = weighted_merge(g, p) | |
| else: | |
| if p is not None: | |
| S.append(p) | |
| p = g | |
| if p is not None: | |
| S.append(p) | |
| if len(S) == 0: | |
| return np.array([]) | |
| return standard_nms(np.array(S), thres) | |
| if __name__ == '__main__': | |
| # 343,350,448,135,474,143,369,359 | |
| print( | |
| Polygon(np.array([[343, 350], [448, 135], [474, 143], [369, 359]])) | |
| .area) |