Add mixin
Browse files- briarmbg.py +3 -1
- mixin.py +15 -0
briarmbg.py
CHANGED
|
@@ -2,6 +2,8 @@ import torch
|
|
| 2 |
import torch.nn as nn
|
| 3 |
import torch.nn.functional as F
|
| 4 |
|
|
|
|
|
|
|
| 5 |
class REBNCONV(nn.Module):
|
| 6 |
def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
|
| 7 |
super(REBNCONV,self).__init__()
|
|
@@ -344,7 +346,7 @@ class myrebnconv(nn.Module):
|
|
| 344 |
return self.rl(self.bn(self.conv(x)))
|
| 345 |
|
| 346 |
|
| 347 |
-
class BriaRMBG(nn.Module):
|
| 348 |
|
| 349 |
def __init__(self,in_ch=3,out_ch=1):
|
| 350 |
super(BriaRMBG,self).__init__()
|
|
|
|
| 2 |
import torch.nn as nn
|
| 3 |
import torch.nn.functional as F
|
| 4 |
|
| 5 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 6 |
+
|
| 7 |
class REBNCONV(nn.Module):
|
| 8 |
def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
|
| 9 |
super(REBNCONV,self).__init__()
|
|
|
|
| 346 |
return self.rl(self.bn(self.conv(x)))
|
| 347 |
|
| 348 |
|
| 349 |
+
class BriaRMBG(nn.Module, PyTorchModelHubMixin):
|
| 350 |
|
| 351 |
def __init__(self,in_ch=3,out_ch=1):
|
| 352 |
super(BriaRMBG,self).__init__()
|
mixin.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from briarmbg import BriaRMBG
|
| 2 |
+
import torch
|
| 3 |
+
from huggingface_hub import hf_hub_download
|
| 4 |
+
|
| 5 |
+
model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
|
| 6 |
+
|
| 7 |
+
net = BriaRMBG()
|
| 8 |
+
net.load_state_dict(torch.load(model_path, map_location="cpu"))
|
| 9 |
+
net.eval()
|
| 10 |
+
|
| 11 |
+
# push to hub
|
| 12 |
+
net.push_to_hub("nielsr/RMBG-1.4")
|
| 13 |
+
|
| 14 |
+
# reload
|
| 15 |
+
net = BriaRMBG.from_pretrained("nielsr/RMBG-1.4")
|