Spaces:
Runtime error
Runtime error
Commit
·
1329cff
1
Parent(s):
8a369ed
Update app.py
Browse files
app.py
CHANGED
@@ -181,12 +181,16 @@ def train_main(train_loader, test_loader, num_epochs, optimizer, model, device='
|
|
181 |
train_accs.append(train_acc)
|
182 |
test_accs.append(test_acc)
|
183 |
print('Epoch: {}, train loss: {:.4f}, test loss: {:.4f}, train accuracy: {:.2f}%, test accuracy: {:.2f}%'.format(epoch+1, train_loss, test_loss, train_acc, test_acc))
|
184 |
-
|
|
|
|
|
185 |
# save the results to a text file
|
|
|
|
|
186 |
with open("results.txt", "w") as f:
|
187 |
for epoch in range(num_epochs):
|
188 |
f.write("Epoch: {}, train loss: {:.4f}, test loss: {:.4f}, train accuracy: {:.2f}%, test accuracy: {:.2f}%\n".format(epoch+1, train_losses[epoch], test_losses[epoch], train_accs[epoch], test_accs[epoch]))
|
189 |
-
|
190 |
# plot the loss and accuracy curves side by side
|
191 |
fig, axs = plt.subplots(1, 2, figsize=(12, 4))
|
192 |
axs[0].plot(train_losses, label='Train Loss')
|
@@ -200,8 +204,7 @@ def train_main(train_loader, test_loader, num_epochs, optimizer, model, device='
|
|
200 |
axs[1].set_ylabel('Accuracy')
|
201 |
axs[1].legend()
|
202 |
plt.savefig('loss_and_accuracy.png')
|
203 |
-
|
204 |
-
plt.show()
|
205 |
|
206 |
num_epochs=3
|
207 |
train_main(train_loader, test_loader, num_epochs, optimizer, model, device)
|
|
|
181 |
train_accs.append(train_acc)
|
182 |
test_accs.append(test_acc)
|
183 |
print('Epoch: {}, train loss: {:.4f}, test loss: {:.4f}, train accuracy: {:.2f}%, test accuracy: {:.2f}%'.format(epoch+1, train_loss, test_loss, train_acc, test_acc))
|
184 |
+
print("Saving the model")
|
185 |
+
torch.save(model, 'trained_model.pt')
|
186 |
+
print("Model saved successfully")
|
187 |
# save the results to a text file
|
188 |
+
|
189 |
+
print("Saving training logs")
|
190 |
with open("results.txt", "w") as f:
|
191 |
for epoch in range(num_epochs):
|
192 |
f.write("Epoch: {}, train loss: {:.4f}, test loss: {:.4f}, train accuracy: {:.2f}%, test accuracy: {:.2f}%\n".format(epoch+1, train_losses[epoch], test_losses[epoch], train_accs[epoch], test_accs[epoch]))
|
193 |
+
print("Logs saved")
|
194 |
# plot the loss and accuracy curves side by side
|
195 |
fig, axs = plt.subplots(1, 2, figsize=(12, 4))
|
196 |
axs[0].plot(train_losses, label='Train Loss')
|
|
|
204 |
axs[1].set_ylabel('Accuracy')
|
205 |
axs[1].legend()
|
206 |
plt.savefig('loss_and_accuracy.png')
|
207 |
+
|
|
|
208 |
|
209 |
num_epochs=3
|
210 |
train_main(train_loader, test_loader, num_epochs, optimizer, model, device)
|