Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -31,8 +31,9 @@ global_model = None
|
|
| 31 |
def load_model():
|
| 32 |
"""Load the model at startup"""
|
| 33 |
global global_model
|
|
|
|
| 34 |
try:
|
| 35 |
-
checkpoint = torch.load(
|
| 36 |
model = big_UNet() if checkpoint['model_config']['big'] else small_UNet()
|
| 37 |
model.load_state_dict(checkpoint['model_state_dict'])
|
| 38 |
model.to(device)
|
|
@@ -86,7 +87,7 @@ class UNetWrapper:
|
|
| 86 |
}
|
| 87 |
|
| 88 |
# Save model locally
|
| 89 |
-
pth_name = '
|
| 90 |
torch.save(save_dict, pth_name)
|
| 91 |
|
| 92 |
# Create repo if it doesn't exist
|
|
@@ -115,14 +116,20 @@ tags:
|
|
| 115 |
- pix2pix
|
| 116 |
- pytorch
|
| 117 |
library_name: pytorch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
---
|
| 119 |
|
| 120 |
# Pix2Pix UNet Model
|
| 121 |
|
| 122 |
## Model Description
|
| 123 |
Custom UNet model for Pix2Pix image translation.
|
| 124 |
-
- **Image Size:**
|
| 125 |
-
- **Model Type:**
|
| 126 |
|
| 127 |
## Usage
|
| 128 |
|
|
@@ -130,9 +137,10 @@ Custom UNet model for Pix2Pix image translation.
|
|
| 130 |
import torch
|
| 131 |
from small_256_model import UNet as small_UNet
|
| 132 |
from big_1024_model import UNet as big_UNet
|
| 133 |
-
|
| 134 |
# Load the model
|
| 135 |
-
|
|
|
|
| 136 |
model = big_UNet() if checkpoint['model_config']['big'] else small_UNet()
|
| 137 |
model.load_state_dict(checkpoint['model_state_dict'])
|
| 138 |
model.eval()
|
|
|
|
| 31 |
def load_model():
|
| 32 |
"""Load the model at startup"""
|
| 33 |
global global_model
|
| 34 |
+
weights_name = 'big_model_weights.pth' if big else 'small_model_weights.pth'
|
| 35 |
try:
|
| 36 |
+
checkpoint = torch.load(weights_name, map_location=device)
|
| 37 |
model = big_UNet() if checkpoint['model_config']['big'] else small_UNet()
|
| 38 |
model.load_state_dict(checkpoint['model_state_dict'])
|
| 39 |
model.to(device)
|
|
|
|
| 87 |
}
|
| 88 |
|
| 89 |
# Save model locally
|
| 90 |
+
pth_name = 'big_model_weights.pth' if big else 'small_model_weights.pth'
|
| 91 |
torch.save(save_dict, pth_name)
|
| 92 |
|
| 93 |
# Create repo if it doesn't exist
|
|
|
|
| 116 |
- pix2pix
|
| 117 |
- pytorch
|
| 118 |
library_name: pytorch
|
| 119 |
+
license: wtfpl
|
| 120 |
+
datasets:
|
| 121 |
+
- K00B404/pix2pix_flux_set
|
| 122 |
+
language:
|
| 123 |
+
- en
|
| 124 |
+
pipeline_tag: image-to-image
|
| 125 |
---
|
| 126 |
|
| 127 |
# Pix2Pix UNet Model
|
| 128 |
|
| 129 |
## Model Description
|
| 130 |
Custom UNet model for Pix2Pix image translation.
|
| 131 |
+
- **Image Size:** 1024
|
| 132 |
+
- **Model Type:** Big (1024)
|
| 133 |
|
| 134 |
## Usage
|
| 135 |
|
|
|
|
| 137 |
import torch
|
| 138 |
from small_256_model import UNet as small_UNet
|
| 139 |
from big_1024_model import UNet as big_UNet
|
| 140 |
+
big = True
|
| 141 |
# Load the model
|
| 142 |
+
name='big_model_weights.pth' if big else 'small_model_weights.pth'
|
| 143 |
+
checkpoint = torch.load(name)
|
| 144 |
model = big_UNet() if checkpoint['model_config']['big'] else small_UNet()
|
| 145 |
model.load_state_dict(checkpoint['model_state_dict'])
|
| 146 |
model.eval()
|