import numpy as np import gradio as gr import tensorflow as tf import matplotlib.pyplot as plt from matplotlib import cm from PIL import Image import pandas as pd from dtaidistance import dtw def getDTWImage(IC_reference, sample, size): d, paths = dtw.warping_paths(IC_reference, sample, window=int(size/2), psi=2) x = np.array(paths) # mask values that are not filled x = np.where(x == np.inf, -99, x) # negative values are replaced by 0 x = np.where(x < 0, 0, x) # normalise values x = x/np.max(x) # reshape the array x = np.expand_dims(x, -1).astype("float32") return x data = np.load('./data_LFP.npy') model = tf.keras.models.load_model('./models/model-bestLFP_V2.h5',compile = False) def predict(Cell_number, Duty_Cycle, Cycle_number): # ------------------------ Prediction ------------------------ # select cell data # data = x_test_1 #if Cell_number == '1' else x_test_2 if Cell_number == '2' else x_test_3 # data_DTW = x_test_DTW_1 #if Cell_number == '1' else x_test_DTW_2 if Cell_number == '2' else x_test_DTW_3 # select cycle number cycle = 0 if Cycle_number == '10' else 1 if Cycle_number == '50' else 2 if Cycle_number == '100' else 3 if Cycle_number == '200'else 4 if Cycle_number == '400' else 5 IC_reference = data[0][0] sample = data[Duty_Cycle-1][cycle] sample_DTW = getDTWImage(IC_reference, sample, size) prediction = model.predict(np.expand_dims(sample_DTW, axis=0)) pred = {"LLI ": str(prediction[0][0]), "LAMPE ": str(prediction[0][1]), "LAMNE ": str(prediction[0][2])} # --------------------------- ICA + image---------------------------- d = {' ': np.linspace(1, len(ICA_reference), len(ICA_reference)), 'pristine': ICA_reference, 'degraded': sample} df = pd.DataFrame(data=d) image_array=sample_DTW.reshape(sample_DTW.shape[0], sample_DTW.shape[1]) image_array = normalise_data(image_array, np.min(image_array), np.max(image_array)) im = Image.fromarray(np.uint8(cm.inferno(image_array)*255)) return pred, df, im iface = gr.Interface( fn=predict, inputs=[gr.inputs.Radio(["Cell #1", "Cell #2", "Cell #3"]), gr.inputs.Slider(1, 1000, step=1), gr.inputs.Radio(["10", "50", "100", "200", "400", "1000"]), "checkbox"], title="LFP degradation diagnosis", description="Enter cell number, duty cycle and cycle number to predict the percentage of LLI, LAMPE and LAMNE", outputs=[gr.outputs.Label(label="Prediction"), gr.outputs.Timeseries(x=" ", y=["pristine", "degraded"]), gr.outputs.Image(type='pil', label="DTW image")], allow_screenshot=False, layout="unaligned") iface.launch(share=True)