RP3D-Diag
Collection
[nature communications 2024] Large-scale Long-tailed Disease Diagnosis on Radiology Images
•
4 items
•
Updated
•
2
The detailed parameter we use for training is in the following:
start_class: 0
end_clas: 5569
backbone: 'resnet'
level: 'articles' # represents the disorder level
depth: 32
ltype: 'MultiLabel' # represents the Binary Cross Entropy Loss
augment: True # represents the medical data augmentation
split: 'late' # represents the late fusion strategy
# Load backnone
model = RadNet(num_cls=num_classes, backbone=backbone, depth=depth, ltype=ltype, augment=augment, fuse=fuse, ke=ke, encoded=encoded, adapter=adapter)
pretrained_weights = torch.load("path/to/pytorch_model_32_late.bin")
missing, unexpect = model.load_state_dict(pretrained_weights,strict=False)
print("missing_cpt:", missing)
print("unexpect_cpt:", unexpect)
# If KE is set True, load text encoder
medcpt = MedCPT_clinical(bert_model_name = 'ncbi/MedCPT-Query-Encoder')
checkpoint = torch.load('path/to/epoch_state.pt',map_location='cpu')['state_dict']
load_checkpoint = {key.replace('module.', ''): value for key, value in checkpoint.items()}
missing, unexpect = medcpt.load_state_dict(load_checkpoint, strict=False)
print("missing_cpt:", missing)
print("unexpect_cpt:", unexpect)
All the early fusion checkpoint can be further finetuned from this checkpoint. If you need other checkpoints using different parameter settings, there are two possible ways:
''' checkpoint: "None" safetensor: path to this checkpoint(pytorch_model.bin) '''
Email the author: [email protected]
Please refer to RP3D-DiagDS
For more information, please refer to our instructions on github to download and use.