Update prediction.py
Browse files- prediction.py +14 -10
prediction.py
CHANGED
@@ -149,12 +149,12 @@ def both_joint_loss_function(y_true, y_pred):
|
|
149 |
return detection_loss + pp_loss
|
150 |
|
151 |
|
|
|
|
|
|
|
|
|
152 |
def load_old_model(model_file):
|
153 |
-
import h5py
|
154 |
-
from keras.models import load_model
|
155 |
-
from keras.optimizers import Adam # or the relevant optimizer you expect
|
156 |
print("Loading pre-trained model")
|
157 |
-
|
158 |
custom_objects = {
|
159 |
'mse': mse,
|
160 |
'weighted_mean_se': weighted_mean_se,
|
@@ -174,18 +174,22 @@ def load_old_model(model_file):
|
|
174 |
pass
|
175 |
|
176 |
try:
|
177 |
-
#
|
178 |
-
with h5py.File(model_file, 'r') as f:
|
179 |
if 'training_config' in f.attrs:
|
180 |
-
|
181 |
-
training_config = json.loads(
|
182 |
-
|
|
|
183 |
config = optimizer_config.get('config', {})
|
184 |
|
185 |
-
# Replace deprecated 'lr' with 'learning_rate'
|
186 |
if 'lr' in config:
|
187 |
config['learning_rate'] = config.pop('lr')
|
188 |
optimizer_config['config'] = config
|
|
|
|
|
|
|
|
|
189 |
|
190 |
return load_model(model_file, custom_objects=custom_objects)
|
191 |
|
|
|
149 |
return detection_loss + pp_loss
|
150 |
|
151 |
|
152 |
+
import h5py
|
153 |
+
import json
|
154 |
+
from keras.models import load_model
|
155 |
+
|
156 |
def load_old_model(model_file):
|
|
|
|
|
|
|
157 |
print("Loading pre-trained model")
|
|
|
158 |
custom_objects = {
|
159 |
'mse': mse,
|
160 |
'weighted_mean_se': weighted_mean_se,
|
|
|
174 |
pass
|
175 |
|
176 |
try:
|
177 |
+
# Patch 'lr' to 'learning_rate' if necessary
|
178 |
+
with h5py.File(model_file, 'r+') as f:
|
179 |
if 'training_config' in f.attrs:
|
180 |
+
training_config_json = f.attrs['training_config']
|
181 |
+
training_config = json.loads(training_config_json) # <- FIXED
|
182 |
+
|
183 |
+
optimizer_config = training_config.get('optimizer_config', {})
|
184 |
config = optimizer_config.get('config', {})
|
185 |
|
|
|
186 |
if 'lr' in config:
|
187 |
config['learning_rate'] = config.pop('lr')
|
188 |
optimizer_config['config'] = config
|
189 |
+
training_config['optimizer_config'] = optimizer_config
|
190 |
+
|
191 |
+
# Update the HDF5 file with modified config
|
192 |
+
f.attrs.modify('training_config', json.dumps(training_config))
|
193 |
|
194 |
return load_model(model_file, custom_objects=custom_objects)
|
195 |
|