File size: 2,699 Bytes
a1306dd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import gradio as gr
import timm
import torch
import torch.nn as nn
def change_num_input_channels(model, in_channels=1):
"""
Assumes number of input channels in model is 3.
"""
for i, m in enumerate(model.modules()):
if isinstance(m, (nn.Conv2d,nn.Conv3d)) and m.in_channels == 3:
m.in_channels = in_channels
# First, sum across channels
W = m.weight.sum(1, keepdim=True)
# Then, divide by number of channels
W = W / in_channels
# Then, repeat by number of channels
size = [1] * W.ndim
size[1] = in_channels
W = W.repeat(size)
m.weight = nn.Parameter(W)
break
return model
class Net2D(nn.Module):
def __init__(self, weights):
super().__init__()
self.backbone = timm.create_model("tf_efficientnetv2_s", pretrained=False, global_pool="", num_classes=0)
self.backbone = change_num_input_channels(self.backbone, 2)
self.pool_layer = nn.AdaptiveAvgPool2d(1)
self.dropout = nn.Dropout(0.2)
self.classifier = nn.Linear(1280, 1)
self.load_state_dict(weights)
def forward(self, x):
x = self.backbone(x)
x = self.pool_layer(x).view(x.size(0), -1)
x = self.dropout(x)
x = self.classifier(x)
return x[:, 0] if x.size(1) == 1 else x
class Ensemble(nn.Module):
def __init__(self, model_list):
super().__init__()
self.model_list = nn.ModuleList(model_list)
def forward(self, x):
return torch.stack([model(x) for model in self.model_list]).mean(0)
checkpoints = ["fold0.ckpt", "fold1.ckpt", "fold2.ckpt"]
weights = [torch.load(ckpt)["state_dict"] for ckpt in checkpoints]
weights = [{k.replace("model.", "") : v for k, v in wt.items()} for wt in weights]
models = [Net2D(wt) for wt in weights]
ensemble = Ensemble(models).eval()
def predict_bone_age(Radiograph, Sex):
img = torch.from_numpy(Radiograph)
img = img.unsqueeze(0).unsqueeze(0)
img = img / img.max()
img = img - 0.5
img = img * 2.0
if Sex == 1:
img = torch.cat([img, torch.zeros_like(img) + 1], dim=1)
else:
img = torch.cat([img, torch.zeros_like(img) - 1], dim=1)
with torch.no_grad():
bone_age = ensemble(img.float())[0].item()
return f"Estimated Bone Age: {int(bone_age)} years, {int(bone_age % int(bone_age) * 12)} months"
image = gr.Image(shape=(512, 512), image_mode="L")
sex = gr.Radio(["Male", "Female"], type="index")
label = gr.Label(show_label=True, label="Result")
demo = gr.Interface(
fn=predict_bone_age,
inputs=[image, sex],
outputs=label,
)
if __name__ == "__main__":
demo.launch()
|