File size: 2,699 Bytes
a1306dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import gradio as gr
import timm
import torch 
import torch.nn as nn


def change_num_input_channels(model, in_channels=1):
    """
    Assumes number of input channels in model is 3.
    """
    for i, m in enumerate(model.modules()):
      if isinstance(m, (nn.Conv2d,nn.Conv3d)) and m.in_channels == 3:
        m.in_channels = in_channels
        # First, sum across channels
        W = m.weight.sum(1, keepdim=True)
        # Then, divide by number of channels
        W = W / in_channels
        # Then, repeat by number of channels
        size = [1] * W.ndim
        size[1] = in_channels
        W = W.repeat(size)
        m.weight = nn.Parameter(W)
        break
    return model


class Net2D(nn.Module):

    def __init__(self, weights):
        super().__init__()
        self.backbone = timm.create_model("tf_efficientnetv2_s", pretrained=False, global_pool="", num_classes=0)
        self.backbone = change_num_input_channels(self.backbone, 2)
        self.pool_layer = nn.AdaptiveAvgPool2d(1)
        self.dropout = nn.Dropout(0.2)
        self.classifier = nn.Linear(1280, 1) 
        self.load_state_dict(weights)

    def forward(self, x):
        x = self.backbone(x)
        x = self.pool_layer(x).view(x.size(0), -1)
        x = self.dropout(x)
        x = self.classifier(x)
        return x[:, 0] if x.size(1) == 1 else x


class Ensemble(nn.Module):

    def __init__(self, model_list):
        super().__init__()
        self.model_list = nn.ModuleList(model_list)
    
    def forward(self, x):
        return torch.stack([model(x) for model in self.model_list]).mean(0)


checkpoints = ["fold0.ckpt", "fold1.ckpt", "fold2.ckpt"]
weights = [torch.load(ckpt)["state_dict"] for ckpt in checkpoints]
weights = [{k.replace("model.", "") : v for k, v in wt.items()} for wt in weights]
models = [Net2D(wt) for wt in weights] 
ensemble = Ensemble(models).eval()

def predict_bone_age(Radiograph, Sex):
    img = torch.from_numpy(Radiograph)
    img = img.unsqueeze(0).unsqueeze(0)
    img = img / img.max()
    img = img - 0.5
    img = img * 2.0
    if Sex == 1:
        img = torch.cat([img, torch.zeros_like(img) + 1], dim=1)
    else:
        img = torch.cat([img, torch.zeros_like(img) - 1], dim=1)
    with torch.no_grad():
        bone_age = ensemble(img.float())[0].item()
    return f"Estimated Bone Age: {int(bone_age)} years, {int(bone_age % int(bone_age) * 12)} months"


image = gr.Image(shape=(512, 512), image_mode="L")
sex = gr.Radio(["Male", "Female"], type="index")
label = gr.Label(show_label=True, label="Result")

demo = gr.Interface(
    fn=predict_bone_age,
    inputs=[image, sex],
    outputs=label,
    )


if __name__ == "__main__":
    demo.launch()