Upload model
Browse files- modeling_basnet.py +18 -19
modeling_basnet.py
CHANGED
|
@@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
|
|
| 14 |
|
| 15 |
|
| 16 |
@dataclass
|
| 17 |
-
class
|
| 18 |
dout: torch.Tensor
|
| 19 |
d1: Optional[torch.Tensor] = None
|
| 20 |
d2: Optional[torch.Tensor] = None
|
|
@@ -25,6 +25,11 @@ class BASNetModelOutput(ModelOutput):
|
|
| 25 |
db: Optional[torch.Tensor] = None
|
| 26 |
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
class RefUnet(nn.Module):
|
| 29 |
def __init__(self, in_ch: int, inc_ch: int) -> None:
|
| 30 |
super().__init__()
|
|
@@ -466,27 +471,21 @@ class BASNetModel(PreTrainedModel):
|
|
| 466 |
d6_act = torch.sigmoid(d6)
|
| 467 |
db_act = torch.sigmoid(db)
|
| 468 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 469 |
if not return_dict:
|
| 470 |
-
return (
|
| 471 |
-
dout_act,
|
| 472 |
-
d1_act,
|
| 473 |
-
d2_act,
|
| 474 |
-
d3_act,
|
| 475 |
-
d4_act,
|
| 476 |
-
d5_act,
|
| 477 |
-
d6_act,
|
| 478 |
-
db_act,
|
| 479 |
-
)
|
| 480 |
|
| 481 |
return BASNetModelOutput(
|
| 482 |
-
|
| 483 |
-
d1=d1_act,
|
| 484 |
-
d2=d2_act,
|
| 485 |
-
d3=d3_act,
|
| 486 |
-
d4=d4_act,
|
| 487 |
-
d5=d5_act,
|
| 488 |
-
d6=d6_act,
|
| 489 |
-
db=db_act,
|
| 490 |
)
|
| 491 |
|
| 492 |
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
@dataclass
|
| 17 |
+
class BasNetSideOutput(ModelOutput):
|
| 18 |
dout: torch.Tensor
|
| 19 |
d1: Optional[torch.Tensor] = None
|
| 20 |
d2: Optional[torch.Tensor] = None
|
|
|
|
| 25 |
db: Optional[torch.Tensor] = None
|
| 26 |
|
| 27 |
|
| 28 |
+
@dataclass
|
| 29 |
+
class BASNetModelOutput(ModelOutput):
|
| 30 |
+
activated: BasNetSideOutput
|
| 31 |
+
|
| 32 |
+
|
| 33 |
class RefUnet(nn.Module):
|
| 34 |
def __init__(self, in_ch: int, inc_ch: int) -> None:
|
| 35 |
super().__init__()
|
|
|
|
| 471 |
d6_act = torch.sigmoid(d6)
|
| 472 |
db_act = torch.sigmoid(db)
|
| 473 |
|
| 474 |
+
side_outputs = (
|
| 475 |
+
dout_act,
|
| 476 |
+
d1_act,
|
| 477 |
+
d2_act,
|
| 478 |
+
d3_act,
|
| 479 |
+
d4_act,
|
| 480 |
+
d5_act,
|
| 481 |
+
d6_act,
|
| 482 |
+
db_act,
|
| 483 |
+
)
|
| 484 |
if not return_dict:
|
| 485 |
+
return (side_outputs,)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 486 |
|
| 487 |
return BASNetModelOutput(
|
| 488 |
+
activated=BasNetSideOutput(*side_outputs),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 489 |
)
|
| 490 |
|
| 491 |
|