File size: 6,264 Bytes
da2e2ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
from __future__ import annotations
from typing import Any, List, Dict
import torch
import torch.optim as optim
import copy
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
import torch.nn as nn
from det_map.data.datasets.dataclasses import SensorConfig, Scene
from det_map.data.datasets.feature_builders import LiDARCameraFeatureBuilder
from navsim.agents.abstract_agent import AbstractAgent
from navsim.planning.training.abstract_feature_target_builder import AbstractFeatureBuilder, AbstractTargetBuilder
from det_map.det.dal.mmdet3d.models.utils.grid_mask import GridMask
import torch.nn.functional as F
from det_map.det.dal.mmdet3d.ops import Voxelization, DynamicScatter
from det_map.det.dal.mmdet3d.models import builder
from mmcv.utils import TORCH_VERSION, digit_version
from typing import Any, List, Dict
import numpy as np
import torch
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from det_map.data.datasets.dataclasses import SensorConfig, Scene
from det_map.data.datasets.feature_builders import LiDARCameraFeatureBuilder
from det_map.map.map_target import MapTargetBuilder
from navsim.agents.abstract_agent import AbstractAgent
from navsim.planning.training.abstract_feature_target_builder import AbstractFeatureBuilder, AbstractTargetBuilder
import torch.optim as optim
try:
from det_map.map.assigners import *
from det_map.map.dense_heads import *
from det_map.map.losses import *
from det_map.map.modules import *
except Exception:
raise Exception
class MapAgent(AbstractAgent):
def __init__(
self,
model,
pipelines,
lr: float,
checkpoint_path: str = None, **kwargs
):
super().__init__()
# todo eval everything
self.model = model
self.pipelines = pipelines
self._checkpoint_path = checkpoint_path
self._lr = lr
def name(self) -> str:
"""Inherited, see superclass."""
return self.__class__.__name__
def initialize(self) -> None:
"""Inherited, see superclass."""
state_dict: Dict[str, Any] = torch.load(self._checkpoint_path)["state_dict"]
self.load_state_dict({k.replace("agent.", ""): v for k, v in state_dict.items()})
def get_sensor_config(self) -> SensorConfig:
"""Inherited, see superclass."""
return SensorConfig.build_all_sensors(True)
def get_target_builders(self) -> List[AbstractTargetBuilder]:
return [
MapTargetBuilder(),
]
def get_feature_builders(self) -> List[AbstractFeatureBuilder]:
return [
LiDARCameraFeatureBuilder(self.pipelines)
]
def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
return self.model(features)
def compute_loss(
self,
features: Dict[str, torch.Tensor],
targets: Dict[str, torch.Tensor],
predictions: Dict[str, torch.Tensor],
tokens=None
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
losses = dict()
# depth = predictions.pop('depth')
# if "gt_depth" in targets:
# gt_depth = targets["gt_depth"]
# loss_depth = self.pts_bbox_head.transformer.encoder.get_depth_loss(gt_depth, depth)
# if digit_version(TORCH_VERSION) >= digit_version('1.8'):
# loss_depth = torch.nan_to_num(loss_depth)
# losses.update(loss_depth=loss_depth)
gt_bboxes_3d = targets["gt_bboxes_3d"]
gt_labels_3d = targets["gt_labels_3d"]
# print(type(gt_labels_3d))
# gt_labels_3d = torch.tensor(gt_labels_3d)
#import pdb;
#pdb.set_trace()
#gt_labels_3d = None
gt_seg_mask = None
gt_pv_seg_mask = None
# gt_seg_mask = targets["gt_seg_mask"]
# gt_pv_seg_mask = targets["gt_pv_seg_mask"]
#import pdb;
# pdb.set_trace()
loss_inputs = [gt_bboxes_3d, gt_labels_3d, gt_seg_mask, gt_pv_seg_mask, predictions]
losses_pts = self.model.pts_bbox_head.loss(*loss_inputs, img_metas=None)
losses.update(losses_pts)
k_one2many = self.model.pts_bbox_head.k_one2many
multi_gt_bboxes_3d = copy.deepcopy(gt_bboxes_3d)
multi_gt_labels_3d = copy.deepcopy(gt_labels_3d)
# multi_gt_labels_3d = torch.zeros((gt_labels_3d.size(0), gt_labels_3d.size(1) * k_one2many))
for i, (each_gt_bboxes_3d, each_gt_labels_3d) in enumerate(zip(multi_gt_bboxes_3d, multi_gt_labels_3d)):
each_gt_bboxes_3d.instance_list = each_gt_bboxes_3d.instance_list * k_one2many
each_gt_bboxes_3d.instance_labels = each_gt_bboxes_3d.instance_labels * k_one2many
multi_gt_labels_3d[i] = each_gt_labels_3d.repeat(k_one2many)
one2many_outs = predictions['one2many_outs']
loss_one2many_inputs = [multi_gt_bboxes_3d, multi_gt_labels_3d, gt_seg_mask, gt_pv_seg_mask, one2many_outs]
loss_dict_one2many = self.model.pts_bbox_head.loss(*loss_one2many_inputs, img_metas=None)
lambda_one2many = self.model.pts_bbox_head.lambda_one2many
for key, value in loss_dict_one2many.items():
if key + "_one2many" in losses.keys():
losses[key + "_one2many"] += value * lambda_one2many
else:
losses[key + "_one2many"] = value * lambda_one2many
loss = 0
for k, v in losses.items():
loss = loss + v
return loss, losses
def get_optimizers(self) -> Optimizer | Dict[str, Optimizer | LRScheduler]:
optimizer = initialize_optimizer(self.model, self._lr)
return {'optimizer': optimizer}
def initialize_optimizer(model, lr):
optimizer = optim.AdamW([
{'params': [param for name, param in model.named_parameters() if 'img_backbone' in name], 'lr': lr * 0.1},
{'params': [param for name, param in model.named_parameters() if 'img_backbone' not in name], 'lr': lr},
], weight_decay=0.01)
return optimizer
|