|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from mmdet.registry import MODELS |
|
|
|
@MODELS.register_module() |
|
class FCHead(nn.Module): |
|
"""Enhanced fully connected head for classification tasks with attention.""" |
|
|
|
def __init__(self, in_channels, num_classes, loss=None): |
|
super().__init__() |
|
self.attention = nn.MultiheadAttention(in_channels, num_heads=8) |
|
self.fc1 = nn.Linear(in_channels, in_channels // 2) |
|
self.fc2 = nn.Linear(in_channels // 2, num_classes) |
|
self.loss = loss |
|
|
|
def forward(self, x): |
|
|
|
x = self.attention(x, x, x)[0] |
|
|
|
x = F.relu(self.fc1(x)) |
|
return self.fc2(x) |
|
|
|
@MODELS.register_module() |
|
class RegHead(nn.Module): |
|
"""Enhanced regression head for coordinate prediction with distance-based loss.""" |
|
|
|
def __init__(self, in_channels, out_dims, max_points=None, loss=None, attention=False, use_axis_info=False): |
|
super().__init__() |
|
self.fc = nn.Linear(in_channels, out_dims) |
|
self.max_points = max_points |
|
self.loss = loss |
|
self.attention = attention |
|
self.use_axis_info = use_axis_info |
|
|
|
if attention: |
|
self.attention_layer = nn.MultiheadAttention(in_channels, num_heads=8) |
|
|
|
|
|
if use_axis_info: |
|
self.axis_orientation = nn.Linear(in_channels, 2) |
|
|
|
def compute_distance_loss(self, pred_points, gt_points): |
|
"""Compute distance-based loss between predicted and ground truth points.""" |
|
|
|
if pred_points.dim() == 2: |
|
pred_points = pred_points.unsqueeze(0) |
|
if gt_points.dim() == 2: |
|
gt_points = gt_points.unsqueeze(0) |
|
|
|
|
|
dist = torch.cdist(pred_points, gt_points) |
|
|
|
|
|
min_dist, _ = torch.min(dist, dim=2) |
|
|
|
|
|
return F.smooth_l1_loss(min_dist, torch.zeros_like(min_dist)) |
|
|
|
def forward(self, x): |
|
if self.attention: |
|
x = self.attention_layer(x, x, x)[0] |
|
|
|
|
|
pred = self.fc(x) |
|
|
|
|
|
if self.use_axis_info: |
|
axis_orientation = self.axis_orientation(x) |
|
return pred, axis_orientation |
|
|
|
return pred |
|
|
|
class CoordinateTransformer: |
|
"""Helper class to transform coordinates between different spaces.""" |
|
|
|
@staticmethod |
|
def to_axis_relative(points, axis_info): |
|
"""Transform points to be relative to axis coordinates. |
|
|
|
Args: |
|
points (torch.Tensor): Points in image coordinates (N, 2) |
|
axis_info (torch.Tensor): Axis information [x_min, x_max, y_min, y_max, x_origin, y_origin, x_scale, y_scale] |
|
""" |
|
|
|
x_min, x_max, y_min, y_max, x_origin, y_origin, x_scale, y_scale = axis_info.unbind(1) |
|
|
|
|
|
x_norm = (points[..., 0] - x_min) / (x_max - x_min) |
|
y_norm = (points[..., 1] - y_min) / (y_max - y_min) |
|
|
|
|
|
x_axis = x_norm * x_scale + x_origin |
|
y_axis = y_norm * y_scale + y_origin |
|
|
|
return torch.stack([x_axis, y_axis], dim=-1) |
|
|
|
@staticmethod |
|
def to_image_coordinates(points, axis_info): |
|
"""Transform points from axis coordinates to image coordinates.""" |
|
|
|
x_min, x_max, y_min, y_max, x_origin, y_origin, x_scale, y_scale = axis_info.unbind(1) |
|
|
|
|
|
x_norm = (points[..., 0] - x_origin) / x_scale |
|
y_norm = (points[..., 1] - y_origin) / y_scale |
|
|
|
|
|
x_img = x_norm * (x_max - x_min) + x_min |
|
y_img = y_norm * (y_max - y_min) + y_min |
|
|
|
return torch.stack([x_img, y_img], dim=-1) |
|
|
|
@MODELS.register_module() |
|
class DataSeriesHead(nn.Module): |
|
"""Specialized head for data series prediction with dual attention to coordinates and axis-relative positions.""" |
|
|
|
def __init__(self, in_channels, max_points=50, loss=None): |
|
super().__init__() |
|
self.max_points = max_points |
|
self.loss = loss |
|
|
|
|
|
self.fc1 = nn.Linear(in_channels, in_channels // 2) |
|
|
|
|
|
self.absolute_branch = nn.Sequential( |
|
nn.Linear(in_channels // 2, in_channels // 4), |
|
nn.ReLU(), |
|
nn.Linear(in_channels // 4, max_points * 2) |
|
) |
|
|
|
self.relative_branch = nn.Sequential( |
|
nn.Linear(in_channels // 2, in_channels // 4), |
|
nn.ReLU(), |
|
nn.Linear(in_channels // 4, max_points * 2) |
|
) |
|
|
|
|
|
self.coord_attention = nn.MultiheadAttention(in_channels, num_heads=8) |
|
self.axis_attention = nn.MultiheadAttention(in_channels, num_heads=8) |
|
self.sequence_attention = nn.MultiheadAttention(in_channels, num_heads=8) |
|
|
|
|
|
self.sequence_encoder = nn.TransformerEncoder( |
|
nn.TransformerEncoderLayer( |
|
d_model=in_channels, |
|
nhead=8, |
|
dim_feedforward=in_channels * 4, |
|
dropout=0.1 |
|
), |
|
num_layers=2 |
|
) |
|
|
|
|
|
self.pattern_recognizer = nn.Sequential( |
|
nn.Linear(in_channels, in_channels // 2), |
|
nn.ReLU(), |
|
nn.Linear(in_channels // 2, 5) |
|
) |
|
|
|
|
|
self.coord_transformer = CoordinateTransformer() |
|
|
|
def check_monotonicity(self, points, chart_type): |
|
"""Check if points follow expected monotonicity based on chart type.""" |
|
if chart_type in ['line', 'scatter']: |
|
|
|
diffs = points[..., 1].diff() |
|
return torch.all(diffs >= 0) or torch.all(diffs <= 0) |
|
return True |
|
|
|
def forward(self, x, axis_info=None, chart_type=None): |
|
|
|
coord_feat = self.coord_attention(x, x, x)[0] |
|
|
|
|
|
if axis_info is not None: |
|
axis_feat = self.axis_attention(x, x, x)[0] |
|
|
|
x = coord_feat + axis_feat |
|
else: |
|
x = coord_feat |
|
|
|
|
|
seq_feat = self.sequence_attention(x, x, x)[0] |
|
x = x + seq_feat |
|
|
|
|
|
x = self.sequence_encoder(x.unsqueeze(0)).squeeze(0) |
|
|
|
|
|
x = F.relu(self.fc1(x)) |
|
|
|
|
|
absolute_points = self.absolute_branch(x) |
|
relative_points = self.relative_branch(x) |
|
|
|
|
|
absolute_points = absolute_points.view(-1, self.max_points, 2) |
|
relative_points = relative_points.view(-1, self.max_points, 2) |
|
|
|
|
|
if axis_info is not None: |
|
relative_points = self.coord_transformer.to_axis_relative(relative_points, axis_info) |
|
|
|
|
|
pattern_logits = self.pattern_recognizer(x) |
|
|
|
|
|
if chart_type is not None: |
|
monotonicity = self.check_monotonicity(absolute_points, chart_type) |
|
else: |
|
monotonicity = None |
|
|
|
return absolute_points, relative_points, pattern_logits, monotonicity |
|
|
|
def compute_loss(self, pred_absolute, pred_relative, gt_absolute, gt_relative, |
|
pattern_logits, gt_pattern, axis_info=None, chart_type=None): |
|
"""Compute combined loss for both absolute and relative coordinates.""" |
|
|
|
if pred_absolute.dim() == 2: |
|
pred_absolute = pred_absolute.unsqueeze(0) |
|
if pred_relative.dim() == 2: |
|
pred_relative = pred_relative.unsqueeze(0) |
|
if gt_absolute.dim() == 2: |
|
gt_absolute = gt_absolute.unsqueeze(0) |
|
if gt_relative.dim() == 2: |
|
gt_relative = gt_relative.unsqueeze(0) |
|
|
|
|
|
absolute_loss = self.compute_distance_loss(pred_absolute, gt_absolute) |
|
|
|
|
|
if axis_info is not None: |
|
|
|
pred_absolute_relative = self.coord_transformer.to_axis_relative(pred_absolute, axis_info) |
|
relative_loss = self.compute_distance_loss(pred_absolute_relative, gt_relative) |
|
else: |
|
relative_loss = torch.tensor(0.0, device=pred_absolute.device) |
|
|
|
|
|
pattern_loss = F.cross_entropy(pattern_logits, gt_pattern) |
|
|
|
|
|
if chart_type is not None: |
|
monotonicity = self.check_monotonicity(pred_absolute, chart_type) |
|
monotonicity_loss = F.binary_cross_entropy(monotonicity.float(), torch.ones_like(monotonicity.float())) |
|
else: |
|
monotonicity_loss = torch.tensor(0.0, device=pred_absolute.device) |
|
|
|
|
|
total_loss = (absolute_loss + relative_loss + |
|
0.5 * pattern_loss + 0.3 * monotonicity_loss) |
|
|
|
return total_loss |
|
|
|
def compute_distance_loss(self, pred_points, gt_points): |
|
"""Compute distance-based loss between predicted and ground truth points.""" |
|
|
|
dist = torch.cdist(pred_points, gt_points) |
|
|
|
|
|
min_dist, _ = torch.min(dist, dim=2) |
|
|
|
|
|
return F.smooth_l1_loss(min_dist, torch.zeros_like(min_dist)) |