edyoshikun commited on
Commit
f3ebb52
·
1 Parent(s): c6e4ba9

disabling hf auto install that overrides gradio version update

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. app.py +7 -39
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🌈
4
  colorFrom: green
5
  colorTo: pink
6
  sdk: gradio
7
- sdk_version: 4.44.0
8
  app_file: app.py
9
  pinned: true
10
  license: bsd-3-clause
 
4
  colorFrom: green
5
  colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 5.27.1
8
  app_file: app.py
9
  pinned: true
10
  license: bsd-3-clause
app.py CHANGED
@@ -28,7 +28,7 @@ class VSGradio:
28
  architecture="UNeXt2_2D",
29
  model_config=self.model_config,
30
  )
31
- self.model.to(self.device) # Move the model to the correct device (GPU/CPU)
32
  self.model.eval()
33
  print("Model loaded successfully and set to evaluation mode")
34
  except Exception as e:
@@ -42,7 +42,6 @@ class VSGradio:
42
  return (input - mean) / std
43
 
44
  def preprocess_image_standard(self, input: ArrayLike):
45
- # Perform standard preprocessing here
46
  input = exposure.equalize_adapthist(input)
47
  return input
48
 
@@ -62,19 +61,16 @@ class VSGradio:
62
  # Normalize the input and convert to tensor
63
  inp = self.normalize_fov(inp)
64
  original_shape = inp.shape
65
- # Resize the input image to the expected cell diameter
66
  inp = apply_rescale_image(inp, scaling_factor)
67
 
68
  # Convert the input to a tensor
69
  inp = torch.from_numpy(np.array(inp).astype(np.float32))
70
 
71
- # Prepare the input dictionary and move input to the correct device (GPU or CPU)
72
  test_dict = dict(
73
  index=None,
74
  source=inp.unsqueeze(0).unsqueeze(0).unsqueeze(0).to(self.device),
75
  )
76
 
77
- # Run model inference
78
  with torch.inference_mode():
79
  self.model.on_predict_start() # Necessary preprocessing for the model
80
  pred = (
@@ -89,18 +85,15 @@ class VSGradio:
89
  nuc_pred = resize(nuc_pred, original_shape, anti_aliasing=True)
90
  mem_pred = resize(mem_pred, original_shape, anti_aliasing=True)
91
 
92
- # Define colormaps
93
- green_colormap = cmap.Colormap("green") # Nucleus: black to green
94
  magenta_colormap = cmap.Colormap("magenta")
95
 
96
- # Apply the colormap to the predictions
97
  nuc_rgb = apply_colormap(nuc_pred, green_colormap)
98
  mem_rgb = apply_colormap(mem_pred, magenta_colormap)
99
 
100
- return nuc_rgb, mem_rgb # Return both nucleus and membrane images
101
  except Exception as e:
102
  print(f"Error during prediction: {e}")
103
- # Return empty images of the right shape and type in case of error
104
  empty_img = np.zeros((300, 300, 3), dtype=np.uint8)
105
  return empty_img, empty_img
106
 
@@ -109,13 +102,8 @@ def apply_colormap(prediction, colormap: cmap.Colormap):
109
  """Apply a colormap to a single-channel prediction image."""
110
  # Ensure the prediction is within the valid range [0, 1]
111
  prediction = exposure.rescale_intensity(prediction, out_range=(0, 1))
112
-
113
- # Apply the colormap to get an RGB image
114
  rgb_image = colormap(prediction)
115
-
116
- # Convert the output from [0, 1] to [0, 255] for display
117
  rgb_image_uint8 = (rgb_image * 255).astype(np.uint8)
118
-
119
  return rgb_image_uint8
120
 
121
 
@@ -125,53 +113,38 @@ def merge_images(nuc_rgb: ArrayLike, mem_rgb: ArrayLike) -> ArrayLike:
125
 
126
 
127
  def apply_image_adjustments(image, invert_image: bool, gamma_factor: float):
128
- """Applies all the image adjustments (invert, contrast, gamma) in sequence"""
129
- # Apply invert
130
  if invert_image:
131
  image = invert(image, signed_float=False)
132
-
133
- # Apply gamma adjustment
134
  image = exposure.adjust_gamma(image, gamma_factor)
135
-
136
  return exposure.rescale_intensity(image, out_range=(0, 255)).astype(np.uint8)
137
 
138
 
139
  def apply_rescale_image(image, scaling_factor: float):
140
- """Resize the input image according to the scaling factor"""
141
  scaling_factor = float(scaling_factor)
142
- image = resize(
143
  image,
144
  (int(image.shape[0] * scaling_factor), int(image.shape[1] * scaling_factor)),
145
  anti_aliasing=True,
146
  )
147
- return image
148
 
149
 
150
- # Function to clear outputs when a new image is uploaded
151
  def clear_outputs(image):
152
- return (
153
- image,
154
- None,
155
- None,
156
- ) # Return None for adjusted_image, output_nucleus, and output_membrane
157
 
158
 
159
  def load_css(file_path):
160
- """Load custom CSS"""
161
  with open(file_path, "r") as file:
162
  return file.read()
163
 
164
 
165
  if __name__ == "__main__":
166
  try:
167
- # Download the model checkpoint from Hugging Face
168
  print("Downloading model checkpoint...")
169
  model_ckpt_path = hf_hub_download(
170
  repo_id="compmicro-czb/VSCyto2D", filename="epoch=399-step=23200.ckpt"
171
  )
172
  print(f"Model downloaded successfully to: {model_ckpt_path}")
173
 
174
- # Model configuration
175
  model_config = {
176
  "in_channels": 1,
177
  "out_channels": 2,
@@ -241,10 +214,7 @@ if __name__ == "__main__":
241
  visible=False,
242
  )
243
 
244
- # Checkbox for applying invert
245
  preprocess_invert = gr.Checkbox(label="Invert Image", value=False)
246
-
247
- # Slider for gamma adjustment
248
  gamma_factor = gr.Slider(
249
  label="Adjust Gamma", minimum=0.01, maximum=5.0, value=1.0, step=0.1
250
  )
@@ -328,14 +298,13 @@ if __name__ == "__main__":
328
  output_membrane,
329
  ],
330
  )
331
- # Clear everything when the input image changes
332
  input_image.change(
333
  fn=clear_outputs,
334
  inputs=input_image,
335
  outputs=[adjusted_image, output_nucleus, output_membrane],
336
  )
337
 
338
- # Function to handle merging the two predictions after they are shown
339
  def merge_predictions_fn(nucleus_image, membrane_image, merge):
340
  if merge:
341
  merged = merge_images(nucleus_image, membrane_image)
@@ -353,7 +322,6 @@ if __name__ == "__main__":
353
  gr.update(visible=True),
354
  )
355
 
356
- # Toggle between merged and separate views when the checkbox is checked
357
  merge_checkbox.change(
358
  fn=merge_predictions_fn,
359
  inputs=[output_nucleus, output_membrane, merge_checkbox],
@@ -435,8 +403,8 @@ if __name__ == "__main__":
435
  </div>
436
  """
437
  )
 
438
 
439
  # Launch the Gradio app
440
- demo.launch(server_name="0.0.0.0", share=False)
441
  except Exception as e:
442
  print(f"Error initializing VSGradio: {e}")
 
28
  architecture="UNeXt2_2D",
29
  model_config=self.model_config,
30
  )
31
+ self.model.to(self.device)
32
  self.model.eval()
33
  print("Model loaded successfully and set to evaluation mode")
34
  except Exception as e:
 
42
  return (input - mean) / std
43
 
44
  def preprocess_image_standard(self, input: ArrayLike):
 
45
  input = exposure.equalize_adapthist(input)
46
  return input
47
 
 
61
  # Normalize the input and convert to tensor
62
  inp = self.normalize_fov(inp)
63
  original_shape = inp.shape
 
64
  inp = apply_rescale_image(inp, scaling_factor)
65
 
66
  # Convert the input to a tensor
67
  inp = torch.from_numpy(np.array(inp).astype(np.float32))
68
 
 
69
  test_dict = dict(
70
  index=None,
71
  source=inp.unsqueeze(0).unsqueeze(0).unsqueeze(0).to(self.device),
72
  )
73
 
 
74
  with torch.inference_mode():
75
  self.model.on_predict_start() # Necessary preprocessing for the model
76
  pred = (
 
85
  nuc_pred = resize(nuc_pred, original_shape, anti_aliasing=True)
86
  mem_pred = resize(mem_pred, original_shape, anti_aliasing=True)
87
 
88
+ green_colormap = cmap.Colormap("green")
 
89
  magenta_colormap = cmap.Colormap("magenta")
90
 
 
91
  nuc_rgb = apply_colormap(nuc_pred, green_colormap)
92
  mem_rgb = apply_colormap(mem_pred, magenta_colormap)
93
 
94
+ return nuc_rgb, mem_rgb
95
  except Exception as e:
96
  print(f"Error during prediction: {e}")
 
97
  empty_img = np.zeros((300, 300, 3), dtype=np.uint8)
98
  return empty_img, empty_img
99
 
 
102
  """Apply a colormap to a single-channel prediction image."""
103
  # Ensure the prediction is within the valid range [0, 1]
104
  prediction = exposure.rescale_intensity(prediction, out_range=(0, 1))
 
 
105
  rgb_image = colormap(prediction)
 
 
106
  rgb_image_uint8 = (rgb_image * 255).astype(np.uint8)
 
107
  return rgb_image_uint8
108
 
109
 
 
113
 
114
 
115
  def apply_image_adjustments(image, invert_image: bool, gamma_factor: float):
 
 
116
  if invert_image:
117
  image = invert(image, signed_float=False)
 
 
118
  image = exposure.adjust_gamma(image, gamma_factor)
 
119
  return exposure.rescale_intensity(image, out_range=(0, 255)).astype(np.uint8)
120
 
121
 
122
  def apply_rescale_image(image, scaling_factor: float):
 
123
  scaling_factor = float(scaling_factor)
124
+ return resize(
125
  image,
126
  (int(image.shape[0] * scaling_factor), int(image.shape[1] * scaling_factor)),
127
  anti_aliasing=True,
128
  )
 
129
 
130
 
 
131
  def clear_outputs(image):
132
+ return image, None, None
 
 
 
 
133
 
134
 
135
  def load_css(file_path):
 
136
  with open(file_path, "r") as file:
137
  return file.read()
138
 
139
 
140
  if __name__ == "__main__":
141
  try:
 
142
  print("Downloading model checkpoint...")
143
  model_ckpt_path = hf_hub_download(
144
  repo_id="compmicro-czb/VSCyto2D", filename="epoch=399-step=23200.ckpt"
145
  )
146
  print(f"Model downloaded successfully to: {model_ckpt_path}")
147
 
 
148
  model_config = {
149
  "in_channels": 1,
150
  "out_channels": 2,
 
214
  visible=False,
215
  )
216
 
 
217
  preprocess_invert = gr.Checkbox(label="Invert Image", value=False)
 
 
218
  gamma_factor = gr.Slider(
219
  label="Adjust Gamma", minimum=0.01, maximum=5.0, value=1.0, step=0.1
220
  )
 
298
  output_membrane,
299
  ],
300
  )
301
+
302
  input_image.change(
303
  fn=clear_outputs,
304
  inputs=input_image,
305
  outputs=[adjusted_image, output_nucleus, output_membrane],
306
  )
307
 
 
308
  def merge_predictions_fn(nucleus_image, membrane_image, merge):
309
  if merge:
310
  merged = merge_images(nucleus_image, membrane_image)
 
322
  gr.update(visible=True),
323
  )
324
 
 
325
  merge_checkbox.change(
326
  fn=merge_predictions_fn,
327
  inputs=[output_nucleus, output_membrane, merge_checkbox],
 
403
  </div>
404
  """
405
  )
406
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
407
 
408
  # Launch the Gradio app
 
409
  except Exception as e:
410
  print(f"Error initializing VSGradio: {e}")