Upload 5 files
Browse files- swin2_mose/model.py +1157 -0
- swin2_mose/moe.py +323 -0
- swin2_mose/run.py +20 -0
- swin2_mose/utils.py +56 -0
- swin2_mose/weights/model-70.pt +3 -0
    	
        swin2_mose/model.py
    ADDED
    
    | @@ -0,0 +1,1157 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #
         | 
| 2 | 
            +
            # Source code: https://github.com/mv-lab/swin2sr
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # -----------------------------------------------------------------------------------
         | 
| 5 | 
            +
            # Swin2SR: Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration, https://arxiv.org/abs/2209.11345
         | 
| 6 | 
            +
            # Written by Conde and Choi et al.
         | 
| 7 | 
            +
            # -----------------------------------------------------------------------------------
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import math
         | 
| 10 | 
            +
            import numpy as np
         | 
| 11 | 
            +
            import torch
         | 
| 12 | 
            +
            import torch.nn as nn
         | 
| 13 | 
            +
            import torch.nn.functional as F
         | 
| 14 | 
            +
            import torch.utils.checkpoint as checkpoint
         | 
| 15 | 
            +
            from timm.models.layers import DropPath, to_2tuple, trunc_normal_
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            from utils import window_reverse, Mlp, window_partition
         | 
| 18 | 
            +
            from moe import MoE
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            class WindowAttention(nn.Module):
         | 
| 22 | 
            +
                r""" Window based multi-head self attention (W-MSA) module with relative position bias.
         | 
| 23 | 
            +
                It supports both of shifted and non-shifted window.
         | 
| 24 | 
            +
                Args:
         | 
| 25 | 
            +
                    dim (int): Number of input channels.
         | 
| 26 | 
            +
                    window_size (tuple[int]): The height and width of the window.
         | 
| 27 | 
            +
                    num_heads (int): Number of attention heads.
         | 
| 28 | 
            +
                    qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
         | 
| 29 | 
            +
                    attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
         | 
| 30 | 
            +
                    proj_drop (float, optional): Dropout ratio of output. Default: 0.0
         | 
| 31 | 
            +
                    pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
         | 
| 32 | 
            +
                """
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
         | 
| 35 | 
            +
                             pretrained_window_size=[0, 0],
         | 
| 36 | 
            +
                             use_lepe=False,
         | 
| 37 | 
            +
                             use_cpb_bias=True,
         | 
| 38 | 
            +
                             use_rpe_bias=False):
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                    super().__init__()
         | 
| 41 | 
            +
                    self.dim = dim
         | 
| 42 | 
            +
                    self.window_size = window_size  # Wh, Ww
         | 
| 43 | 
            +
                    self.pretrained_window_size = pretrained_window_size
         | 
| 44 | 
            +
                    self.num_heads = num_heads
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                    self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                    self.use_cpb_bias = use_cpb_bias
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                    if self.use_cpb_bias:
         | 
| 51 | 
            +
                        print('positional encoder: CPB')
         | 
| 52 | 
            +
                        # mlp to generate continuous relative position bias
         | 
| 53 | 
            +
                        self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
         | 
| 54 | 
            +
                                                     nn.ReLU(inplace=True),
         | 
| 55 | 
            +
                                                     nn.Linear(512, num_heads, bias=False))
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                        # get relative_coords_table
         | 
| 58 | 
            +
                        relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
         | 
| 59 | 
            +
                        relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
         | 
| 60 | 
            +
                        relative_coords_table = torch.stack(
         | 
| 61 | 
            +
                            torch.meshgrid([relative_coords_h,
         | 
| 62 | 
            +
                                            relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0)  # 1, 2*Wh-1, 2*Ww-1, 2
         | 
| 63 | 
            +
                        if pretrained_window_size[0] > 0:
         | 
| 64 | 
            +
                            relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
         | 
| 65 | 
            +
                            relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
         | 
| 66 | 
            +
                        else:
         | 
| 67 | 
            +
                            relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
         | 
| 68 | 
            +
                            relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
         | 
| 69 | 
            +
                        relative_coords_table *= 8  # normalize to -8, 8
         | 
| 70 | 
            +
                        relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
         | 
| 71 | 
            +
                            torch.abs(relative_coords_table) + 1.0) / np.log2(8)
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                        self.register_buffer("relative_coords_table", relative_coords_table)
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                        # get pair-wise relative position index for each token inside the window
         | 
| 76 | 
            +
                        coords_h = torch.arange(self.window_size[0])
         | 
| 77 | 
            +
                        coords_w = torch.arange(self.window_size[1])
         | 
| 78 | 
            +
                        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
         | 
| 79 | 
            +
                        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
         | 
| 80 | 
            +
                        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
         | 
| 81 | 
            +
                        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
         | 
| 82 | 
            +
                        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
         | 
| 83 | 
            +
                        relative_coords[:, :, 1] += self.window_size[1] - 1
         | 
| 84 | 
            +
                        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
         | 
| 85 | 
            +
                        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
         | 
| 86 | 
            +
                        self.register_buffer("relative_position_index", relative_position_index)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    self.use_rpe_bias = use_rpe_bias
         | 
| 89 | 
            +
                    if self.use_rpe_bias:
         | 
| 90 | 
            +
                        print('positional encoder: RPE')
         | 
| 91 | 
            +
                        # define a parameter table of relative position bias
         | 
| 92 | 
            +
                        self.relative_position_bias_table = nn.Parameter(
         | 
| 93 | 
            +
                            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                        # get pair-wise relative position index for each token inside the window
         | 
| 96 | 
            +
                        coords_h = torch.arange(self.window_size[0])
         | 
| 97 | 
            +
                        coords_w = torch.arange(self.window_size[1])
         | 
| 98 | 
            +
                        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
         | 
| 99 | 
            +
                        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
         | 
| 100 | 
            +
                        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
         | 
| 101 | 
            +
                        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
         | 
| 102 | 
            +
                        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
         | 
| 103 | 
            +
                        relative_coords[:, :, 1] += self.window_size[1] - 1
         | 
| 104 | 
            +
                        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
         | 
| 105 | 
            +
                        rpe_relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
         | 
| 106 | 
            +
                        self.register_buffer("rpe_relative_position_index", rpe_relative_position_index)
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                        trunc_normal_(self.relative_position_bias_table, std=.02)
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                    self.qkv = nn.Linear(dim, dim * 3, bias=False)
         | 
| 111 | 
            +
                    if qkv_bias:
         | 
| 112 | 
            +
                        self.q_bias = nn.Parameter(torch.zeros(dim))
         | 
| 113 | 
            +
                        self.v_bias = nn.Parameter(torch.zeros(dim))
         | 
| 114 | 
            +
                    else:
         | 
| 115 | 
            +
                        self.q_bias = None
         | 
| 116 | 
            +
                        self.v_bias = None
         | 
| 117 | 
            +
                    self.attn_drop = nn.Dropout(attn_drop)
         | 
| 118 | 
            +
                    self.proj = nn.Linear(dim, dim)
         | 
| 119 | 
            +
                    self.proj_drop = nn.Dropout(proj_drop)
         | 
| 120 | 
            +
                    self.softmax = nn.Softmax(dim=-1)
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                    self.use_lepe = use_lepe
         | 
| 123 | 
            +
                    if self.use_lepe:
         | 
| 124 | 
            +
                        print('positional encoder: LEPE')
         | 
| 125 | 
            +
                        self.get_v = nn.Conv2d(
         | 
| 126 | 
            +
                            dim, dim, kernel_size=3, stride=1, padding=1, groups=dim)
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                def forward(self, x, mask=None):
         | 
| 129 | 
            +
                    """
         | 
| 130 | 
            +
                    Args:
         | 
| 131 | 
            +
                        x: input features with shape of (num_windows*B, N, C)
         | 
| 132 | 
            +
                        mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
         | 
| 133 | 
            +
                    """
         | 
| 134 | 
            +
                    B_, N, C = x.shape
         | 
| 135 | 
            +
                    qkv_bias = None
         | 
| 136 | 
            +
                    if self.q_bias is not None:
         | 
| 137 | 
            +
                        qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
         | 
| 138 | 
            +
                    qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
         | 
| 139 | 
            +
                    qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
         | 
| 140 | 
            +
                    q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                    if self.use_lepe:
         | 
| 143 | 
            +
                        lepe = self.lepe_pos(v)
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    # cosine attention
         | 
| 146 | 
            +
                    attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
         | 
| 147 | 
            +
                    logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale.device)).exp()
         | 
| 148 | 
            +
                    attn = attn * logit_scale
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                    if self.use_cpb_bias:
         | 
| 151 | 
            +
                        relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
         | 
| 152 | 
            +
                        relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
         | 
| 153 | 
            +
                            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
         | 
| 154 | 
            +
                        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
         | 
| 155 | 
            +
                        relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
         | 
| 156 | 
            +
                        attn = attn + relative_position_bias.unsqueeze(0)
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                    if self.use_rpe_bias:
         | 
| 159 | 
            +
                        relative_position_bias = self.relative_position_bias_table[self.rpe_relative_position_index.view(-1)].view(
         | 
| 160 | 
            +
                            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
         | 
| 161 | 
            +
                        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
         | 
| 162 | 
            +
                        attn = attn + relative_position_bias.unsqueeze(0)
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                    if mask is not None:
         | 
| 165 | 
            +
                        nW = mask.shape[0]
         | 
| 166 | 
            +
                        attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
         | 
| 167 | 
            +
                        attn = attn.view(-1, self.num_heads, N, N)
         | 
| 168 | 
            +
                        attn = self.softmax(attn)
         | 
| 169 | 
            +
                    else:
         | 
| 170 | 
            +
                        attn = self.softmax(attn)
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                    attn = self.attn_drop(attn)
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                    x = (attn @ v)
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                    if self.use_lepe:
         | 
| 177 | 
            +
                        x = x + lepe
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                    x = x.transpose(1, 2).reshape(B_, N, C)
         | 
| 180 | 
            +
                    x = self.proj(x)
         | 
| 181 | 
            +
                    x = self.proj_drop(x)
         | 
| 182 | 
            +
                    return x
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                def lepe_pos(self, v):
         | 
| 185 | 
            +
                    B, NH, HW, NW = v.shape
         | 
| 186 | 
            +
                    C = NH * NW
         | 
| 187 | 
            +
                    H = W = int(math.sqrt(HW))
         | 
| 188 | 
            +
                    v = v.transpose(-2, -1).contiguous().view(B, C, H, W)
         | 
| 189 | 
            +
                    lepe = self.get_v(v)
         | 
| 190 | 
            +
                    lepe = lepe.reshape(-1, self.num_heads, NW, HW)
         | 
| 191 | 
            +
                    lepe = lepe.permute(0, 1, 3, 2).contiguous()
         | 
| 192 | 
            +
                    return lepe
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                def extra_repr(self) -> str:
         | 
| 195 | 
            +
                    return f'dim={self.dim}, window_size={self.window_size}, ' \
         | 
| 196 | 
            +
                           f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                def flops(self, N):
         | 
| 199 | 
            +
                    # calculate flops for 1 window with token length of N
         | 
| 200 | 
            +
                    flops = 0
         | 
| 201 | 
            +
                    # qkv = self.qkv(x)
         | 
| 202 | 
            +
                    flops += N * self.dim * 3 * self.dim
         | 
| 203 | 
            +
                    # attn = (q @ k.transpose(-2, -1))
         | 
| 204 | 
            +
                    flops += self.num_heads * N * (self.dim // self.num_heads) * N
         | 
| 205 | 
            +
                    #  x = (attn @ v)
         | 
| 206 | 
            +
                    flops += self.num_heads * N * N * (self.dim // self.num_heads)
         | 
| 207 | 
            +
                    # x = self.proj(x)
         | 
| 208 | 
            +
                    flops += N * self.dim * self.dim
         | 
| 209 | 
            +
                    return flops
         | 
| 210 | 
            +
             | 
| 211 | 
            +
             | 
| 212 | 
            +
            class SwinTransformerBlock(nn.Module):
         | 
| 213 | 
            +
                r""" Swin Transformer Block.
         | 
| 214 | 
            +
                Args:
         | 
| 215 | 
            +
                    dim (int): Number of input channels.
         | 
| 216 | 
            +
                    input_resolution (tuple[int]): Input resulotion.
         | 
| 217 | 
            +
                    num_heads (int): Number of attention heads.
         | 
| 218 | 
            +
                    window_size (int): Window size.
         | 
| 219 | 
            +
                    shift_size (int): Shift size for SW-MSA.
         | 
| 220 | 
            +
                    mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
         | 
| 221 | 
            +
                    qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
         | 
| 222 | 
            +
                    drop (float, optional): Dropout rate. Default: 0.0
         | 
| 223 | 
            +
                    attn_drop (float, optional): Attention dropout rate. Default: 0.0
         | 
| 224 | 
            +
                    drop_path (float, optional): Stochastic depth rate. Default: 0.0
         | 
| 225 | 
            +
                    act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
         | 
| 226 | 
            +
                    norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
         | 
| 227 | 
            +
                    pretrained_window_size (int): Window size in pre-training.
         | 
| 228 | 
            +
                """
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
         | 
| 231 | 
            +
                             mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
         | 
| 232 | 
            +
                             act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0,
         | 
| 233 | 
            +
                             use_lepe=False,
         | 
| 234 | 
            +
                             use_cpb_bias=True,
         | 
| 235 | 
            +
                             MoE_config=None,
         | 
| 236 | 
            +
                             use_rpe_bias=False):
         | 
| 237 | 
            +
                    super().__init__()
         | 
| 238 | 
            +
                    self.dim = dim
         | 
| 239 | 
            +
                    self.input_resolution = input_resolution
         | 
| 240 | 
            +
                    self.num_heads = num_heads
         | 
| 241 | 
            +
                    self.window_size = window_size
         | 
| 242 | 
            +
                    self.shift_size = shift_size
         | 
| 243 | 
            +
                    self.mlp_ratio = mlp_ratio
         | 
| 244 | 
            +
                    if min(self.input_resolution) <= self.window_size:
         | 
| 245 | 
            +
                        # if window size is larger than input resolution, we don't partition windows
         | 
| 246 | 
            +
                        self.shift_size = 0
         | 
| 247 | 
            +
                        self.window_size = min(self.input_resolution)
         | 
| 248 | 
            +
                    assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                    self.norm1 = norm_layer(dim)
         | 
| 251 | 
            +
                    self.attn = WindowAttention(
         | 
| 252 | 
            +
                        dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
         | 
| 253 | 
            +
                        qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
         | 
| 254 | 
            +
                        pretrained_window_size=to_2tuple(pretrained_window_size),
         | 
| 255 | 
            +
                        use_lepe=use_lepe,
         | 
| 256 | 
            +
                        use_cpb_bias=use_cpb_bias,
         | 
| 257 | 
            +
                        use_rpe_bias=use_rpe_bias)
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                    self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
         | 
| 260 | 
            +
                    self.norm2 = norm_layer(dim)
         | 
| 261 | 
            +
                    mlp_hidden_dim = int(dim * mlp_ratio)
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                    if MoE_config is None:
         | 
| 264 | 
            +
                        print('-->>> MLP')
         | 
| 265 | 
            +
                        self.mlp = Mlp(
         | 
| 266 | 
            +
                            in_features=dim, hidden_features=mlp_hidden_dim,
         | 
| 267 | 
            +
                            act_layer=act_layer, drop=drop)
         | 
| 268 | 
            +
                    else:
         | 
| 269 | 
            +
                        print('-->>> MOE')
         | 
| 270 | 
            +
                        print(MoE_config)
         | 
| 271 | 
            +
                        self.mlp = MoE(
         | 
| 272 | 
            +
                            input_size=dim, output_size=dim, hidden_size=mlp_hidden_dim,
         | 
| 273 | 
            +
                            **MoE_config)
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                    if self.shift_size > 0:
         | 
| 276 | 
            +
                        attn_mask = self.calculate_mask(self.input_resolution)
         | 
| 277 | 
            +
                    else:
         | 
| 278 | 
            +
                        attn_mask = None
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                    self.register_buffer("attn_mask", attn_mask)
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                def calculate_mask(self, x_size):
         | 
| 283 | 
            +
                    # calculate attention mask for SW-MSA
         | 
| 284 | 
            +
                    H, W = x_size
         | 
| 285 | 
            +
                    img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
         | 
| 286 | 
            +
                    h_slices = (slice(0, -self.window_size),
         | 
| 287 | 
            +
                                slice(-self.window_size, -self.shift_size),
         | 
| 288 | 
            +
                                slice(-self.shift_size, None))
         | 
| 289 | 
            +
                    w_slices = (slice(0, -self.window_size),
         | 
| 290 | 
            +
                                slice(-self.window_size, -self.shift_size),
         | 
| 291 | 
            +
                                slice(-self.shift_size, None))
         | 
| 292 | 
            +
                    cnt = 0
         | 
| 293 | 
            +
                    for h in h_slices:
         | 
| 294 | 
            +
                        for w in w_slices:
         | 
| 295 | 
            +
                            img_mask[:, h, w, :] = cnt
         | 
| 296 | 
            +
                            cnt += 1
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                    mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
         | 
| 299 | 
            +
                    mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
         | 
| 300 | 
            +
                    attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
         | 
| 301 | 
            +
                    attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                    return attn_mask
         | 
| 304 | 
            +
             | 
| 305 | 
            +
                def forward(self, x, x_size):
         | 
| 306 | 
            +
                    H, W = x_size
         | 
| 307 | 
            +
                    B, L, C = x.shape
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                    shortcut = x
         | 
| 310 | 
            +
                    x = x.view(B, H, W, C)
         | 
| 311 | 
            +
             | 
| 312 | 
            +
                    # cyclic shift
         | 
| 313 | 
            +
                    if self.shift_size > 0:
         | 
| 314 | 
            +
                        shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
         | 
| 315 | 
            +
                    else:
         | 
| 316 | 
            +
                        shifted_x = x
         | 
| 317 | 
            +
             | 
| 318 | 
            +
                    # partition windows
         | 
| 319 | 
            +
                    x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
         | 
| 320 | 
            +
                    x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C
         | 
| 321 | 
            +
             | 
| 322 | 
            +
                    # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
         | 
| 323 | 
            +
                    if self.input_resolution == x_size:
         | 
| 324 | 
            +
                        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C
         | 
| 325 | 
            +
                    else:
         | 
| 326 | 
            +
                        attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                    # merge windows
         | 
| 329 | 
            +
                    attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
         | 
| 330 | 
            +
                    shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C
         | 
| 331 | 
            +
             | 
| 332 | 
            +
                    # reverse cyclic shift
         | 
| 333 | 
            +
                    if self.shift_size > 0:
         | 
| 334 | 
            +
                        x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
         | 
| 335 | 
            +
                    else:
         | 
| 336 | 
            +
                        x = shifted_x
         | 
| 337 | 
            +
                    x = x.view(B, H * W, C)
         | 
| 338 | 
            +
                    x = shortcut + self.drop_path(self.norm1(x))
         | 
| 339 | 
            +
             | 
| 340 | 
            +
                    # FFN
         | 
| 341 | 
            +
             | 
| 342 | 
            +
                    loss_moe = None
         | 
| 343 | 
            +
                    res = self.mlp(x)
         | 
| 344 | 
            +
                    if not torch.is_tensor(res):
         | 
| 345 | 
            +
                        res, loss_moe = res
         | 
| 346 | 
            +
             | 
| 347 | 
            +
                    x = x + self.drop_path(self.norm2(res))
         | 
| 348 | 
            +
             | 
| 349 | 
            +
                    return x, loss_moe
         | 
| 350 | 
            +
             | 
| 351 | 
            +
                def extra_repr(self) -> str:
         | 
| 352 | 
            +
                    return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
         | 
| 353 | 
            +
                           f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
         | 
| 354 | 
            +
             | 
| 355 | 
            +
                def flops(self):
         | 
| 356 | 
            +
                    flops = 0
         | 
| 357 | 
            +
                    H, W = self.input_resolution
         | 
| 358 | 
            +
                    # norm1
         | 
| 359 | 
            +
                    flops += self.dim * H * W
         | 
| 360 | 
            +
                    # W-MSA/SW-MSA
         | 
| 361 | 
            +
                    nW = H * W / self.window_size / self.window_size
         | 
| 362 | 
            +
                    flops += nW * self.attn.flops(self.window_size * self.window_size)
         | 
| 363 | 
            +
                    # mlp
         | 
| 364 | 
            +
                    flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
         | 
| 365 | 
            +
                    # norm2
         | 
| 366 | 
            +
                    flops += self.dim * H * W
         | 
| 367 | 
            +
                    return flops
         | 
| 368 | 
            +
             | 
| 369 | 
            +
             | 
| 370 | 
            +
            class PatchMerging(nn.Module):
         | 
| 371 | 
            +
                r""" Patch Merging Layer.
         | 
| 372 | 
            +
                Args:
         | 
| 373 | 
            +
                    input_resolution (tuple[int]): Resolution of input feature.
         | 
| 374 | 
            +
                    dim (int): Number of input channels.
         | 
| 375 | 
            +
                    norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
         | 
| 376 | 
            +
                """
         | 
| 377 | 
            +
             | 
| 378 | 
            +
                def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
         | 
| 379 | 
            +
                    super().__init__()
         | 
| 380 | 
            +
                    self.input_resolution = input_resolution
         | 
| 381 | 
            +
                    self.dim = dim
         | 
| 382 | 
            +
                    self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
         | 
| 383 | 
            +
                    self.norm = norm_layer(2 * dim)
         | 
| 384 | 
            +
             | 
| 385 | 
            +
                def forward(self, x):
         | 
| 386 | 
            +
                    """
         | 
| 387 | 
            +
                    x: B, H*W, C
         | 
| 388 | 
            +
                    """
         | 
| 389 | 
            +
                    H, W = self.input_resolution
         | 
| 390 | 
            +
                    B, L, C = x.shape
         | 
| 391 | 
            +
                    assert L == H * W, "input feature has wrong size"
         | 
| 392 | 
            +
                    assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
         | 
| 393 | 
            +
             | 
| 394 | 
            +
                    x = x.view(B, H, W, C)
         | 
| 395 | 
            +
             | 
| 396 | 
            +
                    x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
         | 
| 397 | 
            +
                    x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
         | 
| 398 | 
            +
                    x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
         | 
| 399 | 
            +
                    x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
         | 
| 400 | 
            +
                    x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
         | 
| 401 | 
            +
                    x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C
         | 
| 402 | 
            +
             | 
| 403 | 
            +
                    x = self.reduction(x)
         | 
| 404 | 
            +
                    x = self.norm(x)
         | 
| 405 | 
            +
             | 
| 406 | 
            +
                    return x
         | 
| 407 | 
            +
             | 
| 408 | 
            +
                def extra_repr(self) -> str:
         | 
| 409 | 
            +
                    return f"input_resolution={self.input_resolution}, dim={self.dim}"
         | 
| 410 | 
            +
             | 
| 411 | 
            +
                def flops(self):
         | 
| 412 | 
            +
                    H, W = self.input_resolution
         | 
| 413 | 
            +
                    flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
         | 
| 414 | 
            +
                    flops += H * W * self.dim // 2
         | 
| 415 | 
            +
                    return flops
         | 
| 416 | 
            +
             | 
| 417 | 
            +
             | 
| 418 | 
            +
            class BasicLayer(nn.Module):
         | 
| 419 | 
            +
                """ A basic Swin Transformer layer for one stage.
         | 
| 420 | 
            +
                Args:
         | 
| 421 | 
            +
                    dim (int): Number of input channels.
         | 
| 422 | 
            +
                    input_resolution (tuple[int]): Input resolution.
         | 
| 423 | 
            +
                    depth (int): Number of blocks.
         | 
| 424 | 
            +
                    num_heads (int): Number of attention heads.
         | 
| 425 | 
            +
                    window_size (int): Local window size.
         | 
| 426 | 
            +
                    mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
         | 
| 427 | 
            +
                    qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
         | 
| 428 | 
            +
                    drop (float, optional): Dropout rate. Default: 0.0
         | 
| 429 | 
            +
                    attn_drop (float, optional): Attention dropout rate. Default: 0.0
         | 
| 430 | 
            +
                    drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
         | 
| 431 | 
            +
                    norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
         | 
| 432 | 
            +
                    downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
         | 
| 433 | 
            +
                    use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
         | 
| 434 | 
            +
                    pretrained_window_size (int): Local window size in pre-training.
         | 
| 435 | 
            +
                """
         | 
| 436 | 
            +
             | 
| 437 | 
            +
                def __init__(self, dim, input_resolution, depth, num_heads, window_size,
         | 
| 438 | 
            +
                             mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
         | 
| 439 | 
            +
                             drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
         | 
| 440 | 
            +
                             pretrained_window_size=0,
         | 
| 441 | 
            +
                             use_lepe=False,
         | 
| 442 | 
            +
                             use_cpb_bias=True,
         | 
| 443 | 
            +
                             MoE_config=None,
         | 
| 444 | 
            +
                             use_rpe_bias=False):
         | 
| 445 | 
            +
             | 
| 446 | 
            +
                    super().__init__()
         | 
| 447 | 
            +
                    self.dim = dim
         | 
| 448 | 
            +
                    self.input_resolution = input_resolution
         | 
| 449 | 
            +
                    self.depth = depth
         | 
| 450 | 
            +
                    self.use_checkpoint = use_checkpoint
         | 
| 451 | 
            +
             | 
| 452 | 
            +
                    # build blocks
         | 
| 453 | 
            +
                    self.blocks = nn.ModuleList([
         | 
| 454 | 
            +
                        SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
         | 
| 455 | 
            +
                                             num_heads=num_heads, window_size=window_size,
         | 
| 456 | 
            +
                                             shift_size=0 if (i % 2 == 0) else window_size // 2,
         | 
| 457 | 
            +
                                             mlp_ratio=mlp_ratio,
         | 
| 458 | 
            +
                                             qkv_bias=qkv_bias,
         | 
| 459 | 
            +
                                             drop=drop, attn_drop=attn_drop,
         | 
| 460 | 
            +
                                             drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
         | 
| 461 | 
            +
                                             norm_layer=norm_layer,
         | 
| 462 | 
            +
                                             pretrained_window_size=pretrained_window_size,
         | 
| 463 | 
            +
                                             use_lepe=use_lepe,
         | 
| 464 | 
            +
                                             use_cpb_bias=use_cpb_bias,
         | 
| 465 | 
            +
                                             MoE_config=MoE_config,
         | 
| 466 | 
            +
                                             use_rpe_bias=use_rpe_bias)
         | 
| 467 | 
            +
                        for i in range(depth)])
         | 
| 468 | 
            +
             | 
| 469 | 
            +
                    # patch merging layer
         | 
| 470 | 
            +
                    if downsample is not None:
         | 
| 471 | 
            +
                        self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
         | 
| 472 | 
            +
                    else:
         | 
| 473 | 
            +
                        self.downsample = None
         | 
| 474 | 
            +
             | 
| 475 | 
            +
                def forward(self, x, x_size):
         | 
| 476 | 
            +
                    loss_moe_all = 0
         | 
| 477 | 
            +
                    for blk in self.blocks:
         | 
| 478 | 
            +
                        if self.use_checkpoint:
         | 
| 479 | 
            +
                            x = checkpoint.checkpoint(blk, x, x_size)
         | 
| 480 | 
            +
                        else:
         | 
| 481 | 
            +
                            x = blk(x, x_size)
         | 
| 482 | 
            +
             | 
| 483 | 
            +
                        if not torch.is_tensor(x):
         | 
| 484 | 
            +
                            x, loss_moe = x
         | 
| 485 | 
            +
                            loss_moe_all += loss_moe or 0
         | 
| 486 | 
            +
             | 
| 487 | 
            +
                    if self.downsample is not None:
         | 
| 488 | 
            +
                        x = self.downsample(x)
         | 
| 489 | 
            +
                    return x, loss_moe_all
         | 
| 490 | 
            +
             | 
| 491 | 
            +
                def extra_repr(self) -> str:
         | 
| 492 | 
            +
                    return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
         | 
| 493 | 
            +
             | 
| 494 | 
            +
                def flops(self):
         | 
| 495 | 
            +
                    flops = 0
         | 
| 496 | 
            +
                    for blk in self.blocks:
         | 
| 497 | 
            +
                        flops += blk.flops()
         | 
| 498 | 
            +
                    if self.downsample is not None:
         | 
| 499 | 
            +
                        flops += self.downsample.flops()
         | 
| 500 | 
            +
                    return flops
         | 
| 501 | 
            +
             | 
| 502 | 
            +
                def _init_respostnorm(self):
         | 
| 503 | 
            +
                    for blk in self.blocks:
         | 
| 504 | 
            +
                        nn.init.constant_(blk.norm1.bias, 0)
         | 
| 505 | 
            +
                        nn.init.constant_(blk.norm1.weight, 0)
         | 
| 506 | 
            +
                        nn.init.constant_(blk.norm2.bias, 0)
         | 
| 507 | 
            +
                        nn.init.constant_(blk.norm2.weight, 0)
         | 
| 508 | 
            +
             | 
| 509 | 
            +
            class PatchEmbed(nn.Module):
         | 
| 510 | 
            +
                r""" Image to Patch Embedding
         | 
| 511 | 
            +
                Args:
         | 
| 512 | 
            +
                    img_size (int): Image size.  Default: 224.
         | 
| 513 | 
            +
                    patch_size (int): Patch token size. Default: 4.
         | 
| 514 | 
            +
                    in_chans (int): Number of input image channels. Default: 3.
         | 
| 515 | 
            +
                    embed_dim (int): Number of linear projection output channels. Default: 96.
         | 
| 516 | 
            +
                    norm_layer (nn.Module, optional): Normalization layer. Default: None
         | 
| 517 | 
            +
                """
         | 
| 518 | 
            +
             | 
| 519 | 
            +
                def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
         | 
| 520 | 
            +
                    super().__init__()
         | 
| 521 | 
            +
                    img_size = to_2tuple(img_size)
         | 
| 522 | 
            +
                    patch_size = to_2tuple(patch_size)
         | 
| 523 | 
            +
                    patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
         | 
| 524 | 
            +
                    self.img_size = img_size
         | 
| 525 | 
            +
                    self.patch_size = patch_size
         | 
| 526 | 
            +
                    self.patches_resolution = patches_resolution
         | 
| 527 | 
            +
                    self.num_patches = patches_resolution[0] * patches_resolution[1]
         | 
| 528 | 
            +
             | 
| 529 | 
            +
                    self.in_chans = in_chans
         | 
| 530 | 
            +
                    self.embed_dim = embed_dim
         | 
| 531 | 
            +
             | 
| 532 | 
            +
                    self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
         | 
| 533 | 
            +
                    if norm_layer is not None:
         | 
| 534 | 
            +
                        self.norm = norm_layer(embed_dim)
         | 
| 535 | 
            +
                    else:
         | 
| 536 | 
            +
                        self.norm = None
         | 
| 537 | 
            +
             | 
| 538 | 
            +
                def forward(self, x):
         | 
| 539 | 
            +
                    B, C, H, W = x.shape
         | 
| 540 | 
            +
                    # FIXME look at relaxing size constraints
         | 
| 541 | 
            +
                    # assert H == self.img_size[0] and W == self.img_size[1],
         | 
| 542 | 
            +
                    #     f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
         | 
| 543 | 
            +
                    x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C
         | 
| 544 | 
            +
                    if self.norm is not None:
         | 
| 545 | 
            +
                        x = self.norm(x)
         | 
| 546 | 
            +
                    return x
         | 
| 547 | 
            +
             | 
| 548 | 
            +
                def flops(self):
         | 
| 549 | 
            +
                    Ho, Wo = self.patches_resolution
         | 
| 550 | 
            +
                    flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
         | 
| 551 | 
            +
                    if self.norm is not None:
         | 
| 552 | 
            +
                        flops += Ho * Wo * self.embed_dim
         | 
| 553 | 
            +
                    return flops
         | 
| 554 | 
            +
             | 
| 555 | 
            +
             | 
| 556 | 
            +
            class RSTB(nn.Module):
         | 
| 557 | 
            +
                """Residual Swin Transformer Block (RSTB).
         | 
| 558 | 
            +
             | 
| 559 | 
            +
                Args:
         | 
| 560 | 
            +
                    dim (int): Number of input channels.
         | 
| 561 | 
            +
                    input_resolution (tuple[int]): Input resolution.
         | 
| 562 | 
            +
                    depth (int): Number of blocks.
         | 
| 563 | 
            +
                    num_heads (int): Number of attention heads.
         | 
| 564 | 
            +
                    window_size (int): Local window size.
         | 
| 565 | 
            +
                    mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
         | 
| 566 | 
            +
                    qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
         | 
| 567 | 
            +
                    drop (float, optional): Dropout rate. Default: 0.0
         | 
| 568 | 
            +
                    attn_drop (float, optional): Attention dropout rate. Default: 0.0
         | 
| 569 | 
            +
                    drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
         | 
| 570 | 
            +
                    norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
         | 
| 571 | 
            +
                    downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
         | 
| 572 | 
            +
                    use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
         | 
| 573 | 
            +
                    img_size: Input image size.
         | 
| 574 | 
            +
                    patch_size: Patch size.
         | 
| 575 | 
            +
                    resi_connection: The convolutional block before residual connection.
         | 
| 576 | 
            +
                """
         | 
| 577 | 
            +
             | 
| 578 | 
            +
                def __init__(self, dim, input_resolution, depth, num_heads, window_size,
         | 
| 579 | 
            +
                             mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
         | 
| 580 | 
            +
                             drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
         | 
| 581 | 
            +
                             img_size=224, patch_size=4, resi_connection='1conv',
         | 
| 582 | 
            +
                             use_lepe=False,
         | 
| 583 | 
            +
                             use_cpb_bias=True,
         | 
| 584 | 
            +
                             MoE_config=None,
         | 
| 585 | 
            +
                             use_rpe_bias=False):
         | 
| 586 | 
            +
                    super(RSTB, self).__init__()
         | 
| 587 | 
            +
             | 
| 588 | 
            +
                    self.dim = dim
         | 
| 589 | 
            +
                    self.input_resolution = input_resolution
         | 
| 590 | 
            +
             | 
| 591 | 
            +
                    self.residual_group = BasicLayer(dim=dim,
         | 
| 592 | 
            +
                                                     input_resolution=input_resolution,
         | 
| 593 | 
            +
                                                     depth=depth,
         | 
| 594 | 
            +
                                                     num_heads=num_heads,
         | 
| 595 | 
            +
                                                     window_size=window_size,
         | 
| 596 | 
            +
                                                     mlp_ratio=mlp_ratio,
         | 
| 597 | 
            +
                                                     qkv_bias=qkv_bias,
         | 
| 598 | 
            +
                                                     drop=drop, attn_drop=attn_drop,
         | 
| 599 | 
            +
                                                     drop_path=drop_path,
         | 
| 600 | 
            +
                                                     norm_layer=norm_layer,
         | 
| 601 | 
            +
                                                     downsample=downsample,
         | 
| 602 | 
            +
                                                     use_checkpoint=use_checkpoint,
         | 
| 603 | 
            +
                                                     use_lepe=use_lepe,
         | 
| 604 | 
            +
                                                     use_cpb_bias=use_cpb_bias,
         | 
| 605 | 
            +
                                                     MoE_config=MoE_config,
         | 
| 606 | 
            +
                                                     use_rpe_bias=use_rpe_bias
         | 
| 607 | 
            +
                                                     )
         | 
| 608 | 
            +
             | 
| 609 | 
            +
                    if resi_connection == '1conv':
         | 
| 610 | 
            +
                        self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
         | 
| 611 | 
            +
                    elif resi_connection == '3conv':
         | 
| 612 | 
            +
                        # to save parameters and memory
         | 
| 613 | 
            +
                        self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
         | 
| 614 | 
            +
                                                  nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
         | 
| 615 | 
            +
                                                  nn.LeakyReLU(negative_slope=0.2, inplace=True),
         | 
| 616 | 
            +
                                                  nn.Conv2d(dim // 4, dim, 3, 1, 1))
         | 
| 617 | 
            +
             | 
| 618 | 
            +
                    self.patch_embed = PatchEmbed(
         | 
| 619 | 
            +
                        img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
         | 
| 620 | 
            +
                        norm_layer=None)
         | 
| 621 | 
            +
             | 
| 622 | 
            +
                    self.patch_unembed = PatchUnEmbed(
         | 
| 623 | 
            +
                        img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
         | 
| 624 | 
            +
                        norm_layer=None)
         | 
| 625 | 
            +
             | 
| 626 | 
            +
                def forward(self, x, x_size):
         | 
| 627 | 
            +
                    loss_moe = None
         | 
| 628 | 
            +
                    res = self.residual_group(x, x_size)
         | 
| 629 | 
            +
             | 
| 630 | 
            +
                    if not torch.is_tensor(res):
         | 
| 631 | 
            +
                        res, loss_moe = res
         | 
| 632 | 
            +
             | 
| 633 | 
            +
                    res = self.patch_embed(self.conv(self.patch_unembed(res, x_size)))
         | 
| 634 | 
            +
                    return res + x, loss_moe
         | 
| 635 | 
            +
             | 
| 636 | 
            +
                def flops(self):
         | 
| 637 | 
            +
                    flops = 0
         | 
| 638 | 
            +
                    flops += self.residual_group.flops()
         | 
| 639 | 
            +
                    H, W = self.input_resolution
         | 
| 640 | 
            +
                    flops += H * W * self.dim * self.dim * 9
         | 
| 641 | 
            +
                    flops += self.patch_embed.flops()
         | 
| 642 | 
            +
                    flops += self.patch_unembed.flops()
         | 
| 643 | 
            +
             | 
| 644 | 
            +
                    return flops
         | 
| 645 | 
            +
             | 
| 646 | 
            +
             | 
| 647 | 
            +
            class PatchUnEmbed(nn.Module):
         | 
| 648 | 
            +
                r""" Image to Patch Unembedding
         | 
| 649 | 
            +
             | 
| 650 | 
            +
                Args:
         | 
| 651 | 
            +
                    img_size (int): Image size.  Default: 224.
         | 
| 652 | 
            +
                    patch_size (int): Patch token size. Default: 4.
         | 
| 653 | 
            +
                    in_chans (int): Number of input image channels. Default: 3.
         | 
| 654 | 
            +
                    embed_dim (int): Number of linear projection output channels. Default: 96.
         | 
| 655 | 
            +
                    norm_layer (nn.Module, optional): Normalization layer. Default: None
         | 
| 656 | 
            +
                """
         | 
| 657 | 
            +
             | 
| 658 | 
            +
                def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
         | 
| 659 | 
            +
                    super().__init__()
         | 
| 660 | 
            +
                    img_size = to_2tuple(img_size)
         | 
| 661 | 
            +
                    patch_size = to_2tuple(patch_size)
         | 
| 662 | 
            +
                    patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
         | 
| 663 | 
            +
                    self.img_size = img_size
         | 
| 664 | 
            +
                    self.patch_size = patch_size
         | 
| 665 | 
            +
                    self.patches_resolution = patches_resolution
         | 
| 666 | 
            +
                    self.num_patches = patches_resolution[0] * patches_resolution[1]
         | 
| 667 | 
            +
             | 
| 668 | 
            +
                    self.in_chans = in_chans
         | 
| 669 | 
            +
                    self.embed_dim = embed_dim
         | 
| 670 | 
            +
             | 
| 671 | 
            +
                def forward(self, x, x_size):
         | 
| 672 | 
            +
                    B, HW, C = x.shape
         | 
| 673 | 
            +
                    x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1])  # B Ph*Pw C
         | 
| 674 | 
            +
                    return x
         | 
| 675 | 
            +
             | 
| 676 | 
            +
                def flops(self):
         | 
| 677 | 
            +
                    flops = 0
         | 
| 678 | 
            +
                    return flops
         | 
| 679 | 
            +
             | 
| 680 | 
            +
             | 
| 681 | 
            +
            class Upsample(nn.Sequential):
         | 
| 682 | 
            +
                """Upsample module.
         | 
| 683 | 
            +
             | 
| 684 | 
            +
                Args:
         | 
| 685 | 
            +
                    scale (int): Scale factor. Supported scales: 2^n and 3.
         | 
| 686 | 
            +
                    num_feat (int): Channel number of intermediate features.
         | 
| 687 | 
            +
                """
         | 
| 688 | 
            +
             | 
| 689 | 
            +
                def __init__(self, scale, num_feat):
         | 
| 690 | 
            +
                    m = []
         | 
| 691 | 
            +
                    if (scale & (scale - 1)) == 0:  # scale = 2^n
         | 
| 692 | 
            +
                        for _ in range(int(math.log(scale, 2))):
         | 
| 693 | 
            +
                            m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
         | 
| 694 | 
            +
                            m.append(nn.PixelShuffle(2))
         | 
| 695 | 
            +
                    elif scale == 3:
         | 
| 696 | 
            +
                        m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
         | 
| 697 | 
            +
                        m.append(nn.PixelShuffle(3))
         | 
| 698 | 
            +
                    else:
         | 
| 699 | 
            +
                        raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
         | 
| 700 | 
            +
                    super(Upsample, self).__init__(*m)
         | 
| 701 | 
            +
             | 
| 702 | 
            +
            class Upsample_hf(nn.Sequential):
         | 
| 703 | 
            +
                """Upsample module.
         | 
| 704 | 
            +
             | 
| 705 | 
            +
                Args:
         | 
| 706 | 
            +
                    scale (int): Scale factor. Supported scales: 2^n and 3.
         | 
| 707 | 
            +
                    num_feat (int): Channel number of intermediate features.
         | 
| 708 | 
            +
                """
         | 
| 709 | 
            +
             | 
| 710 | 
            +
                def __init__(self, scale, num_feat):
         | 
| 711 | 
            +
                    m = []
         | 
| 712 | 
            +
                    if (scale & (scale - 1)) == 0:  # scale = 2^n
         | 
| 713 | 
            +
                        for _ in range(int(math.log(scale, 2))):
         | 
| 714 | 
            +
                            m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
         | 
| 715 | 
            +
                            m.append(nn.PixelShuffle(2))
         | 
| 716 | 
            +
                    elif scale == 3:
         | 
| 717 | 
            +
                        m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
         | 
| 718 | 
            +
                        m.append(nn.PixelShuffle(3))
         | 
| 719 | 
            +
                    else:
         | 
| 720 | 
            +
                        raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
         | 
| 721 | 
            +
                    super(Upsample_hf, self).__init__(*m)
         | 
| 722 | 
            +
             | 
| 723 | 
            +
             | 
| 724 | 
            +
            class UpsampleOneStep(nn.Sequential):
         | 
| 725 | 
            +
                """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
         | 
| 726 | 
            +
                   Used in lightweight SR to save parameters.
         | 
| 727 | 
            +
             | 
| 728 | 
            +
                Args:
         | 
| 729 | 
            +
                    scale (int): Scale factor. Supported scales: 2^n and 3.
         | 
| 730 | 
            +
                    num_feat (int): Channel number of intermediate features.
         | 
| 731 | 
            +
             | 
| 732 | 
            +
                """
         | 
| 733 | 
            +
             | 
| 734 | 
            +
                def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
         | 
| 735 | 
            +
                    self.num_feat = num_feat
         | 
| 736 | 
            +
                    self.input_resolution = input_resolution
         | 
| 737 | 
            +
                    m = []
         | 
| 738 | 
            +
                    m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
         | 
| 739 | 
            +
                    m.append(nn.PixelShuffle(scale))
         | 
| 740 | 
            +
                    super(UpsampleOneStep, self).__init__(*m)
         | 
| 741 | 
            +
             | 
| 742 | 
            +
                def flops(self):
         | 
| 743 | 
            +
                    H, W = self.input_resolution
         | 
| 744 | 
            +
                    flops = H * W * self.num_feat * 3 * 9
         | 
| 745 | 
            +
                    return flops
         | 
| 746 | 
            +
             | 
| 747 | 
            +
             | 
| 748 | 
            +
             | 
| 749 | 
            +
            class Swin2SR(nn.Module):
         | 
| 750 | 
            +
                r""" Swin2SR
         | 
| 751 | 
            +
                    A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`.
         | 
| 752 | 
            +
             | 
| 753 | 
            +
                Args:
         | 
| 754 | 
            +
                    img_size (int | tuple(int)): Input image size. Default 64
         | 
| 755 | 
            +
                    patch_size (int | tuple(int)): Patch size. Default: 1
         | 
| 756 | 
            +
                    in_chans (int): Number of input image channels. Default: 3
         | 
| 757 | 
            +
                    embed_dim (int): Patch embedding dimension. Default: 96
         | 
| 758 | 
            +
                    depths (tuple(int)): Depth of each Swin Transformer layer.
         | 
| 759 | 
            +
                    num_heads (tuple(int)): Number of attention heads in different layers.
         | 
| 760 | 
            +
                    window_size (int): Window size. Default: 7
         | 
| 761 | 
            +
                    mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
         | 
| 762 | 
            +
                    qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
         | 
| 763 | 
            +
                    drop_rate (float): Dropout rate. Default: 0
         | 
| 764 | 
            +
                    attn_drop_rate (float): Attention dropout rate. Default: 0
         | 
| 765 | 
            +
                    drop_path_rate (float): Stochastic depth rate. Default: 0.1
         | 
| 766 | 
            +
                    norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
         | 
| 767 | 
            +
                    ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
         | 
| 768 | 
            +
                    patch_norm (bool): If True, add normalization after patch embedding. Default: True
         | 
| 769 | 
            +
                    use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
         | 
| 770 | 
            +
                    upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
         | 
| 771 | 
            +
                    img_range: Image range. 1. or 255.
         | 
| 772 | 
            +
                    upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
         | 
| 773 | 
            +
                    resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
         | 
| 774 | 
            +
                """
         | 
| 775 | 
            +
             | 
| 776 | 
            +
                def __init__(self, img_size=64, patch_size=1, in_chans=3,
         | 
| 777 | 
            +
                             embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
         | 
| 778 | 
            +
                             window_size=7, mlp_ratio=4., qkv_bias=True,
         | 
| 779 | 
            +
                             drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
         | 
| 780 | 
            +
                             norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
         | 
| 781 | 
            +
                             use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
         | 
| 782 | 
            +
                             use_lepe=False,
         | 
| 783 | 
            +
                             use_cpb_bias=True,
         | 
| 784 | 
            +
                             MoE_config=None,
         | 
| 785 | 
            +
                             use_rpe_bias=False,
         | 
| 786 | 
            +
                             **kwargs):
         | 
| 787 | 
            +
                    super(Swin2SR, self).__init__()
         | 
| 788 | 
            +
                    print('==== SWIN 2SR')
         | 
| 789 | 
            +
                    num_in_ch = in_chans
         | 
| 790 | 
            +
                    num_out_ch = in_chans
         | 
| 791 | 
            +
                    num_feat = 64
         | 
| 792 | 
            +
                    self.img_range = img_range
         | 
| 793 | 
            +
                    if in_chans == 3:
         | 
| 794 | 
            +
                        rgb_mean = (0.4488, 0.4371, 0.4040)
         | 
| 795 | 
            +
                        self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
         | 
| 796 | 
            +
                    else:
         | 
| 797 | 
            +
                        self.mean = torch.zeros(1, 1, 1, 1)
         | 
| 798 | 
            +
                    self.upscale = upscale
         | 
| 799 | 
            +
                    self.upsampler = upsampler
         | 
| 800 | 
            +
                    self.window_size = window_size
         | 
| 801 | 
            +
             | 
| 802 | 
            +
                    #####################################################################################################
         | 
| 803 | 
            +
                    ################################### 1, shallow feature extraction ###################################
         | 
| 804 | 
            +
                    self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
         | 
| 805 | 
            +
             | 
| 806 | 
            +
                    #####################################################################################################
         | 
| 807 | 
            +
                    ################################### 2, deep feature extraction ######################################
         | 
| 808 | 
            +
                    self.num_layers = len(depths)
         | 
| 809 | 
            +
                    self.embed_dim = embed_dim
         | 
| 810 | 
            +
                    self.ape = ape
         | 
| 811 | 
            +
                    self.patch_norm = patch_norm
         | 
| 812 | 
            +
                    self.num_features = embed_dim
         | 
| 813 | 
            +
                    self.mlp_ratio = mlp_ratio
         | 
| 814 | 
            +
             | 
| 815 | 
            +
                    # split image into non-overlapping patches
         | 
| 816 | 
            +
                    self.patch_embed = PatchEmbed(
         | 
| 817 | 
            +
                        img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
         | 
| 818 | 
            +
                        norm_layer=norm_layer if self.patch_norm else None)
         | 
| 819 | 
            +
                    num_patches = self.patch_embed.num_patches
         | 
| 820 | 
            +
                    patches_resolution = self.patch_embed.patches_resolution
         | 
| 821 | 
            +
                    self.patches_resolution = patches_resolution
         | 
| 822 | 
            +
             | 
| 823 | 
            +
                    # merge non-overlapping patches into image
         | 
| 824 | 
            +
                    self.patch_unembed = PatchUnEmbed(
         | 
| 825 | 
            +
                        img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
         | 
| 826 | 
            +
                        norm_layer=norm_layer if self.patch_norm else None)
         | 
| 827 | 
            +
             | 
| 828 | 
            +
                    # absolute position embedding
         | 
| 829 | 
            +
                    if self.ape:
         | 
| 830 | 
            +
                        self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
         | 
| 831 | 
            +
                        trunc_normal_(self.absolute_pos_embed, std=.02)
         | 
| 832 | 
            +
             | 
| 833 | 
            +
                    self.pos_drop = nn.Dropout(p=drop_rate)
         | 
| 834 | 
            +
             | 
| 835 | 
            +
                    # stochastic depth
         | 
| 836 | 
            +
                    dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule
         | 
| 837 | 
            +
             | 
| 838 | 
            +
                    # build Residual Swin Transformer blocks (RSTB)
         | 
| 839 | 
            +
                    self.layers = nn.ModuleList()
         | 
| 840 | 
            +
                    for i_layer in range(self.num_layers):
         | 
| 841 | 
            +
                        layer = RSTB(dim=embed_dim,
         | 
| 842 | 
            +
                                     input_resolution=(patches_resolution[0],
         | 
| 843 | 
            +
                                                       patches_resolution[1]),
         | 
| 844 | 
            +
                                     depth=depths[i_layer],
         | 
| 845 | 
            +
                                     num_heads=num_heads[i_layer],
         | 
| 846 | 
            +
                                     window_size=window_size,
         | 
| 847 | 
            +
                                     mlp_ratio=self.mlp_ratio,
         | 
| 848 | 
            +
                                     qkv_bias=qkv_bias,
         | 
| 849 | 
            +
                                     drop=drop_rate, attn_drop=attn_drop_rate,
         | 
| 850 | 
            +
                                     drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],  # no impact on SR results
         | 
| 851 | 
            +
                                     norm_layer=norm_layer,
         | 
| 852 | 
            +
                                     downsample=None,
         | 
| 853 | 
            +
                                     use_checkpoint=use_checkpoint,
         | 
| 854 | 
            +
                                     img_size=img_size,
         | 
| 855 | 
            +
                                     patch_size=patch_size,
         | 
| 856 | 
            +
                                     resi_connection=resi_connection,
         | 
| 857 | 
            +
                                     use_lepe=use_lepe,
         | 
| 858 | 
            +
                                     use_cpb_bias=use_cpb_bias,
         | 
| 859 | 
            +
                                     MoE_config=MoE_config,
         | 
| 860 | 
            +
                                     use_rpe_bias=use_rpe_bias,
         | 
| 861 | 
            +
                                     )
         | 
| 862 | 
            +
                        self.layers.append(layer)
         | 
| 863 | 
            +
             | 
| 864 | 
            +
                    if self.upsampler == 'pixelshuffle_hf':
         | 
| 865 | 
            +
                        self.layers_hf = nn.ModuleList()
         | 
| 866 | 
            +
                        for i_layer in range(self.num_layers):
         | 
| 867 | 
            +
                            layer = RSTB(dim=embed_dim,
         | 
| 868 | 
            +
                                         input_resolution=(patches_resolution[0],
         | 
| 869 | 
            +
                                                           patches_resolution[1]),
         | 
| 870 | 
            +
                                         depth=depths[i_layer],
         | 
| 871 | 
            +
                                         num_heads=num_heads[i_layer],
         | 
| 872 | 
            +
                                         window_size=window_size,
         | 
| 873 | 
            +
                                         mlp_ratio=self.mlp_ratio,
         | 
| 874 | 
            +
                                         qkv_bias=qkv_bias,
         | 
| 875 | 
            +
                                         drop=drop_rate, attn_drop=attn_drop_rate,
         | 
| 876 | 
            +
                                         drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],  # no impact on SR results
         | 
| 877 | 
            +
                                         norm_layer=norm_layer,
         | 
| 878 | 
            +
                                         downsample=None,
         | 
| 879 | 
            +
                                         use_checkpoint=use_checkpoint,
         | 
| 880 | 
            +
                                         img_size=img_size,
         | 
| 881 | 
            +
                                         patch_size=patch_size,
         | 
| 882 | 
            +
                                         resi_connection=resi_connection,
         | 
| 883 | 
            +
                                         use_lepe=use_lepe,
         | 
| 884 | 
            +
                                         use_cpb_bias=use_cpb_bias,
         | 
| 885 | 
            +
                                         MoE_config=MoE_config,
         | 
| 886 | 
            +
                                         use_rpe_bias=use_rpe_bias
         | 
| 887 | 
            +
                                         )
         | 
| 888 | 
            +
                            self.layers_hf.append(layer)
         | 
| 889 | 
            +
             | 
| 890 | 
            +
                    self.norm = norm_layer(self.num_features)
         | 
| 891 | 
            +
             | 
| 892 | 
            +
                    # build the last conv layer in deep feature extraction
         | 
| 893 | 
            +
                    if resi_connection == '1conv':
         | 
| 894 | 
            +
                        self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
         | 
| 895 | 
            +
                    elif resi_connection == '3conv':
         | 
| 896 | 
            +
                        # to save parameters and memory
         | 
| 897 | 
            +
                        self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
         | 
| 898 | 
            +
                                                             nn.LeakyReLU(negative_slope=0.2, inplace=True),
         | 
| 899 | 
            +
                                                             nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
         | 
| 900 | 
            +
                                                             nn.LeakyReLU(negative_slope=0.2, inplace=True),
         | 
| 901 | 
            +
                                                             nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
         | 
| 902 | 
            +
             | 
| 903 | 
            +
                    #####################################################################################################
         | 
| 904 | 
            +
                    ################################ 3, high quality image reconstruction ################################
         | 
| 905 | 
            +
                    if self.upsampler == 'pixelshuffle':
         | 
| 906 | 
            +
                        # for classical SR
         | 
| 907 | 
            +
                        self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
         | 
| 908 | 
            +
                                                                  nn.LeakyReLU(inplace=True))
         | 
| 909 | 
            +
                        self.upsample = Upsample(upscale, num_feat)
         | 
| 910 | 
            +
                        self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
         | 
| 911 | 
            +
                    elif self.upsampler == 'pixelshuffle_aux':
         | 
| 912 | 
            +
                        self.conv_bicubic = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
         | 
| 913 | 
            +
                        self.conv_before_upsample = nn.Sequential(
         | 
| 914 | 
            +
                            nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
         | 
| 915 | 
            +
                            nn.LeakyReLU(inplace=True))
         | 
| 916 | 
            +
                        self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
         | 
| 917 | 
            +
                        self.conv_after_aux = nn.Sequential(
         | 
| 918 | 
            +
                            nn.Conv2d(3, num_feat, 3, 1, 1),
         | 
| 919 | 
            +
                            nn.LeakyReLU(inplace=True))
         | 
| 920 | 
            +
                        self.upsample = Upsample(upscale, num_feat)
         | 
| 921 | 
            +
                        self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
         | 
| 922 | 
            +
             | 
| 923 | 
            +
                    elif self.upsampler == 'pixelshuffle_hf':
         | 
| 924 | 
            +
                        self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
         | 
| 925 | 
            +
                                                                  nn.LeakyReLU(inplace=True))
         | 
| 926 | 
            +
                        self.upsample = Upsample(upscale, num_feat)
         | 
| 927 | 
            +
                        self.upsample_hf = Upsample_hf(upscale, num_feat)
         | 
| 928 | 
            +
                        self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
         | 
| 929 | 
            +
                        self.conv_first_hf = nn.Sequential(nn.Conv2d(num_feat, embed_dim, 3, 1, 1),
         | 
| 930 | 
            +
                                                                  nn.LeakyReLU(inplace=True))
         | 
| 931 | 
            +
                        self.conv_after_body_hf = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
         | 
| 932 | 
            +
                        self.conv_before_upsample_hf = nn.Sequential(
         | 
| 933 | 
            +
                            nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
         | 
| 934 | 
            +
                            nn.LeakyReLU(inplace=True))
         | 
| 935 | 
            +
                        self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
         | 
| 936 | 
            +
             | 
| 937 | 
            +
                    elif self.upsampler == 'pixelshuffledirect':
         | 
| 938 | 
            +
                        # for lightweight SR (to save parameters)
         | 
| 939 | 
            +
                        self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
         | 
| 940 | 
            +
                                                        (patches_resolution[0], patches_resolution[1]))
         | 
| 941 | 
            +
                    elif self.upsampler == 'nearest+conv':
         | 
| 942 | 
            +
                        # for real-world SR (less artifacts)
         | 
| 943 | 
            +
                        assert self.upscale == 4, 'only support x4 now.'
         | 
| 944 | 
            +
                        self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
         | 
| 945 | 
            +
                                                                  nn.LeakyReLU(inplace=True))
         | 
| 946 | 
            +
                        self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
         | 
| 947 | 
            +
                        self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
         | 
| 948 | 
            +
                        self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
         | 
| 949 | 
            +
                        self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
         | 
| 950 | 
            +
                        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
         | 
| 951 | 
            +
                    else:
         | 
| 952 | 
            +
                        # for image denoising and JPEG compression artifact reduction
         | 
| 953 | 
            +
                        self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
         | 
| 954 | 
            +
             | 
| 955 | 
            +
                    self.apply(self._init_weights)
         | 
| 956 | 
            +
             | 
| 957 | 
            +
                def _init_weights(self, m):
         | 
| 958 | 
            +
                    if isinstance(m, nn.Linear):
         | 
| 959 | 
            +
                        trunc_normal_(m.weight, std=.02)
         | 
| 960 | 
            +
                        if isinstance(m, nn.Linear) and m.bias is not None:
         | 
| 961 | 
            +
                            nn.init.constant_(m.bias, 0)
         | 
| 962 | 
            +
                    elif isinstance(m, nn.LayerNorm):
         | 
| 963 | 
            +
                        nn.init.constant_(m.bias, 0)
         | 
| 964 | 
            +
                        nn.init.constant_(m.weight, 1.0)
         | 
| 965 | 
            +
             | 
| 966 | 
            +
                @torch.jit.ignore
         | 
| 967 | 
            +
                def no_weight_decay(self):
         | 
| 968 | 
            +
                    return {'absolute_pos_embed'}
         | 
| 969 | 
            +
             | 
| 970 | 
            +
                @torch.jit.ignore
         | 
| 971 | 
            +
                def no_weight_decay_keywords(self):
         | 
| 972 | 
            +
                    return {'relative_position_bias_table'}
         | 
| 973 | 
            +
             | 
| 974 | 
            +
                def check_image_size(self, x):
         | 
| 975 | 
            +
                    _, _, h, w = x.size()
         | 
| 976 | 
            +
                    mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
         | 
| 977 | 
            +
                    mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
         | 
| 978 | 
            +
                    x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
         | 
| 979 | 
            +
                    return x
         | 
| 980 | 
            +
             | 
| 981 | 
            +
                def forward_features(self, x):
         | 
| 982 | 
            +
                    x_size = (x.shape[2], x.shape[3])
         | 
| 983 | 
            +
                    x = self.patch_embed(x)
         | 
| 984 | 
            +
                    if self.ape:
         | 
| 985 | 
            +
                        x = x + self.absolute_pos_embed
         | 
| 986 | 
            +
                    x = self.pos_drop(x)
         | 
| 987 | 
            +
             | 
| 988 | 
            +
                    loss_moe_all = 0
         | 
| 989 | 
            +
                    for layer in self.layers:
         | 
| 990 | 
            +
                        x = layer(x, x_size)
         | 
| 991 | 
            +
             | 
| 992 | 
            +
                        if not torch.is_tensor(x):
         | 
| 993 | 
            +
                            x, loss_moe = x
         | 
| 994 | 
            +
                            loss_moe_all += loss_moe or 0
         | 
| 995 | 
            +
             | 
| 996 | 
            +
                    x = self.norm(x)  # B L C
         | 
| 997 | 
            +
                    x = self.patch_unembed(x, x_size)
         | 
| 998 | 
            +
             | 
| 999 | 
            +
                    return x, loss_moe_all
         | 
| 1000 | 
            +
             | 
| 1001 | 
            +
                def forward_features_hf(self, x):
         | 
| 1002 | 
            +
                    x_size = (x.shape[2], x.shape[3])
         | 
| 1003 | 
            +
                    x = self.patch_embed(x)
         | 
| 1004 | 
            +
                    if self.ape:
         | 
| 1005 | 
            +
                        x = x + self.absolute_pos_embed
         | 
| 1006 | 
            +
                    x = self.pos_drop(x)
         | 
| 1007 | 
            +
             | 
| 1008 | 
            +
                    loss_moe_all = 0
         | 
| 1009 | 
            +
                    for layer in self.layers_hf:
         | 
| 1010 | 
            +
                        x = layer(x, x_size)
         | 
| 1011 | 
            +
             | 
| 1012 | 
            +
                        if not torch.is_tensor(x):
         | 
| 1013 | 
            +
                            x, loss_moe = x
         | 
| 1014 | 
            +
                            loss_moe_all += loss_moe or 0
         | 
| 1015 | 
            +
             | 
| 1016 | 
            +
                    x = self.norm(x)  # B L C
         | 
| 1017 | 
            +
                    x = self.patch_unembed(x, x_size)
         | 
| 1018 | 
            +
             | 
| 1019 | 
            +
                    return x, loss_moe_all
         | 
| 1020 | 
            +
             | 
| 1021 | 
            +
                def forward_backbone(self, x):
         | 
| 1022 | 
            +
                    H, W = x.shape[2:]
         | 
| 1023 | 
            +
                    x = self.check_image_size(x)
         | 
| 1024 | 
            +
             | 
| 1025 | 
            +
                    self.mean = self.mean.type_as(x)
         | 
| 1026 | 
            +
                    x = (x - self.mean) * self.img_range
         | 
| 1027 | 
            +
             | 
| 1028 | 
            +
                    if self.upsampler == 'pixelshuffledirect':
         | 
| 1029 | 
            +
                        # for lightweight SR
         | 
| 1030 | 
            +
                        x = self.conv_first(x)
         | 
| 1031 | 
            +
             | 
| 1032 | 
            +
                        res = self.forward_features(x)
         | 
| 1033 | 
            +
                        if not torch.is_tensor(res):
         | 
| 1034 | 
            +
                            res, loss_moe = res
         | 
| 1035 | 
            +
             | 
| 1036 | 
            +
                        x = self.conv_after_body(res) + x
         | 
| 1037 | 
            +
                    else:
         | 
| 1038 | 
            +
                        raise Exception('not implemented yet')
         | 
| 1039 | 
            +
             | 
| 1040 | 
            +
                    x = x / self.img_range + self.mean
         | 
| 1041 | 
            +
                    return x
         | 
| 1042 | 
            +
             | 
| 1043 | 
            +
                def forward(self, x):
         | 
| 1044 | 
            +
                    H, W = x.shape[2:]
         | 
| 1045 | 
            +
                    x = self.check_image_size(x)
         | 
| 1046 | 
            +
             | 
| 1047 | 
            +
                    self.mean = self.mean.type_as(x)
         | 
| 1048 | 
            +
                    x = (x - self.mean) * self.img_range
         | 
| 1049 | 
            +
             | 
| 1050 | 
            +
                    loss_moe = 0
         | 
| 1051 | 
            +
                    if self.upsampler == 'pixelshuffle':
         | 
| 1052 | 
            +
                        # for classical SR
         | 
| 1053 | 
            +
                        x = self.conv_first(x)
         | 
| 1054 | 
            +
             | 
| 1055 | 
            +
                        res = self.forward_features(x)
         | 
| 1056 | 
            +
                        if not torch.is_tensor(res):
         | 
| 1057 | 
            +
                            res, loss_moe = res
         | 
| 1058 | 
            +
             | 
| 1059 | 
            +
                        x = self.conv_after_body(res) + x
         | 
| 1060 | 
            +
                        x = self.conv_before_upsample(x)
         | 
| 1061 | 
            +
                        x = self.conv_last(self.upsample(x))
         | 
| 1062 | 
            +
                    elif self.upsampler == 'pixelshuffle_aux':
         | 
| 1063 | 
            +
                        bicubic = F.interpolate(x, size=(H * self.upscale, W * self.upscale), mode='bicubic', align_corners=False)
         | 
| 1064 | 
            +
                        bicubic = self.conv_bicubic(bicubic)
         | 
| 1065 | 
            +
                        x = self.conv_first(x)
         | 
| 1066 | 
            +
             | 
| 1067 | 
            +
                        res = self.forward_features(x)
         | 
| 1068 | 
            +
                        if not torch.is_tensor(res):
         | 
| 1069 | 
            +
                            res, loss_moe = res
         | 
| 1070 | 
            +
             | 
| 1071 | 
            +
                        x = self.conv_after_body(res) + x
         | 
| 1072 | 
            +
                        x = self.conv_before_upsample(x)
         | 
| 1073 | 
            +
                        aux = self.conv_aux(x) # b, 3, LR_H, LR_W
         | 
| 1074 | 
            +
                        x = self.conv_after_aux(aux)
         | 
| 1075 | 
            +
                        x = self.upsample(x)[:, :, :H * self.upscale, :W * self.upscale] + bicubic[:, :, :H * self.upscale, :W * self.upscale]
         | 
| 1076 | 
            +
                        x = self.conv_last(x)
         | 
| 1077 | 
            +
                        aux = aux / self.img_range + self.mean
         | 
| 1078 | 
            +
                    elif self.upsampler == 'pixelshuffle_hf':
         | 
| 1079 | 
            +
                        # for classical SR with HF
         | 
| 1080 | 
            +
                        x = self.conv_first(x)
         | 
| 1081 | 
            +
             | 
| 1082 | 
            +
                        res = self.forward_features(x)
         | 
| 1083 | 
            +
                        if not torch.is_tensor(res):
         | 
| 1084 | 
            +
                            res, loss_moe = res
         | 
| 1085 | 
            +
             | 
| 1086 | 
            +
                        x = self.conv_after_body(res) + x
         | 
| 1087 | 
            +
                        x_before = self.conv_before_upsample(x)
         | 
| 1088 | 
            +
                        x_out = self.conv_last(self.upsample(x_before))
         | 
| 1089 | 
            +
             | 
| 1090 | 
            +
                        x_hf = self.conv_first_hf(x_before)
         | 
| 1091 | 
            +
             | 
| 1092 | 
            +
                        res_hf = self.forward_features_hf(x_hf)
         | 
| 1093 | 
            +
                        if not torch.is_tensor(res_hf):
         | 
| 1094 | 
            +
                            res_hf, loss_moe_hf = res_hf
         | 
| 1095 | 
            +
                            loss_moe += loss_moe_hf
         | 
| 1096 | 
            +
             | 
| 1097 | 
            +
                        x_hf = self.conv_after_body_hf(res_hf) + x_hf
         | 
| 1098 | 
            +
                        x_hf = self.conv_before_upsample_hf(x_hf)
         | 
| 1099 | 
            +
                        x_hf = self.conv_last_hf(self.upsample_hf(x_hf))
         | 
| 1100 | 
            +
                        x = x_out + x_hf
         | 
| 1101 | 
            +
                        x_hf = x_hf / self.img_range + self.mean
         | 
| 1102 | 
            +
             | 
| 1103 | 
            +
                    elif self.upsampler == 'pixelshuffledirect':
         | 
| 1104 | 
            +
                        # for lightweight SR
         | 
| 1105 | 
            +
                        x = self.conv_first(x)
         | 
| 1106 | 
            +
             | 
| 1107 | 
            +
                        res = self.forward_features(x)
         | 
| 1108 | 
            +
                        if not torch.is_tensor(res):
         | 
| 1109 | 
            +
                            res, loss_moe = res
         | 
| 1110 | 
            +
             | 
| 1111 | 
            +
                        x = self.conv_after_body(res) + x
         | 
| 1112 | 
            +
                        x = self.upsample(x)
         | 
| 1113 | 
            +
                    elif self.upsampler == 'nearest+conv':
         | 
| 1114 | 
            +
                        # for real-world SR
         | 
| 1115 | 
            +
                        x = self.conv_first(x)
         | 
| 1116 | 
            +
             | 
| 1117 | 
            +
                        res = self.forward_features(x)
         | 
| 1118 | 
            +
                        if not torch.is_tensor(res):
         | 
| 1119 | 
            +
                            res, loss_moe = res
         | 
| 1120 | 
            +
             | 
| 1121 | 
            +
                        x = self.conv_after_body(res) + x
         | 
| 1122 | 
            +
                        x = self.conv_before_upsample(x)
         | 
| 1123 | 
            +
                        x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
         | 
| 1124 | 
            +
                        x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
         | 
| 1125 | 
            +
                        x = self.conv_last(self.lrelu(self.conv_hr(x)))
         | 
| 1126 | 
            +
                    else:
         | 
| 1127 | 
            +
                        # for image denoising and JPEG compression artifact reduction
         | 
| 1128 | 
            +
                        x_first = self.conv_first(x)
         | 
| 1129 | 
            +
             | 
| 1130 | 
            +
                        res = self.forward_features(x_first)
         | 
| 1131 | 
            +
                        if not torch.is_tensor(res):
         | 
| 1132 | 
            +
                            res, loss_moe = res
         | 
| 1133 | 
            +
             | 
| 1134 | 
            +
                        res = self.conv_after_body(res) + x_first
         | 
| 1135 | 
            +
                        x = x + self.conv_last(res)
         | 
| 1136 | 
            +
             | 
| 1137 | 
            +
                    x = x / self.img_range + self.mean
         | 
| 1138 | 
            +
                    if self.upsampler == "pixelshuffle_aux":
         | 
| 1139 | 
            +
                        return x[:, :, :H*self.upscale, :W*self.upscale], aux, loss_moe
         | 
| 1140 | 
            +
             | 
| 1141 | 
            +
                    elif self.upsampler == "pixelshuffle_hf":
         | 
| 1142 | 
            +
                        x_out = x_out / self.img_range + self.mean
         | 
| 1143 | 
            +
                        return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale], loss_moe
         | 
| 1144 | 
            +
             | 
| 1145 | 
            +
                    else:
         | 
| 1146 | 
            +
                        return x[:, :, :H*self.upscale, :W*self.upscale], loss_moe
         | 
| 1147 | 
            +
             | 
| 1148 | 
            +
                def flops(self):
         | 
| 1149 | 
            +
                    flops = 0
         | 
| 1150 | 
            +
                    H, W = self.patches_resolution
         | 
| 1151 | 
            +
                    flops += H * W * 3 * self.embed_dim * 9
         | 
| 1152 | 
            +
                    flops += self.patch_embed.flops()
         | 
| 1153 | 
            +
                    for i, layer in enumerate(self.layers):
         | 
| 1154 | 
            +
                        flops += layer.flops()
         | 
| 1155 | 
            +
                    flops += H * W * 3 * self.embed_dim * self.embed_dim
         | 
| 1156 | 
            +
                    flops += self.upsample.flops()
         | 
| 1157 | 
            +
                    return flops
         | 
    	
        swin2_mose/moe.py
    ADDED
    
    | @@ -0,0 +1,323 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #
         | 
| 2 | 
            +
            # Source code: https://github.com/davidmrau/mixture-of-experts
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            # Sparsely-Gated Mixture-of-Experts Layers.
         | 
| 6 | 
            +
            # See "Outrageously Large Neural Networks"
         | 
| 7 | 
            +
            # https://arxiv.org/abs/1701.06538
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Author: David Rau
         | 
| 10 | 
            +
            #
         | 
| 11 | 
            +
            # The code is based on the TensorFlow implementation:
         | 
| 12 | 
            +
            # https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/expert_utils.py
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            import torch
         | 
| 16 | 
            +
            import torch.nn as nn
         | 
| 17 | 
            +
            from torch.distributions.normal import Normal
         | 
| 18 | 
            +
            from copy import deepcopy
         | 
| 19 | 
            +
            import numpy as np
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            from utils import Mlp as MLP
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            class SparseDispatcher(object):
         | 
| 24 | 
            +
                """Helper for implementing a mixture of experts.
         | 
| 25 | 
            +
                The purpose of this class is to create input minibatches for the
         | 
| 26 | 
            +
                experts and to combine the results of the experts to form a unified
         | 
| 27 | 
            +
                output tensor.
         | 
| 28 | 
            +
                There are two functions:
         | 
| 29 | 
            +
                dispatch - take an input Tensor and create input Tensors for each expert.
         | 
| 30 | 
            +
                combine - take output Tensors from each expert and form a combined output
         | 
| 31 | 
            +
                  Tensor.  Outputs from different experts for the same batch element are
         | 
| 32 | 
            +
                  summed together, weighted by the provided "gates".
         | 
| 33 | 
            +
                The class is initialized with a "gates" Tensor, which specifies which
         | 
| 34 | 
            +
                batch elements go to which experts, and the weights to use when combining
         | 
| 35 | 
            +
                the outputs.  Batch element b is sent to expert e iff gates[b, e] != 0.
         | 
| 36 | 
            +
                The inputs and outputs are all two-dimensional [batch, depth].
         | 
| 37 | 
            +
                Caller is responsible for collapsing additional dimensions prior to
         | 
| 38 | 
            +
                calling this class and reshaping the output to the original shape.
         | 
| 39 | 
            +
                See common_layers.reshape_like().
         | 
| 40 | 
            +
                Example use:
         | 
| 41 | 
            +
                gates: a float32 `Tensor` with shape `[batch_size, num_experts]`
         | 
| 42 | 
            +
                inputs: a float32 `Tensor` with shape `[batch_size, input_size]`
         | 
| 43 | 
            +
                experts: a list of length `num_experts` containing sub-networks.
         | 
| 44 | 
            +
                dispatcher = SparseDispatcher(num_experts, gates)
         | 
| 45 | 
            +
                expert_inputs = dispatcher.dispatch(inputs)
         | 
| 46 | 
            +
                expert_outputs = [experts[i](expert_inputs[i]) for i in range(num_experts)]
         | 
| 47 | 
            +
                outputs = dispatcher.combine(expert_outputs)
         | 
| 48 | 
            +
                The preceding code sets the output for a particular example b to:
         | 
| 49 | 
            +
                output[b] = Sum_i(gates[b, i] * experts[i](inputs[b]))
         | 
| 50 | 
            +
                This class takes advantage of sparsity in the gate matrix by including in the
         | 
| 51 | 
            +
                `Tensor`s for expert i only the batch elements for which `gates[b, i] > 0`.
         | 
| 52 | 
            +
                """
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                def __init__(self, num_experts, gates):
         | 
| 55 | 
            +
                    """Create a SparseDispatcher."""
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                    self._gates = gates
         | 
| 58 | 
            +
                    self._num_experts = num_experts
         | 
| 59 | 
            +
                    # sort experts
         | 
| 60 | 
            +
                    sorted_experts, index_sorted_experts = torch.nonzero(gates).sort(0)
         | 
| 61 | 
            +
                    # drop indices
         | 
| 62 | 
            +
                    _, self._expert_index = sorted_experts.split(1, dim=1)
         | 
| 63 | 
            +
                    # get according batch index for each expert
         | 
| 64 | 
            +
                    self._batch_index = torch.nonzero(gates)[index_sorted_experts[:, 1], 0]
         | 
| 65 | 
            +
                    # calculate num samples that each expert gets
         | 
| 66 | 
            +
                    self._part_sizes = (gates > 0).sum(0).tolist()
         | 
| 67 | 
            +
                    # expand gates to match with self._batch_index
         | 
| 68 | 
            +
                    gates_exp = gates[self._batch_index.flatten()]
         | 
| 69 | 
            +
                    self._nonzero_gates = torch.gather(gates_exp, 1, self._expert_index)
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                def dispatch(self, inp):
         | 
| 72 | 
            +
                    """Create one input Tensor for each expert.
         | 
| 73 | 
            +
                    The `Tensor` for a expert `i` contains the slices of `inp` corresponding
         | 
| 74 | 
            +
                    to the batch elements `b` where `gates[b, i] > 0`.
         | 
| 75 | 
            +
                    Args:
         | 
| 76 | 
            +
                      inp: a `Tensor` of shape "[batch_size, <extra_input_dims>]`
         | 
| 77 | 
            +
                    Returns:
         | 
| 78 | 
            +
                      a list of `num_experts` `Tensor`s with shapes
         | 
| 79 | 
            +
                        `[expert_batch_size_i, <extra_input_dims>]`.
         | 
| 80 | 
            +
                    """
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    # assigns samples to experts whose gate is nonzero
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                    # expand according to batch index so we can just split by _part_sizes
         | 
| 85 | 
            +
                    inp_exp = inp[self._batch_index].squeeze(1)
         | 
| 86 | 
            +
                    return torch.split(inp_exp, self._part_sizes, dim=0)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                def combine(self, expert_out, multiply_by_gates=True, cnn_combine=None):
         | 
| 89 | 
            +
                    """Sum together the expert output, weighted by the gates.
         | 
| 90 | 
            +
                    The slice corresponding to a particular batch element `b` is computed
         | 
| 91 | 
            +
                    as the sum over all experts `i` of the expert output, weighted by the
         | 
| 92 | 
            +
                    corresponding gate values.  If `multiply_by_gates` is set to False, the
         | 
| 93 | 
            +
                    gate values are ignored.
         | 
| 94 | 
            +
                    Args:
         | 
| 95 | 
            +
                      expert_out: a list of `num_experts` `Tensor`s, each with shape
         | 
| 96 | 
            +
                        `[expert_batch_size_i, <extra_output_dims>]`.
         | 
| 97 | 
            +
                      multiply_by_gates: a boolean
         | 
| 98 | 
            +
                    Returns:
         | 
| 99 | 
            +
                      a `Tensor` with shape `[batch_size, <extra_output_dims>]`.
         | 
| 100 | 
            +
                    """
         | 
| 101 | 
            +
                    # apply exp to expert outputs, so we are not longer in log space
         | 
| 102 | 
            +
                    stitched = torch.cat(expert_out, 0)
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                    if multiply_by_gates:
         | 
| 105 | 
            +
                        stitched = stitched.mul(self._nonzero_gates.unsqueeze(1))
         | 
| 106 | 
            +
                    zeros = torch.zeros((self._gates.size(0),) + expert_out[-1].shape[1:],
         | 
| 107 | 
            +
                                        requires_grad=True, device=stitched.device)
         | 
| 108 | 
            +
                    # combine samples that have been processed by the same k experts
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                    if cnn_combine is not None:
         | 
| 111 | 
            +
                        return self.smartly_combine(stitched, cnn_combine)
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    combined = zeros.index_add(0, self._batch_index, stitched.float())
         | 
| 114 | 
            +
                    return combined
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                def smartly_combine(self, stitched, cnn_combine):
         | 
| 117 | 
            +
                    idxes = []
         | 
| 118 | 
            +
                    for i in self._batch_index.unique():
         | 
| 119 | 
            +
                        idx = (self._batch_index == i).nonzero().squeeze(1)
         | 
| 120 | 
            +
                        idxes.append(idx)
         | 
| 121 | 
            +
                    idxes = torch.stack(idxes)
         | 
| 122 | 
            +
                    return cnn_combine(stitched[idxes]).squeeze(1)
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                def expert_to_gates(self):
         | 
| 125 | 
            +
                    """Gate values corresponding to the examples in the per-expert `Tensor`s.
         | 
| 126 | 
            +
                    Returns:
         | 
| 127 | 
            +
                      a list of `num_experts` one-dimensional `Tensor`s with type `tf.float32`
         | 
| 128 | 
            +
                          and shapes `[expert_batch_size_i]`
         | 
| 129 | 
            +
                    """
         | 
| 130 | 
            +
                    # split nonzero gates for each expert
         | 
| 131 | 
            +
                    return torch.split(self._nonzero_gates, self._part_sizes, dim=0)
         | 
| 132 | 
            +
             | 
| 133 | 
            +
             | 
| 134 | 
            +
            def build_experts(experts_cfg, default_cfg, num_experts):
         | 
| 135 | 
            +
                experts_cfg = deepcopy(experts_cfg)
         | 
| 136 | 
            +
                if experts_cfg is None:
         | 
| 137 | 
            +
                    # old build way
         | 
| 138 | 
            +
                    return nn.ModuleList([
         | 
| 139 | 
            +
                        MLP(*default_cfg)
         | 
| 140 | 
            +
                        for i in range(num_experts)])
         | 
| 141 | 
            +
                # new build way: mix mlp with leff
         | 
| 142 | 
            +
                experts = []
         | 
| 143 | 
            +
                for e_cfg in experts_cfg:
         | 
| 144 | 
            +
                    type_ = e_cfg.pop('type')
         | 
| 145 | 
            +
                    if type_ == 'mlp':
         | 
| 146 | 
            +
                        experts.append(MLP(*default_cfg))
         | 
| 147 | 
            +
                return nn.ModuleList(experts)
         | 
| 148 | 
            +
             | 
| 149 | 
            +
             | 
| 150 | 
            +
            class MoE(nn.Module):
         | 
| 151 | 
            +
                """Call a Sparsely gated mixture of experts layer with 1-layer
         | 
| 152 | 
            +
                   Feed-Forward networks as experts.
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                Args:
         | 
| 155 | 
            +
                input_size: integer - size of the input
         | 
| 156 | 
            +
                output_size: integer - size of the input
         | 
| 157 | 
            +
                num_experts: an integer - number of experts
         | 
| 158 | 
            +
                hidden_size: an integer - hidden size of the experts
         | 
| 159 | 
            +
                noisy_gating: a boolean
         | 
| 160 | 
            +
                k: an integer - how many experts to use for each batch element
         | 
| 161 | 
            +
                """
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                def __init__(self, input_size, output_size, num_experts, hidden_size,
         | 
| 164 | 
            +
                             experts=None, noisy_gating=True, k=4,
         | 
| 165 | 
            +
                             x_gating=None, with_noise=True, with_smart_merger=None):
         | 
| 166 | 
            +
                    super(MoE, self).__init__()
         | 
| 167 | 
            +
                    self.noisy_gating = noisy_gating
         | 
| 168 | 
            +
                    self.num_experts = num_experts
         | 
| 169 | 
            +
                    self.output_size = output_size
         | 
| 170 | 
            +
                    self.input_size = input_size
         | 
| 171 | 
            +
                    self.hidden_size = hidden_size
         | 
| 172 | 
            +
                    self.k = k
         | 
| 173 | 
            +
                    self.with_noise = with_noise
         | 
| 174 | 
            +
                    # instantiate experts
         | 
| 175 | 
            +
                    self.experts = build_experts(
         | 
| 176 | 
            +
                        experts,
         | 
| 177 | 
            +
                        (self.input_size, self.hidden_size, self.output_size),
         | 
| 178 | 
            +
                        num_experts)
         | 
| 179 | 
            +
                    self.w_gate = nn.Parameter(torch.zeros(input_size, num_experts), requires_grad=True)
         | 
| 180 | 
            +
                    self.w_noise = nn.Parameter(torch.zeros(input_size, num_experts), requires_grad=True)
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                    self.x_gating = x_gating
         | 
| 183 | 
            +
                    if self.x_gating == 'conv1d':
         | 
| 184 | 
            +
                        self.x_gate = nn.Conv1d(4096, 1, kernel_size=3, padding=1)
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                    self.softplus = nn.Softplus()
         | 
| 187 | 
            +
                    self.softmax = nn.Softmax(1)
         | 
| 188 | 
            +
                    self.register_buffer("mean", torch.tensor([0.0]))
         | 
| 189 | 
            +
                    self.register_buffer("std", torch.tensor([1.0]))
         | 
| 190 | 
            +
                    assert(self.k <= self.num_experts)
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                    self.cnn_combine = None
         | 
| 193 | 
            +
                    if with_smart_merger == 'v1':
         | 
| 194 | 
            +
                        print('with SMART MERGER')
         | 
| 195 | 
            +
                        self.cnn_combine = nn.Conv2d(self.k, 1, kernel_size=3, padding=1)
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                def cv_squared(self, x):
         | 
| 198 | 
            +
                    """The squared coefficient of variation of a sample.
         | 
| 199 | 
            +
                    Useful as a loss to encourage a positive distribution to be more uniform.
         | 
| 200 | 
            +
                    Epsilons added for numerical stability.
         | 
| 201 | 
            +
                    Returns 0 for an empty Tensor.
         | 
| 202 | 
            +
                    Args:
         | 
| 203 | 
            +
                    x: a `Tensor`.
         | 
| 204 | 
            +
                    Returns:
         | 
| 205 | 
            +
                    a `Scalar`.
         | 
| 206 | 
            +
                    """
         | 
| 207 | 
            +
                    eps = 1e-10
         | 
| 208 | 
            +
                    # if only num_experts = 1
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                    if x.shape[0] == 1:
         | 
| 211 | 
            +
                        return torch.tensor([0], device=x.device, dtype=x.dtype)
         | 
| 212 | 
            +
                    return x.float().var() / (x.float().mean()**2 + eps)
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                def _gates_to_load(self, gates):
         | 
| 215 | 
            +
                    """Compute the true load per expert, given the gates.
         | 
| 216 | 
            +
                    The load is the number of examples for which the corresponding gate is >0.
         | 
| 217 | 
            +
                    Args:
         | 
| 218 | 
            +
                    gates: a `Tensor` of shape [batch_size, n]
         | 
| 219 | 
            +
                    Returns:
         | 
| 220 | 
            +
                    a float32 `Tensor` of shape [n]
         | 
| 221 | 
            +
                    """
         | 
| 222 | 
            +
                    return (gates > 0).sum(0)
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                def _prob_in_top_k(self, clean_values, noisy_values, noise_stddev, noisy_top_values):
         | 
| 225 | 
            +
                    """Helper function to NoisyTopKGating.
         | 
| 226 | 
            +
                    Computes the probability that value is in top k, given different random noise.
         | 
| 227 | 
            +
                    This gives us a way of backpropagating from a loss that balances the number
         | 
| 228 | 
            +
                    of times each expert is in the top k experts per example.
         | 
| 229 | 
            +
                    In the case of no noise, pass in None for noise_stddev, and the result will
         | 
| 230 | 
            +
                    not be differentiable.
         | 
| 231 | 
            +
                    Args:
         | 
| 232 | 
            +
                    clean_values: a `Tensor` of shape [batch, n].
         | 
| 233 | 
            +
                    noisy_values: a `Tensor` of shape [batch, n].  Equal to clean values plus
         | 
| 234 | 
            +
                      normally distributed noise with standard deviation noise_stddev.
         | 
| 235 | 
            +
                    noise_stddev: a `Tensor` of shape [batch, n], or None
         | 
| 236 | 
            +
                    noisy_top_values: a `Tensor` of shape [batch, m].
         | 
| 237 | 
            +
                       "values" Output of tf.top_k(noisy_top_values, m).  m >= k+1
         | 
| 238 | 
            +
                    Returns:
         | 
| 239 | 
            +
                    a `Tensor` of shape [batch, n].
         | 
| 240 | 
            +
                    """
         | 
| 241 | 
            +
                    batch = clean_values.size(0)
         | 
| 242 | 
            +
                    m = noisy_top_values.size(1)
         | 
| 243 | 
            +
                    top_values_flat = noisy_top_values.flatten()
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                    threshold_positions_if_in = torch.arange(batch, device=clean_values.device) * m + self.k
         | 
| 246 | 
            +
                    threshold_if_in = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_in), 1)
         | 
| 247 | 
            +
                    is_in = torch.gt(noisy_values, threshold_if_in)
         | 
| 248 | 
            +
                    threshold_positions_if_out = threshold_positions_if_in - 1
         | 
| 249 | 
            +
                    threshold_if_out = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_out), 1)
         | 
| 250 | 
            +
                    # is each value currently in the top k.
         | 
| 251 | 
            +
                    normal = Normal(self.mean, self.std)
         | 
| 252 | 
            +
                    prob_if_in = normal.cdf((clean_values - threshold_if_in)/noise_stddev)
         | 
| 253 | 
            +
                    prob_if_out = normal.cdf((clean_values - threshold_if_out)/noise_stddev)
         | 
| 254 | 
            +
                    prob = torch.where(is_in, prob_if_in, prob_if_out)
         | 
| 255 | 
            +
                    return prob
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                def noisy_top_k_gating(self, x, train, noise_epsilon=1e-2):
         | 
| 258 | 
            +
                    """Noisy top-k gating.
         | 
| 259 | 
            +
                      See paper: https://arxiv.org/abs/1701.06538.
         | 
| 260 | 
            +
                      Args:
         | 
| 261 | 
            +
                        x: input Tensor with shape [batch_size, input_size]
         | 
| 262 | 
            +
                        train: a boolean - we only add noise at training time.
         | 
| 263 | 
            +
                        noise_epsilon: a float
         | 
| 264 | 
            +
                      Returns:
         | 
| 265 | 
            +
                        gates: a Tensor with shape [batch_size, num_experts]
         | 
| 266 | 
            +
                        load: a Tensor with shape [num_experts]
         | 
| 267 | 
            +
                    """
         | 
| 268 | 
            +
                    clean_logits = x @ self.w_gate
         | 
| 269 | 
            +
                    if self.noisy_gating and train:
         | 
| 270 | 
            +
                        raw_noise_stddev = x @ self.w_noise
         | 
| 271 | 
            +
                        noise_stddev = ((self.softplus(raw_noise_stddev) + noise_epsilon))
         | 
| 272 | 
            +
                        noisy_logits = clean_logits + (torch.randn_like(clean_logits) * noise_stddev)
         | 
| 273 | 
            +
                        logits = noisy_logits
         | 
| 274 | 
            +
                    else:
         | 
| 275 | 
            +
                        logits = clean_logits
         | 
| 276 | 
            +
             | 
| 277 | 
            +
                    # calculate topk + 1 that will be needed for the noisy gates
         | 
| 278 | 
            +
                    top_logits, top_indices = logits.topk(min(self.k + 1, self.num_experts), dim=1)
         | 
| 279 | 
            +
                    top_k_logits = top_logits[:, :self.k]
         | 
| 280 | 
            +
                    top_k_indices = top_indices[:, :self.k]
         | 
| 281 | 
            +
                    top_k_gates = self.softmax(top_k_logits)
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                    zeros = torch.zeros_like(logits, requires_grad=True)
         | 
| 284 | 
            +
                    gates = zeros.scatter(1, top_k_indices, top_k_gates)
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                    if self.noisy_gating and self.k < self.num_experts and train:
         | 
| 287 | 
            +
                        load = (self._prob_in_top_k(clean_logits, noisy_logits, noise_stddev, top_logits)).sum(0)
         | 
| 288 | 
            +
                    else:
         | 
| 289 | 
            +
                        load = self._gates_to_load(gates)
         | 
| 290 | 
            +
                    return gates, load
         | 
| 291 | 
            +
             | 
| 292 | 
            +
                def forward(self, x, loss_coef=1e-2):
         | 
| 293 | 
            +
                    """Args:
         | 
| 294 | 
            +
                    x: tensor shape [batch_size, input_size]
         | 
| 295 | 
            +
                    train: a boolean scalar.
         | 
| 296 | 
            +
                    loss_coef: a scalar - multiplier on load-balancing losses
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                    Returns:
         | 
| 299 | 
            +
                    y: a tensor with shape [batch_size, output_size].
         | 
| 300 | 
            +
                    extra_training_loss: a scalar.  This should be added into the overall
         | 
| 301 | 
            +
                    training loss of the model.  The backpropagation of this loss
         | 
| 302 | 
            +
                    encourages all experts to be approximately equally used across a batch.
         | 
| 303 | 
            +
                    """
         | 
| 304 | 
            +
                    if self.x_gating is not None:
         | 
| 305 | 
            +
                        xg = self.x_gate(x).squeeze(1)
         | 
| 306 | 
            +
                    else:
         | 
| 307 | 
            +
                        xg = x.mean(1)
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                    gates, load = self.noisy_top_k_gating(
         | 
| 310 | 
            +
                        xg, self.training and self.with_noise)
         | 
| 311 | 
            +
                    # calculate importance loss
         | 
| 312 | 
            +
                    importance = gates.sum(0)
         | 
| 313 | 
            +
                    #
         | 
| 314 | 
            +
                    loss = self.cv_squared(importance) + self.cv_squared(load)
         | 
| 315 | 
            +
                    loss *= loss_coef
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                    dispatcher = SparseDispatcher(self.num_experts, gates)
         | 
| 318 | 
            +
                    expert_inputs = dispatcher.dispatch(x)
         | 
| 319 | 
            +
                    gates = dispatcher.expert_to_gates()
         | 
| 320 | 
            +
                    expert_outputs = [self.experts[i](expert_inputs[i])
         | 
| 321 | 
            +
                                      for i in range(self.num_experts)]
         | 
| 322 | 
            +
                    y = dispatcher.combine(expert_outputs, cnn_combine=self.cnn_combine)
         | 
| 323 | 
            +
                    return y, loss
         | 
    	
        swin2_mose/run.py
    ADDED
    
    | @@ -0,0 +1,20 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            from model import Swin2SR
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            model_weights = "model-70.pt"
         | 
| 5 | 
            +
            model_params = {
         | 
| 6 | 
            +
                "upscale": 2,
         | 
| 7 | 
            +
                "in_chans": 4,
         | 
| 8 | 
            +
                "img_size": 64,
         | 
| 9 | 
            +
                "window_size": 16,
         | 
| 10 | 
            +
                "img_range": 1.,
         | 
| 11 | 
            +
                "depths": [6, 6, 6, 6],
         | 
| 12 | 
            +
                "embed_dim": 90,
         | 
| 13 | 
            +
                "num_heads": [6, 6, 6, 6],
         | 
| 14 | 
            +
                "mlp_ratio": 2,
         | 
| 15 | 
            +
                "upsampler": "pixelshuffledirect",
         | 
| 16 | 
            +
                "resi_connection": "1conv"
         | 
| 17 | 
            +
            }
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            sr_model = Swin2SR(**model_params)
         | 
| 20 | 
            +
            sr_model.load_state_dict(torch.load(model_weights))
         | 
    	
        swin2_mose/utils.py
    ADDED
    
    | @@ -0,0 +1,56 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from torch import nn
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            def window_reverse(windows, window_size, H, W):
         | 
| 5 | 
            +
                """
         | 
| 6 | 
            +
                Args:
         | 
| 7 | 
            +
                    windows: (num_windows*B, window_size, window_size, C)
         | 
| 8 | 
            +
                    window_size (int): Window size
         | 
| 9 | 
            +
                    H (int): Height of image
         | 
| 10 | 
            +
                    W (int): Width of image
         | 
| 11 | 
            +
             | 
| 12 | 
            +
                Returns:
         | 
| 13 | 
            +
                    x: (B, H, W, C)
         | 
| 14 | 
            +
                """
         | 
| 15 | 
            +
                B = int(windows.shape[0] / (H * W / window_size / window_size))
         | 
| 16 | 
            +
                x = windows.view(B, H // window_size, W // window_size, window_size,
         | 
| 17 | 
            +
                                 window_size, -1)
         | 
| 18 | 
            +
                x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
         | 
| 19 | 
            +
                return x
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            class Mlp(nn.Module):
         | 
| 23 | 
            +
                def __init__(self, in_features, hidden_features=None, out_features=None,
         | 
| 24 | 
            +
                             act_layer=nn.GELU, drop=0.):
         | 
| 25 | 
            +
                    super().__init__()
         | 
| 26 | 
            +
                    out_features = out_features or in_features
         | 
| 27 | 
            +
                    hidden_features = hidden_features or in_features
         | 
| 28 | 
            +
                    self.fc1 = nn.Linear(in_features, hidden_features)
         | 
| 29 | 
            +
                    self.act = act_layer()
         | 
| 30 | 
            +
                    self.fc2 = nn.Linear(hidden_features, out_features)
         | 
| 31 | 
            +
                    self.drop = nn.Dropout(drop)
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                def forward(self, x):
         | 
| 34 | 
            +
                    x = self.fc1(x)
         | 
| 35 | 
            +
                    x = self.act(x)
         | 
| 36 | 
            +
                    x = self.drop(x)
         | 
| 37 | 
            +
                    x = self.fc2(x)
         | 
| 38 | 
            +
                    x = self.drop(x)
         | 
| 39 | 
            +
                    return x
         | 
| 40 | 
            +
             | 
| 41 | 
            +
             | 
| 42 | 
            +
            def window_partition(x, window_size):
         | 
| 43 | 
            +
                """
         | 
| 44 | 
            +
                Args:
         | 
| 45 | 
            +
                    x: (B, H, W, C)
         | 
| 46 | 
            +
                    window_size (int): window size
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                Returns:
         | 
| 49 | 
            +
                    windows: (num_windows*B, window_size, window_size, C)
         | 
| 50 | 
            +
                """
         | 
| 51 | 
            +
                B, H, W, C = x.shape
         | 
| 52 | 
            +
                x = x.view(B, H // window_size, window_size,
         | 
| 53 | 
            +
                           W // window_size, window_size, C)
         | 
| 54 | 
            +
                windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(
         | 
| 55 | 
            +
                    -1, window_size, window_size, C)
         | 
| 56 | 
            +
                return windows
         | 
    	
        swin2_mose/weights/model-70.pt
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:c9f1229521879af2c8162f7a32fe278e487d0bc0826dddccc87a4e22294aa067
         | 
| 3 | 
            +
            size 118890958
         | 

