hanszhu commited on
Commit
164f8a3
·
verified ·
1 Parent(s): 92843ca

Upload chart_elementnet_swin.py

Browse files
Files changed (1) hide show
  1. chart_elementnet_swin.py +399 -0
chart_elementnet_swin.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # cascade_rcnn_r50_fpn_meta.py - Enhanced config with Swin Transformer backbone
2
+ #
3
+ # PROGRESSIVE LOSS STRATEGY:
4
+ # - All 3 Cascade stages start with SmoothL1Loss for stable initial training
5
+ # - At epoch 5, Stage 3 (final stage) switches to GIoULoss via ProgressiveLossHook
6
+ # - Stage 1 & 2 remain SmoothL1Loss throughout training
7
+ # - This ensures model stability before introducing more complex IoU-based losses
8
+ _base_ = [
9
+ '../../mmdetection/configs/_base_/datasets/coco_detection.py',
10
+ '../../mmdetection/configs/_base_/schedules/schedule_1x.py',
11
+ '../../mmdetection/configs/_base_/default_runtime.py'
12
+ ]
13
+
14
+ # Custom imports - this registers our modules without polluting config namespace
15
+ custom_imports = dict(
16
+ imports=[
17
+ 'legend_match_swin.custom_models.custom_dataset',
18
+ 'legend_match_swin.custom_models.register',
19
+ 'legend_match_swin.custom_models.custom_hooks',
20
+ 'legend_match_swin.custom_models.progressive_loss_hook',
21
+ ],
22
+ allow_failed_imports=False
23
+ )
24
+
25
+ # Add to Python path
26
+ import sys
27
+ import os
28
+ # Use a simpler path approach that doesn't rely on __file__
29
+ sys.path.insert(0, os.path.join(os.getcwd(), '..', '..'))
30
+
31
+ # Custom Cascade model with coordinate handling for chart data
32
+ model = dict(
33
+ type='CustomCascadeWithMeta', # Use custom model with coordinate handling
34
+ coordinate_standardization=dict(
35
+ enabled=True,
36
+ origin='bottom_left', # Match annotation creation coordinate system
37
+ normalize=True,
38
+ relative_to_plot=False, # Keep simple for now
39
+ scale_to_axis=False # Keep simple for now
40
+ ),
41
+ data_preprocessor=dict(
42
+ type='DetDataPreprocessor',
43
+ mean=[123.675, 116.28, 103.53],
44
+ std=[58.395, 57.12, 57.375],
45
+ bgr_to_rgb=True,
46
+ pad_size_divisor=32),
47
+ # ----- Swin Transformer Base (22K) Backbone + FPN -----
48
+ backbone=dict(
49
+ type='SwinTransformer',
50
+ embed_dims=128, # Swin Base embedding dimensions
51
+ depths=[2, 2, 18, 2], # Swin Base depths
52
+ num_heads=[4, 8, 16, 32], # Swin Base attention heads
53
+ window_size=7,
54
+ mlp_ratio=4,
55
+ qkv_bias=True,
56
+ qk_scale=None,
57
+ drop_rate=0.0,
58
+ attn_drop_rate=0.0,
59
+ drop_path_rate=0.3, # Slightly higher for more complex model
60
+ patch_norm=True,
61
+ out_indices=(0, 1, 2, 3),
62
+ with_cp=False,
63
+ convert_weights=True,
64
+ init_cfg=dict(
65
+ type='Pretrained',
66
+ checkpoint='https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/swin/swin_base_patch4_window7_224_22k_20220317-4f79f7c0.pth'
67
+ )
68
+ ),
69
+ neck=dict(
70
+ type='FPN',
71
+ in_channels=[128, 256, 512, 1024], # Swin Base: embed_dims * 2^(stage)
72
+ out_channels=256,
73
+ num_outs=6,
74
+ start_level=0,
75
+ add_extra_convs='on_input'
76
+ ),
77
+ # Enhanced RPN with smaller anchors for tiny objects + improved losses
78
+ rpn_head=dict(
79
+ type='RPNHead',
80
+ in_channels=256,
81
+ feat_channels=256,
82
+ anchor_generator=dict(
83
+ type='AnchorGenerator',
84
+ scales=[1, 2, 4, 8], # Even smaller scales for tiny objects
85
+ ratios=[0.5, 1.0, 2.0], # Multiple aspect ratios
86
+ strides=[4, 8, 16, 32, 64, 128]), # Extended FPN strides
87
+ bbox_coder=dict(
88
+ type='DeltaXYWHBBoxCoder',
89
+ target_means=[.0, .0, .0, .0],
90
+ target_stds=[1.0, 1.0, 1.0, 1.0]),
91
+ loss_cls=dict(
92
+ type='CrossEntropyLoss',
93
+ use_sigmoid=True,
94
+ loss_weight=1.0),
95
+ loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),
96
+ # Progressive Loss Strategy: Start with SmoothL1 for all 3 stages
97
+ # Stage 3 (final stage) will switch to GIoU at epoch 5 via ProgressiveLossHook
98
+ roi_head=dict(
99
+ type='CascadeRoIHead',
100
+ num_stages=3,
101
+ stage_loss_weights=[1, 0.5, 0.25],
102
+ bbox_roi_extractor=dict(
103
+ type='SingleRoIExtractor',
104
+ roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
105
+ out_channels=256,
106
+ featmap_strides=[4, 8, 16, 32]),
107
+ bbox_head=[
108
+ # Stage 1: Always SmoothL1Loss (coarse detection)
109
+ dict(
110
+ type='Shared2FCBBoxHead',
111
+ in_channels=256,
112
+ fc_out_channels=1024,
113
+ roi_feat_size=7,
114
+ num_classes=21, # 21 enhanced categories
115
+ bbox_coder=dict(
116
+ type='DeltaXYWHBBoxCoder',
117
+ target_means=[0., 0., 0., 0.],
118
+ target_stds=[0.05, 0.05, 0.1, 0.1]),
119
+ reg_class_agnostic=True,
120
+ loss_cls=dict(
121
+ type='CrossEntropyLoss',
122
+ use_sigmoid=False,
123
+ loss_weight=1.0),
124
+ loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),
125
+ # Stage 2: Always SmoothL1Loss (intermediate refinement)
126
+ dict(
127
+ type='Shared2FCBBoxHead',
128
+ in_channels=256,
129
+ fc_out_channels=1024,
130
+ roi_feat_size=7,
131
+ num_classes=21, # 21 enhanced categories
132
+ bbox_coder=dict(
133
+ type='DeltaXYWHBBoxCoder',
134
+ target_means=[0., 0., 0., 0.],
135
+ target_stds=[0.033, 0.033, 0.067, 0.067]),
136
+ reg_class_agnostic=True,
137
+ loss_cls=dict(
138
+ type='CrossEntropyLoss',
139
+ use_sigmoid=False,
140
+ loss_weight=1.0),
141
+ loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)),
142
+ # Stage 3: SmoothL1 → GIoU at epoch 5 (progressive switching)
143
+ dict(
144
+ type='Shared2FCBBoxHead',
145
+ in_channels=256,
146
+ fc_out_channels=1024,
147
+ roi_feat_size=7,
148
+ num_classes=21, # 21 enhanced categories
149
+ bbox_coder=dict(
150
+ type='DeltaXYWHBBoxCoder',
151
+ target_means=[0., 0., 0., 0.],
152
+ target_stds=[0.02, 0.02, 0.05, 0.05]),
153
+ reg_class_agnostic=True,
154
+ loss_cls=dict(
155
+ type='CrossEntropyLoss',
156
+ use_sigmoid=False,
157
+ loss_weight=1.0),
158
+ loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))
159
+ ]),
160
+ train_cfg=dict(
161
+ rpn=dict(
162
+ assigner=dict(
163
+ type='MaxIoUAssigner',
164
+ pos_iou_thr=0.7,
165
+ neg_iou_thr=0.3,
166
+ min_pos_iou=0.3,
167
+ match_low_quality=True,
168
+ ignore_iof_thr=-1),
169
+ sampler=dict(
170
+ type='RandomSampler',
171
+ num=256,
172
+ pos_fraction=0.5,
173
+ neg_pos_ub=-1,
174
+ add_gt_as_proposals=False),
175
+ allowed_border=0,
176
+ pos_weight=-1,
177
+ debug=False),
178
+ rpn_proposal=dict(
179
+ nms_pre=2000,
180
+ max_per_img=2000,
181
+ nms=dict(type='nms', iou_threshold=0.8),
182
+ min_bbox_size=0),
183
+ rcnn=[
184
+ dict(
185
+ assigner=dict(
186
+ type='MaxIoUAssigner',
187
+ pos_iou_thr=0.4,
188
+ neg_iou_thr=0.4,
189
+ min_pos_iou=0.4,
190
+ match_low_quality=False,
191
+ ignore_iof_thr=-1),
192
+ sampler=dict(
193
+ type='RandomSampler',
194
+ num=512,
195
+ pos_fraction=0.25,
196
+ neg_pos_ub=-1,
197
+ add_gt_as_proposals=True),
198
+ pos_weight=-1,
199
+ debug=False),
200
+ dict(
201
+ assigner=dict(
202
+ type='MaxIoUAssigner',
203
+ pos_iou_thr=0.6,
204
+ neg_iou_thr=0.6,
205
+ min_pos_iou=0.6,
206
+ match_low_quality=False,
207
+ ignore_iof_thr=-1),
208
+ sampler=dict(
209
+ type='RandomSampler',
210
+ num=512,
211
+ pos_fraction=0.25,
212
+ neg_pos_ub=-1,
213
+ add_gt_as_proposals=True),
214
+ pos_weight=-1,
215
+ debug=False),
216
+ dict(
217
+ assigner=dict(
218
+ type='MaxIoUAssigner',
219
+ pos_iou_thr=0.7,
220
+ neg_iou_thr=0.7,
221
+ min_pos_iou=0.7,
222
+ match_low_quality=False,
223
+ ignore_iof_thr=-1),
224
+ sampler=dict(
225
+ type='RandomSampler',
226
+ num=512,
227
+ pos_fraction=0.25,
228
+ neg_pos_ub=-1,
229
+ add_gt_as_proposals=True),
230
+ pos_weight=-1,
231
+ debug=False)
232
+ ]),
233
+ # Enhanced test configuration with soft-NMS and multi-scale support
234
+ test_cfg=dict(
235
+ rpn=dict(
236
+ nms_pre=1000,
237
+ max_per_img=1000,
238
+ nms=dict(type='nms', iou_threshold=0.7),
239
+ min_bbox_size=0),
240
+ rcnn=dict(
241
+ score_thr=0.005, # Even lower threshold to catch more classes
242
+ nms=dict(
243
+ type='soft_nms', # Soft-NMS for better small object detection
244
+ iou_threshold=0.5,
245
+ min_score=0.005,
246
+ method='gaussian',
247
+ sigma=0.5),
248
+ max_per_img=500))) # Allow more detections
249
+
250
+ # Dataset settings - using cleaned annotations
251
+ dataset_type = 'ChartDataset'
252
+ data_root = '' # Remove data_root duplication
253
+
254
+ # Define the 21 chart element classes that match the annotations
255
+ CLASSES = (
256
+ 'title', 'subtitle', 'x-axis', 'y-axis', 'x-axis-label', 'y-axis-label',
257
+ 'x-tick-label', 'y-tick-label', 'legend', 'legend-title', 'legend-item',
258
+ 'data-point', 'data-line', 'data-bar', 'data-area', 'grid-line',
259
+ 'axis-title', 'tick-label', 'data-label', 'legend-text', 'plot-area'
260
+ )
261
+
262
+ # Updated to use cleaned annotation files
263
+ train_dataloader = dict(
264
+ batch_size=2, # Increased back to 2
265
+ num_workers=2,
266
+ persistent_workers=True,
267
+ sampler=dict(type='DefaultSampler', shuffle=True),
268
+ dataset=dict(
269
+ type=dataset_type,
270
+ data_root=data_root,
271
+ ann_file='legend_data/annotations_JSON_cleaned/train_enriched.json', # Full path
272
+ data_prefix=dict(img='legend_data/train/images/'), # Full path
273
+ metainfo=dict(classes=CLASSES), # Tell dataset what classes to expect
274
+ filter_cfg=dict(filter_empty_gt=True, min_size=0, class_specific_min_sizes={
275
+ 'data-point': 16, # Back to 16x16 from 32x32
276
+ 'data-bar': 16, # Back to 16x16 from 32x32
277
+ 'tick-label': 16, # Back to 16x16 from 32x32
278
+ 'x-tick-label': 16, # Back to 16x16 from 32x32
279
+ 'y-tick-label': 16 # Back to 16x16 from 32x32
280
+ }),
281
+ pipeline=[
282
+ dict(type='LoadImageFromFile'),
283
+ dict(type='LoadAnnotations', with_bbox=True),
284
+ dict(type='Resize', scale=(1600, 1000), keep_ratio=True), # Higher resolution for tiny objects
285
+ dict(type='RandomFlip', prob=0.5),
286
+ dict(type='ClampBBoxes'), # Ensure bboxes stay within image bounds
287
+ dict(type='PackDetInputs')
288
+ ]
289
+ )
290
+ )
291
+
292
+ val_dataloader = dict(
293
+ batch_size=1,
294
+ num_workers=2,
295
+ persistent_workers=True,
296
+ drop_last=False,
297
+ sampler=dict(type='DefaultSampler', shuffle=False),
298
+ dataset=dict(
299
+ type=dataset_type,
300
+ data_root=data_root,
301
+ ann_file='legend_data/annotations_JSON_cleaned/val_enriched_with_info.json', # Full path
302
+ data_prefix=dict(img='legend_data/train/images/'), # All images are in train/images
303
+ metainfo=dict(classes=CLASSES), # Tell dataset what classes to expect
304
+ test_mode=True,
305
+ pipeline=[
306
+ dict(type='LoadImageFromFile'),
307
+ dict(type='Resize', scale=(1600, 1000), keep_ratio=True), # Base resolution for validation
308
+ dict(type='LoadAnnotations', with_bbox=True),
309
+ dict(type='ClampBBoxes'), # Ensure bboxes stay within image bounds
310
+ dict(type='PackDetInputs', meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', 'scale_factor'))
311
+ ]
312
+ )
313
+ )
314
+
315
+ test_dataloader = val_dataloader
316
+
317
+ # Enhanced evaluators with debugging
318
+ val_evaluator = dict(
319
+ type='CocoMetric',
320
+ ann_file='legend_data/annotations_JSON_cleaned/val_enriched_with_info.json', # Using cleaned annotations
321
+ metric='bbox',
322
+ format_only=False,
323
+ classwise=True, # Enable detailed per-class metrics table
324
+ proposal_nums=(100, 300, 1000)) # More detailed AR metrics
325
+
326
+ test_evaluator = val_evaluator
327
+
328
+ # Add custom hooks for debugging empty results
329
+ default_hooks = dict(
330
+ timer=dict(type='IterTimerHook'),
331
+ logger=dict(type='LoggerHook', interval=50),
332
+ param_scheduler=dict(type='ParamSchedulerHook'),
333
+ checkpoint=dict(type='CompatibleCheckpointHook', interval=1, save_best='auto', max_keep_ckpts=3),
334
+ sampler_seed=dict(type='DistSamplerSeedHook'),
335
+ visualization=dict(type='DetVisualizationHook'))
336
+
337
+ # Add NaN recovery hook for graceful handling like Faster R-CNN
338
+ custom_hooks = [
339
+ dict(type='SkipBadSamplesHook', interval=1), # Skip samples with bad GT data
340
+ dict(type='ChartTypeDistributionHook', interval=500), # Monitor class distribution
341
+ dict(type='MissingImageReportHook', interval=1000), # Track missing images
342
+ dict(type='NanRecoveryHook', # For logging & monitoring
343
+ fallback_loss=1.0,
344
+ max_consecutive_nans=100,
345
+ log_interval=50),
346
+ dict(type='ProgressiveLossHook', # Progressive loss switching
347
+ switch_epoch=5, # Switch stage 3 to GIoU at epoch 5
348
+ target_loss_type='GIoULoss', # Use GIoU for stage 3 (final stage)
349
+ loss_weight=1.0, # Keep same loss weight
350
+ warmup_epochs=2, # Monitor for 2 epochs after switch
351
+ monitor_stage_weights=True), # Log stage loss details
352
+ ]
353
+
354
+ # Training configuration - extended to 40 epochs for Swin Base on small objects
355
+ train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=40, val_interval=1)
356
+ val_cfg = dict(type='ValLoop')
357
+ test_cfg = dict(type='TestLoop')
358
+
359
+ # Optimizer with standard stable settings
360
+ optim_wrapper = dict(
361
+ type='OptimWrapper',
362
+ optimizer=dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001),
363
+ clip_grad=dict(max_norm=35.0, norm_type=2)
364
+ )
365
+
366
+ # Extended learning rate schedule with cosine annealing for Swin Base
367
+ param_scheduler = [
368
+ dict(
369
+ type='LinearLR',
370
+ start_factor=0.05, # 1e-4 / 2e-2 = 0.05 (warmup from 1e-4 to 2e-2)
371
+ by_epoch=False,
372
+ begin=0,
373
+ end=1000), # 1k iteration warmup
374
+ dict(
375
+ type='CosineAnnealingLR',
376
+ begin=0,
377
+ end=40, # Match max_epochs
378
+ by_epoch=True,
379
+ T_max=40,
380
+ eta_min=1e-6, # Minimum learning rate
381
+ convert_to_iter_based=True)
382
+ ]
383
+
384
+ # Work directory
385
+ work_dir = './work_dirs/cascade_rcnn_swin_base_40ep_cosine_fpn_meta'
386
+
387
+ # Multi-scale test configuration (uncomment to enable)
388
+ # img_scales = [(800, 500), (1600, 1000), (2400, 1500)] # 0.5x, 1.0x, 1.5x scales
389
+ # tta_model = dict(
390
+ # type='DetTTAModel',
391
+ # tta_cfg=dict(
392
+ # nms=dict(type='nms', iou_threshold=0.5),
393
+ # max_per_img=100)
394
+ # )
395
+
396
+ # Fresh start
397
+ resume = False
398
+ load_from = None
399
+