File size: 6,264 Bytes
da2e2ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from __future__ import annotations

from typing import Any, List, Dict

import torch
import torch.optim as optim
import copy
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
import torch.nn as nn
from det_map.data.datasets.dataclasses import SensorConfig, Scene
from det_map.data.datasets.feature_builders import LiDARCameraFeatureBuilder
from navsim.agents.abstract_agent import AbstractAgent
from navsim.planning.training.abstract_feature_target_builder import AbstractFeatureBuilder, AbstractTargetBuilder

from det_map.det.dal.mmdet3d.models.utils.grid_mask import GridMask

import torch.nn.functional as F

from det_map.det.dal.mmdet3d.ops import Voxelization, DynamicScatter
from det_map.det.dal.mmdet3d.models import builder
from mmcv.utils import TORCH_VERSION, digit_version


from typing import Any, List, Dict

import numpy as np
import torch
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler

from det_map.data.datasets.dataclasses import SensorConfig, Scene
from det_map.data.datasets.feature_builders import LiDARCameraFeatureBuilder
from det_map.map.map_target import MapTargetBuilder
from navsim.agents.abstract_agent import AbstractAgent
from navsim.planning.training.abstract_feature_target_builder import AbstractFeatureBuilder, AbstractTargetBuilder
import torch.optim as optim

try:
    from det_map.map.assigners import *
    from det_map.map.dense_heads import *
    from det_map.map.losses import *
    from det_map.map.modules import *
except Exception:
    raise Exception
class MapAgent(AbstractAgent):
    def __init__(

            self,

            model,

            pipelines,

            lr: float,

            checkpoint_path: str = None, **kwargs

    ):
        super().__init__()
        # todo eval everything
        self.model = model
        self.pipelines = pipelines
        self._checkpoint_path = checkpoint_path
        self._lr = lr

    def name(self) -> str:
        """Inherited, see superclass."""

        return self.__class__.__name__

    def initialize(self) -> None:
        """Inherited, see superclass."""
        state_dict: Dict[str, Any] = torch.load(self._checkpoint_path)["state_dict"]
        self.load_state_dict({k.replace("agent.", ""): v for k, v in state_dict.items()})

    def get_sensor_config(self) -> SensorConfig:
        """Inherited, see superclass."""
        return SensorConfig.build_all_sensors(True)

    def get_target_builders(self) -> List[AbstractTargetBuilder]:
        return [
            MapTargetBuilder(),
        ]

    def get_feature_builders(self) -> List[AbstractFeatureBuilder]:
        return [
            LiDARCameraFeatureBuilder(self.pipelines)
        ]

    def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:

        return self.model(features)

    def compute_loss(

            self,

            features: Dict[str, torch.Tensor],

            targets: Dict[str, torch.Tensor],

            predictions: Dict[str, torch.Tensor],

            tokens=None

    ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:

        losses = dict()

        # depth = predictions.pop('depth')
        # if "gt_depth" in targets:
        #     gt_depth = targets["gt_depth"]
            # loss_depth = self.pts_bbox_head.transformer.encoder.get_depth_loss(gt_depth, depth)
            # if digit_version(TORCH_VERSION) >= digit_version('1.8'):
            #     loss_depth = torch.nan_to_num(loss_depth)
            # losses.update(loss_depth=loss_depth)

        gt_bboxes_3d = targets["gt_bboxes_3d"]
        gt_labels_3d = targets["gt_labels_3d"]
        # print(type(gt_labels_3d))
        # gt_labels_3d = torch.tensor(gt_labels_3d)
        #import pdb;
        #pdb.set_trace()
        #gt_labels_3d = None
        gt_seg_mask = None
        gt_pv_seg_mask = None
        # gt_seg_mask = targets["gt_seg_mask"]
        # gt_pv_seg_mask = targets["gt_pv_seg_mask"]
        #import pdb;
       # pdb.set_trace()
        loss_inputs = [gt_bboxes_3d, gt_labels_3d, gt_seg_mask, gt_pv_seg_mask, predictions]
        losses_pts = self.model.pts_bbox_head.loss(*loss_inputs, img_metas=None)

        losses.update(losses_pts)

        k_one2many = self.model.pts_bbox_head.k_one2many
        multi_gt_bboxes_3d = copy.deepcopy(gt_bboxes_3d)
        multi_gt_labels_3d = copy.deepcopy(gt_labels_3d)
        # multi_gt_labels_3d = torch.zeros((gt_labels_3d.size(0), gt_labels_3d.size(1) * k_one2many))
        for i, (each_gt_bboxes_3d, each_gt_labels_3d) in enumerate(zip(multi_gt_bboxes_3d, multi_gt_labels_3d)):
            each_gt_bboxes_3d.instance_list = each_gt_bboxes_3d.instance_list * k_one2many
            each_gt_bboxes_3d.instance_labels = each_gt_bboxes_3d.instance_labels * k_one2many
            multi_gt_labels_3d[i] = each_gt_labels_3d.repeat(k_one2many)
        one2many_outs = predictions['one2many_outs']
        loss_one2many_inputs = [multi_gt_bboxes_3d, multi_gt_labels_3d, gt_seg_mask, gt_pv_seg_mask, one2many_outs]
        loss_dict_one2many = self.model.pts_bbox_head.loss(*loss_one2many_inputs, img_metas=None)

        lambda_one2many = self.model.pts_bbox_head.lambda_one2many
        for key, value in loss_dict_one2many.items():
            if key + "_one2many" in losses.keys():
                losses[key + "_one2many"] += value * lambda_one2many
            else:
                losses[key + "_one2many"] = value * lambda_one2many
        loss = 0
        for k, v in losses.items():
            loss = loss + v
        return loss, losses

    def get_optimizers(self) -> Optimizer | Dict[str, Optimizer | LRScheduler]:
        optimizer = initialize_optimizer(self.model, self._lr)
        return {'optimizer': optimizer}

def initialize_optimizer(model, lr):
        optimizer = optim.AdamW([
            {'params': [param for name, param in model.named_parameters() if 'img_backbone' in name], 'lr': lr * 0.1},
            {'params': [param for name, param in model.named_parameters() if 'img_backbone' not in name], 'lr': lr},
        ], weight_decay=0.01)
        return optimizer