|
|
|
import numpy as np |
|
import torch |
|
from mmcv.ops import RoIAlignRotated |
|
|
|
from .utils import (euclidean_distance_matrix, feature_embedding, |
|
normalize_adjacent_matrix) |
|
|
|
|
|
class LocalGraphs: |
|
"""Generate local graphs for GCN to classify the neighbors of a pivot for |
|
DRRG: Deep Relational Reasoning Graph Network for Arbitrary Shape Text |
|
Detection. |
|
|
|
[https://arxiv.org/abs/2003.07493]. This code was partially adapted from |
|
https://github.com/GXYM/DRRG licensed under the MIT license. |
|
|
|
Args: |
|
k_at_hops (tuple(int)): The number of i-hop neighbors, i = 1, 2. |
|
num_adjacent_linkages (int): The number of linkages when constructing |
|
adjacent matrix. |
|
node_geo_feat_len (int): The length of embedded geometric feature |
|
vector of a text component. |
|
pooling_scale (float): The spatial scale of rotated RoI-Align. |
|
pooling_output_size (tuple(int)): The output size of rotated RoI-Align. |
|
local_graph_thr(float): The threshold for filtering out identical local |
|
graphs. |
|
""" |
|
|
|
def __init__(self, k_at_hops, num_adjacent_linkages, node_geo_feat_len, |
|
pooling_scale, pooling_output_size, local_graph_thr): |
|
|
|
assert len(k_at_hops) == 2 |
|
assert all(isinstance(n, int) for n in k_at_hops) |
|
assert isinstance(num_adjacent_linkages, int) |
|
assert isinstance(node_geo_feat_len, int) |
|
assert isinstance(pooling_scale, float) |
|
assert all(isinstance(n, int) for n in pooling_output_size) |
|
assert isinstance(local_graph_thr, float) |
|
|
|
self.k_at_hops = k_at_hops |
|
self.num_adjacent_linkages = num_adjacent_linkages |
|
self.node_geo_feat_dim = node_geo_feat_len |
|
self.pooling = RoIAlignRotated(pooling_output_size, pooling_scale) |
|
self.local_graph_thr = local_graph_thr |
|
|
|
def generate_local_graphs(self, sorted_dist_inds, gt_comp_labels): |
|
"""Generate local graphs for GCN to predict which instance a text |
|
component belongs to. |
|
|
|
Args: |
|
sorted_dist_inds (ndarray): The complete graph node indices, which |
|
is sorted according to the Euclidean distance. |
|
gt_comp_labels(ndarray): The ground truth labels define the |
|
instance to which the text components (nodes in graphs) belong. |
|
|
|
Returns: |
|
pivot_local_graphs(list[list[int]]): The list of local graph |
|
neighbor indices of pivots. |
|
pivot_knns(list[list[int]]): The list of k-nearest neighbor indices |
|
of pivots. |
|
""" |
|
|
|
assert sorted_dist_inds.ndim == 2 |
|
assert (sorted_dist_inds.shape[0] == sorted_dist_inds.shape[1] == |
|
gt_comp_labels.shape[0]) |
|
|
|
knn_graph = sorted_dist_inds[:, 1:self.k_at_hops[0] + 1] |
|
pivot_local_graphs = [] |
|
pivot_knns = [] |
|
for pivot_ind, knn in enumerate(knn_graph): |
|
|
|
local_graph_neighbors = set(knn) |
|
|
|
for neighbor_ind in knn: |
|
local_graph_neighbors.update( |
|
set(sorted_dist_inds[neighbor_ind, |
|
1:self.k_at_hops[1] + 1])) |
|
|
|
local_graph_neighbors.discard(pivot_ind) |
|
pivot_local_graph = list(local_graph_neighbors) |
|
pivot_local_graph.insert(0, pivot_ind) |
|
pivot_knn = [pivot_ind] + list(knn) |
|
|
|
if pivot_ind < 1: |
|
pivot_local_graphs.append(pivot_local_graph) |
|
pivot_knns.append(pivot_knn) |
|
else: |
|
add_flag = True |
|
for graph_ind, added_knn in enumerate(pivot_knns): |
|
added_pivot_ind = added_knn[0] |
|
added_local_graph = pivot_local_graphs[graph_ind] |
|
|
|
union = len( |
|
set(pivot_local_graph[1:]).union( |
|
set(added_local_graph[1:]))) |
|
intersect = len( |
|
set(pivot_local_graph[1:]).intersection( |
|
set(added_local_graph[1:]))) |
|
local_graph_iou = intersect / (union + 1e-8) |
|
|
|
if (local_graph_iou > self.local_graph_thr |
|
and pivot_ind in added_knn |
|
and gt_comp_labels[added_pivot_ind] |
|
== gt_comp_labels[pivot_ind] |
|
and gt_comp_labels[pivot_ind] != 0): |
|
add_flag = False |
|
break |
|
if add_flag: |
|
pivot_local_graphs.append(pivot_local_graph) |
|
pivot_knns.append(pivot_knn) |
|
|
|
return pivot_local_graphs, pivot_knns |
|
|
|
def generate_gcn_input(self, node_feat_batch, node_label_batch, |
|
local_graph_batch, knn_batch, |
|
sorted_dist_ind_batch): |
|
"""Generate graph convolution network input data. |
|
|
|
Args: |
|
node_feat_batch (List[Tensor]): The batched graph node features. |
|
node_label_batch (List[ndarray]): The batched text component |
|
labels. |
|
local_graph_batch (List[List[list[int]]]): The local graph node |
|
indices of image batch. |
|
knn_batch (List[List[list[int]]]): The knn graph node indices of |
|
image batch. |
|
sorted_dist_ind_batch (list[ndarray]): The node indices sorted |
|
according to the Euclidean distance. |
|
|
|
Returns: |
|
local_graphs_node_feat (Tensor): The node features of graph. |
|
adjacent_matrices (Tensor): The adjacent matrices of local graphs. |
|
pivots_knn_inds (Tensor): The k-nearest neighbor indices in |
|
local graph. |
|
gt_linkage (Tensor): The surpervision signal of GCN for linkage |
|
prediction. |
|
""" |
|
assert isinstance(node_feat_batch, list) |
|
assert isinstance(node_label_batch, list) |
|
assert isinstance(local_graph_batch, list) |
|
assert isinstance(knn_batch, list) |
|
assert isinstance(sorted_dist_ind_batch, list) |
|
|
|
num_max_nodes = max([ |
|
len(pivot_local_graph) for pivot_local_graphs in local_graph_batch |
|
for pivot_local_graph in pivot_local_graphs |
|
]) |
|
|
|
local_graphs_node_feat = [] |
|
adjacent_matrices = [] |
|
pivots_knn_inds = [] |
|
pivots_gt_linkage = [] |
|
|
|
for batch_ind, sorted_dist_inds in enumerate(sorted_dist_ind_batch): |
|
node_feats = node_feat_batch[batch_ind] |
|
pivot_local_graphs = local_graph_batch[batch_ind] |
|
pivot_knns = knn_batch[batch_ind] |
|
node_labels = node_label_batch[batch_ind] |
|
device = node_feats.device |
|
|
|
for graph_ind, pivot_knn in enumerate(pivot_knns): |
|
pivot_local_graph = pivot_local_graphs[graph_ind] |
|
num_nodes = len(pivot_local_graph) |
|
pivot_ind = pivot_local_graph[0] |
|
node2ind_map = {j: i for i, j in enumerate(pivot_local_graph)} |
|
|
|
knn_inds = torch.tensor( |
|
[node2ind_map[i] for i in pivot_knn[1:]]) |
|
pivot_feats = node_feats[pivot_ind] |
|
normalized_feats = node_feats[pivot_local_graph] - pivot_feats |
|
|
|
adjacent_matrix = np.zeros((num_nodes, num_nodes), |
|
dtype=np.float32) |
|
for node in pivot_local_graph: |
|
neighbors = sorted_dist_inds[node, |
|
1:self.num_adjacent_linkages + |
|
1] |
|
for neighbor in neighbors: |
|
if neighbor in pivot_local_graph: |
|
|
|
adjacent_matrix[node2ind_map[node], |
|
node2ind_map[neighbor]] = 1 |
|
adjacent_matrix[node2ind_map[neighbor], |
|
node2ind_map[node]] = 1 |
|
|
|
adjacent_matrix = normalize_adjacent_matrix(adjacent_matrix) |
|
pad_adjacent_matrix = torch.zeros( |
|
(num_max_nodes, num_max_nodes), |
|
dtype=torch.float, |
|
device=device) |
|
pad_adjacent_matrix[:num_nodes, :num_nodes] = torch.from_numpy( |
|
adjacent_matrix) |
|
|
|
pad_normalized_feats = torch.cat([ |
|
normalized_feats, |
|
torch.zeros( |
|
(num_max_nodes - num_nodes, normalized_feats.shape[1]), |
|
dtype=torch.float, |
|
device=device) |
|
], |
|
dim=0) |
|
|
|
local_graph_labels = node_labels[pivot_local_graph] |
|
knn_labels = local_graph_labels[knn_inds] |
|
link_labels = ((node_labels[pivot_ind] == knn_labels) & |
|
(node_labels[pivot_ind] > 0)).astype(np.int64) |
|
link_labels = torch.from_numpy(link_labels) |
|
|
|
local_graphs_node_feat.append(pad_normalized_feats) |
|
adjacent_matrices.append(pad_adjacent_matrix) |
|
pivots_knn_inds.append(knn_inds) |
|
pivots_gt_linkage.append(link_labels) |
|
|
|
local_graphs_node_feat = torch.stack(local_graphs_node_feat, 0) |
|
adjacent_matrices = torch.stack(adjacent_matrices, 0) |
|
pivots_knn_inds = torch.stack(pivots_knn_inds, 0) |
|
pivots_gt_linkage = torch.stack(pivots_gt_linkage, 0) |
|
|
|
return (local_graphs_node_feat, adjacent_matrices, pivots_knn_inds, |
|
pivots_gt_linkage) |
|
|
|
def __call__(self, feat_maps, comp_attribs): |
|
"""Generate local graphs as GCN input. |
|
|
|
Args: |
|
feat_maps (Tensor): The feature maps to extract the content |
|
features of text components. |
|
comp_attribs (ndarray): The text component attributes. |
|
|
|
Returns: |
|
local_graphs_node_feat (Tensor): The node features of graph. |
|
adjacent_matrices (Tensor): The adjacent matrices of local graphs. |
|
pivots_knn_inds (Tensor): The k-nearest neighbor indices in local |
|
graph. |
|
gt_linkage (Tensor): The surpervision signal of GCN for linkage |
|
prediction. |
|
""" |
|
|
|
assert isinstance(feat_maps, torch.Tensor) |
|
assert comp_attribs.ndim == 3 |
|
assert comp_attribs.shape[2] == 8 |
|
|
|
sorted_dist_inds_batch = [] |
|
local_graph_batch = [] |
|
knn_batch = [] |
|
node_feat_batch = [] |
|
node_label_batch = [] |
|
device = feat_maps.device |
|
|
|
for batch_ind in range(comp_attribs.shape[0]): |
|
num_comps = int(comp_attribs[batch_ind, 0, 0]) |
|
comp_geo_attribs = comp_attribs[batch_ind, :num_comps, 1:7] |
|
node_labels = comp_attribs[batch_ind, :num_comps, |
|
7].astype(np.int32) |
|
|
|
comp_centers = comp_geo_attribs[:, 0:2] |
|
distance_matrix = euclidean_distance_matrix( |
|
comp_centers, comp_centers) |
|
|
|
batch_id = np.zeros( |
|
(comp_geo_attribs.shape[0], 1), dtype=np.float32) * batch_ind |
|
comp_geo_attribs[:, -2] = np.clip(comp_geo_attribs[:, -2], -1, 1) |
|
angle = np.arccos(comp_geo_attribs[:, -2]) * np.sign( |
|
comp_geo_attribs[:, -1]) |
|
angle = angle.reshape((-1, 1)) |
|
rotated_rois = np.hstack( |
|
[batch_id, comp_geo_attribs[:, :-2], angle]) |
|
rois = torch.from_numpy(rotated_rois).to(device) |
|
content_feats = self.pooling(feat_maps[batch_ind].unsqueeze(0), |
|
rois) |
|
|
|
content_feats = content_feats.view(content_feats.shape[0], |
|
-1).to(feat_maps.device) |
|
geo_feats = feature_embedding(comp_geo_attribs, |
|
self.node_geo_feat_dim) |
|
geo_feats = torch.from_numpy(geo_feats).to(device) |
|
node_feats = torch.cat([content_feats, geo_feats], dim=-1) |
|
|
|
sorted_dist_inds = np.argsort(distance_matrix, axis=1) |
|
pivot_local_graphs, pivot_knns = self.generate_local_graphs( |
|
sorted_dist_inds, node_labels) |
|
|
|
node_feat_batch.append(node_feats) |
|
node_label_batch.append(node_labels) |
|
local_graph_batch.append(pivot_local_graphs) |
|
knn_batch.append(pivot_knns) |
|
sorted_dist_inds_batch.append(sorted_dist_inds) |
|
|
|
(node_feats, adjacent_matrices, knn_inds, gt_linkage) = \ |
|
self.generate_gcn_input(node_feat_batch, |
|
node_label_batch, |
|
local_graph_batch, |
|
knn_batch, |
|
sorted_dist_inds_batch) |
|
|
|
return node_feats, adjacent_matrices, knn_inds, gt_linkage |
|
|