|
|
import gradio as gr |
|
|
import timm |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
def change_num_input_channels(model, in_channels=1): |
|
|
""" |
|
|
Assumes number of input channels in model is 3. |
|
|
""" |
|
|
for i, m in enumerate(model.modules()): |
|
|
if isinstance(m, (nn.Conv2d,nn.Conv3d)) and m.in_channels == 3: |
|
|
m.in_channels = in_channels |
|
|
|
|
|
W = m.weight.sum(1, keepdim=True) |
|
|
|
|
|
W = W / in_channels |
|
|
|
|
|
size = [1] * W.ndim |
|
|
size[1] = in_channels |
|
|
W = W.repeat(size) |
|
|
m.weight = nn.Parameter(W) |
|
|
break |
|
|
return model |
|
|
|
|
|
|
|
|
class Net2D(nn.Module): |
|
|
|
|
|
def __init__(self, weights): |
|
|
super().__init__() |
|
|
self.backbone = timm.create_model("tf_efficientnetv2_s", pretrained=False, global_pool="", num_classes=0) |
|
|
self.backbone = change_num_input_channels(self.backbone, 2) |
|
|
self.pool_layer = nn.AdaptiveAvgPool2d(1) |
|
|
self.dropout = nn.Dropout(0.2) |
|
|
self.classifier = nn.Linear(1280, 1) |
|
|
self.load_state_dict(weights) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.backbone(x) |
|
|
x = self.pool_layer(x).view(x.size(0), -1) |
|
|
x = self.dropout(x) |
|
|
x = self.classifier(x) |
|
|
return x[:, 0] if x.size(1) == 1 else x |
|
|
|
|
|
|
|
|
class Ensemble(nn.Module): |
|
|
|
|
|
def __init__(self, model_list): |
|
|
super().__init__() |
|
|
self.model_list = nn.ModuleList(model_list) |
|
|
|
|
|
def forward(self, x): |
|
|
return torch.stack([model(x) for model in self.model_list]).mean(0) |
|
|
|
|
|
|
|
|
checkpoints = ["fold0.ckpt", "fold1.ckpt", "fold2.ckpt"] |
|
|
weights = [torch.load(ckpt)["state_dict"] for ckpt in checkpoints] |
|
|
weights = [{k.replace("model.", "") : v for k, v in wt.items()} for wt in weights] |
|
|
models = [Net2D(wt) for wt in weights] |
|
|
ensemble = Ensemble(models).eval() |
|
|
|
|
|
def predict_bone_age(Radiograph, Sex): |
|
|
img = torch.from_numpy(Radiograph) |
|
|
img = img.unsqueeze(0).unsqueeze(0) |
|
|
img = img / img.max() |
|
|
img = img - 0.5 |
|
|
img = img * 2.0 |
|
|
if Sex == 1: |
|
|
img = torch.cat([img, torch.zeros_like(img) + 1], dim=1) |
|
|
else: |
|
|
img = torch.cat([img, torch.zeros_like(img) - 1], dim=1) |
|
|
with torch.no_grad(): |
|
|
bone_age = ensemble(img.float())[0].item() |
|
|
return f"Estimated Bone Age: {int(bone_age)} years, {int(bone_age % int(bone_age) * 12)} months" |
|
|
|
|
|
|
|
|
image = gr.Image(shape=(512, 512), image_mode="L") |
|
|
sex = gr.Radio(["Male", "Female"], type="index") |
|
|
label = gr.Label(show_label=True, label="Result") |
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=predict_bone_age, |
|
|
inputs=[image, sex], |
|
|
outputs=label, |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|
|
|
|
|
|
|