Hemaxi commited on
Commit
0d2a77e
·
verified ·
1 Parent(s): 01b25ea

Update prediction.py

Browse files
Files changed (1) hide show
  1. prediction.py +14 -10
prediction.py CHANGED
@@ -149,12 +149,12 @@ def both_joint_loss_function(y_true, y_pred):
149
  return detection_loss + pp_loss
150
 
151
 
 
 
 
 
152
  def load_old_model(model_file):
153
- import h5py
154
- from keras.models import load_model
155
- from keras.optimizers import Adam # or the relevant optimizer you expect
156
  print("Loading pre-trained model")
157
-
158
  custom_objects = {
159
  'mse': mse,
160
  'weighted_mean_se': weighted_mean_se,
@@ -174,18 +174,22 @@ def load_old_model(model_file):
174
  pass
175
 
176
  try:
177
- # Intercept and fix optimizer config before loading
178
- with h5py.File(model_file, 'r') as f:
179
  if 'training_config' in f.attrs:
180
- import json
181
- training_config = json.loads(f.attrs['training_config'].decode('utf-8'))
182
- optimizer_config = training_config['optimizer_config']
 
183
  config = optimizer_config.get('config', {})
184
 
185
- # Replace deprecated 'lr' with 'learning_rate'
186
  if 'lr' in config:
187
  config['learning_rate'] = config.pop('lr')
188
  optimizer_config['config'] = config
 
 
 
 
189
 
190
  return load_model(model_file, custom_objects=custom_objects)
191
 
 
149
  return detection_loss + pp_loss
150
 
151
 
152
+ import h5py
153
+ import json
154
+ from keras.models import load_model
155
+
156
  def load_old_model(model_file):
 
 
 
157
  print("Loading pre-trained model")
 
158
  custom_objects = {
159
  'mse': mse,
160
  'weighted_mean_se': weighted_mean_se,
 
174
  pass
175
 
176
  try:
177
+ # Patch 'lr' to 'learning_rate' if necessary
178
+ with h5py.File(model_file, 'r+') as f:
179
  if 'training_config' in f.attrs:
180
+ training_config_json = f.attrs['training_config']
181
+ training_config = json.loads(training_config_json) # <- FIXED
182
+
183
+ optimizer_config = training_config.get('optimizer_config', {})
184
  config = optimizer_config.get('config', {})
185
 
 
186
  if 'lr' in config:
187
  config['learning_rate'] = config.pop('lr')
188
  optimizer_config['config'] = config
189
+ training_config['optimizer_config'] = optimizer_config
190
+
191
+ # Update the HDF5 file with modified config
192
+ f.attrs.modify('training_config', json.dumps(training_config))
193
 
194
  return load_model(model_file, custom_objects=custom_objects)
195