blumenstiel's picture
Add model weights
60bda88
# lightning.pytorch==2.4.0
seed_everything: 0
trainer:
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
precision: 16-mixed
logger:
class_path: lightning.pytorch.loggers.TensorBoardLogger
init_args:
save_dir: /dccstor/geofm-finetuning/benchmark-geo-bench-paolo/
name: test2
log_graph: false
default_hp_metric: true
prefix: ''
comment: ''
max_queue: 10
flush_secs: 120
filename_suffix: ''
callbacks:
- class_path: lightning.pytorch.callbacks.RichProgressBar
init_args:
refresh_rate: 1
leave: false
theme:
description: white
progress_bar: '#6206E0'
progress_bar_finished: '#6206E0'
progress_bar_pulse: '#6206E0'
batch_progress: white
time: grey54
processing_speed: grey70
metrics: white
metrics_text_delimiter: ' '
metrics_format: .3f
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: epoch
log_momentum: false
log_weight_decay: false
- class_path: lightning.pytorch.callbacks.EarlyStopping
init_args:
monitor: val/loss
min_delta: 0.0
patience: 20
verbose: false
mode: min
strict: true
check_finite: true
log_rank_zero_only: false
fast_dev_run: false
max_epochs: 50
max_steps: -1
overfit_batches: 0.0
check_val_every_n_epoch: 2
log_every_n_steps: 10
enable_checkpointing: true
accumulate_grad_batches: 1
inference_mode: true
use_distributed_sampler: true
detect_anomaly: false
barebones: false
sync_batchnorm: false
reload_dataloaders_every_n_epochs: 0
default_root_dir: /dccstor/geofm-finetuning/benchmark-geo-bench-paolo/
model:
class_path: terratorch.tasks.SemanticSegmentationTask
init_args:
model_args:
backbone_pretrained: true
backbone: prithvi_eo_v2_300_tl
decoder: UperNetDecoder
decoder_channels: 256
decoder_scale_modules: true
num_classes: 2
rescale: true
backbone_bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
head_dropout: 0.1
necks:
- name: SelectIndices
indices:
- 5
- 11
- 17
- 23
- name: ReshapeTokensToImage
model_factory: EncoderDecoderFactory
loss: ce
ignore_index: -1
lr: 0.001
freeze_backbone: false
freeze_decoder: false
plot_on_val: 10
data:
class_path: terratorch.datamodules.Sen1Floods11NonGeoDataModule
init_args:
data_root: /dccstor/geofm-finetuning/datasets/sen1floods11
batch_size: 16
num_workers: 8
bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
train_transform:
- class_path: albumentations.Resize
init_args:
height: 448
width: 448
interpolation: 1
always_apply: false
p: 1
- class_path: albumentations.RandomCrop
init_args:
height: 224
width: 224
always_apply: false
p: 1.0
- class_path: albumentations.HorizontalFlip
init_args:
always_apply: false
p: 0.5
- class_path: albumentations.VerticalFlip
init_args:
always_apply: false
p: 0.5
- class_path: albumentations.pytorch.ToTensorV2
init_args:
transpose_mask: false
always_apply: true
p: 1.0
val_transform:
- class_path: albumentations.Resize
init_args:
height: 448
width: 448
interpolation: 1
always_apply: false
p: 1
- class_path: albumentations.pytorch.ToTensorV2
init_args:
transpose_mask: false
always_apply: true
p: 1.0
test_transform:
- class_path: albumentations.Resize
init_args:
height: 448
width: 448
interpolation: 1
always_apply: false
p: 1
- class_path: albumentations.pytorch.ToTensorV2
init_args:
transpose_mask: false
always_apply: true
p: 1.0
drop_last: true
constant_scale: 0.0001
no_data_replace: 0.0
no_label_replace: -1
use_metadata: false
out_dtype: int16
deploy_config_file: true
ModelCheckpoint:
filename: '{epoch}'
monitor: val/loss
verbose: false
save_top_k: 1
mode: min
save_weights_only: false
auto_insert_metric_name: true
enable_version_counter: true
StateDictModelCheckpoint:
filename: '{epoch}_state_dict'
monitor: val/loss
verbose: false
save_top_k: 1
mode: min
save_weights_only: true
auto_insert_metric_name: true
enable_version_counter: true
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 5.0e-05
betas:
- 0.9
- 0.999
eps: 1.0e-08
weight_decay: 0.05
amsgrad: false
maximize: false
capturable: false
differentiable: false
lr_scheduler:
class_path: torch.optim.lr_scheduler.CosineAnnealingLR
init_args:
T_max: 50
eta_min: 0
last_epoch: -1
verbose: deprecated