Spaces:
Paused
Paused
import gradio as gr | |
import torch | |
import numpy as np | |
import nibabel as nib | |
from model import My3DModel | |
model = My3DModel() | |
model.load_state_dict(torch.load("model.pt", map_location="cpu")) | |
model.eval() | |
def predict_from_mri(file_obj): | |
img = nib.load(file_obj.name) | |
data = img.get_fdata().astype(np.float32) | |
data = np.expand_dims(data, axis=0) # (1, D, H, W) | |
tensor = torch.tensor(data).unsqueeze(0) # (1, 1, D, H, W) | |
with torch.no_grad(): | |
output = model(tensor) | |
pred_class = torch.argmax(output, dim=1).item() | |
return ["CN", "MCI", "AD"][pred_class] | |
iface = gr.Interface(fn=predict_from_mri, | |
inputs=gr.File(label="MRI .nii dosyası"), | |
outputs="text", | |
title="3D MRI Alzheimer Teşhis Modeli") | |
iface.launch() | |