Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	
		NoelShin
		
	commited on
		
		
					Commit 
							
							·
						
						35188e4
	
1
								Parent(s):
							
							6c45278
								
Add application file
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- .DS_Store +0 -0
 - .idea/.gitignore +8 -0
 - .idea/deployment.xml +15 -0
 - .idea/inspectionProfiles/Project_Default.xml +23 -0
 - .idea/inspectionProfiles/profiles_settings.xml +6 -0
 - .idea/misc.xml +4 -0
 - .idea/modules.xml +8 -0
 - .idea/selfmask_demo.iml +8 -0
 - .idea/sonarlint/issuestore/index.pb +0 -0
 - .idea/webServers.xml +14 -0
 - __pycache__/bilateral_solver.cpython-38.pyc +0 -0
 - __pycache__/utils.cpython-38.pyc +0 -0
 - app.py +134 -0
 - bilateral_solver.py +206 -0
 - duts-dino-k234-nq20-224-swav-mocov2-dino-p16-sr10100.yaml +56 -0
 - networks/__init__.py +0 -0
 - networks/__pycache__/__init__.cpython-38.pyc +0 -0
 - networks/__pycache__/timm_deit.cpython-38.pyc +0 -0
 - networks/__pycache__/timm_vit.cpython-38.pyc +0 -0
 - networks/__pycache__/vision_transformer.cpython-38.pyc +0 -0
 - networks/maskformer/__pycache__/maskformer.cpython-38.pyc +0 -0
 - networks/maskformer/__pycache__/transformer_decoder.cpython-38.pyc +0 -0
 - networks/maskformer/maskformer.py +267 -0
 - networks/maskformer/positional_embedding.py +48 -0
 - networks/maskformer/transformer_decoder.py +376 -0
 - networks/module_helper.py +176 -0
 - networks/resnet.py +60 -0
 - networks/resnet_backbone.py +194 -0
 - networks/resnet_models.py +273 -0
 - networks/timm_deit.py +254 -0
 - networks/timm_vit.py +819 -0
 - networks/vision_transformer.py +569 -0
 - resources/.DS_Store +0 -0
 - resources/0053.jpg +0 -0
 - resources/0236.jpg +0 -0
 - resources/0239.jpg +0 -0
 - resources/0403.jpg +0 -0
 - resources/0412.jpg +0 -0
 - resources/ILSVRC2012_test_00005309.jpg +0 -0
 - resources/ILSVRC2012_test_00012622.jpg +0 -0
 - resources/ILSVRC2012_test_00022698.jpg +0 -0
 - resources/ILSVRC2012_test_00040725.jpg +0 -0
 - resources/ILSVRC2012_test_00075738.jpg +0 -0
 - resources/ILSVRC2012_test_00080683.jpg +0 -0
 - resources/ILSVRC2012_test_00085874.jpg +0 -0
 - resources/im052.jpg +0 -0
 - resources/sun_ainjbonxmervsvpv.jpg +0 -0
 - resources/sun_alfntqzssslakmss.jpg +0 -0
 - resources/sun_amnrcxhisjfrliwa.jpg +0 -0
 - resources/sun_bvyxpvkouzlfwwod.jpg +0 -0
 
    	
        .DS_Store
    ADDED
    
    | 
         Binary file (6.15 kB). View file 
     | 
| 
         | 
    	
        .idea/.gitignore
    ADDED
    
    | 
         @@ -0,0 +1,8 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Default ignored files
         
     | 
| 2 | 
         
            +
            /shelf/
         
     | 
| 3 | 
         
            +
            /workspace.xml
         
     | 
| 4 | 
         
            +
            # Editor-based HTTP Client requests
         
     | 
| 5 | 
         
            +
            /httpRequests/
         
     | 
| 6 | 
         
            +
            # Datasource local storage ignored files
         
     | 
| 7 | 
         
            +
            /dataSources/
         
     | 
| 8 | 
         
            +
            /dataSources.local.xml
         
     | 
    	
        .idea/deployment.xml
    ADDED
    
    | 
         @@ -0,0 +1,15 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            <?xml version="1.0" encoding="UTF-8"?>
         
     | 
| 2 | 
         
            +
            <project version="4">
         
     | 
| 3 | 
         
            +
              <component name="PublishConfigData" autoUpload="Always" serverName="mydev" remoteFilesAllowedToDisappearOnAutoupload="false">
         
     | 
| 4 | 
         
            +
                <serverData>
         
     | 
| 5 | 
         
            +
                  <paths name="mydev">
         
     | 
| 6 | 
         
            +
                    <serverdata>
         
     | 
| 7 | 
         
            +
                      <mappings>
         
     | 
| 8 | 
         
            +
                        <mapping deploy="/" local="$PROJECT_DIR$" web="/" />
         
     | 
| 9 | 
         
            +
                      </mappings>
         
     | 
| 10 | 
         
            +
                    </serverdata>
         
     | 
| 11 | 
         
            +
                  </paths>
         
     | 
| 12 | 
         
            +
                </serverData>
         
     | 
| 13 | 
         
            +
                <option name="myAutoUpload" value="ALWAYS" />
         
     | 
| 14 | 
         
            +
              </component>
         
     | 
| 15 | 
         
            +
            </project>
         
     | 
    	
        .idea/inspectionProfiles/Project_Default.xml
    ADDED
    
    | 
         @@ -0,0 +1,23 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            <component name="InspectionProjectProfileManager">
         
     | 
| 2 | 
         
            +
              <profile version="1.0">
         
     | 
| 3 | 
         
            +
                <option name="myName" value="Project Default" />
         
     | 
| 4 | 
         
            +
                <inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
         
     | 
| 5 | 
         
            +
                  <option name="ignoredPackages">
         
     | 
| 6 | 
         
            +
                    <value>
         
     | 
| 7 | 
         
            +
                      <list size="10">
         
     | 
| 8 | 
         
            +
                        <item index="0" class="java.lang.String" itemvalue="prettytable" />
         
     | 
| 9 | 
         
            +
                        <item index="1" class="java.lang.String" itemvalue="interrogate" />
         
     | 
| 10 | 
         
            +
                        <item index="2" class="java.lang.String" itemvalue="pytest" />
         
     | 
| 11 | 
         
            +
                        <item index="3" class="java.lang.String" itemvalue="yapf" />
         
     | 
| 12 | 
         
            +
                        <item index="4" class="java.lang.String" itemvalue="cityscapesscripts" />
         
     | 
| 13 | 
         
            +
                        <item index="5" class="java.lang.String" itemvalue="Wand" />
         
     | 
| 14 | 
         
            +
                        <item index="6" class="java.lang.String" itemvalue="isort" />
         
     | 
| 15 | 
         
            +
                        <item index="7" class="java.lang.String" itemvalue="xdoctest" />
         
     | 
| 16 | 
         
            +
                        <item index="8" class="java.lang.String" itemvalue="codecov" />
         
     | 
| 17 | 
         
            +
                        <item index="9" class="java.lang.String" itemvalue="flake8" />
         
     | 
| 18 | 
         
            +
                      </list>
         
     | 
| 19 | 
         
            +
                    </value>
         
     | 
| 20 | 
         
            +
                  </option>
         
     | 
| 21 | 
         
            +
                </inspection_tool>
         
     | 
| 22 | 
         
            +
              </profile>
         
     | 
| 23 | 
         
            +
            </component>
         
     | 
    	
        .idea/inspectionProfiles/profiles_settings.xml
    ADDED
    
    | 
         @@ -0,0 +1,6 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            <component name="InspectionProjectProfileManager">
         
     | 
| 2 | 
         
            +
              <settings>
         
     | 
| 3 | 
         
            +
                <option name="USE_PROJECT_PROFILE" value="false" />
         
     | 
| 4 | 
         
            +
                <version value="1.0" />
         
     | 
| 5 | 
         
            +
              </settings>
         
     | 
| 6 | 
         
            +
            </component>
         
     | 
    	
        .idea/misc.xml
    ADDED
    
    | 
         @@ -0,0 +1,4 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            <?xml version="1.0" encoding="UTF-8"?>
         
     | 
| 2 | 
         
            +
            <project version="4">
         
     | 
| 3 | 
         
            +
              <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (pytorch)" project-jdk-type="Python SDK" />
         
     | 
| 4 | 
         
            +
            </project>
         
     | 
    	
        .idea/modules.xml
    ADDED
    
    | 
         @@ -0,0 +1,8 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            <?xml version="1.0" encoding="UTF-8"?>
         
     | 
| 2 | 
         
            +
            <project version="4">
         
     | 
| 3 | 
         
            +
              <component name="ProjectModuleManager">
         
     | 
| 4 | 
         
            +
                <modules>
         
     | 
| 5 | 
         
            +
                  <module fileurl="file://$PROJECT_DIR$/.idea/selfmask_demo.iml" filepath="$PROJECT_DIR$/.idea/selfmask_demo.iml" />
         
     | 
| 6 | 
         
            +
                </modules>
         
     | 
| 7 | 
         
            +
              </component>
         
     | 
| 8 | 
         
            +
            </project>
         
     | 
    	
        .idea/selfmask_demo.iml
    ADDED
    
    | 
         @@ -0,0 +1,8 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            <?xml version="1.0" encoding="UTF-8"?>
         
     | 
| 2 | 
         
            +
            <module type="PYTHON_MODULE" version="4">
         
     | 
| 3 | 
         
            +
              <component name="NewModuleRootManager">
         
     | 
| 4 | 
         
            +
                <content url="file://$MODULE_DIR$" />
         
     | 
| 5 | 
         
            +
                <orderEntry type="inheritedJdk" />
         
     | 
| 6 | 
         
            +
                <orderEntry type="sourceFolder" forTests="false" />
         
     | 
| 7 | 
         
            +
              </component>
         
     | 
| 8 | 
         
            +
            </module>
         
     | 
    	
        .idea/sonarlint/issuestore/index.pb
    ADDED
    
    | 
         
            File without changes
         
     | 
    	
        .idea/webServers.xml
    ADDED
    
    | 
         @@ -0,0 +1,14 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            <?xml version="1.0" encoding="UTF-8"?>
         
     | 
| 2 | 
         
            +
            <project version="4">
         
     | 
| 3 | 
         
            +
              <component name="WebServers">
         
     | 
| 4 | 
         
            +
                <option name="servers">
         
     | 
| 5 | 
         
            +
                  <webServer id="12e2cf4d-3b81-4241-9665-54a333f70567" name="mydev">
         
     | 
| 6 | 
         
            +
                    <fileTransfer rootFolder="/users/gyungin/selfmask_demo" accessType="SFTP" host="mydev" port="22" sshConfigId="3e23a652-ab3c-4dc2-a117-84c2bf217891" sshConfig="gyungin@mydev:22 password">
         
     | 
| 7 | 
         
            +
                      <advancedOptions>
         
     | 
| 8 | 
         
            +
                        <advancedOptions dataProtectionLevel="Private" passiveMode="true" shareSSLContext="true" />
         
     | 
| 9 | 
         
            +
                      </advancedOptions>
         
     | 
| 10 | 
         
            +
                    </fileTransfer>
         
     | 
| 11 | 
         
            +
                  </webServer>
         
     | 
| 12 | 
         
            +
                </option>
         
     | 
| 13 | 
         
            +
              </component>
         
     | 
| 14 | 
         
            +
            </project>
         
     | 
    	
        __pycache__/bilateral_solver.cpython-38.pyc
    ADDED
    
    | 
         Binary file (6.76 kB). View file 
     | 
| 
         | 
    	
        __pycache__/utils.cpython-38.pyc
    ADDED
    
    | 
         Binary file (2.9 kB). View file 
     | 
| 
         | 
    	
        app.py
    ADDED
    
    | 
         @@ -0,0 +1,134 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from argparse import ArgumentParser, Namespace
         
     | 
| 2 | 
         
            +
            from typing import Dict, List, Tuple
         
     | 
| 3 | 
         
            +
            import yaml
         
     | 
| 4 | 
         
            +
            import numpy as np
         
     | 
| 5 | 
         
            +
            import cv2
         
     | 
| 6 | 
         
            +
            from PIL import Image
         
     | 
| 7 | 
         
            +
            import torch
         
     | 
| 8 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 9 | 
         
            +
            from torchvision.transforms.functional import to_tensor, normalize, resize
         
     | 
| 10 | 
         
            +
            import gradio as gr
         
     | 
| 11 | 
         
            +
            from utils import get_model
         
     | 
| 12 | 
         
            +
            from bilateral_solver import bilateral_solver_output
         
     | 
| 13 | 
         
            +
            import os
         
     | 
| 14 | 
         
            +
            os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         
     | 
| 17 | 
         
            +
            state_dict: dict = torch.hub.load_state_dict_from_url(
         
     | 
| 18 | 
         
            +
                "https://github.com/NoelShin/selfmask/releases/download/v1.0.0/selfmask_nq20.pt",
         
     | 
| 19 | 
         
            +
                map_location=device  # "cuda" if torch.cuda.is_available() else "cpu"
         
     | 
| 20 | 
         
            +
            )["model"]
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            parser = ArgumentParser("SelfMask demo")
         
     | 
| 23 | 
         
            +
            parser.add_argument(
         
     | 
| 24 | 
         
            +
                "--config",
         
     | 
| 25 | 
         
            +
                type=str,
         
     | 
| 26 | 
         
            +
                default="duts-dino-k234-nq20-224-swav-mocov2-dino-p16-sr10100.yaml"
         
     | 
| 27 | 
         
            +
            )
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            # parser.add_argument(
         
     | 
| 30 | 
         
            +
            #     "--p_state_dict",
         
     | 
| 31 | 
         
            +
            #     type=str,
         
     | 
| 32 | 
         
            +
            #     default="/users/gyungin/selfmask_bak/ckpt/nq20_ndl6_bc_sr10100_duts_pm_all_k2,3,4_md_seed0_final/eval/hku_is/best_model.pt",
         
     | 
| 33 | 
         
            +
            # )
         
     | 
| 34 | 
         
            +
            #
         
     | 
| 35 | 
         
            +
            # parser.add_argument(
         
     | 
| 36 | 
         
            +
            #     "--dataset_name", '-dn', type=str, default="duts",
         
     | 
| 37 | 
         
            +
            #     choices=["dut_omron", "duts", "ecssd"]
         
     | 
| 38 | 
         
            +
            # )
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            # independent variables
         
     | 
| 41 | 
         
            +
            # parser.add_argument("--use_gpu", type=bool, default=True)
         
     | 
| 42 | 
         
            +
            # parser.add_argument('--seed', default=0, type=int)
         
     | 
| 43 | 
         
            +
            # parser.add_argument("--dir_root", type=str, default="..")
         
     | 
| 44 | 
         
            +
            # parser.add_argument("--gpu_id", type=int, default=2)
         
     | 
| 45 | 
         
            +
            # parser.add_argument("--suffix", type=str, default='')
         
     | 
| 46 | 
         
            +
            args: Namespace = parser.parse_args()
         
     | 
| 47 | 
         
            +
            base_args = yaml.safe_load(open(f"{args.config}", 'r'))
         
     | 
| 48 | 
         
            +
            base_args.pop("dataset_name")
         
     | 
| 49 | 
         
            +
            args: dict = vars(args)
         
     | 
| 50 | 
         
            +
            args.update(base_args)
         
     | 
| 51 | 
         
            +
            args: Namespace = Namespace(**args)
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
            model = get_model(arch="maskformer", configs=args).to(device)
         
     | 
| 54 | 
         
            +
            model.load_state_dict(state_dict)
         
     | 
| 55 | 
         
            +
            model.eval()
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
            @torch.no_grad()
         
     | 
| 59 | 
         
            +
            def main(
         
     | 
| 60 | 
         
            +
                    image: Image.Image,
         
     | 
| 61 | 
         
            +
                    size: int = 384,
         
     | 
| 62 | 
         
            +
                    max_size: int = 512,
         
     | 
| 63 | 
         
            +
                    mean: Tuple[float, float, float] = (0.485, 0.456, 0.406),
         
     | 
| 64 | 
         
            +
                    std: Tuple[float, float, float] = (0.229, 0.224, 0.225)
         
     | 
| 65 | 
         
            +
            ):
         
     | 
| 66 | 
         
            +
                pil_image: Image.Image = resize(image, size=size, max_size=max_size)
         
     | 
| 67 | 
         
            +
                image: torch.Tensor = normalize(to_tensor(pil_image), mean=list(mean), std=list(std))  # 3 x H x W
         
     | 
| 68 | 
         
            +
                dict_outputs = model(image[None].to(device))
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                batch_pred_masks: torch.Tensor = dict_outputs["mask_pred"]  # [0, 1]
         
     | 
| 71 | 
         
            +
                batch_objectness: torch.Tensor = dict_outputs.get("objectness", None)  # [0, 1]
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                if len(batch_pred_masks.shape) == 5:
         
     | 
| 74 | 
         
            +
                    # b x n_layers x n_queries x h x w -> b x n_queries x h x w
         
     | 
| 75 | 
         
            +
                    batch_pred_masks = batch_pred_masks[:, -1, ...]  # extract the output from the last decoder layer
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                    if batch_objectness is not None:
         
     | 
| 78 | 
         
            +
                        # b x n_layers x n_queries x 1 -> b x n_queries x 1
         
     | 
| 79 | 
         
            +
                        batch_objectness = batch_objectness[:, -1, ...]
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                # resize prediction to original resolution
         
     | 
| 82 | 
         
            +
                # note: upsampling by 4 and cutting the padded region allows for a better result
         
     | 
| 83 | 
         
            +
                H, W = image.shape[-2:]
         
     | 
| 84 | 
         
            +
                batch_pred_masks = F.interpolate(
         
     | 
| 85 | 
         
            +
                    batch_pred_masks, scale_factor=4, mode="bilinear", align_corners=False
         
     | 
| 86 | 
         
            +
                )[..., :H, :W]
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                # iterate over batch dimension
         
     | 
| 89 | 
         
            +
                for batch_index, pred_masks in enumerate(batch_pred_masks):
         
     | 
| 90 | 
         
            +
                    # n_queries x 1 -> n_queries
         
     | 
| 91 | 
         
            +
                    objectness: torch.Tensor = batch_objectness[batch_index].squeeze(dim=-1)
         
     | 
| 92 | 
         
            +
                    ranks = torch.argsort(objectness, descending=True)  # n_queries
         
     | 
| 93 | 
         
            +
                    pred_mask: torch.Tensor = pred_masks[ranks[0]]  # H x W
         
     | 
| 94 | 
         
            +
                pred_mask: np.ndarray = (pred_mask > 0.5).cpu().numpy().astype(np.uint8) * 255
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                pred_mask_bi, _ = bilateral_solver_output(img=pil_image, target=pred_mask)  # float64
         
     | 
| 97 | 
         
            +
                pred_mask_bi: np.ndarray = np.clip(pred_mask_bi, 0, 255).astype(np.uint8)
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                attn_map = cv2.cvtColor(cv2.applyColorMap(pred_mask_bi, cv2.COLORMAP_VIRIDIS), cv2.COLOR_BGR2RGB)
         
     | 
| 100 | 
         
            +
                super_imposed_img = cv2.addWeighted(attn_map, 0.5, np.array(pil_image), 0.5, 0)
         
     | 
| 101 | 
         
            +
                return super_imposed_img
         
     | 
| 102 | 
         
            +
                # return pred_mask_bi
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
            demo = gr.Interface(
         
     | 
| 105 | 
         
            +
                fn=main,
         
     | 
| 106 | 
         
            +
                inputs=gr.inputs.Image(type="pil"),
         
     | 
| 107 | 
         
            +
                outputs="image",
         
     | 
| 108 | 
         
            +
                examples=[f"resources/{fname}.jpg" for fname in [
         
     | 
| 109 | 
         
            +
                    "0053",
         
     | 
| 110 | 
         
            +
                    "0236",
         
     | 
| 111 | 
         
            +
                    "0239",
         
     | 
| 112 | 
         
            +
                    "0403",
         
     | 
| 113 | 
         
            +
                    "0412",
         
     | 
| 114 | 
         
            +
                    "ILSVRC2012_test_00005309",
         
     | 
| 115 | 
         
            +
                    "ILSVRC2012_test_00012622",
         
     | 
| 116 | 
         
            +
                    "ILSVRC2012_test_00022698",
         
     | 
| 117 | 
         
            +
                    "ILSVRC2012_test_00040725",
         
     | 
| 118 | 
         
            +
                    "ILSVRC2012_test_00075738",
         
     | 
| 119 | 
         
            +
                    "ILSVRC2012_test_00080683",
         
     | 
| 120 | 
         
            +
                    "ILSVRC2012_test_00085874",
         
     | 
| 121 | 
         
            +
                    "im052",
         
     | 
| 122 | 
         
            +
                    "sun_ainjbonxmervsvpv",
         
     | 
| 123 | 
         
            +
                    "sun_alfntqzssslakmss",
         
     | 
| 124 | 
         
            +
                    "sun_amnrcxhisjfrliwa",
         
     | 
| 125 | 
         
            +
                    "sun_bvyxpvkouzlfwwod"
         
     | 
| 126 | 
         
            +
                ]],
         
     | 
| 127 | 
         
            +
                title="Unsupervised Salient Object Detection with Spectral Cluster Voting",
         
     | 
| 128 | 
         
            +
                allow_flagging="never",
         
     | 
| 129 | 
         
            +
                analytics_enabled=False
         
     | 
| 130 | 
         
            +
            )
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
            demo.launch(
         
     | 
| 133 | 
         
            +
                # share=True
         
     | 
| 134 | 
         
            +
            )
         
     | 
    	
        bilateral_solver.py
    ADDED
    
    | 
         @@ -0,0 +1,206 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from scipy.sparse import diags
         
     | 
| 2 | 
         
            +
            from scipy.sparse.linalg import cg
         
     | 
| 3 | 
         
            +
            from scipy.sparse import csr_matrix
         
     | 
| 4 | 
         
            +
            import numpy as np
         
     | 
| 5 | 
         
            +
            from skimage.io import imread
         
     | 
| 6 | 
         
            +
            from scipy import ndimage
         
     | 
| 7 | 
         
            +
            import torch
         
     | 
| 8 | 
         
            +
            import PIL.Image as Image
         
     | 
| 9 | 
         
            +
            import os
         
     | 
| 10 | 
         
            +
            from argparse import ArgumentParser, Namespace
         
     | 
| 11 | 
         
            +
            from typing import Dict, Union
         
     | 
| 12 | 
         
            +
            from collections import defaultdict
         
     | 
| 13 | 
         
            +
            import yaml
         
     | 
| 14 | 
         
            +
            import ujson as json
         
     | 
| 15 | 
         
            +
            import numpy as np
         
     | 
| 16 | 
         
            +
            import torch
         
     | 
| 17 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 18 | 
         
            +
            from PIL import Image
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            RGB_TO_YUV = np.array([
         
     | 
| 22 | 
         
            +
                [0.299, 0.587, 0.114],
         
     | 
| 23 | 
         
            +
                [-0.168736, -0.331264, 0.5],
         
     | 
| 24 | 
         
            +
                [0.5, -0.418688, -0.081312]])
         
     | 
| 25 | 
         
            +
            YUV_TO_RGB = np.array([
         
     | 
| 26 | 
         
            +
                [1.0, 0.0, 1.402],
         
     | 
| 27 | 
         
            +
                [1.0, -0.34414, -0.71414],
         
     | 
| 28 | 
         
            +
                [1.0, 1.772, 0.0]])
         
     | 
| 29 | 
         
            +
            YUV_OFFSET = np.array([0, 128.0, 128.0]).reshape(1, 1, -1)
         
     | 
| 30 | 
         
            +
            MAX_VAL = 255.0
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
            def rgb2yuv(im):
         
     | 
| 34 | 
         
            +
                return np.tensordot(im, RGB_TO_YUV, ([2], [1])) + YUV_OFFSET
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
            def yuv2rgb(im):
         
     | 
| 38 | 
         
            +
                return np.tensordot(im.astype(float) - YUV_OFFSET, YUV_TO_RGB, ([2], [1]))
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
            def get_valid_idx(valid, candidates):
         
     | 
| 42 | 
         
            +
                """Find which values are present in a list and where they are located"""
         
     | 
| 43 | 
         
            +
                locs = np.searchsorted(valid, candidates)
         
     | 
| 44 | 
         
            +
                # Handle edge case where the candidate is larger than all valid values
         
     | 
| 45 | 
         
            +
                locs = np.clip(locs, 0, len(valid) - 1)
         
     | 
| 46 | 
         
            +
                # Identify which values are actually present
         
     | 
| 47 | 
         
            +
                valid_idx = np.flatnonzero(valid[locs] == candidates)
         
     | 
| 48 | 
         
            +
                locs = locs[valid_idx]
         
     | 
| 49 | 
         
            +
                return valid_idx, locs
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
            class BilateralGrid(object):
         
     | 
| 53 | 
         
            +
                def __init__(self, im, sigma_spatial=32, sigma_luma=8, sigma_chroma=8):
         
     | 
| 54 | 
         
            +
                    im_yuv = rgb2yuv(im)
         
     | 
| 55 | 
         
            +
                    # Compute 5-dimensional XYLUV bilateral-space coordinates
         
     | 
| 56 | 
         
            +
                    Iy, Ix = np.mgrid[:im.shape[0], :im.shape[1]]
         
     | 
| 57 | 
         
            +
                    x_coords = (Ix / sigma_spatial).astype(int)
         
     | 
| 58 | 
         
            +
                    y_coords = (Iy / sigma_spatial).astype(int)
         
     | 
| 59 | 
         
            +
                    luma_coords = (im_yuv[..., 0] / sigma_luma).astype(int)
         
     | 
| 60 | 
         
            +
                    chroma_coords = (im_yuv[..., 1:] / sigma_chroma).astype(int)
         
     | 
| 61 | 
         
            +
                    coords = np.dstack((x_coords, y_coords, luma_coords, chroma_coords))
         
     | 
| 62 | 
         
            +
                    coords_flat = coords.reshape(-1, coords.shape[-1])
         
     | 
| 63 | 
         
            +
                    self.npixels, self.dim = coords_flat.shape
         
     | 
| 64 | 
         
            +
                    # Hacky "hash vector" for coordinates,
         
     | 
| 65 | 
         
            +
                    # Requires all scaled coordinates be < MAX_VAL
         
     | 
| 66 | 
         
            +
                    self.hash_vec = (MAX_VAL ** np.arange(self.dim))
         
     | 
| 67 | 
         
            +
                    # Construct S and B matrix
         
     | 
| 68 | 
         
            +
                    self._compute_factorization(coords_flat)
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                def _compute_factorization(self, coords_flat):
         
     | 
| 71 | 
         
            +
                    # Hash each coordinate in grid to a unique value
         
     | 
| 72 | 
         
            +
                    hashed_coords = self._hash_coords(coords_flat)
         
     | 
| 73 | 
         
            +
                    unique_hashes, unique_idx, idx = \
         
     | 
| 74 | 
         
            +
                        np.unique(hashed_coords, return_index=True, return_inverse=True)
         
     | 
| 75 | 
         
            +
                    # Identify unique set of vertices
         
     | 
| 76 | 
         
            +
                    unique_coords = coords_flat[unique_idx]
         
     | 
| 77 | 
         
            +
                    self.nvertices = len(unique_coords)
         
     | 
| 78 | 
         
            +
                    # Construct sparse splat matrix that maps from pixels to vertices
         
     | 
| 79 | 
         
            +
                    self.S = csr_matrix((np.ones(self.npixels), (idx, np.arange(self.npixels))))
         
     | 
| 80 | 
         
            +
                    # Construct sparse blur matrices.
         
     | 
| 81 | 
         
            +
                    # Note that these represent [1 0 1] blurs, excluding the central element
         
     | 
| 82 | 
         
            +
                    self.blurs = []
         
     | 
| 83 | 
         
            +
                    for d in range(self.dim):
         
     | 
| 84 | 
         
            +
                        blur = 0.0
         
     | 
| 85 | 
         
            +
                        for offset in (-1, 1):
         
     | 
| 86 | 
         
            +
                            offset_vec = np.zeros((1, self.dim))
         
     | 
| 87 | 
         
            +
                            offset_vec[:, d] = offset
         
     | 
| 88 | 
         
            +
                            neighbor_hash = self._hash_coords(unique_coords + offset_vec)
         
     | 
| 89 | 
         
            +
                            valid_coord, idx = get_valid_idx(unique_hashes, neighbor_hash)
         
     | 
| 90 | 
         
            +
                            blur = blur + csr_matrix((np.ones((len(valid_coord),)),
         
     | 
| 91 | 
         
            +
                                                      (valid_coord, idx)),
         
     | 
| 92 | 
         
            +
                                                     shape=(self.nvertices, self.nvertices))
         
     | 
| 93 | 
         
            +
                        self.blurs.append(blur)
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                def _hash_coords(self, coord):
         
     | 
| 96 | 
         
            +
                    """Hacky function to turn a coordinate into a unique value"""
         
     | 
| 97 | 
         
            +
                    return np.dot(coord.reshape(-1, self.dim), self.hash_vec)
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                def splat(self, x):
         
     | 
| 100 | 
         
            +
                    return self.S.dot(x)
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                def slice(self, y):
         
     | 
| 103 | 
         
            +
                    return self.S.T.dot(y)
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                def blur(self, x):
         
     | 
| 106 | 
         
            +
                    """Blur a bilateral-space vector with a 1 2 1 kernel in each dimension"""
         
     | 
| 107 | 
         
            +
                    assert x.shape[0] == self.nvertices
         
     | 
| 108 | 
         
            +
                    out = 2 * self.dim * x
         
     | 
| 109 | 
         
            +
                    for blur in self.blurs:
         
     | 
| 110 | 
         
            +
                        out = out + blur.dot(x)
         
     | 
| 111 | 
         
            +
                    return out
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
                def filter(self, x):
         
     | 
| 114 | 
         
            +
                    """Apply bilateral filter to an input x"""
         
     | 
| 115 | 
         
            +
                    return self.slice(self.blur(self.splat(x))) / \
         
     | 
| 116 | 
         
            +
                           self.slice(self.blur(self.splat(np.ones_like(x))))
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
            def bistochastize(grid, maxiter=10):
         
     | 
| 120 | 
         
            +
                """Compute diagonal matrices to bistochastize a bilateral grid"""
         
     | 
| 121 | 
         
            +
                m = grid.splat(np.ones(grid.npixels))
         
     | 
| 122 | 
         
            +
                n = np.ones(grid.nvertices)
         
     | 
| 123 | 
         
            +
                for i in range(maxiter):
         
     | 
| 124 | 
         
            +
                    n = np.sqrt(n * m / grid.blur(n))
         
     | 
| 125 | 
         
            +
                # Correct m to satisfy the assumption of bistochastization regardless
         
     | 
| 126 | 
         
            +
                # of how many iterations have been run.
         
     | 
| 127 | 
         
            +
                m = n * grid.blur(n)
         
     | 
| 128 | 
         
            +
                Dm = diags(m, 0)
         
     | 
| 129 | 
         
            +
                Dn = diags(n, 0)
         
     | 
| 130 | 
         
            +
                return Dn, Dm
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
            class BilateralSolver(object):
         
     | 
| 134 | 
         
            +
                def __init__(self, grid, params):
         
     | 
| 135 | 
         
            +
                    self.grid = grid
         
     | 
| 136 | 
         
            +
                    self.params = params
         
     | 
| 137 | 
         
            +
                    self.Dn, self.Dm = bistochastize(grid)
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
                def solve(self, x, w):
         
     | 
| 140 | 
         
            +
                    # Check that w is a vector or a nx1 matrix
         
     | 
| 141 | 
         
            +
                    if w.ndim == 2:
         
     | 
| 142 | 
         
            +
                        assert (w.shape[1] == 1)
         
     | 
| 143 | 
         
            +
                    elif w.dim == 1:
         
     | 
| 144 | 
         
            +
                        w = w.reshape(w.shape[0], 1)
         
     | 
| 145 | 
         
            +
                    A_smooth = (self.Dm - self.Dn.dot(self.grid.blur(self.Dn)))
         
     | 
| 146 | 
         
            +
                    w_splat = self.grid.splat(w)
         
     | 
| 147 | 
         
            +
                    A_data = diags(w_splat[:, 0], 0)
         
     | 
| 148 | 
         
            +
                    A = self.params["lam"] * A_smooth + A_data
         
     | 
| 149 | 
         
            +
                    xw = x * w
         
     | 
| 150 | 
         
            +
                    b = self.grid.splat(xw)
         
     | 
| 151 | 
         
            +
                    # Use simple Jacobi preconditioner
         
     | 
| 152 | 
         
            +
                    A_diag = np.maximum(A.diagonal(), self.params["A_diag_min"])
         
     | 
| 153 | 
         
            +
                    M = diags(1 / A_diag, 0)
         
     | 
| 154 | 
         
            +
                    # Flat initialization
         
     | 
| 155 | 
         
            +
                    y0 = self.grid.splat(xw) / w_splat
         
     | 
| 156 | 
         
            +
                    yhat = np.empty_like(y0)
         
     | 
| 157 | 
         
            +
                    for d in range(x.shape[-1]):
         
     | 
| 158 | 
         
            +
                        yhat[..., d], info = cg(A, b[..., d], x0=y0[..., d], M=M, maxiter=self.params["cg_maxiter"],
         
     | 
| 159 | 
         
            +
                                                tol=self.params["cg_tol"])
         
     | 
| 160 | 
         
            +
                    xhat = self.grid.slice(yhat)
         
     | 
| 161 | 
         
            +
                    return xhat
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
            def bilateral_solver_output(
         
     | 
| 165 | 
         
            +
                    img: Image.Image,
         
     | 
| 166 | 
         
            +
                    target: np.ndarray,
         
     | 
| 167 | 
         
            +
                    sigma_spatial=16,
         
     | 
| 168 | 
         
            +
                    sigma_luma=16,
         
     | 
| 169 | 
         
            +
                    sigma_chroma=8
         
     | 
| 170 | 
         
            +
            ):
         
     | 
| 171 | 
         
            +
                reference = np.array(img)
         
     | 
| 172 | 
         
            +
                h, w = target.shape
         
     | 
| 173 | 
         
            +
                confidence = np.ones((h, w)) * 0.999
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
                grid_params = {
         
     | 
| 176 | 
         
            +
                    'sigma_luma': sigma_luma,  # Brightness bandwidth
         
     | 
| 177 | 
         
            +
                    'sigma_chroma': sigma_chroma,  # Color bandwidth
         
     | 
| 178 | 
         
            +
                    'sigma_spatial': sigma_spatial  # Spatial bandwidth
         
     | 
| 179 | 
         
            +
                }
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
                bs_params = {
         
     | 
| 182 | 
         
            +
                    'lam': 256,  # The strength of the smoothness parameter
         
     | 
| 183 | 
         
            +
                    'A_diag_min': 1e-5,  # Clamp the diagonal of the A diagonal in the Jacobi preconditioner.
         
     | 
| 184 | 
         
            +
                    'cg_tol': 1e-5,  # The tolerance on the convergence in PCG
         
     | 
| 185 | 
         
            +
                    'cg_maxiter': 25  # The number of PCG iterations
         
     | 
| 186 | 
         
            +
                }
         
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
                grid = BilateralGrid(reference, **grid_params)
         
     | 
| 189 | 
         
            +
             
     | 
| 190 | 
         
            +
                t = target.reshape(-1, 1).astype(np.double)
         
     | 
| 191 | 
         
            +
                c = confidence.reshape(-1, 1).astype(np.double)
         
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
                ## output solver, which is a soft value
         
     | 
| 194 | 
         
            +
                output_solver = BilateralSolver(grid, bs_params).solve(t, c).reshape((h, w))
         
     | 
| 195 | 
         
            +
             
     | 
| 196 | 
         
            +
                binary_solver = ndimage.binary_fill_holes(output_solver > 0.5)
         
     | 
| 197 | 
         
            +
                labeled, nr_objects = ndimage.label(binary_solver)
         
     | 
| 198 | 
         
            +
             
     | 
| 199 | 
         
            +
                nb_pixel = [np.sum(labeled == i) for i in range(nr_objects + 1)]
         
     | 
| 200 | 
         
            +
                pixel_order = np.argsort(nb_pixel)
         
     | 
| 201 | 
         
            +
                try:
         
     | 
| 202 | 
         
            +
                    binary_solver = labeled == pixel_order[-2]
         
     | 
| 203 | 
         
            +
                except:
         
     | 
| 204 | 
         
            +
                    binary_solver = np.ones((h, w), dtype=bool)
         
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
                return output_solver, binary_solver
         
     | 
    	
        duts-dino-k234-nq20-224-swav-mocov2-dino-p16-sr10100.yaml
    ADDED
    
    | 
         @@ -0,0 +1,56 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # augmentations
         
     | 
| 2 | 
         
            +
            use_copy_paste: false
         
     | 
| 3 | 
         
            +
            scale_range: [ 0.1, 1.0 ]
         
     | 
| 4 | 
         
            +
            repeat_image: false
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            # base directories
         
     | 
| 7 | 
         
            +
            dir_ckpt: "/users/gyungin/selfmask/ckpt"  # "/work/gyungin/selfmask/ckpt"
         
     | 
| 8 | 
         
            +
            dir_dataset: "/scratch/shared/beegfs/gyungin/datasets"
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            # clustering
         
     | 
| 11 | 
         
            +
            k: [2, 3, 4]
         
     | 
| 12 | 
         
            +
            clustering_mode: "spectral"
         
     | 
| 13 | 
         
            +
            use_gpu: true  # if you want to use gpu-accelerated code for clustering
         
     | 
| 14 | 
         
            +
            scale_factor: 2  # "how much you want to upsample encoder features before clustering"
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            # dataset
         
     | 
| 17 | 
         
            +
            dataset_name: "duts"
         
     | 
| 18 | 
         
            +
            use_pseudo_masks: true
         
     | 
| 19 | 
         
            +
            train_image_size: 224
         
     | 
| 20 | 
         
            +
            eval_image_size: 224
         
     | 
| 21 | 
         
            +
            n_percent: 100
         
     | 
| 22 | 
         
            +
            n_copy_pastes: null
         
     | 
| 23 | 
         
            +
            pseudo_masks_fp: "/users/gyungin/selfmask/datasets/swav_mocov2_dino_p16_k234.json"
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            # dataloader:
         
     | 
| 26 | 
         
            +
            batch_size: 8
         
     | 
| 27 | 
         
            +
            num_workers: 4
         
     | 
| 28 | 
         
            +
            pin_memory: true
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            # networks
         
     | 
| 31 | 
         
            +
            abs_2d_pe_init: false
         
     | 
| 32 | 
         
            +
            arch: "vit_small"
         
     | 
| 33 | 
         
            +
            lateral_connection: false
         
     | 
| 34 | 
         
            +
            learnable_pixel_decoder: false  # if False, use the bilinear interpolation
         
     | 
| 35 | 
         
            +
            use_binary_classifier: true # if True, use a binary classifier to get an objectness for each query from transformer decoder
         
     | 
| 36 | 
         
            +
            n_decoder_layers: 6
         
     | 
| 37 | 
         
            +
            n_queries: 20
         
     | 
| 38 | 
         
            +
            num_layers: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
         
     | 
| 39 | 
         
            +
            patch_size: 8
         
     | 
| 40 | 
         
            +
            training_method: "dino"  # "supervised", "deit", "dino", "mocov2", "swav"
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
            # objective
         
     | 
| 43 | 
         
            +
            loss_every_decoder_layer: true
         
     | 
| 44 | 
         
            +
            weight_dice_loss: 1.0
         
     | 
| 45 | 
         
            +
            weight_focal_loss: 0.0
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
            # optimizer
         
     | 
| 48 | 
         
            +
            lr: 0.000006 # default: 0.00006
         
     | 
| 49 | 
         
            +
            lr_warmup_duration: 0  # 5
         
     | 
| 50 | 
         
            +
            momentum: 0.9
         
     | 
| 51 | 
         
            +
            n_epochs: 12
         
     | 
| 52 | 
         
            +
            weight_decay: 0.01
         
     | 
| 53 | 
         
            +
            optimizer_type: "adamw"
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
            # validation
         
     | 
| 56 | 
         
            +
            benchmarks: null
         
     | 
    	
        networks/__init__.py
    ADDED
    
    | 
         
            File without changes
         
     | 
    	
        networks/__pycache__/__init__.cpython-38.pyc
    ADDED
    
    | 
         Binary file (146 Bytes). View file 
     | 
| 
         | 
    	
        networks/__pycache__/timm_deit.cpython-38.pyc
    ADDED
    
    | 
         Binary file (7.08 kB). View file 
     | 
| 
         | 
    	
        networks/__pycache__/timm_vit.cpython-38.pyc
    ADDED
    
    | 
         Binary file (27.7 kB). View file 
     | 
| 
         | 
    	
        networks/__pycache__/vision_transformer.cpython-38.pyc
    ADDED
    
    | 
         Binary file (15.8 kB). View file 
     | 
| 
         | 
    	
        networks/maskformer/__pycache__/maskformer.cpython-38.pyc
    ADDED
    
    | 
         Binary file (8.51 kB). View file 
     | 
| 
         | 
    	
        networks/maskformer/__pycache__/transformer_decoder.cpython-38.pyc
    ADDED
    
    | 
         Binary file (8.83 kB). View file 
     | 
| 
         | 
    	
        networks/maskformer/maskformer.py
    ADDED
    
    | 
         @@ -0,0 +1,267 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import Dict, List
         
     | 
| 2 | 
         
            +
            from math import sqrt, log
         
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
            import torch.nn as nn
         
     | 
| 5 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            from networks.maskformer.transformer_decoder import TransformerDecoderLayer, TransformerDecoder
         
     | 
| 8 | 
         
            +
            from utils import get_model
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            class MaskFormer(nn.Module):
         
     | 
| 12 | 
         
            +
                def __init__(
         
     | 
| 13 | 
         
            +
                        self,
         
     | 
| 14 | 
         
            +
                        n_queries: int = 100,
         
     | 
| 15 | 
         
            +
                        arch: str = "vit_small",
         
     | 
| 16 | 
         
            +
                        patch_size: int = 8,
         
     | 
| 17 | 
         
            +
                        training_method: str = "dino",
         
     | 
| 18 | 
         
            +
                        n_decoder_layers: int = 6,
         
     | 
| 19 | 
         
            +
                        normalize_before: bool = False,
         
     | 
| 20 | 
         
            +
                        return_intermediate: bool = False,
         
     | 
| 21 | 
         
            +
                        learnable_pixel_decoder: bool = False,
         
     | 
| 22 | 
         
            +
                        lateral_connection: bool = False,
         
     | 
| 23 | 
         
            +
                        scale_factor: int = 2,
         
     | 
| 24 | 
         
            +
                        abs_2d_pe_init: bool = False,
         
     | 
| 25 | 
         
            +
                        use_binary_classifier: bool = False
         
     | 
| 26 | 
         
            +
                ):
         
     | 
| 27 | 
         
            +
                    """Define a encoder and decoder along with queries to be learned through the decoder."""
         
     | 
| 28 | 
         
            +
                    super(MaskFormer, self).__init__()
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                    if arch == "vit_small":
         
     | 
| 31 | 
         
            +
                        self.encoder = get_model(arch=arch, patch_size=patch_size, training_method=training_method)
         
     | 
| 32 | 
         
            +
                        n_dims: int = self.encoder.n_embs
         
     | 
| 33 | 
         
            +
                        n_heads: int = self.encoder.n_heads
         
     | 
| 34 | 
         
            +
                        mlp_ratio: int = self.encoder.mlp_ratio
         
     | 
| 35 | 
         
            +
                    else:
         
     | 
| 36 | 
         
            +
                        self.encoder = get_model(arch=arch, training_method=training_method)
         
     | 
| 37 | 
         
            +
                        n_dims_resnet: int = self.encoder.n_embs
         
     | 
| 38 | 
         
            +
                        n_dims: int = 384
         
     | 
| 39 | 
         
            +
                        n_heads: int = 6
         
     | 
| 40 | 
         
            +
                        mlp_ratio: int = 4
         
     | 
| 41 | 
         
            +
                        self.linear_layer = nn.Conv2d(n_dims_resnet, n_dims, kernel_size=1)
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                    decoder_layer = TransformerDecoderLayer(
         
     | 
| 44 | 
         
            +
                        n_dims, n_heads, n_dims * mlp_ratio, 0., activation="relu", normalize_before=normalize_before
         
     | 
| 45 | 
         
            +
                    )
         
     | 
| 46 | 
         
            +
                    self.decoder = TransformerDecoder(
         
     | 
| 47 | 
         
            +
                        decoder_layer,
         
     | 
| 48 | 
         
            +
                        n_decoder_layers,
         
     | 
| 49 | 
         
            +
                        norm=nn.LayerNorm(n_dims),
         
     | 
| 50 | 
         
            +
                        return_intermediate=return_intermediate
         
     | 
| 51 | 
         
            +
                    )
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                    self.query_embed = nn.Embedding(n_queries, n_dims).weight  # initialized with gaussian(0, 1)
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                    if use_binary_classifier:
         
     | 
| 56 | 
         
            +
                        # self.ffn = MLP(n_dims, n_dims, n_dims, num_layers=3)
         
     | 
| 57 | 
         
            +
                        # self.linear_classifier = nn.Linear(n_dims, 1)
         
     | 
| 58 | 
         
            +
                        self.ffn = MLP(n_dims, n_dims, 1, num_layers=3)
         
     | 
| 59 | 
         
            +
                        # self.norm = nn.LayerNorm(n_dims)
         
     | 
| 60 | 
         
            +
                    else:
         
     | 
| 61 | 
         
            +
                        # self.ffn = None
         
     | 
| 62 | 
         
            +
                        # self.linear_classifier = None
         
     | 
| 63 | 
         
            +
                        # self.norm = None
         
     | 
| 64 | 
         
            +
                        self.ffn = MLP(n_dims, n_dims, n_dims, num_layers=3)
         
     | 
| 65 | 
         
            +
                        self.linear_classifier = nn.Linear(n_dims, 2)
         
     | 
| 66 | 
         
            +
                        self.norm = nn.LayerNorm(n_dims)
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                    self.arch = arch
         
     | 
| 69 | 
         
            +
                    self.use_binary_classifier = use_binary_classifier
         
     | 
| 70 | 
         
            +
                    self.lateral_connection = lateral_connection
         
     | 
| 71 | 
         
            +
                    self.learnable_pixel_decoder = learnable_pixel_decoder
         
     | 
| 72 | 
         
            +
                    self.scale_factor = scale_factor
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                # copy-pasted from https://github.com/wzlxjtu/PositionalEncoding2D/blob/master/positionalembedding2d.py
         
     | 
| 75 | 
         
            +
                @staticmethod
         
     | 
| 76 | 
         
            +
                def positional_encoding_2d(n_dims: int, height: int, width: int):
         
     | 
| 77 | 
         
            +
                    """
         
     | 
| 78 | 
         
            +
                    :param n_dims: dimension of the model
         
     | 
| 79 | 
         
            +
                    :param height: height of the positions
         
     | 
| 80 | 
         
            +
                    :param width: width of the positions
         
     | 
| 81 | 
         
            +
                    :return: d_model*height*width position matrix
         
     | 
| 82 | 
         
            +
                    """
         
     | 
| 83 | 
         
            +
                    if n_dims % 4 != 0:
         
     | 
| 84 | 
         
            +
                        raise ValueError("Cannot use sin/cos positional encoding with "
         
     | 
| 85 | 
         
            +
                                         "odd dimension (got dim={:d})".format(n_dims))
         
     | 
| 86 | 
         
            +
                    pe = torch.zeros(n_dims, height, width)
         
     | 
| 87 | 
         
            +
                    # Each dimension use half of d_model
         
     | 
| 88 | 
         
            +
                    d_model = int(n_dims / 2)
         
     | 
| 89 | 
         
            +
                    div_term = torch.exp(torch.arange(0., d_model, 2) * -(log(10000.0) / d_model))
         
     | 
| 90 | 
         
            +
                    pos_w = torch.arange(0., width).unsqueeze(1)
         
     | 
| 91 | 
         
            +
                    pos_h = torch.arange(0., height).unsqueeze(1)
         
     | 
| 92 | 
         
            +
                    pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
         
     | 
| 93 | 
         
            +
                    pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
         
     | 
| 94 | 
         
            +
                    pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
         
     | 
| 95 | 
         
            +
                    pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                    return pe
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                def forward_encoder(self, x: torch.Tensor):
         
     | 
| 100 | 
         
            +
                    """
         
     | 
| 101 | 
         
            +
                    :param x: b x c x h x w
         
     | 
| 102 | 
         
            +
                    :return patch_tokens: b x depth x hw x n_dims
         
     | 
| 103 | 
         
            +
                    """
         
     | 
| 104 | 
         
            +
                    if self.arch == "vit_small":
         
     | 
| 105 | 
         
            +
                        encoder_outputs: Dict[str, torch.Tensor] = self.encoder(x)  # [:, 1:, :]
         
     | 
| 106 | 
         
            +
                        all_patch_tokens: List[torch.Tensor] = list()
         
     | 
| 107 | 
         
            +
                        for layer_name in [f"layer{num_layer}" for num_layer in range(1, self.encoder.depth + 1)]:
         
     | 
| 108 | 
         
            +
                            patch_tokens: torch.Tensor = encoder_outputs[layer_name][:, 1:, :]  # b x hw x n_dims
         
     | 
| 109 | 
         
            +
                            all_patch_tokens.append(patch_tokens)
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                        all_patch_tokens: torch.Tensor = torch.stack(all_patch_tokens, dim=0)  # depth x b x hw x n_dims
         
     | 
| 112 | 
         
            +
                        all_patch_tokens = all_patch_tokens.permute(1, 0, 3, 2)  # b x depth x n_dims x hw
         
     | 
| 113 | 
         
            +
                        return all_patch_tokens
         
     | 
| 114 | 
         
            +
                    else:
         
     | 
| 115 | 
         
            +
                        encoder_outputs = self.linear_layer(self.encoder(x)[-1])  # b x n_dims x h x w
         
     | 
| 116 | 
         
            +
                        return encoder_outputs
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
                def forward_transformer_decoder(self, patch_tokens: torch.Tensor, skip_decoder: bool = False) -> torch.Tensor:
         
     | 
| 119 | 
         
            +
                    """Forward transformer decoder given patch tokens from the encoder's last layer.
         
     | 
| 120 | 
         
            +
                    :param patch_tokens: b x n_dims x hw -> hw x b x n_dims
         
     | 
| 121 | 
         
            +
                    :param skip_decoder: if True, skip the decoder and produce mask predictions directly by matrix multiplication
         
     | 
| 122 | 
         
            +
                    between learnable queries and encoder features (i.e., patch tokens). This is for the purpose of an overfitting
         
     | 
| 123 | 
         
            +
                    experiment.
         
     | 
| 124 | 
         
            +
                    :return queries: n_queries x b x n_dims -> b x n_queries x n_dims or b x n_layers x n_queries x n_dims
         
     | 
| 125 | 
         
            +
                    """
         
     | 
| 126 | 
         
            +
                    b = patch_tokens.shape[0]
         
     | 
| 127 | 
         
            +
                    patch_tokens = patch_tokens.permute(2, 0, 1)  # b x n_dims x hw -> hw x b x n_dims
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                    # n_queries x n_dims -> n_queries x b x n_dims
         
     | 
| 130 | 
         
            +
                    queries: torch.Tensor = self.query_embed.unsqueeze(1).repeat(1, b, 1)
         
     | 
| 131 | 
         
            +
                    queries: torch.Tensor = self.decoder.forward(
         
     | 
| 132 | 
         
            +
                        tgt=torch.zeros_like(queries),
         
     | 
| 133 | 
         
            +
                        memory=patch_tokens,
         
     | 
| 134 | 
         
            +
                        query_pos=queries
         
     | 
| 135 | 
         
            +
                    ).squeeze(dim=0)
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                    if len(queries.shape) == 3:
         
     | 
| 138 | 
         
            +
                        queries: torch.Tensor = queries.permute(1, 0, 2)  # n_queries x b x n_dims -> b x n_queries x n_dims
         
     | 
| 139 | 
         
            +
                    elif len(queries.shape) == 4:
         
     | 
| 140 | 
         
            +
                        # n_layers x n_queries x b x n_dims -> b x n_layers x n_queries x n_dims
         
     | 
| 141 | 
         
            +
                        queries: torch.Tensor = queries.permute(2, 0, 1, 3)
         
     | 
| 142 | 
         
            +
                    return queries
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
                def forward_pixel_decoder(self, patch_tokens: torch.Tensor, input_size=None):
         
     | 
| 145 | 
         
            +
                    """ Upsample patch tokens by self.scale_factor and produce mask predictions
         
     | 
| 146 | 
         
            +
                    :param patch_tokens: b (x depth) x n_dims x hw -> b (x depth) x n_dims x h x w
         
     | 
| 147 | 
         
            +
                    :param queries: b x n_queries x n_dims
         
     | 
| 148 | 
         
            +
                    :return mask_predictions: b x n_queries x h x w
         
     | 
| 149 | 
         
            +
                    """
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
                    if input_size is None:
         
     | 
| 152 | 
         
            +
                        # assume square shape features
         
     | 
| 153 | 
         
            +
                        hw = patch_tokens.shape[-1]
         
     | 
| 154 | 
         
            +
                        h = w = int(sqrt(hw))
         
     | 
| 155 | 
         
            +
                    else:
         
     | 
| 156 | 
         
            +
                        # arbitrary shape features
         
     | 
| 157 | 
         
            +
                        h, w = input_size
         
     | 
| 158 | 
         
            +
                    patch_tokens = patch_tokens.view(*patch_tokens.shape[:-1], h, w)
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
                    assert len(patch_tokens.shape) == 4
         
     | 
| 161 | 
         
            +
                    patch_tokens = F.interpolate(patch_tokens, scale_factor=self.scale_factor, mode="bilinear")
         
     | 
| 162 | 
         
            +
                    return patch_tokens
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                def forward(self, x, encoder_only=False, skip_decoder: bool = False):
         
     | 
| 165 | 
         
            +
                    """
         
     | 
| 166 | 
         
            +
                    x: b x c x h x w
         
     | 
| 167 | 
         
            +
                    patch_tokens: b x n_patches x n_dims -> n_patches x b x n_dims
         
     | 
| 168 | 
         
            +
                    query_emb: n_queries x n_dims -> n_queries x b x n_dims
         
     | 
| 169 | 
         
            +
                    """
         
     | 
| 170 | 
         
            +
                    dict_outputs: dict = dict()
         
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
                    # b x depth x n_dims x hw (vit) or b x n_dims x h x w (resnet50)
         
     | 
| 173 | 
         
            +
                    features: torch.Tensor = self.forward_encoder(x)
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
                    if self.arch == "vit_small":
         
     | 
| 176 | 
         
            +
                        # extract the last layer for decoder input
         
     | 
| 177 | 
         
            +
                        last_layer_features: torch.Tensor = features[:, -1, ...]  # b x n_dims x hw
         
     | 
| 178 | 
         
            +
                    else:
         
     | 
| 179 | 
         
            +
                        # transform the shape of the features to the one compatible with transformer decoder
         
     | 
| 180 | 
         
            +
                        b, n_dims, h, w = features.shape
         
     | 
| 181 | 
         
            +
                        last_layer_features: torch.Tensor = features.view(b, n_dims, h * w)  # b x n_dims x hw
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
                    if encoder_only:
         
     | 
| 184 | 
         
            +
                        _h, _w = self.encoder.make_input_divisible(x).shape[-2:]
         
     | 
| 185 | 
         
            +
                        _h, _w = _h // self.encoder.patch_size, _w // self.encoder.patch_size
         
     | 
| 186 | 
         
            +
             
     | 
| 187 | 
         
            +
                        b, n_dims, hw = last_layer_features.shape
         
     | 
| 188 | 
         
            +
                        dict_outputs.update({"patch_tokens": last_layer_features.view(b, _h, _w, n_dims)})
         
     | 
| 189 | 
         
            +
                        return dict_outputs
         
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
                    # transformer decoder forward
         
     | 
| 192 | 
         
            +
                    queries: torch.Tensor = self.forward_transformer_decoder(
         
     | 
| 193 | 
         
            +
                        last_layer_features,
         
     | 
| 194 | 
         
            +
                        skip_decoder=skip_decoder
         
     | 
| 195 | 
         
            +
                    )  # b x n_queries x n_dims or b x n_layers x n_queries x n_dims
         
     | 
| 196 | 
         
            +
             
     | 
| 197 | 
         
            +
                    # pixel decoder forward (upsampling the patch tokens by self.scale_factor)
         
     | 
| 198 | 
         
            +
                    if self.arch == "vit_small":
         
     | 
| 199 | 
         
            +
                        _h, _w = self.encoder.make_input_divisible(x).shape[-2:]
         
     | 
| 200 | 
         
            +
                        _h, _w = _h // self.encoder.patch_size, _w // self.encoder.patch_size
         
     | 
| 201 | 
         
            +
                    else:
         
     | 
| 202 | 
         
            +
                        _h, _w = h, w
         
     | 
| 203 | 
         
            +
                    features: torch.Tensor = self.forward_pixel_decoder(
         
     | 
| 204 | 
         
            +
                        patch_tokens=features if self.lateral_connection else last_layer_features,
         
     | 
| 205 | 
         
            +
                        input_size=(_h, _w)
         
     | 
| 206 | 
         
            +
                    )  # b x n_dims x h x w
         
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
                    # queries: b x n_queries x n_dims or b x n_layers x n_queries x n_dims
         
     | 
| 209 | 
         
            +
                    # features: b x n_dims x h x w
         
     | 
| 210 | 
         
            +
                    # mask_pred: b x n_queries x h x w or b x n_layers x n_queries x h x w
         
     | 
| 211 | 
         
            +
                    if len(queries.shape) == 3:
         
     | 
| 212 | 
         
            +
                        mask_pred = torch.einsum("bqn,bnhw->bqhw", queries, features)
         
     | 
| 213 | 
         
            +
                    else:
         
     | 
| 214 | 
         
            +
                        if self.use_binary_classifier:
         
     | 
| 215 | 
         
            +
                            mask_pred = torch.sigmoid(torch.einsum("bdqn,bnhw->bdqhw", queries, features))
         
     | 
| 216 | 
         
            +
                        else:
         
     | 
| 217 | 
         
            +
                            mask_pred = torch.sigmoid(torch.einsum("bdqn,bnhw->bdqhw", self.ffn(queries), features))
         
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
            +
                    if self.use_binary_classifier:
         
     | 
| 220 | 
         
            +
                        # queries: b x n_layers x n_queries x n_dims -> n_layers x b x n_queries x n_dims
         
     | 
| 221 | 
         
            +
                        queries = queries.permute(1, 0, 2, 3)
         
     | 
| 222 | 
         
            +
                        objectness: List[torch.Tensor] = list()
         
     | 
| 223 | 
         
            +
                        for n_layer, queries_per_layer in enumerate(queries):  # queries_per_layer: b x n_queries x n_dims
         
     | 
| 224 | 
         
            +
                            # objectness_per_layer = self.linear_classifier(
         
     | 
| 225 | 
         
            +
                            #     self.ffn(self.norm(queries_per_layer))
         
     | 
| 226 | 
         
            +
                            # )  # b x n_queries x 1
         
     | 
| 227 | 
         
            +
                            objectness_per_layer = self.ffn(queries_per_layer)  # b x n_queries x 1
         
     | 
| 228 | 
         
            +
                            objectness.append(objectness_per_layer)
         
     | 
| 229 | 
         
            +
                        # n_layers x b x n_queries x 1 -> # b x n_layers x n_queries x 1
         
     | 
| 230 | 
         
            +
                        objectness: torch.Tensor = torch.stack(objectness).permute(1, 0, 2, 3)
         
     | 
| 231 | 
         
            +
                        dict_outputs.update({
         
     | 
| 232 | 
         
            +
                            "objectness": torch.sigmoid(objectness),
         
     | 
| 233 | 
         
            +
                            "mask_pred": mask_pred
         
     | 
| 234 | 
         
            +
                        })
         
     | 
| 235 | 
         
            +
             
     | 
| 236 | 
         
            +
                    return dict_outputs
         
     | 
| 237 | 
         
            +
             
     | 
| 238 | 
         
            +
             
     | 
| 239 | 
         
            +
            class MLP(nn.Module):
         
     | 
| 240 | 
         
            +
                """Very simple multi-layer perceptron (also called FFN)"""
         
     | 
| 241 | 
         
            +
             
     | 
| 242 | 
         
            +
                def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
         
     | 
| 243 | 
         
            +
                    super().__init__()
         
     | 
| 244 | 
         
            +
                    self.num_layers = num_layers
         
     | 
| 245 | 
         
            +
                    h = [hidden_dim] * (num_layers - 1)
         
     | 
| 246 | 
         
            +
                    self.layers = nn.ModuleList(
         
     | 
| 247 | 
         
            +
                        nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
         
     | 
| 248 | 
         
            +
                    )
         
     | 
| 249 | 
         
            +
             
     | 
| 250 | 
         
            +
                def forward(self, x):
         
     | 
| 251 | 
         
            +
                    for i, layer in enumerate(self.layers):
         
     | 
| 252 | 
         
            +
                        x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
         
     | 
| 253 | 
         
            +
                    return x
         
     | 
| 254 | 
         
            +
             
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
            class UpsampleBlock(nn.Module):
         
     | 
| 257 | 
         
            +
                def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, n_groups=32, scale_factor=2):
         
     | 
| 258 | 
         
            +
                    super(UpsampleBlock, self).__init__()
         
     | 
| 259 | 
         
            +
                    self.block = nn.Sequential(
         
     | 
| 260 | 
         
            +
                        nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding),
         
     | 
| 261 | 
         
            +
                        nn.GroupNorm(n_groups, out_channels),
         
     | 
| 262 | 
         
            +
                        nn.ReLU()
         
     | 
| 263 | 
         
            +
                    )
         
     | 
| 264 | 
         
            +
                    self.scale_factor = scale_factor
         
     | 
| 265 | 
         
            +
             
     | 
| 266 | 
         
            +
                def forward(self, x):
         
     | 
| 267 | 
         
            +
                    return F.interpolate(self.block(x), scale_factor=self.scale_factor, mode="bilinear")
         
     | 
    	
        networks/maskformer/positional_embedding.py
    ADDED
    
    | 
         @@ -0,0 +1,48 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 2 | 
         
            +
            # # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py
         
     | 
| 3 | 
         
            +
            """
         
     | 
| 4 | 
         
            +
            Various positional encodings for the transformer.
         
     | 
| 5 | 
         
            +
            """
         
     | 
| 6 | 
         
            +
            import math
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            import torch
         
     | 
| 9 | 
         
            +
            from torch import nn
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            class PositionEmbeddingSine(nn.Module):
         
     | 
| 13 | 
         
            +
                """
         
     | 
| 14 | 
         
            +
                This is a more standard version of the position embedding, very similar to the one
         
     | 
| 15 | 
         
            +
                used by the Attention is all you need paper, generalized to work on images.
         
     | 
| 16 | 
         
            +
                """
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
                def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
         
     | 
| 19 | 
         
            +
                    super().__init__()
         
     | 
| 20 | 
         
            +
                    self.num_pos_feats = num_pos_feats
         
     | 
| 21 | 
         
            +
                    self.temperature = temperature
         
     | 
| 22 | 
         
            +
                    self.normalize = normalize
         
     | 
| 23 | 
         
            +
                    if scale is not None and normalize is False:
         
     | 
| 24 | 
         
            +
                        raise ValueError("normalize should be True if scale is passed")
         
     | 
| 25 | 
         
            +
                    if scale is None:
         
     | 
| 26 | 
         
            +
                        scale = 2 * math.pi
         
     | 
| 27 | 
         
            +
                    self.scale = scale
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                def forward(self, x, mask=None):
         
     | 
| 30 | 
         
            +
                    if mask is None:
         
     | 
| 31 | 
         
            +
                        mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
         
     | 
| 32 | 
         
            +
                    not_mask = ~mask
         
     | 
| 33 | 
         
            +
                    y_embed = not_mask.cumsum(1, dtype=torch.float32)
         
     | 
| 34 | 
         
            +
                    x_embed = not_mask.cumsum(2, dtype=torch.float32)
         
     | 
| 35 | 
         
            +
                    if self.normalize:
         
     | 
| 36 | 
         
            +
                        eps = 1e-6
         
     | 
| 37 | 
         
            +
                        y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
         
     | 
| 38 | 
         
            +
                        x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                    dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
         
     | 
| 41 | 
         
            +
                    dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                    pos_x = x_embed[:, :, :, None] / dim_t
         
     | 
| 44 | 
         
            +
                    pos_y = y_embed[:, :, :, None] / dim_t
         
     | 
| 45 | 
         
            +
                    pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
         
     | 
| 46 | 
         
            +
                    pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
         
     | 
| 47 | 
         
            +
                    pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
         
     | 
| 48 | 
         
            +
                    return pos
         
     | 
    	
        networks/maskformer/transformer_decoder.py
    ADDED
    
    | 
         @@ -0,0 +1,376 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         
     | 
| 2 | 
         
            +
            # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/transformer.py
         
     | 
| 3 | 
         
            +
            """
         
     | 
| 4 | 
         
            +
            Transformer class.
         
     | 
| 5 | 
         
            +
            Copy-paste from torch.nn.Transformer with modifications:
         
     | 
| 6 | 
         
            +
                * positional encodings are passed in MHattention
         
     | 
| 7 | 
         
            +
                * extra LN at the end of encoder is removed
         
     | 
| 8 | 
         
            +
                * decoder returns a stack of activations from all decoding layers
         
     | 
| 9 | 
         
            +
            """
         
     | 
| 10 | 
         
            +
            import copy
         
     | 
| 11 | 
         
            +
            from typing import List, Optional
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            import torch
         
     | 
| 14 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 15 | 
         
            +
            from torch import Tensor, nn
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            class Transformer(nn.Module):
         
     | 
| 19 | 
         
            +
                def __init__(
         
     | 
| 20 | 
         
            +
                    self,
         
     | 
| 21 | 
         
            +
                    d_model=512,
         
     | 
| 22 | 
         
            +
                    nhead=8,
         
     | 
| 23 | 
         
            +
                    num_encoder_layers=6,
         
     | 
| 24 | 
         
            +
                    num_decoder_layers=6,
         
     | 
| 25 | 
         
            +
                    dim_feedforward=2048,
         
     | 
| 26 | 
         
            +
                    dropout=0.1,
         
     | 
| 27 | 
         
            +
                    activation="relu",  # noel - dino used GeLU
         
     | 
| 28 | 
         
            +
                    normalize_before=False,
         
     | 
| 29 | 
         
            +
                    return_intermediate_dec=False,
         
     | 
| 30 | 
         
            +
                ):
         
     | 
| 31 | 
         
            +
                    super().__init__()
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                    encoder_layer = TransformerEncoderLayer(
         
     | 
| 34 | 
         
            +
                        d_model, nhead, dim_feedforward, dropout, activation, normalize_before
         
     | 
| 35 | 
         
            +
                    )
         
     | 
| 36 | 
         
            +
                    encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
         
     | 
| 37 | 
         
            +
                    self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                    decoder_layer = TransformerDecoderLayer(
         
     | 
| 40 | 
         
            +
                        d_model, nhead, dim_feedforward, dropout, activation, normalize_before
         
     | 
| 41 | 
         
            +
                    )
         
     | 
| 42 | 
         
            +
                    decoder_norm = nn.LayerNorm(d_model)
         
     | 
| 43 | 
         
            +
                    self.decoder = TransformerDecoder(
         
     | 
| 44 | 
         
            +
                        decoder_layer,
         
     | 
| 45 | 
         
            +
                        num_decoder_layers,
         
     | 
| 46 | 
         
            +
                        decoder_norm,
         
     | 
| 47 | 
         
            +
                        return_intermediate=return_intermediate_dec,
         
     | 
| 48 | 
         
            +
                    )
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                    self._reset_parameters()
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                    self.d_model = d_model
         
     | 
| 53 | 
         
            +
                    self.nhead = nhead
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                def _reset_parameters(self):
         
     | 
| 56 | 
         
            +
                    for p in self.parameters():
         
     | 
| 57 | 
         
            +
                        if p.dim() > 1:
         
     | 
| 58 | 
         
            +
                            nn.init.xavier_uniform_(p)
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                def forward(self, src, mask, query_embed, pos_embed):
         
     | 
| 61 | 
         
            +
                    # flatten NxCxHxW to HWxNxC
         
     | 
| 62 | 
         
            +
                    bs, c, h, w = src.shape
         
     | 
| 63 | 
         
            +
                    src = src.flatten(2).permute(2, 0, 1)
         
     | 
| 64 | 
         
            +
                    pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
         
     | 
| 65 | 
         
            +
                    query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
         
     | 
| 66 | 
         
            +
                    if mask is not None:
         
     | 
| 67 | 
         
            +
                        mask = mask.flatten(1)
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                    tgt = torch.zeros_like(query_embed)
         
     | 
| 70 | 
         
            +
                    memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
         
     | 
| 71 | 
         
            +
                    hs = self.decoder(
         
     | 
| 72 | 
         
            +
                        tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed
         
     | 
| 73 | 
         
            +
                    )
         
     | 
| 74 | 
         
            +
                    return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
            class TransformerEncoder(nn.Module):
         
     | 
| 78 | 
         
            +
                def __init__(self, encoder_layer, num_layers, norm=None):
         
     | 
| 79 | 
         
            +
                    super().__init__()
         
     | 
| 80 | 
         
            +
                    self.layers = _get_clones(encoder_layer, num_layers)
         
     | 
| 81 | 
         
            +
                    self.num_layers = num_layers
         
     | 
| 82 | 
         
            +
                    self.norm = norm
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                def forward(
         
     | 
| 85 | 
         
            +
                    self,
         
     | 
| 86 | 
         
            +
                    src,
         
     | 
| 87 | 
         
            +
                    mask: Optional[Tensor] = None,
         
     | 
| 88 | 
         
            +
                    src_key_padding_mask: Optional[Tensor] = None,
         
     | 
| 89 | 
         
            +
                    pos: Optional[Tensor] = None,
         
     | 
| 90 | 
         
            +
                ):
         
     | 
| 91 | 
         
            +
                    output = src
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
                    for layer in self.layers:
         
     | 
| 94 | 
         
            +
                        output = layer(
         
     | 
| 95 | 
         
            +
                            output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos
         
     | 
| 96 | 
         
            +
                        )
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                    if self.norm is not None:
         
     | 
| 99 | 
         
            +
                        output = self.norm(output)
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
                    return output
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
            class TransformerDecoder(nn.Module):
         
     | 
| 105 | 
         
            +
                def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
         
     | 
| 106 | 
         
            +
                    super().__init__()
         
     | 
| 107 | 
         
            +
                    self.layers: nn.ModuleList = _get_clones(decoder_layer, num_layers)
         
     | 
| 108 | 
         
            +
                    self.num_layers: int = num_layers
         
     | 
| 109 | 
         
            +
                    self.norm = norm
         
     | 
| 110 | 
         
            +
                    self.return_intermediate: bool = return_intermediate
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                def forward(
         
     | 
| 113 | 
         
            +
                    self,
         
     | 
| 114 | 
         
            +
                    tgt,
         
     | 
| 115 | 
         
            +
                    memory,
         
     | 
| 116 | 
         
            +
                    tgt_mask: Optional[Tensor] = None,
         
     | 
| 117 | 
         
            +
                    memory_mask: Optional[Tensor] = None,
         
     | 
| 118 | 
         
            +
                    tgt_key_padding_mask: Optional[Tensor] = None,
         
     | 
| 119 | 
         
            +
                    memory_key_padding_mask: Optional[Tensor] = None,
         
     | 
| 120 | 
         
            +
                    pos: Optional[Tensor] = None,
         
     | 
| 121 | 
         
            +
                    query_pos: Optional[Tensor] = None,
         
     | 
| 122 | 
         
            +
                ):
         
     | 
| 123 | 
         
            +
                    output = tgt
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                    intermediate = []
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                    for layer in self.layers:
         
     | 
| 128 | 
         
            +
                        output = layer(
         
     | 
| 129 | 
         
            +
                            output,
         
     | 
| 130 | 
         
            +
                            memory,
         
     | 
| 131 | 
         
            +
                            tgt_mask=tgt_mask,
         
     | 
| 132 | 
         
            +
                            memory_mask=memory_mask,
         
     | 
| 133 | 
         
            +
                            tgt_key_padding_mask=tgt_key_padding_mask,
         
     | 
| 134 | 
         
            +
                            memory_key_padding_mask=memory_key_padding_mask,
         
     | 
| 135 | 
         
            +
                            pos=pos,
         
     | 
| 136 | 
         
            +
                            query_pos=query_pos,
         
     | 
| 137 | 
         
            +
                        )
         
     | 
| 138 | 
         
            +
                        if self.return_intermediate:
         
     | 
| 139 | 
         
            +
                            intermediate.append(self.norm(output))
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
                    if self.norm is not None:
         
     | 
| 142 | 
         
            +
                        output = self.norm(output)
         
     | 
| 143 | 
         
            +
                        if self.return_intermediate:
         
     | 
| 144 | 
         
            +
                            intermediate.pop()
         
     | 
| 145 | 
         
            +
                            intermediate.append(output)
         
     | 
| 146 | 
         
            +
             
     | 
| 147 | 
         
            +
                    if self.return_intermediate:
         
     | 
| 148 | 
         
            +
                        return torch.stack(intermediate)
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
                    return output.unsqueeze(0)
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
            class TransformerEncoderLayer(nn.Module):
         
     | 
| 154 | 
         
            +
                def __init__(
         
     | 
| 155 | 
         
            +
                    self,
         
     | 
| 156 | 
         
            +
                    d_model,
         
     | 
| 157 | 
         
            +
                    nhead,
         
     | 
| 158 | 
         
            +
                    dim_feedforward=2048,
         
     | 
| 159 | 
         
            +
                    dropout=0.1,
         
     | 
| 160 | 
         
            +
                    activation="relu",
         
     | 
| 161 | 
         
            +
                    normalize_before=False,
         
     | 
| 162 | 
         
            +
                ):
         
     | 
| 163 | 
         
            +
                    super().__init__()
         
     | 
| 164 | 
         
            +
                    self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
         
     | 
| 165 | 
         
            +
                    # Implementation of Feedforward model
         
     | 
| 166 | 
         
            +
                    self.linear1 = nn.Linear(d_model, dim_feedforward)
         
     | 
| 167 | 
         
            +
                    self.dropout = nn.Dropout(dropout)
         
     | 
| 168 | 
         
            +
                    self.linear2 = nn.Linear(dim_feedforward, d_model)
         
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
                    self.norm1 = nn.LayerNorm(d_model)
         
     | 
| 171 | 
         
            +
                    self.norm2 = nn.LayerNorm(d_model)
         
     | 
| 172 | 
         
            +
                    self.dropout1 = nn.Dropout(dropout)
         
     | 
| 173 | 
         
            +
                    self.dropout2 = nn.Dropout(dropout)
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
                    self.activation = _get_activation_fn(activation)
         
     | 
| 176 | 
         
            +
                    self.normalize_before = normalize_before
         
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
                def with_pos_embed(self, tensor, pos: Optional[Tensor]):
         
     | 
| 179 | 
         
            +
                    return tensor if pos is None else tensor + pos
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
                def forward_post(
         
     | 
| 182 | 
         
            +
                    self,
         
     | 
| 183 | 
         
            +
                    src,
         
     | 
| 184 | 
         
            +
                    src_mask: Optional[Tensor] = None,
         
     | 
| 185 | 
         
            +
                    src_key_padding_mask: Optional[Tensor] = None,
         
     | 
| 186 | 
         
            +
                    pos: Optional[Tensor] = None,
         
     | 
| 187 | 
         
            +
                ):
         
     | 
| 188 | 
         
            +
                    q = k = self.with_pos_embed(src, pos)
         
     | 
| 189 | 
         
            +
                    src2 = self.self_attn(
         
     | 
| 190 | 
         
            +
                        q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
         
     | 
| 191 | 
         
            +
                    )[0]
         
     | 
| 192 | 
         
            +
                    src = src + self.dropout1(src2)
         
     | 
| 193 | 
         
            +
                    src = self.norm1(src)
         
     | 
| 194 | 
         
            +
                    src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
         
     | 
| 195 | 
         
            +
                    src = src + self.dropout2(src2)
         
     | 
| 196 | 
         
            +
                    src = self.norm2(src)
         
     | 
| 197 | 
         
            +
                    return src
         
     | 
| 198 | 
         
            +
             
     | 
| 199 | 
         
            +
                def forward_pre(
         
     | 
| 200 | 
         
            +
                    self,
         
     | 
| 201 | 
         
            +
                    src,
         
     | 
| 202 | 
         
            +
                    src_mask: Optional[Tensor] = None,
         
     | 
| 203 | 
         
            +
                    src_key_padding_mask: Optional[Tensor] = None,
         
     | 
| 204 | 
         
            +
                    pos: Optional[Tensor] = None,
         
     | 
| 205 | 
         
            +
                ):
         
     | 
| 206 | 
         
            +
                    src2 = self.norm1(src)
         
     | 
| 207 | 
         
            +
                    q = k = self.with_pos_embed(src2, pos)
         
     | 
| 208 | 
         
            +
                    src2 = self.self_attn(
         
     | 
| 209 | 
         
            +
                        q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
         
     | 
| 210 | 
         
            +
                    )[0]
         
     | 
| 211 | 
         
            +
                    src = src + self.dropout1(src2)
         
     | 
| 212 | 
         
            +
                    src2 = self.norm2(src)
         
     | 
| 213 | 
         
            +
                    src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
         
     | 
| 214 | 
         
            +
                    src = src + self.dropout2(src2)
         
     | 
| 215 | 
         
            +
                    return src
         
     | 
| 216 | 
         
            +
             
     | 
| 217 | 
         
            +
                def forward(
         
     | 
| 218 | 
         
            +
                    self,
         
     | 
| 219 | 
         
            +
                    src,
         
     | 
| 220 | 
         
            +
                    src_mask: Optional[Tensor] = None,
         
     | 
| 221 | 
         
            +
                    src_key_padding_mask: Optional[Tensor] = None,
         
     | 
| 222 | 
         
            +
                    pos: Optional[Tensor] = None,
         
     | 
| 223 | 
         
            +
                ):
         
     | 
| 224 | 
         
            +
                    if self.normalize_before:
         
     | 
| 225 | 
         
            +
                        return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
         
     | 
| 226 | 
         
            +
                    return self.forward_post(src, src_mask, src_key_padding_mask, pos)
         
     | 
| 227 | 
         
            +
             
     | 
| 228 | 
         
            +
             
     | 
| 229 | 
         
            +
            class TransformerDecoderLayer(nn.Module):
         
     | 
| 230 | 
         
            +
                def __init__(
         
     | 
| 231 | 
         
            +
                    self,
         
     | 
| 232 | 
         
            +
                    d_model,
         
     | 
| 233 | 
         
            +
                    nhead,
         
     | 
| 234 | 
         
            +
                    dim_feedforward=2048,
         
     | 
| 235 | 
         
            +
                    dropout=0.1,
         
     | 
| 236 | 
         
            +
                    activation="relu",
         
     | 
| 237 | 
         
            +
                    normalize_before=False,
         
     | 
| 238 | 
         
            +
                ):
         
     | 
| 239 | 
         
            +
                    super().__init__()
         
     | 
| 240 | 
         
            +
                    self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
         
     | 
| 241 | 
         
            +
                    self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
         
     | 
| 242 | 
         
            +
                    # Implementation of Feedforward model
         
     | 
| 243 | 
         
            +
                    self.linear1 = nn.Linear(d_model, dim_feedforward)
         
     | 
| 244 | 
         
            +
                    self.dropout = nn.Dropout(dropout)
         
     | 
| 245 | 
         
            +
                    self.linear2 = nn.Linear(dim_feedforward, d_model)
         
     | 
| 246 | 
         
            +
             
     | 
| 247 | 
         
            +
                    self.norm1 = nn.LayerNorm(d_model)
         
     | 
| 248 | 
         
            +
                    self.norm2 = nn.LayerNorm(d_model)
         
     | 
| 249 | 
         
            +
                    self.norm3 = nn.LayerNorm(d_model)
         
     | 
| 250 | 
         
            +
                    self.dropout1 = nn.Dropout(dropout)
         
     | 
| 251 | 
         
            +
                    self.dropout2 = nn.Dropout(dropout)
         
     | 
| 252 | 
         
            +
                    self.dropout3 = nn.Dropout(dropout)
         
     | 
| 253 | 
         
            +
             
     | 
| 254 | 
         
            +
                    self.activation = _get_activation_fn(activation)
         
     | 
| 255 | 
         
            +
                    self.normalize_before = normalize_before
         
     | 
| 256 | 
         
            +
             
     | 
| 257 | 
         
            +
                def with_pos_embed(self, tensor, pos: Optional[Tensor]):
         
     | 
| 258 | 
         
            +
                    return tensor if pos is None else tensor + pos
         
     | 
| 259 | 
         
            +
             
     | 
| 260 | 
         
            +
                def forward_post(
         
     | 
| 261 | 
         
            +
                    self,
         
     | 
| 262 | 
         
            +
                    tgt,
         
     | 
| 263 | 
         
            +
                    memory,
         
     | 
| 264 | 
         
            +
                    tgt_mask: Optional[Tensor] = None,
         
     | 
| 265 | 
         
            +
                    memory_mask: Optional[Tensor] = None,
         
     | 
| 266 | 
         
            +
                    tgt_key_padding_mask: Optional[Tensor] = None,
         
     | 
| 267 | 
         
            +
                    memory_key_padding_mask: Optional[Tensor] = None,
         
     | 
| 268 | 
         
            +
                    pos: Optional[Tensor] = None,
         
     | 
| 269 | 
         
            +
                    query_pos: Optional[Tensor] = None,
         
     | 
| 270 | 
         
            +
                ):
         
     | 
| 271 | 
         
            +
                    q = k = self.with_pos_embed(tgt, query_pos)
         
     | 
| 272 | 
         
            +
             
     | 
| 273 | 
         
            +
                    tgt2 = self.self_attn(
         
     | 
| 274 | 
         
            +
                        q,
         
     | 
| 275 | 
         
            +
                        k,
         
     | 
| 276 | 
         
            +
                        value=tgt,
         
     | 
| 277 | 
         
            +
                        attn_mask=tgt_mask,
         
     | 
| 278 | 
         
            +
                        key_padding_mask=tgt_key_padding_mask
         
     | 
| 279 | 
         
            +
                    )[0]
         
     | 
| 280 | 
         
            +
                    tgt = tgt + self.dropout1(tgt2)
         
     | 
| 281 | 
         
            +
                    tgt = self.norm1(tgt)
         
     | 
| 282 | 
         
            +
             
     | 
| 283 | 
         
            +
                    tgt2 = self.multihead_attn(
         
     | 
| 284 | 
         
            +
                        query=self.with_pos_embed(tgt, query_pos),
         
     | 
| 285 | 
         
            +
                        key=self.with_pos_embed(memory, pos),
         
     | 
| 286 | 
         
            +
                        value=memory,
         
     | 
| 287 | 
         
            +
                        attn_mask=memory_mask,
         
     | 
| 288 | 
         
            +
                        key_padding_mask=memory_key_padding_mask,
         
     | 
| 289 | 
         
            +
                    )[0]
         
     | 
| 290 | 
         
            +
                    tgt = tgt + self.dropout2(tgt2)
         
     | 
| 291 | 
         
            +
                    tgt = self.norm2(tgt)
         
     | 
| 292 | 
         
            +
             
     | 
| 293 | 
         
            +
                    tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
         
     | 
| 294 | 
         
            +
                    tgt = tgt + self.dropout3(tgt2)
         
     | 
| 295 | 
         
            +
                    tgt = self.norm3(tgt)
         
     | 
| 296 | 
         
            +
             
     | 
| 297 | 
         
            +
                    return tgt
         
     | 
| 298 | 
         
            +
             
     | 
| 299 | 
         
            +
                def forward_pre(
         
     | 
| 300 | 
         
            +
                    self,
         
     | 
| 301 | 
         
            +
                    tgt,
         
     | 
| 302 | 
         
            +
                    memory,
         
     | 
| 303 | 
         
            +
                    tgt_mask: Optional[Tensor] = None,
         
     | 
| 304 | 
         
            +
                    memory_mask: Optional[Tensor] = None,
         
     | 
| 305 | 
         
            +
                    tgt_key_padding_mask: Optional[Tensor] = None,
         
     | 
| 306 | 
         
            +
                    memory_key_padding_mask: Optional[Tensor] = None,
         
     | 
| 307 | 
         
            +
                    pos: Optional[Tensor] = None,
         
     | 
| 308 | 
         
            +
                    query_pos: Optional[Tensor] = None,
         
     | 
| 309 | 
         
            +
                ):
         
     | 
| 310 | 
         
            +
                    tgt2 = self.norm1(tgt)
         
     | 
| 311 | 
         
            +
                    q = k = self.with_pos_embed(tgt2, query_pos)
         
     | 
| 312 | 
         
            +
                    tgt2 = self.self_attn(
         
     | 
| 313 | 
         
            +
                        q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
         
     | 
| 314 | 
         
            +
                    )[0]
         
     | 
| 315 | 
         
            +
                    tgt = tgt + self.dropout1(tgt2)
         
     | 
| 316 | 
         
            +
                    tgt2 = self.norm2(tgt)
         
     | 
| 317 | 
         
            +
                    tgt2 = self.multihead_attn(
         
     | 
| 318 | 
         
            +
                        query=self.with_pos_embed(tgt2, query_pos),
         
     | 
| 319 | 
         
            +
                        key=self.with_pos_embed(memory, pos),
         
     | 
| 320 | 
         
            +
                        value=memory,
         
     | 
| 321 | 
         
            +
                        attn_mask=memory_mask,
         
     | 
| 322 | 
         
            +
                        key_padding_mask=memory_key_padding_mask,
         
     | 
| 323 | 
         
            +
                    )[0]
         
     | 
| 324 | 
         
            +
                    tgt = tgt + self.dropout2(tgt2)
         
     | 
| 325 | 
         
            +
                    tgt2 = self.norm3(tgt)
         
     | 
| 326 | 
         
            +
                    tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
         
     | 
| 327 | 
         
            +
                    tgt = tgt + self.dropout3(tgt2)
         
     | 
| 328 | 
         
            +
                    return tgt
         
     | 
| 329 | 
         
            +
             
     | 
| 330 | 
         
            +
                def forward(
         
     | 
| 331 | 
         
            +
                    self,
         
     | 
| 332 | 
         
            +
                    tgt,
         
     | 
| 333 | 
         
            +
                    memory,
         
     | 
| 334 | 
         
            +
                    tgt_mask: Optional[Tensor] = None,
         
     | 
| 335 | 
         
            +
                    memory_mask: Optional[Tensor] = None,
         
     | 
| 336 | 
         
            +
                    tgt_key_padding_mask: Optional[Tensor] = None,
         
     | 
| 337 | 
         
            +
                    memory_key_padding_mask: Optional[Tensor] = None,
         
     | 
| 338 | 
         
            +
                    pos: Optional[Tensor] = None,
         
     | 
| 339 | 
         
            +
                    query_pos: Optional[Tensor] = None,
         
     | 
| 340 | 
         
            +
                ):
         
     | 
| 341 | 
         
            +
                    if self.normalize_before:
         
     | 
| 342 | 
         
            +
                        return self.forward_pre(
         
     | 
| 343 | 
         
            +
                            tgt,
         
     | 
| 344 | 
         
            +
                            memory,
         
     | 
| 345 | 
         
            +
                            tgt_mask,
         
     | 
| 346 | 
         
            +
                            memory_mask,
         
     | 
| 347 | 
         
            +
                            tgt_key_padding_mask,
         
     | 
| 348 | 
         
            +
                            memory_key_padding_mask,
         
     | 
| 349 | 
         
            +
                            pos,
         
     | 
| 350 | 
         
            +
                            query_pos,
         
     | 
| 351 | 
         
            +
                        )
         
     | 
| 352 | 
         
            +
                    return self.forward_post(
         
     | 
| 353 | 
         
            +
                        tgt,
         
     | 
| 354 | 
         
            +
                        memory,
         
     | 
| 355 | 
         
            +
                        tgt_mask,
         
     | 
| 356 | 
         
            +
                        memory_mask,
         
     | 
| 357 | 
         
            +
                        tgt_key_padding_mask,
         
     | 
| 358 | 
         
            +
                        memory_key_padding_mask,
         
     | 
| 359 | 
         
            +
                        pos,
         
     | 
| 360 | 
         
            +
                        query_pos,
         
     | 
| 361 | 
         
            +
                    )
         
     | 
| 362 | 
         
            +
             
     | 
| 363 | 
         
            +
             
     | 
| 364 | 
         
            +
            def _get_clones(module, N):
         
     | 
| 365 | 
         
            +
                return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
         
     | 
| 366 | 
         
            +
             
     | 
| 367 | 
         
            +
             
     | 
| 368 | 
         
            +
            def _get_activation_fn(activation):
         
     | 
| 369 | 
         
            +
                """Return an activation function given a string"""
         
     | 
| 370 | 
         
            +
                if activation == "relu":
         
     | 
| 371 | 
         
            +
                    return F.relu
         
     | 
| 372 | 
         
            +
                if activation == "gelu":
         
     | 
| 373 | 
         
            +
                    return F.gelu
         
     | 
| 374 | 
         
            +
                if activation == "glu":
         
     | 
| 375 | 
         
            +
                    return F.glu
         
     | 
| 376 | 
         
            +
                raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
         
     | 
    	
        networks/module_helper.py
    ADDED
    
    | 
         @@ -0,0 +1,176 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            #!/usr/bin/env python
         
     | 
| 2 | 
         
            +
            # -*- coding:utf-8 -*-
         
     | 
| 3 | 
         
            +
            # Author: Donny You ([email protected])
         
     | 
| 4 | 
         
            +
            import os
         
     | 
| 5 | 
         
            +
            import torch
         
     | 
| 6 | 
         
            +
            import torch.nn as nn
         
     | 
| 7 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            try:
         
     | 
| 10 | 
         
            +
                from urllib import urlretrieve
         
     | 
| 11 | 
         
            +
            except ImportError:
         
     | 
| 12 | 
         
            +
                from urllib.request import urlretrieve
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            class FixedBatchNorm(nn.BatchNorm2d):
         
     | 
| 16 | 
         
            +
                def forward(self, input):
         
     | 
| 17 | 
         
            +
                    return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, training=False, eps=self.eps)
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            class ModuleHelper(object):
         
     | 
| 21 | 
         
            +
                @staticmethod
         
     | 
| 22 | 
         
            +
                def BNReLU(num_features, norm_type=None, **kwargs):
         
     | 
| 23 | 
         
            +
                    if norm_type == 'batchnorm':
         
     | 
| 24 | 
         
            +
                        return nn.Sequential(
         
     | 
| 25 | 
         
            +
                            nn.BatchNorm2d(num_features, **kwargs),
         
     | 
| 26 | 
         
            +
                            nn.ReLU()
         
     | 
| 27 | 
         
            +
                        )
         
     | 
| 28 | 
         
            +
                    elif norm_type == 'encsync_batchnorm':
         
     | 
| 29 | 
         
            +
                        from encoding.nn import BatchNorm2d
         
     | 
| 30 | 
         
            +
                        return nn.Sequential(
         
     | 
| 31 | 
         
            +
                            BatchNorm2d(num_features, **kwargs),
         
     | 
| 32 | 
         
            +
                            nn.ReLU()
         
     | 
| 33 | 
         
            +
                        )
         
     | 
| 34 | 
         
            +
                    elif norm_type == 'instancenorm':
         
     | 
| 35 | 
         
            +
                        return nn.Sequential(
         
     | 
| 36 | 
         
            +
                            nn.InstanceNorm2d(num_features, **kwargs),
         
     | 
| 37 | 
         
            +
                            nn.ReLU()
         
     | 
| 38 | 
         
            +
                        )
         
     | 
| 39 | 
         
            +
                    elif norm_type == 'fixed_batchnorm':
         
     | 
| 40 | 
         
            +
                        return nn.Sequential(
         
     | 
| 41 | 
         
            +
                            FixedBatchNorm(num_features, **kwargs),
         
     | 
| 42 | 
         
            +
                            nn.ReLU()
         
     | 
| 43 | 
         
            +
                        )
         
     | 
| 44 | 
         
            +
                    else:
         
     | 
| 45 | 
         
            +
                        raise ValueError('Not support BN type: {}.'.format(norm_type))
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                @staticmethod
         
     | 
| 48 | 
         
            +
                def BatchNorm3d(norm_type=None, ret_cls=False):
         
     | 
| 49 | 
         
            +
                    if norm_type == 'batchnorm':
         
     | 
| 50 | 
         
            +
                        return nn.BatchNorm3d
         
     | 
| 51 | 
         
            +
                    elif norm_type == 'encsync_batchnorm':
         
     | 
| 52 | 
         
            +
                        from encoding.nn import BatchNorm3d
         
     | 
| 53 | 
         
            +
                        return BatchNorm3d
         
     | 
| 54 | 
         
            +
                    elif norm_type == 'instancenorm':
         
     | 
| 55 | 
         
            +
                        return nn.InstanceNorm3d
         
     | 
| 56 | 
         
            +
                    else:
         
     | 
| 57 | 
         
            +
                        raise ValueError('Not support BN type: {}.'.format(norm_type))
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                @staticmethod
         
     | 
| 60 | 
         
            +
                def BatchNorm2d(norm_type=None, ret_cls=False):
         
     | 
| 61 | 
         
            +
                    if norm_type == 'batchnorm':
         
     | 
| 62 | 
         
            +
                        return nn.BatchNorm2d
         
     | 
| 63 | 
         
            +
                    elif norm_type == 'encsync_batchnorm':
         
     | 
| 64 | 
         
            +
                        from encoding.nn import BatchNorm2d
         
     | 
| 65 | 
         
            +
                        return BatchNorm2d
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
                    elif norm_type == 'instancenorm':
         
     | 
| 68 | 
         
            +
                        return nn.InstanceNorm2d
         
     | 
| 69 | 
         
            +
                    else:
         
     | 
| 70 | 
         
            +
                        raise ValueError('Not support BN type: {}.'.format(norm_type))
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                @staticmethod
         
     | 
| 73 | 
         
            +
                def BatchNorm1d(norm_type=None, ret_cls=False):
         
     | 
| 74 | 
         
            +
                    if norm_type == 'batchnorm':
         
     | 
| 75 | 
         
            +
                        return nn.BatchNorm1d
         
     | 
| 76 | 
         
            +
                    elif norm_type == 'encsync_batchnorm':
         
     | 
| 77 | 
         
            +
                        from encoding.nn import BatchNorm1d
         
     | 
| 78 | 
         
            +
                        return BatchNorm1d
         
     | 
| 79 | 
         
            +
                    elif norm_type == 'instancenorm':
         
     | 
| 80 | 
         
            +
                        return nn.InstanceNorm1d
         
     | 
| 81 | 
         
            +
                    else:
         
     | 
| 82 | 
         
            +
                        raise ValueError('Not support BN type: {}.'.format(norm_type))
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                @staticmethod
         
     | 
| 85 | 
         
            +
                def load_model(model, pretrained=None, all_match=True, map_location='cpu'):
         
     | 
| 86 | 
         
            +
                    if pretrained is None:
         
     | 
| 87 | 
         
            +
                        return model
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                    if not os.path.exists(pretrained):
         
     | 
| 90 | 
         
            +
                        pretrained = pretrained.replace("..", "/home/gishin-temp/projects/open_set/segmentation")
         
     | 
| 91 | 
         
            +
                        if os.path.exists(pretrained):
         
     | 
| 92 | 
         
            +
                            pass
         
     | 
| 93 | 
         
            +
                        else:
         
     | 
| 94 | 
         
            +
                            raise FileNotFoundError('{} not exists.'.format(pretrained))
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                    print('Loading pretrained model:{}'.format(pretrained))
         
     | 
| 97 | 
         
            +
                    if all_match:
         
     | 
| 98 | 
         
            +
                        pretrained_dict = torch.load(pretrained, map_location=map_location)
         
     | 
| 99 | 
         
            +
                        model_dict = model.state_dict()
         
     | 
| 100 | 
         
            +
                        load_dict = dict()
         
     | 
| 101 | 
         
            +
                        for k, v in pretrained_dict.items():
         
     | 
| 102 | 
         
            +
                            if 'prefix.{}'.format(k) in model_dict:
         
     | 
| 103 | 
         
            +
                                load_dict['prefix.{}'.format(k)] = v
         
     | 
| 104 | 
         
            +
                            else:
         
     | 
| 105 | 
         
            +
                                load_dict[k] = v
         
     | 
| 106 | 
         
            +
                        model.load_state_dict(load_dict)
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                    else:
         
     | 
| 109 | 
         
            +
                        pretrained_dict = torch.load(pretrained)
         
     | 
| 110 | 
         
            +
                        model_dict = model.state_dict()
         
     | 
| 111 | 
         
            +
                        load_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
         
     | 
| 112 | 
         
            +
                        print('Matched Keys: {}'.format(load_dict.keys()))
         
     | 
| 113 | 
         
            +
                        model_dict.update(load_dict)
         
     | 
| 114 | 
         
            +
                        model.load_state_dict(model_dict)
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                    return model
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
                @staticmethod
         
     | 
| 119 | 
         
            +
                def load_url(url, map_location=None):
         
     | 
| 120 | 
         
            +
                    model_dir = os.path.join('~', '.TorchCV', 'model')
         
     | 
| 121 | 
         
            +
                    if not os.path.exists(model_dir):
         
     | 
| 122 | 
         
            +
                        os.makedirs(model_dir)
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
                    filename = url.split('/')[-1]
         
     | 
| 125 | 
         
            +
                    cached_file = os.path.join(model_dir, filename)
         
     | 
| 126 | 
         
            +
                    if not os.path.exists(cached_file):
         
     | 
| 127 | 
         
            +
                        print('Downloading: "{}" to {}\n'.format(url, cached_file))
         
     | 
| 128 | 
         
            +
                        urlretrieve(url, cached_file)
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
                    print('Loading pretrained model:{}'.format(cached_file))
         
     | 
| 131 | 
         
            +
                    return torch.load(cached_file, map_location=map_location)
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                @staticmethod
         
     | 
| 134 | 
         
            +
                def constant_init(module, val, bias=0):
         
     | 
| 135 | 
         
            +
                    nn.init.constant_(module.weight, val)
         
     | 
| 136 | 
         
            +
                    if hasattr(module, 'bias') and module.bias is not None:
         
     | 
| 137 | 
         
            +
                        nn.init.constant_(module.bias, bias)
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
                @staticmethod
         
     | 
| 140 | 
         
            +
                def xavier_init(module, gain=1, bias=0, distribution='normal'):
         
     | 
| 141 | 
         
            +
                    assert distribution in ['uniform', 'normal']
         
     | 
| 142 | 
         
            +
                    if distribution == 'uniform':
         
     | 
| 143 | 
         
            +
                        nn.init.xavier_uniform_(module.weight, gain=gain)
         
     | 
| 144 | 
         
            +
                    else:
         
     | 
| 145 | 
         
            +
                        nn.init.xavier_normal_(module.weight, gain=gain)
         
     | 
| 146 | 
         
            +
                    if hasattr(module, 'bias') and module.bias is not None:
         
     | 
| 147 | 
         
            +
                        nn.init.constant_(module.bias, bias)
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
                @staticmethod
         
     | 
| 150 | 
         
            +
                def normal_init(module, mean=0, std=1, bias=0):
         
     | 
| 151 | 
         
            +
                    nn.init.normal_(module.weight, mean, std)
         
     | 
| 152 | 
         
            +
                    if hasattr(module, 'bias') and module.bias is not None:
         
     | 
| 153 | 
         
            +
                        nn.init.constant_(module.bias, bias)
         
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
                @staticmethod
         
     | 
| 156 | 
         
            +
                def uniform_init(module, a=0, b=1, bias=0):
         
     | 
| 157 | 
         
            +
                    nn.init.uniform_(module.weight, a, b)
         
     | 
| 158 | 
         
            +
                    if hasattr(module, 'bias') and module.bias is not None:
         
     | 
| 159 | 
         
            +
                        nn.init.constant_(module.bias, bias)
         
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
                @staticmethod
         
     | 
| 162 | 
         
            +
                def kaiming_init(module,
         
     | 
| 163 | 
         
            +
                                 mode='fan_in',
         
     | 
| 164 | 
         
            +
                                 nonlinearity='leaky_relu',
         
     | 
| 165 | 
         
            +
                                 bias=0,
         
     | 
| 166 | 
         
            +
                                 distribution='normal'):
         
     | 
| 167 | 
         
            +
                    assert distribution in ['uniform', 'normal']
         
     | 
| 168 | 
         
            +
                    if distribution == 'uniform':
         
     | 
| 169 | 
         
            +
                        nn.init.kaiming_uniform_(
         
     | 
| 170 | 
         
            +
                            module.weight, mode=mode, nonlinearity=nonlinearity)
         
     | 
| 171 | 
         
            +
                    else:
         
     | 
| 172 | 
         
            +
                        nn.init.kaiming_normal_(
         
     | 
| 173 | 
         
            +
                            module.weight, mode=mode, nonlinearity=nonlinearity)
         
     | 
| 174 | 
         
            +
                    if hasattr(module, 'bias') and module.bias is not None:
         
     | 
| 175 | 
         
            +
                        nn.init.constant_(module.bias, bias)
         
     | 
| 176 | 
         
            +
             
     | 
    	
        networks/resnet.py
    ADDED
    
    | 
         @@ -0,0 +1,60 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
            import torch.nn as nn
         
     | 
| 5 | 
         
            +
            from .resnet_backbone import ResNetBackbone
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            class ResNet50(nn.Module):
         
     | 
| 9 | 
         
            +
                def __init__(
         
     | 
| 10 | 
         
            +
                        self,
         
     | 
| 11 | 
         
            +
                        weight_type: str = "supervised",
         
     | 
| 12 | 
         
            +
                        use_dilated_resnet: bool = True
         
     | 
| 13 | 
         
            +
                ):
         
     | 
| 14 | 
         
            +
                    super(ResNet50, self).__init__()
         
     | 
| 15 | 
         
            +
                    self.network = ResNetBackbone(backbone=f"resnet50{'_dilated8' if use_dilated_resnet else ''}", pretrained=None)
         
     | 
| 16 | 
         
            +
                    self.n_embs = self.network.num_features
         
     | 
| 17 | 
         
            +
                    self.use_dilated_resnet = use_dilated_resnet
         
     | 
| 18 | 
         
            +
                    self._load_pretrained(weight_type)
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
                def _load_pretrained(self, training_method: str) -> None:
         
     | 
| 21 | 
         
            +
                    curr_state_dict = self.network.state_dict()
         
     | 
| 22 | 
         
            +
                    if training_method == "mocov2":
         
     | 
| 23 | 
         
            +
                        state_dict = torch.load("/users/gyungin/sos/networks/pretrained/moco_v2_800ep_pretrain.pth.tar")["state_dict"]
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
                        for k in list(state_dict.keys()):
         
     | 
| 26 | 
         
            +
                            if any([k.find(w) != -1 for w in ("fc.0", "fc.2")]):
         
     | 
| 27 | 
         
            +
                                state_dict.pop(k)
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                    elif training_method == "swav":
         
     | 
| 30 | 
         
            +
                        state_dict = torch.load("/users/gyungin/sos/networks/pretrained/swav_800ep_pretrain.pth.tar")
         
     | 
| 31 | 
         
            +
                        for k in list(state_dict.keys()):
         
     | 
| 32 | 
         
            +
                            if any([k.find(w) != -1 for w in ("projection_head", "prototypes")]):
         
     | 
| 33 | 
         
            +
                                state_dict.pop(k)
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                    elif training_method == "supervised":
         
     | 
| 36 | 
         
            +
                        # Note - pytorch resnet50 model doesn't have num_batches_tracked layers. Need to know why.
         
     | 
| 37 | 
         
            +
                        # for k in list(curr_state_dict.keys()):
         
     | 
| 38 | 
         
            +
                        #     if k.find("num_batches_tracked") != -1:
         
     | 
| 39 | 
         
            +
                        #         curr_state_dict.pop(k)
         
     | 
| 40 | 
         
            +
                        # state_dict = torch.load("../networks/pretrained/resnet50-pytorch.pth")
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                        from torchvision.models.resnet import resnet50
         
     | 
| 43 | 
         
            +
                        resnet50_supervised = resnet50(True, True)
         
     | 
| 44 | 
         
            +
                        state_dict = resnet50_supervised.state_dict()
         
     | 
| 45 | 
         
            +
                        for k in list(state_dict.keys()):
         
     | 
| 46 | 
         
            +
                            if any([k.find(w) != -1 for w in ("fc.weight", "fc.bias")]):
         
     | 
| 47 | 
         
            +
                                state_dict.pop(k)
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                    assert len(curr_state_dict) == len(state_dict), f"# layers are different: {len(curr_state_dict)} != {len(state_dict)}"
         
     | 
| 50 | 
         
            +
                    for k_curr, k in zip(curr_state_dict.keys(), state_dict.keys()):
         
     | 
| 51 | 
         
            +
                        curr_state_dict[k_curr].copy_(state_dict[k])
         
     | 
| 52 | 
         
            +
                    print(f"ResNet50{' (dilated)' if self.use_dilated_resnet else ''} intialised with {training_method} weights is loaded.")
         
     | 
| 53 | 
         
            +
                    return
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                def forward(self, x):
         
     | 
| 56 | 
         
            +
                    return self.network(x)
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
            if __name__ == '__main__':
         
     | 
| 60 | 
         
            +
                resnet = ResNet50("mocov2")
         
     | 
    	
        networks/resnet_backbone.py
    ADDED
    
    | 
         @@ -0,0 +1,194 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            #!/usr/bin/env python
         
     | 
| 2 | 
         
            +
            # -*- coding:utf-8 -*-
         
     | 
| 3 | 
         
            +
            # Author: Donny You([email protected])
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import torch.nn as nn
         
     | 
| 7 | 
         
            +
            from networks.resnet_models import *
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            class NormalResnetBackbone(nn.Module):
         
     | 
| 11 | 
         
            +
                def __init__(self, orig_resnet):
         
     | 
| 12 | 
         
            +
                    super(NormalResnetBackbone, self).__init__()
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
                    self.num_features = 2048
         
     | 
| 15 | 
         
            +
                    # take pretrained resnet, except AvgPool and FC
         
     | 
| 16 | 
         
            +
                    self.prefix = orig_resnet.prefix
         
     | 
| 17 | 
         
            +
                    self.maxpool = orig_resnet.maxpool
         
     | 
| 18 | 
         
            +
                    self.layer1 = orig_resnet.layer1
         
     | 
| 19 | 
         
            +
                    self.layer2 = orig_resnet.layer2
         
     | 
| 20 | 
         
            +
                    self.layer3 = orig_resnet.layer3
         
     | 
| 21 | 
         
            +
                    self.layer4 = orig_resnet.layer4
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
                def get_num_features(self):
         
     | 
| 24 | 
         
            +
                    return self.num_features
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                def forward(self, x):
         
     | 
| 27 | 
         
            +
                    tuple_features = list()
         
     | 
| 28 | 
         
            +
                    x = self.prefix(x)
         
     | 
| 29 | 
         
            +
                    x = self.maxpool(x)
         
     | 
| 30 | 
         
            +
                    x = self.layer1(x)
         
     | 
| 31 | 
         
            +
                    tuple_features.append(x)
         
     | 
| 32 | 
         
            +
                    x = self.layer2(x)
         
     | 
| 33 | 
         
            +
                    tuple_features.append(x)
         
     | 
| 34 | 
         
            +
                    x = self.layer3(x)
         
     | 
| 35 | 
         
            +
                    tuple_features.append(x)
         
     | 
| 36 | 
         
            +
                    x = self.layer4(x)
         
     | 
| 37 | 
         
            +
                    tuple_features.append(x)
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                    return tuple_features
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
            class DilatedResnetBackbone(nn.Module):
         
     | 
| 43 | 
         
            +
                def __init__(self, orig_resnet, dilate_scale=8, multi_grid=(1, 2, 4)):
         
     | 
| 44 | 
         
            +
                    super(DilatedResnetBackbone, self).__init__()
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                    self.num_features = 2048
         
     | 
| 47 | 
         
            +
                    from functools import partial
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                    if dilate_scale == 8:
         
     | 
| 50 | 
         
            +
                        orig_resnet.layer3.apply(partial(self._nostride_dilate, dilate=2))
         
     | 
| 51 | 
         
            +
                        if multi_grid is None:
         
     | 
| 52 | 
         
            +
                            orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=4))
         
     | 
| 53 | 
         
            +
                        else:
         
     | 
| 54 | 
         
            +
                            for i, r in enumerate(multi_grid):
         
     | 
| 55 | 
         
            +
                                orig_resnet.layer4[i].apply(partial(self._nostride_dilate, dilate=int(4 * r)))
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                    elif dilate_scale == 16:
         
     | 
| 58 | 
         
            +
                        if multi_grid is None:
         
     | 
| 59 | 
         
            +
                            orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2))
         
     | 
| 60 | 
         
            +
                        else:
         
     | 
| 61 | 
         
            +
                            for i, r in enumerate(multi_grid):
         
     | 
| 62 | 
         
            +
                                orig_resnet.layer4[i].apply(partial(self._nostride_dilate, dilate=int(2 * r)))
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                    # Take pretrained resnet, except AvgPool and FC
         
     | 
| 65 | 
         
            +
                    self.prefix = orig_resnet.prefix
         
     | 
| 66 | 
         
            +
                    self.maxpool = orig_resnet.maxpool
         
     | 
| 67 | 
         
            +
                    self.layer1 = orig_resnet.layer1
         
     | 
| 68 | 
         
            +
                    self.layer2 = orig_resnet.layer2
         
     | 
| 69 | 
         
            +
                    self.layer3 = orig_resnet.layer3
         
     | 
| 70 | 
         
            +
                    self.layer4 = orig_resnet.layer4
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                def _nostride_dilate(self, m, dilate):
         
     | 
| 73 | 
         
            +
                    classname = m.__class__.__name__
         
     | 
| 74 | 
         
            +
                    if classname.find('Conv') != -1:
         
     | 
| 75 | 
         
            +
                        # the convolution with stride
         
     | 
| 76 | 
         
            +
                        if m.stride == (2, 2):
         
     | 
| 77 | 
         
            +
                            m.stride = (1, 1)
         
     | 
| 78 | 
         
            +
                            if m.kernel_size == (3, 3):
         
     | 
| 79 | 
         
            +
                                m.dilation = (dilate // 2, dilate // 2)
         
     | 
| 80 | 
         
            +
                                m.padding = (dilate // 2, dilate // 2)
         
     | 
| 81 | 
         
            +
                        # other convoluions
         
     | 
| 82 | 
         
            +
                        else:
         
     | 
| 83 | 
         
            +
                            if m.kernel_size == (3, 3):
         
     | 
| 84 | 
         
            +
                                m.dilation = (dilate, dilate)
         
     | 
| 85 | 
         
            +
                                m.padding = (dilate, dilate)
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                def get_num_features(self):
         
     | 
| 88 | 
         
            +
                    return self.num_features
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                def forward(self, x):
         
     | 
| 91 | 
         
            +
                    tuple_features = list()
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
                    x = self.prefix(x)
         
     | 
| 94 | 
         
            +
                    x = self.maxpool(x)
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                    x = self.layer1(x)
         
     | 
| 97 | 
         
            +
                    tuple_features.append(x)
         
     | 
| 98 | 
         
            +
                    x = self.layer2(x)
         
     | 
| 99 | 
         
            +
                    tuple_features.append(x)
         
     | 
| 100 | 
         
            +
                    x = self.layer3(x)
         
     | 
| 101 | 
         
            +
                    tuple_features.append(x)
         
     | 
| 102 | 
         
            +
                    x = self.layer4(x)
         
     | 
| 103 | 
         
            +
                    tuple_features.append(x)
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                    return tuple_features
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
            def ResNetBackbone(backbone=None, width_multiplier=1.0, pretrained=None, multi_grid=None, norm_type='batchnorm'):
         
     | 
| 109 | 
         
            +
                arch = backbone
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                if arch == 'resnet18':
         
     | 
| 112 | 
         
            +
                    orig_resnet = resnet18(pretrained=pretrained)
         
     | 
| 113 | 
         
            +
                    arch_net = NormalResnetBackbone(orig_resnet)
         
     | 
| 114 | 
         
            +
                    arch_net.num_features = 512
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                elif arch == 'resnet18_dilated8':
         
     | 
| 117 | 
         
            +
                    orig_resnet = resnet18(pretrained=pretrained)
         
     | 
| 118 | 
         
            +
                    arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid)
         
     | 
| 119 | 
         
            +
                    arch_net.num_features = 512
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                elif arch == 'resnet34':
         
     | 
| 122 | 
         
            +
                    orig_resnet = resnet34(pretrained=pretrained)
         
     | 
| 123 | 
         
            +
                    arch_net = NormalResnetBackbone(orig_resnet)
         
     | 
| 124 | 
         
            +
                    arch_net.num_features = 512
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
                elif arch == 'resnet34_dilated8':
         
     | 
| 127 | 
         
            +
                    orig_resnet = resnet34(pretrained=pretrained)
         
     | 
| 128 | 
         
            +
                    arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid)
         
     | 
| 129 | 
         
            +
                    arch_net.num_features = 512
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
                elif arch == 'resnet34_dilated16':
         
     | 
| 132 | 
         
            +
                    orig_resnet = resnet34(pretrained=pretrained)
         
     | 
| 133 | 
         
            +
                    arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid)
         
     | 
| 134 | 
         
            +
                    arch_net.num_features = 512
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
                elif arch == 'resnet50':
         
     | 
| 137 | 
         
            +
                    orig_resnet = resnet50(pretrained=pretrained, width_multiplier=width_multiplier)
         
     | 
| 138 | 
         
            +
                    arch_net = NormalResnetBackbone(orig_resnet)
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
                elif arch == 'resnet50_dilated8':
         
     | 
| 141 | 
         
            +
                    orig_resnet = resnet50(pretrained=pretrained, width_multiplier=width_multiplier)
         
     | 
| 142 | 
         
            +
                    arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid)
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
                elif arch == 'resnet50_dilated16':
         
     | 
| 145 | 
         
            +
                    orig_resnet = resnet50(pretrained=pretrained)
         
     | 
| 146 | 
         
            +
                    arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid)
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
                elif arch == 'deepbase_resnet50':
         
     | 
| 149 | 
         
            +
                    if pretrained:
         
     | 
| 150 | 
         
            +
                        pretrained = 'models/backbones/pretrained/3x3resnet50-imagenet.pth'
         
     | 
| 151 | 
         
            +
                    orig_resnet = deepbase_resnet50(pretrained=pretrained)
         
     | 
| 152 | 
         
            +
                    arch_net = NormalResnetBackbone(orig_resnet)
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
                elif arch == 'deepbase_resnet50_dilated8':
         
     | 
| 155 | 
         
            +
                    if pretrained:
         
     | 
| 156 | 
         
            +
                        pretrained = 'models/backbones/pretrained/3x3resnet50-imagenet.pth'
         
     | 
| 157 | 
         
            +
                        # pretrained = "/home/gishin/Projects/DeepLearning/Oxford/cct/models/backbones/pretrained/3x3resnet50-imagenet.pth"
         
     | 
| 158 | 
         
            +
                    orig_resnet = deepbase_resnet50(pretrained=pretrained)
         
     | 
| 159 | 
         
            +
                    arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid)
         
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
                elif arch == 'deepbase_resnet50_dilated16':
         
     | 
| 162 | 
         
            +
                    orig_resnet = deepbase_resnet50(pretrained=pretrained)
         
     | 
| 163 | 
         
            +
                    arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid)
         
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
                elif arch == 'resnet101':
         
     | 
| 166 | 
         
            +
                    orig_resnet = resnet101(pretrained=pretrained)
         
     | 
| 167 | 
         
            +
                    arch_net = NormalResnetBackbone(orig_resnet)
         
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
                elif arch == 'resnet101_dilated8':
         
     | 
| 170 | 
         
            +
                    orig_resnet = resnet101(pretrained=pretrained)
         
     | 
| 171 | 
         
            +
                    arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid)
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
                elif arch == 'resnet101_dilated16':
         
     | 
| 174 | 
         
            +
                    orig_resnet = resnet101(pretrained=pretrained)
         
     | 
| 175 | 
         
            +
                    arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid)
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
                elif arch == 'deepbase_resnet101':
         
     | 
| 178 | 
         
            +
                    orig_resnet = deepbase_resnet101(pretrained=pretrained)
         
     | 
| 179 | 
         
            +
                    arch_net = NormalResnetBackbone(orig_resnet)
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
                elif arch == 'deepbase_resnet101_dilated8':
         
     | 
| 182 | 
         
            +
                    if pretrained:
         
     | 
| 183 | 
         
            +
                        pretrained = 'backbones/backbones/pretrained/3x3resnet101-imagenet.pth'
         
     | 
| 184 | 
         
            +
                    orig_resnet = deepbase_resnet101(pretrained=pretrained)
         
     | 
| 185 | 
         
            +
                    arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=8, multi_grid=multi_grid)
         
     | 
| 186 | 
         
            +
             
     | 
| 187 | 
         
            +
                elif arch == 'deepbase_resnet101_dilated16':
         
     | 
| 188 | 
         
            +
                    orig_resnet = deepbase_resnet101(pretrained=pretrained)
         
     | 
| 189 | 
         
            +
                    arch_net = DilatedResnetBackbone(orig_resnet, dilate_scale=16, multi_grid=multi_grid)
         
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
                else:
         
     | 
| 192 | 
         
            +
                    raise Exception('Architecture undefined!')
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
                return arch_net
         
     | 
    	
        networks/resnet_models.py
    ADDED
    
    | 
         @@ -0,0 +1,273 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            #!/usr/bin/env python
         
     | 
| 2 | 
         
            +
            # -*- coding:utf-8 -*-
         
     | 
| 3 | 
         
            +
            # Author: Donny You([email protected])
         
     | 
| 4 | 
         
            +
            import math
         
     | 
| 5 | 
         
            +
            import torch.nn as nn
         
     | 
| 6 | 
         
            +
            from collections import OrderedDict
         
     | 
| 7 | 
         
            +
            from .module_helper import ModuleHelper
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            model_urls = {
         
     | 
| 11 | 
         
            +
                'resnet18': 'https://download.pytorch.org/backbones/resnet18-5c106cde.pth',
         
     | 
| 12 | 
         
            +
                'resnet34': 'https://download.pytorch.org/backbones/resnet34-333f7ec4.pth',
         
     | 
| 13 | 
         
            +
                'resnet50': 'https://download.pytorch.org/backbones/resnet50-19c8e357.pth',
         
     | 
| 14 | 
         
            +
                'resnet101': 'https://download.pytorch.org/backbones/resnet101-5d3b4d8f.pth',
         
     | 
| 15 | 
         
            +
                'resnet152': 'https://download.pytorch.org/backbones/resnet152-b121ed2d.pth'
         
     | 
| 16 | 
         
            +
            }
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            def conv3x3(in_planes, out_planes, stride=1):
         
     | 
| 20 | 
         
            +
                "3x3 convolution with padding"
         
     | 
| 21 | 
         
            +
                return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
         
     | 
| 22 | 
         
            +
                                 padding=1, bias=False)
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            class BasicBlock(nn.Module):
         
     | 
| 26 | 
         
            +
                expansion = 1
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
                def __init__(self, inplanes, planes, stride=1, downsample=None, norm_type=None):
         
     | 
| 29 | 
         
            +
                    super(BasicBlock, self).__init__()
         
     | 
| 30 | 
         
            +
                    self.conv1 = conv3x3(inplanes, planes, stride)
         
     | 
| 31 | 
         
            +
                    self.bn1 = ModuleHelper.BatchNorm2d(norm_type=norm_type)(planes)
         
     | 
| 32 | 
         
            +
                    self.relu = nn.ReLU(inplace=True)
         
     | 
| 33 | 
         
            +
                    self.conv2 = conv3x3(planes, planes)
         
     | 
| 34 | 
         
            +
                    self.bn2 = ModuleHelper.BatchNorm2d(norm_type=norm_type)(planes)
         
     | 
| 35 | 
         
            +
                    self.downsample = downsample
         
     | 
| 36 | 
         
            +
                    self.stride = stride
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                def forward(self, x):
         
     | 
| 39 | 
         
            +
                    residual = x
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                    out = self.conv1(x)
         
     | 
| 42 | 
         
            +
                    out = self.bn1(out)
         
     | 
| 43 | 
         
            +
                    out = self.relu(out)
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                    out = self.conv2(out)
         
     | 
| 46 | 
         
            +
                    out = self.bn2(out)
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                    if self.downsample is not None:
         
     | 
| 49 | 
         
            +
                        residual = self.downsample(x)
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                    out += residual
         
     | 
| 52 | 
         
            +
                    out = self.relu(out)
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                    return out
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
            class Bottleneck(nn.Module):
         
     | 
| 58 | 
         
            +
                expansion = 4
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                def __init__(self, inplanes, planes, stride=1, downsample=None, norm_type=None):
         
     | 
| 61 | 
         
            +
                    super(Bottleneck, self).__init__()
         
     | 
| 62 | 
         
            +
                    self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
         
     | 
| 63 | 
         
            +
                    self.bn1 = ModuleHelper.BatchNorm2d(norm_type=norm_type)(planes)
         
     | 
| 64 | 
         
            +
                    self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
         
     | 
| 65 | 
         
            +
                                           padding=1, bias=False)
         
     | 
| 66 | 
         
            +
                    self.bn2 = ModuleHelper.BatchNorm2d(norm_type=norm_type)(planes)
         
     | 
| 67 | 
         
            +
                    self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
         
     | 
| 68 | 
         
            +
                    self.bn3 = ModuleHelper.BatchNorm2d(norm_type=norm_type)(planes * 4)
         
     | 
| 69 | 
         
            +
                    self.relu = nn.ReLU(inplace=True)
         
     | 
| 70 | 
         
            +
                    self.downsample = downsample
         
     | 
| 71 | 
         
            +
                    self.stride = stride
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                def forward(self, x):
         
     | 
| 74 | 
         
            +
                    residual = x
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                    out = self.conv1(x)
         
     | 
| 77 | 
         
            +
                    out = self.bn1(out)
         
     | 
| 78 | 
         
            +
                    out = self.relu(out)
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                    out = self.conv2(out)
         
     | 
| 81 | 
         
            +
                    out = self.bn2(out)
         
     | 
| 82 | 
         
            +
                    out = self.relu(out)
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                    out = self.conv3(out)
         
     | 
| 85 | 
         
            +
                    out = self.bn3(out)
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                    if self.downsample is not None:
         
     | 
| 88 | 
         
            +
                        residual = self.downsample(x)
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                    out += residual
         
     | 
| 91 | 
         
            +
                    out = self.relu(out)
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
                    return out
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
            class ResNet(nn.Module):
         
     | 
| 97 | 
         
            +
                def __init__(self, block, layers, width_multiplier=1.0, num_classes=1000, deep_base=False, norm_type=None):
         
     | 
| 98 | 
         
            +
                    super(ResNet, self).__init__()
         
     | 
| 99 | 
         
            +
                    self.inplanes = 128 if deep_base else int(64 * width_multiplier)
         
     | 
| 100 | 
         
            +
                    self.width_multiplier = width_multiplier
         
     | 
| 101 | 
         
            +
                    if deep_base:
         
     | 
| 102 | 
         
            +
                        self.prefix = nn.Sequential(OrderedDict([
         
     | 
| 103 | 
         
            +
                            ('conv1', nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)),
         
     | 
| 104 | 
         
            +
                            ('bn1', ModuleHelper.BatchNorm2d(norm_type=norm_type)(64)),
         
     | 
| 105 | 
         
            +
                            ('relu1', nn.ReLU(inplace=False)),
         
     | 
| 106 | 
         
            +
                            ('conv2', nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)),
         
     | 
| 107 | 
         
            +
                            ('bn2', ModuleHelper.BatchNorm2d(norm_type=norm_type)(64)),
         
     | 
| 108 | 
         
            +
                            ('relu2', nn.ReLU(inplace=False)),
         
     | 
| 109 | 
         
            +
                            ('conv3', nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False)),
         
     | 
| 110 | 
         
            +
                            ('bn3', ModuleHelper.BatchNorm2d(norm_type=norm_type)(self.inplanes)),
         
     | 
| 111 | 
         
            +
                            ('relu3', nn.ReLU(inplace=False))]
         
     | 
| 112 | 
         
            +
                        ))
         
     | 
| 113 | 
         
            +
                    else:
         
     | 
| 114 | 
         
            +
                        self.prefix = nn.Sequential(OrderedDict([
         
     | 
| 115 | 
         
            +
                            ('conv1', nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)),
         
     | 
| 116 | 
         
            +
                            ('bn1', ModuleHelper.BatchNorm2d(norm_type=norm_type)(self.inplanes)),
         
     | 
| 117 | 
         
            +
                            ('relu', nn.ReLU(inplace=False))]
         
     | 
| 118 | 
         
            +
                        ))
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
                    self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False)  # change.
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
                    self.layer1 = self._make_layer(block, int(64 * width_multiplier), layers[0], norm_type=norm_type)
         
     | 
| 123 | 
         
            +
                    self.layer2 = self._make_layer(block, int(128 * width_multiplier), layers[1], stride=2, norm_type=norm_type)
         
     | 
| 124 | 
         
            +
                    self.layer3 = self._make_layer(block, int(256 * width_multiplier), layers[2], stride=2, norm_type=norm_type)
         
     | 
| 125 | 
         
            +
                    self.layer4 = self._make_layer(block, int(512 * width_multiplier), layers[3], stride=2, norm_type=norm_type)
         
     | 
| 126 | 
         
            +
                    self.avgpool = nn.AvgPool2d(7, stride=1)
         
     | 
| 127 | 
         
            +
                    self.fc = nn.Linear(int(512 * block.expansion * width_multiplier), num_classes)
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                    for m in self.modules():
         
     | 
| 130 | 
         
            +
                        if isinstance(m, nn.Conv2d):
         
     | 
| 131 | 
         
            +
                            n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
         
     | 
| 132 | 
         
            +
                            m.weight.data.normal_(0, math.sqrt(2. / n))
         
     | 
| 133 | 
         
            +
                        elif isinstance(m, ModuleHelper.BatchNorm2d(norm_type=norm_type, ret_cls=True)):
         
     | 
| 134 | 
         
            +
                            m.weight.data.fill_(1)
         
     | 
| 135 | 
         
            +
                            m.bias.data.zero_()
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                def _make_layer(self, block, planes, blocks, stride=1, norm_type=None):
         
     | 
| 138 | 
         
            +
                    downsample = None
         
     | 
| 139 | 
         
            +
                    if stride != 1 or self.inplanes != planes * block.expansion:
         
     | 
| 140 | 
         
            +
                        downsample = nn.Sequential(
         
     | 
| 141 | 
         
            +
                            nn.Conv2d(self.inplanes, planes * block.expansion,
         
     | 
| 142 | 
         
            +
                                      kernel_size=1, stride=stride, bias=False),
         
     | 
| 143 | 
         
            +
                            ModuleHelper.BatchNorm2d(norm_type=norm_type)(int(planes * block.expansion * self.width_multiplier)),
         
     | 
| 144 | 
         
            +
                        )
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
                    layers = []
         
     | 
| 147 | 
         
            +
                    layers.append(block(self.inplanes, planes,
         
     | 
| 148 | 
         
            +
                                        stride, downsample, norm_type=norm_type))
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
                    self.inplanes = planes * block.expansion
         
     | 
| 151 | 
         
            +
                    for i in range(1, blocks):
         
     | 
| 152 | 
         
            +
                        layers.append(block(self.inplanes, planes, norm_type=norm_type))
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
                    return nn.Sequential(*layers)
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
                def forward(self, x):
         
     | 
| 157 | 
         
            +
                    x = self.prefix(x)
         
     | 
| 158 | 
         
            +
                    x = self.maxpool(x)
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
                    x = self.layer1(x)
         
     | 
| 161 | 
         
            +
                    x = self.layer2(x)
         
     | 
| 162 | 
         
            +
                    x = self.layer3(x)
         
     | 
| 163 | 
         
            +
                    x = self.layer4(x)
         
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
                    x = self.avgpool(x)
         
     | 
| 166 | 
         
            +
                    x = x.view(x.size(0), -1)
         
     | 
| 167 | 
         
            +
                    x = self.fc(x)
         
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
                    return x
         
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
            def resnet18(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
         
     | 
| 173 | 
         
            +
                """Constructs a ResNet-18 model.
         
     | 
| 174 | 
         
            +
                Args:
         
     | 
| 175 | 
         
            +
                    pretrained (bool): If True, returns a model pre-trained on Places
         
     | 
| 176 | 
         
            +
                    norm_type (str): choose norm type
         
     | 
| 177 | 
         
            +
                """
         
     | 
| 178 | 
         
            +
                model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, deep_base=False, norm_type=norm_type)
         
     | 
| 179 | 
         
            +
                model = ModuleHelper.load_model(model, pretrained=pretrained)
         
     | 
| 180 | 
         
            +
                return model
         
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
            def deepbase_resnet18(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
         
     | 
| 184 | 
         
            +
                """Constructs a ResNet-18 model.
         
     | 
| 185 | 
         
            +
                Args:
         
     | 
| 186 | 
         
            +
                    pretrained (bool): If True, returns a model pre-trained on Places
         
     | 
| 187 | 
         
            +
                """
         
     | 
| 188 | 
         
            +
                model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, deep_base=True, norm_type=norm_type)
         
     | 
| 189 | 
         
            +
                model = ModuleHelper.load_model(model, pretrained=pretrained)
         
     | 
| 190 | 
         
            +
                return model
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
            def resnet34(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
         
     | 
| 194 | 
         
            +
                """Constructs a ResNet-34 model.
         
     | 
| 195 | 
         
            +
                Args:
         
     | 
| 196 | 
         
            +
                    pretrained (bool): If True, returns a model pre-trained on Places
         
     | 
| 197 | 
         
            +
                """
         
     | 
| 198 | 
         
            +
                model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, deep_base=False, norm_type=norm_type)
         
     | 
| 199 | 
         
            +
                model = ModuleHelper.load_model(model, pretrained=pretrained)
         
     | 
| 200 | 
         
            +
                return model
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
             
     | 
| 203 | 
         
            +
            def deepbase_resnet34(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
         
     | 
| 204 | 
         
            +
                """Constructs a ResNet-34 model.
         
     | 
| 205 | 
         
            +
                Args:
         
     | 
| 206 | 
         
            +
                    pretrained (bool): If True, returns a model pre-trained on Places
         
     | 
| 207 | 
         
            +
                """
         
     | 
| 208 | 
         
            +
                model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, deep_base=True, norm_type=norm_type)
         
     | 
| 209 | 
         
            +
                model = ModuleHelper.load_model(model, pretrained=pretrained)
         
     | 
| 210 | 
         
            +
                return model
         
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
             
     | 
| 213 | 
         
            +
            def resnet50(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
         
     | 
| 214 | 
         
            +
                """Constructs a ResNet-50 model.
         
     | 
| 215 | 
         
            +
                Args:
         
     | 
| 216 | 
         
            +
                    pretrained (bool): If True, returns a model pre-trained on Places
         
     | 
| 217 | 
         
            +
                """
         
     | 
| 218 | 
         
            +
                model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, deep_base=False, norm_type=norm_type,
         
     | 
| 219 | 
         
            +
                               width_multiplier=kwargs["width_multiplier"])
         
     | 
| 220 | 
         
            +
                model = ModuleHelper.load_model(model, pretrained=pretrained)
         
     | 
| 221 | 
         
            +
                return model
         
     | 
| 222 | 
         
            +
             
     | 
| 223 | 
         
            +
             
     | 
| 224 | 
         
            +
            def deepbase_resnet50(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
         
     | 
| 225 | 
         
            +
                """Constructs a ResNet-50 model.
         
     | 
| 226 | 
         
            +
                Args:
         
     | 
| 227 | 
         
            +
                    pretrained (bool): If True, returns a model pre-trained on Places
         
     | 
| 228 | 
         
            +
                """
         
     | 
| 229 | 
         
            +
                model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, deep_base=True, norm_type=norm_type)
         
     | 
| 230 | 
         
            +
                model = ModuleHelper.load_model(model, pretrained=pretrained)
         
     | 
| 231 | 
         
            +
                return model
         
     | 
| 232 | 
         
            +
             
     | 
| 233 | 
         
            +
             
     | 
| 234 | 
         
            +
            def resnet101(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
         
     | 
| 235 | 
         
            +
                """Constructs a ResNet-101 model.
         
     | 
| 236 | 
         
            +
                Args:
         
     | 
| 237 | 
         
            +
                    pretrained (bool): If True, returns a model pre-trained on Places
         
     | 
| 238 | 
         
            +
                """
         
     | 
| 239 | 
         
            +
                model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, deep_base=False, norm_type=norm_type)
         
     | 
| 240 | 
         
            +
                model = ModuleHelper.load_model(model, pretrained=pretrained)
         
     | 
| 241 | 
         
            +
                return model
         
     | 
| 242 | 
         
            +
             
     | 
| 243 | 
         
            +
             
     | 
| 244 | 
         
            +
            def deepbase_resnet101(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
         
     | 
| 245 | 
         
            +
                """Constructs a ResNet-101 model.
         
     | 
| 246 | 
         
            +
                Args:
         
     | 
| 247 | 
         
            +
                    pretrained (bool): If True, returns a model pre-trained on Places
         
     | 
| 248 | 
         
            +
                """
         
     | 
| 249 | 
         
            +
                model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, deep_base=True, norm_type=norm_type)
         
     | 
| 250 | 
         
            +
                model = ModuleHelper.load_model(model, pretrained=pretrained)
         
     | 
| 251 | 
         
            +
                return model
         
     | 
| 252 | 
         
            +
             
     | 
| 253 | 
         
            +
             
     | 
| 254 | 
         
            +
            def resnet152(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
         
     | 
| 255 | 
         
            +
                """Constructs a ResNet-152 model.
         
     | 
| 256 | 
         
            +
             
     | 
| 257 | 
         
            +
                Args:
         
     | 
| 258 | 
         
            +
                    pretrained (bool): If True, returns a model pre-trained on Places
         
     | 
| 259 | 
         
            +
                """
         
     | 
| 260 | 
         
            +
                model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, deep_base=False, norm_type=norm_type)
         
     | 
| 261 | 
         
            +
                model = ModuleHelper.load_model(model, pretrained=pretrained)
         
     | 
| 262 | 
         
            +
                return model
         
     | 
| 263 | 
         
            +
             
     | 
| 264 | 
         
            +
             
     | 
| 265 | 
         
            +
            def deepbase_resnet152(num_classes=1000, pretrained=None, norm_type='batchnorm', **kwargs):
         
     | 
| 266 | 
         
            +
                """Constructs a ResNet-152 model.
         
     | 
| 267 | 
         
            +
             
     | 
| 268 | 
         
            +
                Args:
         
     | 
| 269 | 
         
            +
                    pretrained (bool): If True, returns a model pre-trained on Places
         
     | 
| 270 | 
         
            +
                """
         
     | 
| 271 | 
         
            +
                model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, deep_base=True, norm_type=norm_type)
         
     | 
| 272 | 
         
            +
                model = ModuleHelper.load_model(model, pretrained=pretrained)
         
     | 
| 273 | 
         
            +
                return model
         
     | 
    	
        networks/timm_deit.py
    ADDED
    
    | 
         @@ -0,0 +1,254 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2015-present, Facebook, Inc.
         
     | 
| 2 | 
         
            +
            # All rights reserved.
         
     | 
| 3 | 
         
            +
            import math
         
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            import torch.nn as nn
         
     | 
| 6 | 
         
            +
            from functools import partial
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            from networks.timm_vit import VisionTransformer, _cfg
         
     | 
| 9 | 
         
            +
            from timm.models.registry import register_model
         
     | 
| 10 | 
         
            +
            from timm.models.layers import trunc_normal_
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            __all__ = [
         
     | 
| 14 | 
         
            +
                'deit_tiny_patch16_224', 'deit_small_patch16_224', 'deit_base_patch16_224',
         
     | 
| 15 | 
         
            +
                'deit_tiny_distilled_patch16_224', 'deit_small_distilled_patch16_224',
         
     | 
| 16 | 
         
            +
                'deit_base_distilled_patch16_224', 'deit_base_patch16_384',
         
     | 
| 17 | 
         
            +
                'deit_base_distilled_patch16_384',
         
     | 
| 18 | 
         
            +
            ]
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            class DistilledVisionTransformer(VisionTransformer):
         
     | 
| 22 | 
         
            +
                def __init__(self, *args, **kwargs):
         
     | 
| 23 | 
         
            +
                    super().__init__(*args, **kwargs)
         
     | 
| 24 | 
         
            +
                    self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
         
     | 
| 25 | 
         
            +
                    num_patches = self.patch_embed.num_patches
         
     | 
| 26 | 
         
            +
                    self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))
         
     | 
| 27 | 
         
            +
                    self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                    trunc_normal_(self.dist_token, std=.02)
         
     | 
| 30 | 
         
            +
                    trunc_normal_(self.pos_embed, std=.02)
         
     | 
| 31 | 
         
            +
                    self.head_dist.apply(self._init_weights)
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                def forward_features(self, x):
         
     | 
| 34 | 
         
            +
                    # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
         
     | 
| 35 | 
         
            +
                    # with slight modifications to add the dist_token
         
     | 
| 36 | 
         
            +
                    B = x.shape[0]
         
     | 
| 37 | 
         
            +
                    x = self.patch_embed(x)
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                    cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
         
     | 
| 40 | 
         
            +
                    dist_token = self.dist_token.expand(B, -1, -1)
         
     | 
| 41 | 
         
            +
                    x = torch.cat((cls_tokens, dist_token, x), dim=1)
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                    x = x + self.pos_embed
         
     | 
| 44 | 
         
            +
                    x = self.pos_drop(x)
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                    for blk in self.blocks:
         
     | 
| 47 | 
         
            +
                        x = blk(x)
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                    x = self.norm(x)
         
     | 
| 50 | 
         
            +
                    return x[:, 0], x[:, 1]
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                def forward(self, x):
         
     | 
| 53 | 
         
            +
                    x, x_dist = self.forward_features(x)
         
     | 
| 54 | 
         
            +
                    x = self.head(x)
         
     | 
| 55 | 
         
            +
                    x_dist = self.head_dist(x_dist)
         
     | 
| 56 | 
         
            +
                    if self.training:
         
     | 
| 57 | 
         
            +
                        return x, x_dist
         
     | 
| 58 | 
         
            +
                    else:
         
     | 
| 59 | 
         
            +
                        # during inference, return the average of both classifier predictions
         
     | 
| 60 | 
         
            +
                        return (x + x_dist) / 2
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                def interpolate_pos_encoding(self, x, pos_embed):
         
     | 
| 63 | 
         
            +
                    """Interpolate the learnable positional encoding to match the number of patches.
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                    x: B x (1 + 1 + N patches) x dim_embedding
         
     | 
| 66 | 
         
            +
                    pos_embed: B x (1 + 1 + N patches) x dim_embedding
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                    return interpolated positional embedding
         
     | 
| 69 | 
         
            +
                    """
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                    npatch = x.shape[1] - 2  # (H // patch_size * W // patch_size)
         
     | 
| 72 | 
         
            +
                    N = pos_embed.shape[1] - 2  # 784 (= 28 x 28)
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                    if npatch == N:
         
     | 
| 75 | 
         
            +
                        return pos_embed
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                    class_emb, distil_token, pos_embed = pos_embed[:, 0], pos_embed[:, 1], pos_embed[:, 2:]  # a learnable CLS token, learnable position embeddings
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                    dim = x.shape[-1]  # dimension of embeddings
         
     | 
| 80 | 
         
            +
                    pos_embed = nn.functional.interpolate(
         
     | 
| 81 | 
         
            +
                        pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),  # B x dim x 28 x 28
         
     | 
| 82 | 
         
            +
                        scale_factor=math.sqrt(npatch / N) + 1e-5,  # noel: this can be a float, but the output shape will be integer.
         
     | 
| 83 | 
         
            +
                        recompute_scale_factor=True,
         
     | 
| 84 | 
         
            +
                        mode='bicubic'
         
     | 
| 85 | 
         
            +
                    )
         
     | 
| 86 | 
         
            +
                    # print("pos_embed", pos_embed.shape, npatch, N, math.sqrt(npatch/N), math.sqrt(npatch/N) * int(math.sqrt(N)))
         
     | 
| 87 | 
         
            +
                    # exit(12)
         
     | 
| 88 | 
         
            +
                    pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
         
     | 
| 89 | 
         
            +
                    pos_embed = torch.cat((class_emb.unsqueeze(0), distil_token.unsqueeze(0), pos_embed), dim=1)
         
     | 
| 90 | 
         
            +
                    return pos_embed
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                def get_tokens(
         
     | 
| 93 | 
         
            +
                        self,
         
     | 
| 94 | 
         
            +
                        x,
         
     | 
| 95 | 
         
            +
                        layers: list,
         
     | 
| 96 | 
         
            +
                        patch_tokens: bool = False,
         
     | 
| 97 | 
         
            +
                        norm: bool = True,
         
     | 
| 98 | 
         
            +
                        input_tokens: bool = False,
         
     | 
| 99 | 
         
            +
                        post_pe: bool = False
         
     | 
| 100 | 
         
            +
                ):
         
     | 
| 101 | 
         
            +
                    """Return intermediate tokens."""
         
     | 
| 102 | 
         
            +
                    list_tokens: list = []
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                    B = x.shape[0]
         
     | 
| 105 | 
         
            +
                    x = self.patch_embed(x)
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                    cls_tokens = self.cls_token.expand(B, -1, -1)
         
     | 
| 108 | 
         
            +
                    dist_token = self.dist_token.expand(B, -1, -1)
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
                    x = torch.cat((cls_tokens, dist_token, x), dim=1)
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                    if input_tokens:
         
     | 
| 113 | 
         
            +
                        list_tokens.append(x)
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                    pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
         
     | 
| 116 | 
         
            +
                    x = x + pos_embed
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
                    if post_pe:
         
     | 
| 119 | 
         
            +
                        list_tokens.append(x)
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                    x = self.pos_drop(x)
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
                    for i, blk in enumerate(self.blocks):
         
     | 
| 124 | 
         
            +
                        x = blk(x)  # B x # patches x dim
         
     | 
| 125 | 
         
            +
                        if layers is None or i in layers:
         
     | 
| 126 | 
         
            +
                            list_tokens.append(self.norm(x) if norm else x)
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
                    tokens = torch.stack(list_tokens, dim=1)  # B x n_layers x (1 + # patches) x dim
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
                    if not patch_tokens:
         
     | 
| 131 | 
         
            +
                        return tokens[:, :, 0, :]  # index [CLS] tokens only, B x n_layers x dim
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                    else:
         
     | 
| 134 | 
         
            +
                        return torch.cat((tokens[:, :, 0, :].unsqueeze(dim=2), tokens[:, :, 2:, :]), dim=2)  # exclude distil token.
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
            @register_model
         
     | 
| 138 | 
         
            +
            def deit_tiny_patch16_224(pretrained=False, **kwargs):
         
     | 
| 139 | 
         
            +
                model = VisionTransformer(
         
     | 
| 140 | 
         
            +
                    patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
         
     | 
| 141 | 
         
            +
                    norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
         
     | 
| 142 | 
         
            +
                model.default_cfg = _cfg()
         
     | 
| 143 | 
         
            +
                if pretrained:
         
     | 
| 144 | 
         
            +
                    checkpoint = torch.hub.load_state_dict_from_url(
         
     | 
| 145 | 
         
            +
                        url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth",
         
     | 
| 146 | 
         
            +
                        map_location="cpu", check_hash=True
         
     | 
| 147 | 
         
            +
                    )
         
     | 
| 148 | 
         
            +
                    model.load_state_dict(checkpoint["model"])
         
     | 
| 149 | 
         
            +
                return model
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
            @register_model
         
     | 
| 153 | 
         
            +
            def deit_small_patch16_224(pretrained=False, **kwargs):
         
     | 
| 154 | 
         
            +
                model = VisionTransformer(
         
     | 
| 155 | 
         
            +
                    patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
         
     | 
| 156 | 
         
            +
                    norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
         
     | 
| 157 | 
         
            +
                model.default_cfg = _cfg()
         
     | 
| 158 | 
         
            +
                if pretrained:
         
     | 
| 159 | 
         
            +
                    checkpoint = torch.hub.load_state_dict_from_url(
         
     | 
| 160 | 
         
            +
                        url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth",
         
     | 
| 161 | 
         
            +
                        map_location="cpu", check_hash=True
         
     | 
| 162 | 
         
            +
                    )
         
     | 
| 163 | 
         
            +
                    model.load_state_dict(checkpoint["model"])
         
     | 
| 164 | 
         
            +
                return model
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
            @register_model
         
     | 
| 168 | 
         
            +
            def deit_base_patch16_224(pretrained=False, **kwargs):
         
     | 
| 169 | 
         
            +
                model = VisionTransformer(
         
     | 
| 170 | 
         
            +
                    patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
         
     | 
| 171 | 
         
            +
                    norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
         
     | 
| 172 | 
         
            +
                model.default_cfg = _cfg()
         
     | 
| 173 | 
         
            +
                if pretrained:
         
     | 
| 174 | 
         
            +
                    checkpoint = torch.hub.load_state_dict_from_url(
         
     | 
| 175 | 
         
            +
                        url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
         
     | 
| 176 | 
         
            +
                        map_location="cpu", check_hash=True
         
     | 
| 177 | 
         
            +
                    )
         
     | 
| 178 | 
         
            +
                    model.load_state_dict(checkpoint["model"])
         
     | 
| 179 | 
         
            +
                return model
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
            @register_model
         
     | 
| 183 | 
         
            +
            def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
         
     | 
| 184 | 
         
            +
                model = DistilledVisionTransformer(
         
     | 
| 185 | 
         
            +
                    patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
         
     | 
| 186 | 
         
            +
                    norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
         
     | 
| 187 | 
         
            +
                model.default_cfg = _cfg()
         
     | 
| 188 | 
         
            +
                if pretrained:
         
     | 
| 189 | 
         
            +
                    checkpoint = torch.hub.load_state_dict_from_url(
         
     | 
| 190 | 
         
            +
                        url="https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth",
         
     | 
| 191 | 
         
            +
                        map_location="cpu", check_hash=True
         
     | 
| 192 | 
         
            +
                    )
         
     | 
| 193 | 
         
            +
                    model.load_state_dict(checkpoint["model"])
         
     | 
| 194 | 
         
            +
                return model
         
     | 
| 195 | 
         
            +
             
     | 
| 196 | 
         
            +
             
     | 
| 197 | 
         
            +
            @register_model
         
     | 
| 198 | 
         
            +
            def deit_small_distilled_patch16_224(pretrained=False, **kwargs):
         
     | 
| 199 | 
         
            +
                model = DistilledVisionTransformer(
         
     | 
| 200 | 
         
            +
                    patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
         
     | 
| 201 | 
         
            +
                    norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
         
     | 
| 202 | 
         
            +
                model.default_cfg = _cfg()
         
     | 
| 203 | 
         
            +
                if pretrained:
         
     | 
| 204 | 
         
            +
                    checkpoint = torch.hub.load_state_dict_from_url(
         
     | 
| 205 | 
         
            +
                        url="https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth",
         
     | 
| 206 | 
         
            +
                        map_location="cpu", check_hash=True
         
     | 
| 207 | 
         
            +
                    )
         
     | 
| 208 | 
         
            +
                    model.load_state_dict(checkpoint["model"])
         
     | 
| 209 | 
         
            +
                return model
         
     | 
| 210 | 
         
            +
             
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
            @register_model
         
     | 
| 213 | 
         
            +
            def deit_base_distilled_patch16_224(pretrained=False, **kwargs):
         
     | 
| 214 | 
         
            +
                model = DistilledVisionTransformer(
         
     | 
| 215 | 
         
            +
                    patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
         
     | 
| 216 | 
         
            +
                    norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
         
     | 
| 217 | 
         
            +
                model.default_cfg = _cfg()
         
     | 
| 218 | 
         
            +
                if pretrained:
         
     | 
| 219 | 
         
            +
                    checkpoint = torch.hub.load_state_dict_from_url(
         
     | 
| 220 | 
         
            +
                        url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth",
         
     | 
| 221 | 
         
            +
                        map_location="cpu", check_hash=True
         
     | 
| 222 | 
         
            +
                    )
         
     | 
| 223 | 
         
            +
                    model.load_state_dict(checkpoint["model"])
         
     | 
| 224 | 
         
            +
                return model
         
     | 
| 225 | 
         
            +
             
     | 
| 226 | 
         
            +
             
     | 
| 227 | 
         
            +
            @register_model
         
     | 
| 228 | 
         
            +
            def deit_base_patch16_384(pretrained=False, **kwargs):
         
     | 
| 229 | 
         
            +
                model = VisionTransformer(
         
     | 
| 230 | 
         
            +
                    img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
         
     | 
| 231 | 
         
            +
                    norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
         
     | 
| 232 | 
         
            +
                model.default_cfg = _cfg()
         
     | 
| 233 | 
         
            +
                if pretrained:
         
     | 
| 234 | 
         
            +
                    checkpoint = torch.hub.load_state_dict_from_url(
         
     | 
| 235 | 
         
            +
                        url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth",
         
     | 
| 236 | 
         
            +
                        map_location="cpu", check_hash=True
         
     | 
| 237 | 
         
            +
                    )
         
     | 
| 238 | 
         
            +
                    model.load_state_dict(checkpoint["model"])
         
     | 
| 239 | 
         
            +
                return model
         
     | 
| 240 | 
         
            +
             
     | 
| 241 | 
         
            +
             
     | 
| 242 | 
         
            +
            @register_model
         
     | 
| 243 | 
         
            +
            def deit_base_distilled_patch16_384(pretrained=False, **kwargs):
         
     | 
| 244 | 
         
            +
                model = DistilledVisionTransformer(
         
     | 
| 245 | 
         
            +
                    img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
         
     | 
| 246 | 
         
            +
                    norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
         
     | 
| 247 | 
         
            +
                model.default_cfg = _cfg()
         
     | 
| 248 | 
         
            +
                if pretrained:
         
     | 
| 249 | 
         
            +
                    checkpoint = torch.hub.load_state_dict_from_url(
         
     | 
| 250 | 
         
            +
                        url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth",
         
     | 
| 251 | 
         
            +
                        map_location="cpu", check_hash=True
         
     | 
| 252 | 
         
            +
                    )
         
     | 
| 253 | 
         
            +
                    model.load_state_dict(checkpoint["model"])
         
     | 
| 254 | 
         
            +
                return model
         
     | 
    	
        networks/timm_vit.py
    ADDED
    
    | 
         @@ -0,0 +1,819 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """ Vision Transformer (ViT) in PyTorch
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            A PyTorch implement of Vision Transformers as described in
         
     | 
| 4 | 
         
            +
            'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            The official jax code is released and available at https://github.com/google-research/vision_transformer
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            DeiT model defs and weights from https://github.com/facebookresearch/deit,
         
     | 
| 9 | 
         
            +
            paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            Acknowledgments:
         
     | 
| 12 | 
         
            +
            * The paper authors for releasing code and weights, thanks!
         
     | 
| 13 | 
         
            +
            * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
         
     | 
| 14 | 
         
            +
            for some einops/einsum fun
         
     | 
| 15 | 
         
            +
            * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
         
     | 
| 16 | 
         
            +
            * Bert reference code checks against Huggingface Transformers and Tensorflow Bert
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            Hacked together by / Copyright 2020 Ross Wightman
         
     | 
| 19 | 
         
            +
            """
         
     | 
| 20 | 
         
            +
            import math
         
     | 
| 21 | 
         
            +
            import logging
         
     | 
| 22 | 
         
            +
            from functools import partial
         
     | 
| 23 | 
         
            +
            from collections import OrderedDict
         
     | 
| 24 | 
         
            +
            from copy import deepcopy
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            import torch
         
     | 
| 27 | 
         
            +
            import torch.nn as nn
         
     | 
| 28 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
         
     | 
| 31 | 
         
            +
            from timm.models.helpers import build_model_with_cfg, overlay_external_default_cfg
         
     | 
| 32 | 
         
            +
            from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_
         
     | 
| 33 | 
         
            +
            from timm.models.registry import register_model
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            _logger = logging.getLogger(__name__)
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
            def _cfg(url='', **kwargs):
         
     | 
| 39 | 
         
            +
                return {
         
     | 
| 40 | 
         
            +
                    'url': url,
         
     | 
| 41 | 
         
            +
                    'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
         
     | 
| 42 | 
         
            +
                    'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
         
     | 
| 43 | 
         
            +
                    'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
         
     | 
| 44 | 
         
            +
                    'first_conv': 'patch_embed.proj', 'classifier': 'head',
         
     | 
| 45 | 
         
            +
                    **kwargs
         
     | 
| 46 | 
         
            +
                }
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
            default_cfgs = {
         
     | 
| 50 | 
         
            +
                # patch models (my experiments)
         
     | 
| 51 | 
         
            +
                'vit_small_patch16_224': _cfg(
         
     | 
| 52 | 
         
            +
                    url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
         
     | 
| 53 | 
         
            +
                ),
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                # patch models (weights ported from official Google JAX impl)
         
     | 
| 56 | 
         
            +
                'vit_base_patch16_224': _cfg(
         
     | 
| 57 | 
         
            +
                    url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
         
     | 
| 58 | 
         
            +
                    mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
         
     | 
| 59 | 
         
            +
                ),
         
     | 
| 60 | 
         
            +
                'vit_base_patch32_224': _cfg(
         
     | 
| 61 | 
         
            +
                    url='',  # no official model weights for this combo, only for in21k
         
     | 
| 62 | 
         
            +
                    mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
         
     | 
| 63 | 
         
            +
                'vit_base_patch16_384': _cfg(
         
     | 
| 64 | 
         
            +
                    url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth',
         
     | 
| 65 | 
         
            +
                    input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
         
     | 
| 66 | 
         
            +
                'vit_base_patch32_384': _cfg(
         
     | 
| 67 | 
         
            +
                    url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth',
         
     | 
| 68 | 
         
            +
                    input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
         
     | 
| 69 | 
         
            +
                'vit_large_patch16_224': _cfg(
         
     | 
| 70 | 
         
            +
                    url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
         
     | 
| 71 | 
         
            +
                    mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
         
     | 
| 72 | 
         
            +
                'vit_large_patch32_224': _cfg(
         
     | 
| 73 | 
         
            +
                    url='',  # no official model weights for this combo, only for in21k
         
     | 
| 74 | 
         
            +
                    mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
         
     | 
| 75 | 
         
            +
                'vit_large_patch16_384': _cfg(
         
     | 
| 76 | 
         
            +
                    url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth',
         
     | 
| 77 | 
         
            +
                    input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
         
     | 
| 78 | 
         
            +
                'vit_large_patch32_384': _cfg(
         
     | 
| 79 | 
         
            +
                    url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
         
     | 
| 80 | 
         
            +
                    input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                # patch models, imagenet21k (weights ported from official Google JAX impl)
         
     | 
| 83 | 
         
            +
                'vit_base_patch16_224_in21k': _cfg(
         
     | 
| 84 | 
         
            +
                    url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth',
         
     | 
| 85 | 
         
            +
                    num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
         
     | 
| 86 | 
         
            +
                'vit_base_patch32_224_in21k': _cfg(
         
     | 
| 87 | 
         
            +
                    url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth',
         
     | 
| 88 | 
         
            +
                    num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
         
     | 
| 89 | 
         
            +
                'vit_large_patch16_224_in21k': _cfg(
         
     | 
| 90 | 
         
            +
                    url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth',
         
     | 
| 91 | 
         
            +
                    num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
         
     | 
| 92 | 
         
            +
                'vit_large_patch32_224_in21k': _cfg(
         
     | 
| 93 | 
         
            +
                    url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
         
     | 
| 94 | 
         
            +
                    num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
         
     | 
| 95 | 
         
            +
                'vit_huge_patch14_224_in21k': _cfg(
         
     | 
| 96 | 
         
            +
                    hf_hub='timm/vit_huge_patch14_224_in21k',
         
     | 
| 97 | 
         
            +
                    num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                # deit models (FB weights)
         
     | 
| 100 | 
         
            +
                'vit_deit_tiny_patch16_224': _cfg(
         
     | 
| 101 | 
         
            +
                    url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'),
         
     | 
| 102 | 
         
            +
                'vit_deit_small_patch16_224': _cfg(
         
     | 
| 103 | 
         
            +
                    url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'),
         
     | 
| 104 | 
         
            +
                'vit_deit_base_patch16_224': _cfg(
         
     | 
| 105 | 
         
            +
                    url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',),
         
     | 
| 106 | 
         
            +
                'vit_deit_base_patch16_384': _cfg(
         
     | 
| 107 | 
         
            +
                    url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
         
     | 
| 108 | 
         
            +
                    input_size=(3, 384, 384), crop_pct=1.0),
         
     | 
| 109 | 
         
            +
                'vit_deit_tiny_distilled_patch16_224': _cfg(
         
     | 
| 110 | 
         
            +
                    url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth',
         
     | 
| 111 | 
         
            +
                    classifier=('head', 'head_dist')),
         
     | 
| 112 | 
         
            +
                'vit_deit_small_distilled_patch16_224': _cfg(
         
     | 
| 113 | 
         
            +
                    url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
         
     | 
| 114 | 
         
            +
                    classifier=('head', 'head_dist')),
         
     | 
| 115 | 
         
            +
                'vit_deit_base_distilled_patch16_224': _cfg(
         
     | 
| 116 | 
         
            +
                    url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth',
         
     | 
| 117 | 
         
            +
                    classifier=('head', 'head_dist')),
         
     | 
| 118 | 
         
            +
                'vit_deit_base_distilled_patch16_384': _cfg(
         
     | 
| 119 | 
         
            +
                    url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
         
     | 
| 120 | 
         
            +
                    input_size=(3, 384, 384), crop_pct=1.0, classifier=('head', 'head_dist')),
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
                # ViT ImageNet-21K-P pretraining
         
     | 
| 123 | 
         
            +
                'vit_base_patch16_224_miil_in21k': _cfg(
         
     | 
| 124 | 
         
            +
                    url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_in21k_miil.pth',
         
     | 
| 125 | 
         
            +
                    mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221,
         
     | 
| 126 | 
         
            +
                ),
         
     | 
| 127 | 
         
            +
                'vit_base_patch16_224_miil': _cfg(
         
     | 
| 128 | 
         
            +
                    url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm'
         
     | 
| 129 | 
         
            +
                        '/vit_base_patch16_224_1k_miil_84_4.pth',
         
     | 
| 130 | 
         
            +
                    mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear',
         
     | 
| 131 | 
         
            +
                ),
         
     | 
| 132 | 
         
            +
            }
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
            class Attention(nn.Module):
         
     | 
| 136 | 
         
            +
                def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
         
     | 
| 137 | 
         
            +
                    super().__init__()
         
     | 
| 138 | 
         
            +
                    self.num_heads = num_heads
         
     | 
| 139 | 
         
            +
                    head_dim = dim // num_heads
         
     | 
| 140 | 
         
            +
                    self.scale = qk_scale or head_dim ** -0.5
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
                    self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
         
     | 
| 143 | 
         
            +
                    self.attn_drop = nn.Dropout(attn_drop)
         
     | 
| 144 | 
         
            +
                    self.proj = nn.Linear(dim, dim)
         
     | 
| 145 | 
         
            +
                    self.proj_drop = nn.Dropout(proj_drop)
         
     | 
| 146 | 
         
            +
             
     | 
| 147 | 
         
            +
                def forward(self, x):
         
     | 
| 148 | 
         
            +
                    B, N, C = x.shape
         
     | 
| 149 | 
         
            +
                    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
         
     | 
| 150 | 
         
            +
                    q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
                    attn = (q @ k.transpose(-2, -1)) * self.scale
         
     | 
| 153 | 
         
            +
                    attn = attn.softmax(dim=-1)
         
     | 
| 154 | 
         
            +
                    attn = self.attn_drop(attn)
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
                    x = (attn @ v).transpose(1, 2).reshape(B, N, C)
         
     | 
| 157 | 
         
            +
                    x = self.proj(x)
         
     | 
| 158 | 
         
            +
                    x = self.proj_drop(x)
         
     | 
| 159 | 
         
            +
                    return x
         
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
            class Block(nn.Module):
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
         
     | 
| 165 | 
         
            +
                             drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
         
     | 
| 166 | 
         
            +
                    super().__init__()
         
     | 
| 167 | 
         
            +
                    self.norm1 = norm_layer(dim)
         
     | 
| 168 | 
         
            +
                    self.attn = Attention(
         
     | 
| 169 | 
         
            +
                        dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
         
     | 
| 170 | 
         
            +
                    # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
         
     | 
| 171 | 
         
            +
                    self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
         
     | 
| 172 | 
         
            +
                    self.norm2 = norm_layer(dim)
         
     | 
| 173 | 
         
            +
                    mlp_hidden_dim = int(dim * mlp_ratio)
         
     | 
| 174 | 
         
            +
                    self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
         
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
                def forward(self, x):
         
     | 
| 177 | 
         
            +
                    x = x + self.drop_path(self.attn(self.norm1(x)))
         
     | 
| 178 | 
         
            +
                    x = x + self.drop_path(self.mlp(self.norm2(x)))
         
     | 
| 179 | 
         
            +
                    return x
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
            class VisionTransformer(nn.Module):
         
     | 
| 183 | 
         
            +
                """ Vision Transformer
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
                A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
         
     | 
| 186 | 
         
            +
                    - https://arxiv.org/abs/2010.11929
         
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
                Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
         
     | 
| 189 | 
         
            +
                    - https://arxiv.org/abs/2012.12877
         
     | 
| 190 | 
         
            +
                """
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
                def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
         
     | 
| 193 | 
         
            +
                             num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, distilled=False,
         
     | 
| 194 | 
         
            +
                             drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None,
         
     | 
| 195 | 
         
            +
                             act_layer=None, weight_init='',
         
     | 
| 196 | 
         
            +
                             # noel
         
     | 
| 197 | 
         
            +
                             img_size_eval: int = 224):
         
     | 
| 198 | 
         
            +
                    """
         
     | 
| 199 | 
         
            +
                    Args:
         
     | 
| 200 | 
         
            +
                        img_size (int, tuple): input image size
         
     | 
| 201 | 
         
            +
                        patch_size (int, tuple): patch size
         
     | 
| 202 | 
         
            +
                        in_chans (int): number of input channels
         
     | 
| 203 | 
         
            +
                        num_classes (int): number of classes for classification head
         
     | 
| 204 | 
         
            +
                        embed_dim (int): embedding dimension
         
     | 
| 205 | 
         
            +
                        depth (int): depth of transformer
         
     | 
| 206 | 
         
            +
                        num_heads (int): number of attention heads
         
     | 
| 207 | 
         
            +
                        mlp_ratio (int): ratio of mlp hidden dim to embedding dim
         
     | 
| 208 | 
         
            +
                        qkv_bias (bool): enable bias for qkv if True
         
     | 
| 209 | 
         
            +
                        qk_scale (float): override default qk scale of head_dim ** -0.5 if set
         
     | 
| 210 | 
         
            +
                        representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
         
     | 
| 211 | 
         
            +
                        distilled (bool): model includes a distillation token and head as in DeiT models
         
     | 
| 212 | 
         
            +
                        drop_rate (float): dropout rate
         
     | 
| 213 | 
         
            +
                        attn_drop_rate (float): attention dropout rate
         
     | 
| 214 | 
         
            +
                        drop_path_rate (float): stochastic depth rate
         
     | 
| 215 | 
         
            +
                        embed_layer (nn.Module): patch embedding layer
         
     | 
| 216 | 
         
            +
                        norm_layer: (nn.Module): normalization layer
         
     | 
| 217 | 
         
            +
                        weight_init: (str): weight init scheme
         
     | 
| 218 | 
         
            +
                    """
         
     | 
| 219 | 
         
            +
                    super().__init__()
         
     | 
| 220 | 
         
            +
                    self.num_classes = num_classes
         
     | 
| 221 | 
         
            +
                    self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
         
     | 
| 222 | 
         
            +
                    self.num_tokens = 2 if distilled else 1
         
     | 
| 223 | 
         
            +
                    norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
         
     | 
| 224 | 
         
            +
                    act_layer = act_layer or nn.GELU
         
     | 
| 225 | 
         
            +
             
     | 
| 226 | 
         
            +
                    self.patch_embed = embed_layer(
         
     | 
| 227 | 
         
            +
                        img_size=img_size,
         
     | 
| 228 | 
         
            +
                        patch_size=patch_size,
         
     | 
| 229 | 
         
            +
                        in_chans=in_chans,
         
     | 
| 230 | 
         
            +
                        embed_dim=embed_dim
         
     | 
| 231 | 
         
            +
                    )
         
     | 
| 232 | 
         
            +
                    num_patches = self.patch_embed.num_patches
         
     | 
| 233 | 
         
            +
             
     | 
| 234 | 
         
            +
                    self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
         
     | 
| 235 | 
         
            +
                    self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
         
     | 
| 236 | 
         
            +
                    self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
         
     | 
| 237 | 
         
            +
                    self.pos_drop = nn.Dropout(p=drop_rate)
         
     | 
| 238 | 
         
            +
             
     | 
| 239 | 
         
            +
                    dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
         
     | 
| 240 | 
         
            +
                    self.blocks = nn.Sequential(*[
         
     | 
| 241 | 
         
            +
                        Block(
         
     | 
| 242 | 
         
            +
                            dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
         
     | 
| 243 | 
         
            +
                            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
         
     | 
| 244 | 
         
            +
                        for i in range(depth)])
         
     | 
| 245 | 
         
            +
                    self.norm = norm_layer(embed_dim)
         
     | 
| 246 | 
         
            +
             
     | 
| 247 | 
         
            +
                    # Representation layer
         
     | 
| 248 | 
         
            +
                    if representation_size and not distilled:
         
     | 
| 249 | 
         
            +
                        self.num_features = representation_size
         
     | 
| 250 | 
         
            +
                        self.pre_logits = nn.Sequential(OrderedDict([
         
     | 
| 251 | 
         
            +
                            ('fc', nn.Linear(embed_dim, representation_size)),
         
     | 
| 252 | 
         
            +
                            ('act', nn.Tanh())
         
     | 
| 253 | 
         
            +
                        ]))
         
     | 
| 254 | 
         
            +
                    else:
         
     | 
| 255 | 
         
            +
                        self.pre_logits = nn.Identity()
         
     | 
| 256 | 
         
            +
             
     | 
| 257 | 
         
            +
                    # Classifier head(s)
         
     | 
| 258 | 
         
            +
                    self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
         
     | 
| 259 | 
         
            +
                    self.head_dist = None
         
     | 
| 260 | 
         
            +
                    if distilled:
         
     | 
| 261 | 
         
            +
                        self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
         
     | 
| 262 | 
         
            +
             
     | 
| 263 | 
         
            +
                    # Weight init
         
     | 
| 264 | 
         
            +
                    assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '')
         
     | 
| 265 | 
         
            +
                    head_bias = -math.log(self.num_classes) if 'nlhb' in weight_init else 0.
         
     | 
| 266 | 
         
            +
                    trunc_normal_(self.pos_embed, std=.02)
         
     | 
| 267 | 
         
            +
                    if self.dist_token is not None:
         
     | 
| 268 | 
         
            +
                        trunc_normal_(self.dist_token, std=.02)
         
     | 
| 269 | 
         
            +
                    if weight_init.startswith('jax'):
         
     | 
| 270 | 
         
            +
                        # leave cls token as zeros to match jax impl
         
     | 
| 271 | 
         
            +
                        for n, m in self.named_modules():
         
     | 
| 272 | 
         
            +
                            _init_vit_weights(m, n, head_bias=head_bias, jax_impl=True)
         
     | 
| 273 | 
         
            +
                    else:
         
     | 
| 274 | 
         
            +
                        trunc_normal_(self.cls_token, std=.02)
         
     | 
| 275 | 
         
            +
                        self.apply(_init_vit_weights)
         
     | 
| 276 | 
         
            +
             
     | 
| 277 | 
         
            +
                    # noel
         
     | 
| 278 | 
         
            +
                    self.depth = depth
         
     | 
| 279 | 
         
            +
                    self.distilled = distilled
         
     | 
| 280 | 
         
            +
                    self.patch_size = patch_size
         
     | 
| 281 | 
         
            +
                    self.patch_embed.img_size = (img_size_eval, img_size_eval)
         
     | 
| 282 | 
         
            +
             
     | 
| 283 | 
         
            +
                def _init_weights(self, m):
         
     | 
| 284 | 
         
            +
                    # this fn left here for compat with downstream users
         
     | 
| 285 | 
         
            +
                    _init_vit_weights(m)
         
     | 
| 286 | 
         
            +
             
     | 
| 287 | 
         
            +
                @torch.jit.ignore
         
     | 
| 288 | 
         
            +
                def no_weight_decay(self):
         
     | 
| 289 | 
         
            +
                    return {'pos_embed', 'cls_token', 'dist_token'}
         
     | 
| 290 | 
         
            +
             
     | 
| 291 | 
         
            +
                def get_classifier(self):
         
     | 
| 292 | 
         
            +
                    if self.dist_token is None:
         
     | 
| 293 | 
         
            +
                        return self.head
         
     | 
| 294 | 
         
            +
                    else:
         
     | 
| 295 | 
         
            +
                        return self.head, self.head_dist
         
     | 
| 296 | 
         
            +
             
     | 
| 297 | 
         
            +
                def reset_classifier(self, num_classes, global_pool=''):
         
     | 
| 298 | 
         
            +
                    self.num_classes = num_classes
         
     | 
| 299 | 
         
            +
                    self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
         
     | 
| 300 | 
         
            +
                    if self.num_tokens == 2:
         
     | 
| 301 | 
         
            +
                        self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
         
     | 
| 302 | 
         
            +
             
     | 
| 303 | 
         
            +
                def forward_features(self, x):
         
     | 
| 304 | 
         
            +
                    x = self.patch_embed(x)
         
     | 
| 305 | 
         
            +
                    cls_token = self.cls_token.expand(x.shape[0], -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
         
     | 
| 306 | 
         
            +
                    if self.dist_token is None:
         
     | 
| 307 | 
         
            +
                        x = torch.cat((cls_token, x), dim=1)
         
     | 
| 308 | 
         
            +
                    else:
         
     | 
| 309 | 
         
            +
                        x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
         
     | 
| 310 | 
         
            +
                    x = self.pos_drop(x + self.pos_embed)
         
     | 
| 311 | 
         
            +
                    x = self.blocks(x)
         
     | 
| 312 | 
         
            +
                    x = self.norm(x)
         
     | 
| 313 | 
         
            +
                    if self.dist_token is None:
         
     | 
| 314 | 
         
            +
                        return self.pre_logits(x[:, 0])
         
     | 
| 315 | 
         
            +
                    else:
         
     | 
| 316 | 
         
            +
                        return x[:, 0], x[:, 1]
         
     | 
| 317 | 
         
            +
             
     | 
| 318 | 
         
            +
                # def forward(self, x):
         
     | 
| 319 | 
         
            +
                #     x = self.forward_features(x)
         
     | 
| 320 | 
         
            +
                #     if self.head_dist is not None:
         
     | 
| 321 | 
         
            +
                #         x, x_dist = self.head(x[0]), self.head_dist(x[1])  # x must be a tuple
         
     | 
| 322 | 
         
            +
                #         if self.training and not torch.jit.is_scripting():
         
     | 
| 323 | 
         
            +
                #             # during inference, return the average of both classifier predictions
         
     | 
| 324 | 
         
            +
                #             return x, x_dist
         
     | 
| 325 | 
         
            +
                #         else:
         
     | 
| 326 | 
         
            +
                #             return (x + x_dist) / 2
         
     | 
| 327 | 
         
            +
                #     else:
         
     | 
| 328 | 
         
            +
                #         x = self.head(x)
         
     | 
| 329 | 
         
            +
                #     return x
         
     | 
| 330 | 
         
            +
             
     | 
| 331 | 
         
            +
                # noel - start
         
     | 
| 332 | 
         
            +
                def make_square(self, x: torch.Tensor):
         
     | 
| 333 | 
         
            +
                    """Pad some pixels to make the input size divisible by the patch size."""
         
     | 
| 334 | 
         
            +
                    B, _, H_0, W_0 = x.shape
         
     | 
| 335 | 
         
            +
                    pad_w = (self.patch_size - W_0 % self.patch_size) % self.patch_size
         
     | 
| 336 | 
         
            +
                    pad_h = (self.patch_size - H_0 % self.patch_size) % self.patch_size
         
     | 
| 337 | 
         
            +
                    x = nn.functional.pad(x, (0, pad_w, 0, pad_h), value=x.mean())
         
     | 
| 338 | 
         
            +
             
     | 
| 339 | 
         
            +
                    H_p, W_p = H_0 + pad_h, W_0 + pad_w
         
     | 
| 340 | 
         
            +
                    x = nn.functional.pad(x, (0, H_p - W_p, 0, 0) if H_p > W_p else (0, 0, 0, W_p - H_p), value=x.mean())
         
     | 
| 341 | 
         
            +
                    return x
         
     | 
| 342 | 
         
            +
             
     | 
| 343 | 
         
            +
                def interpolate_pos_encoding(self, x, pos_embed, size):
         
     | 
| 344 | 
         
            +
                    """Interpolate the learnable positional encoding to match the number of patches.
         
     | 
| 345 | 
         
            +
             
     | 
| 346 | 
         
            +
                    x: B x (1 + N patches) x dim_embedding
         
     | 
| 347 | 
         
            +
                    pos_embed: B x (1 + N patches) x dim_embedding
         
     | 
| 348 | 
         
            +
             
     | 
| 349 | 
         
            +
                    return interpolated positional embedding
         
     | 
| 350 | 
         
            +
                    """
         
     | 
| 351 | 
         
            +
                    npatch = x.shape[1] - 1  # (H // patch_size * W // patch_size)
         
     | 
| 352 | 
         
            +
                    N = pos_embed.shape[1] - 1  # 784 (= 28 x 28)
         
     | 
| 353 | 
         
            +
                    if npatch == N:
         
     | 
| 354 | 
         
            +
                        return pos_embed
         
     | 
| 355 | 
         
            +
                    class_emb, pos_embed = pos_embed[:, 0], pos_embed[:, 1:]  # a learnable CLS token, learnable position embeddings
         
     | 
| 356 | 
         
            +
             
     | 
| 357 | 
         
            +
                    dim = x.shape[-1]  # dimension of embeddings
         
     | 
| 358 | 
         
            +
                    pos_embed = nn.functional.interpolate(
         
     | 
| 359 | 
         
            +
                        pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),  # B x dim x 28 x 28
         
     | 
| 360 | 
         
            +
                        size=size,
         
     | 
| 361 | 
         
            +
                        mode='bicubic',
         
     | 
| 362 | 
         
            +
                        align_corners=False
         
     | 
| 363 | 
         
            +
                    )
         
     | 
| 364 | 
         
            +
             
     | 
| 365 | 
         
            +
                    pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
         
     | 
| 366 | 
         
            +
                    pos_embed = torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)
         
     | 
| 367 | 
         
            +
                    return pos_embed
         
     | 
| 368 | 
         
            +
             
     | 
| 369 | 
         
            +
                # def interpolate_pos_encoding(self, x, pos_embed):
         
     | 
| 370 | 
         
            +
                #     """Interpolate the learnable positional encoding to match the number of patches.
         
     | 
| 371 | 
         
            +
                #
         
     | 
| 372 | 
         
            +
                #     x: B x (1 + N patches) x dim_embedding
         
     | 
| 373 | 
         
            +
                #     pos_embed: B x (1 + N patches) x dim_embedding
         
     | 
| 374 | 
         
            +
                #
         
     | 
| 375 | 
         
            +
                #     return interpolated positional embedding
         
     | 
| 376 | 
         
            +
                #     """
         
     | 
| 377 | 
         
            +
                #     npatch = x.shape[1] - 1  # (H // patch_size * W // patch_size)
         
     | 
| 378 | 
         
            +
                #     N = pos_embed.shape[1] - 1  # 784 (= 28 x 28)
         
     | 
| 379 | 
         
            +
                #     if npatch == N:
         
     | 
| 380 | 
         
            +
                #         return pos_embed
         
     | 
| 381 | 
         
            +
                #     class_emb, pos_embed = pos_embed[:, 0], pos_embed[:, 1:]  # a learnable CLS token, learnable position embeddings
         
     | 
| 382 | 
         
            +
                #
         
     | 
| 383 | 
         
            +
                #     dim = x.shape[-1]  # dimension of embeddings
         
     | 
| 384 | 
         
            +
                #     pos_embed = nn.functional.interpolate(
         
     | 
| 385 | 
         
            +
                #         pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),  # B x dim x 28 x 28
         
     | 
| 386 | 
         
            +
                #         scale_factor=math.sqrt(npatch / N) + 1e-5,  # noel: this can be a float, but the output shape will be integer.
         
     | 
| 387 | 
         
            +
                #         recompute_scale_factor=True,
         
     | 
| 388 | 
         
            +
                #         mode='bicubic',
         
     | 
| 389 | 
         
            +
                #         align_corners=False
         
     | 
| 390 | 
         
            +
                #     )
         
     | 
| 391 | 
         
            +
                #
         
     | 
| 392 | 
         
            +
                #     pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
         
     | 
| 393 | 
         
            +
                #     pos_embed = torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)
         
     | 
| 394 | 
         
            +
                #     return pos_embed
         
     | 
| 395 | 
         
            +
             
     | 
| 396 | 
         
            +
                def prepare_tokens(self, x):
         
     | 
| 397 | 
         
            +
                    B, nc, h, w = x.shape
         
     | 
| 398 | 
         
            +
                    patch_embed_h, patch_embed_w = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size
         
     | 
| 399 | 
         
            +
                    x = self.patch_embed(x)  # patch linear embedding
         
     | 
| 400 | 
         
            +
             
     | 
| 401 | 
         
            +
                    # add the [CLS] token to the embed patch tokens
         
     | 
| 402 | 
         
            +
                    cls_tokens = self.cls_token.expand(B, -1, -1)
         
     | 
| 403 | 
         
            +
                    x = torch.cat((cls_tokens, x), dim=1)
         
     | 
| 404 | 
         
            +
             
     | 
| 405 | 
         
            +
                    # add positional encoding to each token
         
     | 
| 406 | 
         
            +
                    x = x + self.interpolate_pos_encoding(x, self.pos_embed, size=(patch_embed_h, patch_embed_w))
         
     | 
| 407 | 
         
            +
                    return self.pos_drop(x)
         
     | 
| 408 | 
         
            +
             
     | 
| 409 | 
         
            +
                def get_tokens(
         
     | 
| 410 | 
         
            +
                        self,
         
     | 
| 411 | 
         
            +
                        x,
         
     | 
| 412 | 
         
            +
                        layers: list,
         
     | 
| 413 | 
         
            +
                        patch_tokens: bool = False,
         
     | 
| 414 | 
         
            +
                        norm: bool = True,
         
     | 
| 415 | 
         
            +
                        input_tokens: bool = False,
         
     | 
| 416 | 
         
            +
                        post_pe: bool = False
         
     | 
| 417 | 
         
            +
                ):
         
     | 
| 418 | 
         
            +
                    """Return intermediate tokens."""
         
     | 
| 419 | 
         
            +
                    list_tokens: list = []
         
     | 
| 420 | 
         
            +
             
     | 
| 421 | 
         
            +
                    B = x.shape[0]
         
     | 
| 422 | 
         
            +
                    x = self.patch_embed(x)
         
     | 
| 423 | 
         
            +
             
     | 
| 424 | 
         
            +
                    cls_tokens = self.cls_token.expand(B, -1, -1)
         
     | 
| 425 | 
         
            +
             
     | 
| 426 | 
         
            +
                    x = torch.cat((cls_tokens, x), dim=1)
         
     | 
| 427 | 
         
            +
             
     | 
| 428 | 
         
            +
                    if input_tokens:
         
     | 
| 429 | 
         
            +
                        list_tokens.append(x)
         
     | 
| 430 | 
         
            +
             
     | 
| 431 | 
         
            +
                    pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
         
     | 
| 432 | 
         
            +
                    x = x + pos_embed
         
     | 
| 433 | 
         
            +
             
     | 
| 434 | 
         
            +
                    if post_pe:
         
     | 
| 435 | 
         
            +
                        list_tokens.append(x)
         
     | 
| 436 | 
         
            +
             
     | 
| 437 | 
         
            +
                    x = self.pos_drop(x)
         
     | 
| 438 | 
         
            +
             
     | 
| 439 | 
         
            +
                    for i, blk in enumerate(self.blocks):
         
     | 
| 440 | 
         
            +
                        x = blk(x)  # B x # patches x dim
         
     | 
| 441 | 
         
            +
                        if layers is None or i in layers:
         
     | 
| 442 | 
         
            +
                            list_tokens.append(self.norm(x) if norm else x)
         
     | 
| 443 | 
         
            +
             
     | 
| 444 | 
         
            +
                    tokens = torch.stack(list_tokens, dim=1)  # B x n_layers x (1 + # patches) x dim
         
     | 
| 445 | 
         
            +
             
     | 
| 446 | 
         
            +
                    if not patch_tokens:
         
     | 
| 447 | 
         
            +
                        return tokens[:, :, 0, :]  # index [CLS] tokens only, B x n_layers x dim
         
     | 
| 448 | 
         
            +
             
     | 
| 449 | 
         
            +
                    else:
         
     | 
| 450 | 
         
            +
                        return tokens
         
     | 
| 451 | 
         
            +
             
     | 
| 452 | 
         
            +
                def forward(self, x, layer: str = None):
         
     | 
| 453 | 
         
            +
                    x = self.prepare_tokens(x)
         
     | 
| 454 | 
         
            +
             
     | 
| 455 | 
         
            +
                    features: dict = {}
         
     | 
| 456 | 
         
            +
                    for i, blk in enumerate(self.blocks):
         
     | 
| 457 | 
         
            +
                        x = blk(x)
         
     | 
| 458 | 
         
            +
                        features[f"layer{i + 1}"] = self.norm(x)
         
     | 
| 459 | 
         
            +
             
     | 
| 460 | 
         
            +
                    if layer is not None:
         
     | 
| 461 | 
         
            +
                        return features[layer]
         
     | 
| 462 | 
         
            +
                    else:
         
     | 
| 463 | 
         
            +
                        return features["layer12"]
         
     | 
| 464 | 
         
            +
                # noel - end
         
     | 
| 465 | 
         
            +
             
     | 
| 466 | 
         
            +
             
     | 
| 467 | 
         
            +
            def _init_vit_weights(m, n: str = '', head_bias: float = 0., jax_impl: bool = False):
         
     | 
| 468 | 
         
            +
                """ ViT weight initialization
         
     | 
| 469 | 
         
            +
                * When called without n, head_bias, jax_impl args it will behave exactly the same
         
     | 
| 470 | 
         
            +
                  as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).
         
     | 
| 471 | 
         
            +
                * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl
         
     | 
| 472 | 
         
            +
                """
         
     | 
| 473 | 
         
            +
                if isinstance(m, nn.Linear):
         
     | 
| 474 | 
         
            +
                    if n.startswith('head'):
         
     | 
| 475 | 
         
            +
                        nn.init.zeros_(m.weight)
         
     | 
| 476 | 
         
            +
                        nn.init.constant_(m.bias, head_bias)
         
     | 
| 477 | 
         
            +
                    elif n.startswith('pre_logits'):
         
     | 
| 478 | 
         
            +
                        lecun_normal_(m.weight)
         
     | 
| 479 | 
         
            +
                        nn.init.zeros_(m.bias)
         
     | 
| 480 | 
         
            +
                    else:
         
     | 
| 481 | 
         
            +
                        if jax_impl:
         
     | 
| 482 | 
         
            +
                            nn.init.xavier_uniform_(m.weight)
         
     | 
| 483 | 
         
            +
                            if m.bias is not None:
         
     | 
| 484 | 
         
            +
                                if 'mlp' in n:
         
     | 
| 485 | 
         
            +
                                    nn.init.normal_(m.bias, std=1e-6)
         
     | 
| 486 | 
         
            +
                                else:
         
     | 
| 487 | 
         
            +
                                    nn.init.zeros_(m.bias)
         
     | 
| 488 | 
         
            +
                        else:
         
     | 
| 489 | 
         
            +
                            trunc_normal_(m.weight, std=.02)
         
     | 
| 490 | 
         
            +
                            if m.bias is not None:
         
     | 
| 491 | 
         
            +
                                nn.init.zeros_(m.bias)
         
     | 
| 492 | 
         
            +
                elif jax_impl and isinstance(m, nn.Conv2d):
         
     | 
| 493 | 
         
            +
                    # NOTE conv was left to pytorch default in my original init
         
     | 
| 494 | 
         
            +
                    lecun_normal_(m.weight)
         
     | 
| 495 | 
         
            +
                    if m.bias is not None:
         
     | 
| 496 | 
         
            +
                        nn.init.zeros_(m.bias)
         
     | 
| 497 | 
         
            +
                elif isinstance(m, nn.LayerNorm):
         
     | 
| 498 | 
         
            +
                    nn.init.zeros_(m.bias)
         
     | 
| 499 | 
         
            +
                    nn.init.ones_(m.weight)
         
     | 
| 500 | 
         
            +
             
     | 
| 501 | 
         
            +
             
     | 
| 502 | 
         
            +
            def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
         
     | 
| 503 | 
         
            +
                # Rescale the grid of position embeddings when loading from state_dict. Adapted from
         
     | 
| 504 | 
         
            +
                # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
         
     | 
| 505 | 
         
            +
                _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
         
     | 
| 506 | 
         
            +
                ntok_new = posemb_new.shape[1]
         
     | 
| 507 | 
         
            +
                if num_tokens:
         
     | 
| 508 | 
         
            +
                    posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
         
     | 
| 509 | 
         
            +
                    ntok_new -= num_tokens
         
     | 
| 510 | 
         
            +
                else:
         
     | 
| 511 | 
         
            +
                    posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
         
     | 
| 512 | 
         
            +
                gs_old = int(math.sqrt(len(posemb_grid)))
         
     | 
| 513 | 
         
            +
                if not len(gs_new):  # backwards compatibility
         
     | 
| 514 | 
         
            +
                    gs_new = [int(math.sqrt(ntok_new))] * 2
         
     | 
| 515 | 
         
            +
                assert len(gs_new) >= 2
         
     | 
| 516 | 
         
            +
                _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new)
         
     | 
| 517 | 
         
            +
                posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
         
     | 
| 518 | 
         
            +
                posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bilinear')
         
     | 
| 519 | 
         
            +
                posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
         
     | 
| 520 | 
         
            +
                posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
         
     | 
| 521 | 
         
            +
                return posemb
         
     | 
| 522 | 
         
            +
             
     | 
| 523 | 
         
            +
             
     | 
| 524 | 
         
            +
            def checkpoint_filter_fn(state_dict, model):
         
     | 
| 525 | 
         
            +
                """ convert patch embedding weight from manual patchify + linear proj to conv"""
         
     | 
| 526 | 
         
            +
                out_dict = {}
         
     | 
| 527 | 
         
            +
                if 'model' in state_dict:
         
     | 
| 528 | 
         
            +
                    # For deit models
         
     | 
| 529 | 
         
            +
                    state_dict = state_dict['model']
         
     | 
| 530 | 
         
            +
                for k, v in state_dict.items():
         
     | 
| 531 | 
         
            +
                    if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
         
     | 
| 532 | 
         
            +
                        # For old models that I trained prior to conv based patchification
         
     | 
| 533 | 
         
            +
                        O, I, H, W = model.patch_embed.proj.weight.shape
         
     | 
| 534 | 
         
            +
                        v = v.reshape(O, -1, H, W)
         
     | 
| 535 | 
         
            +
                    elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
         
     | 
| 536 | 
         
            +
                        # To resize pos embedding when using model at different size from pretrained weights
         
     | 
| 537 | 
         
            +
                        v = resize_pos_embed(
         
     | 
| 538 | 
         
            +
                            v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
         
     | 
| 539 | 
         
            +
                    out_dict[k] = v
         
     | 
| 540 | 
         
            +
                return out_dict
         
     | 
| 541 | 
         
            +
             
     | 
| 542 | 
         
            +
             
     | 
| 543 | 
         
            +
            def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs):
         
     | 
| 544 | 
         
            +
                default_cfg = default_cfg or default_cfgs[variant]
         
     | 
| 545 | 
         
            +
                if kwargs.get('features_only', None):
         
     | 
| 546 | 
         
            +
                    raise RuntimeError('features_only not implemented for Vision Transformer models.')
         
     | 
| 547 | 
         
            +
             
     | 
| 548 | 
         
            +
                # NOTE this extra code to support handling of repr size for in21k pretrained models
         
     | 
| 549 | 
         
            +
                default_num_classes = default_cfg['num_classes']
         
     | 
| 550 | 
         
            +
                num_classes = kwargs.get('num_classes', default_num_classes)
         
     | 
| 551 | 
         
            +
                repr_size = kwargs.pop('representation_size', None)
         
     | 
| 552 | 
         
            +
                if repr_size is not None and num_classes != default_num_classes:
         
     | 
| 553 | 
         
            +
                    # Remove representation layer if fine-tuning. This may not always be the desired action,
         
     | 
| 554 | 
         
            +
                    # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface?
         
     | 
| 555 | 
         
            +
                    _logger.warning("Removing representation layer for fine-tuning.")
         
     | 
| 556 | 
         
            +
                    repr_size = None
         
     | 
| 557 | 
         
            +
             
     | 
| 558 | 
         
            +
                model = build_model_with_cfg(
         
     | 
| 559 | 
         
            +
                    VisionTransformer, variant, pretrained,
         
     | 
| 560 | 
         
            +
                    default_cfg=default_cfg,
         
     | 
| 561 | 
         
            +
                    representation_size=repr_size,
         
     | 
| 562 | 
         
            +
                    pretrained_filter_fn=checkpoint_filter_fn,
         
     | 
| 563 | 
         
            +
                    **kwargs)
         
     | 
| 564 | 
         
            +
                return model
         
     | 
| 565 | 
         
            +
             
     | 
| 566 | 
         
            +
             
     | 
| 567 | 
         
            +
            @register_model
         
     | 
| 568 | 
         
            +
            def vit_small_patch16_224(pretrained=False, **kwargs):
         
     | 
| 569 | 
         
            +
                """ My custom 'small' ViT model. embed_dim=768, depth=8, num_heads=8, mlp_ratio=3.
         
     | 
| 570 | 
         
            +
                NOTE:
         
     | 
| 571 | 
         
            +
                    * this differs from the DeiT based 'small' definitions with embed_dim=384, depth=12, num_heads=6
         
     | 
| 572 | 
         
            +
                    * this model does not have a bias for QKV (unlike the official ViT and DeiT models)
         
     | 
| 573 | 
         
            +
                """
         
     | 
| 574 | 
         
            +
                model_kwargs = dict(
         
     | 
| 575 | 
         
            +
                    patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3.,
         
     | 
| 576 | 
         
            +
                    qkv_bias=False, norm_layer=nn.LayerNorm, **kwargs)
         
     | 
| 577 | 
         
            +
                if pretrained:
         
     | 
| 578 | 
         
            +
                    # NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model
         
     | 
| 579 | 
         
            +
                    model_kwargs.setdefault('qk_scale', 768 ** -0.5)
         
     | 
| 580 | 
         
            +
                model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs)
         
     | 
| 581 | 
         
            +
                return model
         
     | 
| 582 | 
         
            +
             
     | 
| 583 | 
         
            +
             
     | 
| 584 | 
         
            +
            @register_model
         
     | 
| 585 | 
         
            +
            def vit_base_patch16_224(pretrained=False, **kwargs):
         
     | 
| 586 | 
         
            +
                """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
         
     | 
| 587 | 
         
            +
                ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
         
     | 
| 588 | 
         
            +
                """
         
     | 
| 589 | 
         
            +
                model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
         
     | 
| 590 | 
         
            +
                model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs)
         
     | 
| 591 | 
         
            +
                return model
         
     | 
| 592 | 
         
            +
             
     | 
| 593 | 
         
            +
             
     | 
| 594 | 
         
            +
            @register_model
         
     | 
| 595 | 
         
            +
            def vit_base_patch32_224(pretrained=False, **kwargs):
         
     | 
| 596 | 
         
            +
                """ ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
         
     | 
| 597 | 
         
            +
                """
         
     | 
| 598 | 
         
            +
                model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
         
     | 
| 599 | 
         
            +
                model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs)
         
     | 
| 600 | 
         
            +
                return model
         
     | 
| 601 | 
         
            +
             
     | 
| 602 | 
         
            +
             
     | 
| 603 | 
         
            +
            @register_model
         
     | 
| 604 | 
         
            +
            def vit_base_patch16_384(pretrained=False, **kwargs):
         
     | 
| 605 | 
         
            +
                """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
         
     | 
| 606 | 
         
            +
                ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
         
     | 
| 607 | 
         
            +
                """
         
     | 
| 608 | 
         
            +
                model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
         
     | 
| 609 | 
         
            +
                model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs)
         
     | 
| 610 | 
         
            +
                return model
         
     | 
| 611 | 
         
            +
             
     | 
| 612 | 
         
            +
             
     | 
| 613 | 
         
            +
            @register_model
         
     | 
| 614 | 
         
            +
            def vit_base_patch32_384(pretrained=False, **kwargs):
         
     | 
| 615 | 
         
            +
                """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
         
     | 
| 616 | 
         
            +
                ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
         
     | 
| 617 | 
         
            +
                """
         
     | 
| 618 | 
         
            +
                model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
         
     | 
| 619 | 
         
            +
                model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs)
         
     | 
| 620 | 
         
            +
                return model
         
     | 
| 621 | 
         
            +
             
     | 
| 622 | 
         
            +
             
     | 
| 623 | 
         
            +
            @register_model
         
     | 
| 624 | 
         
            +
            def vit_large_patch16_224(pretrained=False, **kwargs):
         
     | 
| 625 | 
         
            +
                """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
         
     | 
| 626 | 
         
            +
                ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
         
     | 
| 627 | 
         
            +
                """
         
     | 
| 628 | 
         
            +
                model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
         
     | 
| 629 | 
         
            +
                model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs)
         
     | 
| 630 | 
         
            +
                return model
         
     | 
| 631 | 
         
            +
             
     | 
| 632 | 
         
            +
             
     | 
| 633 | 
         
            +
            @register_model
         
     | 
| 634 | 
         
            +
            def vit_large_patch32_224(pretrained=False, **kwargs):
         
     | 
| 635 | 
         
            +
                """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
         
     | 
| 636 | 
         
            +
                """
         
     | 
| 637 | 
         
            +
                model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
         
     | 
| 638 | 
         
            +
                model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs)
         
     | 
| 639 | 
         
            +
                return model
         
     | 
| 640 | 
         
            +
             
     | 
| 641 | 
         
            +
             
     | 
| 642 | 
         
            +
            @register_model
         
     | 
| 643 | 
         
            +
            def vit_large_patch16_384(pretrained=False, **kwargs):
         
     | 
| 644 | 
         
            +
                """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
         
     | 
| 645 | 
         
            +
                ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
         
     | 
| 646 | 
         
            +
                """
         
     | 
| 647 | 
         
            +
                model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
         
     | 
| 648 | 
         
            +
                model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs)
         
     | 
| 649 | 
         
            +
                return model
         
     | 
| 650 | 
         
            +
             
     | 
| 651 | 
         
            +
             
     | 
| 652 | 
         
            +
            @register_model
         
     | 
| 653 | 
         
            +
            def vit_large_patch32_384(pretrained=False, **kwargs):
         
     | 
| 654 | 
         
            +
                """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
         
     | 
| 655 | 
         
            +
                ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
         
     | 
| 656 | 
         
            +
                """
         
     | 
| 657 | 
         
            +
                model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
         
     | 
| 658 | 
         
            +
                model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs)
         
     | 
| 659 | 
         
            +
                return model
         
     | 
| 660 | 
         
            +
             
     | 
| 661 | 
         
            +
             
     | 
| 662 | 
         
            +
            @register_model
         
     | 
| 663 | 
         
            +
            def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
         
     | 
| 664 | 
         
            +
                """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
         
     | 
| 665 | 
         
            +
                ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
         
     | 
| 666 | 
         
            +
                """
         
     | 
| 667 | 
         
            +
                model_kwargs = dict(
         
     | 
| 668 | 
         
            +
                    patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
         
     | 
| 669 | 
         
            +
                model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
         
     | 
| 670 | 
         
            +
                return model
         
     | 
| 671 | 
         
            +
             
     | 
| 672 | 
         
            +
             
     | 
| 673 | 
         
            +
            @register_model
         
     | 
| 674 | 
         
            +
            def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
         
     | 
| 675 | 
         
            +
                """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
         
     | 
| 676 | 
         
            +
                ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
         
     | 
| 677 | 
         
            +
                """
         
     | 
| 678 | 
         
            +
                model_kwargs = dict(
         
     | 
| 679 | 
         
            +
                    patch_size=32, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
         
     | 
| 680 | 
         
            +
                model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
         
     | 
| 681 | 
         
            +
                return model
         
     | 
| 682 | 
         
            +
             
     | 
| 683 | 
         
            +
             
     | 
| 684 | 
         
            +
            @register_model
         
     | 
| 685 | 
         
            +
            def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
         
     | 
| 686 | 
         
            +
                """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
         
     | 
| 687 | 
         
            +
                ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
         
     | 
| 688 | 
         
            +
                """
         
     | 
| 689 | 
         
            +
                model_kwargs = dict(
         
     | 
| 690 | 
         
            +
                    patch_size=16, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
         
     | 
| 691 | 
         
            +
                model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
         
     | 
| 692 | 
         
            +
                return model
         
     | 
| 693 | 
         
            +
             
     | 
| 694 | 
         
            +
             
     | 
| 695 | 
         
            +
            @register_model
         
     | 
| 696 | 
         
            +
            def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
         
     | 
| 697 | 
         
            +
                """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
         
     | 
| 698 | 
         
            +
                ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
         
     | 
| 699 | 
         
            +
                """
         
     | 
| 700 | 
         
            +
                model_kwargs = dict(
         
     | 
| 701 | 
         
            +
                    patch_size=32, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
         
     | 
| 702 | 
         
            +
                model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
         
     | 
| 703 | 
         
            +
                return model
         
     | 
| 704 | 
         
            +
             
     | 
| 705 | 
         
            +
             
     | 
| 706 | 
         
            +
            @register_model
         
     | 
| 707 | 
         
            +
            def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
         
     | 
| 708 | 
         
            +
                """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
         
     | 
| 709 | 
         
            +
                ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
         
     | 
| 710 | 
         
            +
                NOTE: converted weights not currently available, too large for github release hosting.
         
     | 
| 711 | 
         
            +
                """
         
     | 
| 712 | 
         
            +
                model_kwargs = dict(
         
     | 
| 713 | 
         
            +
                    patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs)
         
     | 
| 714 | 
         
            +
                model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs)
         
     | 
| 715 | 
         
            +
                return model
         
     | 
| 716 | 
         
            +
             
     | 
| 717 | 
         
            +
             
     | 
| 718 | 
         
            +
            @register_model
         
     | 
| 719 | 
         
            +
            def vit_deit_tiny_patch16_224(pretrained=False, **kwargs):
         
     | 
| 720 | 
         
            +
                """ DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
         
     | 
| 721 | 
         
            +
                ImageNet-1k weights from https://github.com/facebookresearch/deit.
         
     | 
| 722 | 
         
            +
                """
         
     | 
| 723 | 
         
            +
                model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
         
     | 
| 724 | 
         
            +
                model = _create_vision_transformer('vit_deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
         
     | 
| 725 | 
         
            +
                return model
         
     | 
| 726 | 
         
            +
             
     | 
| 727 | 
         
            +
             
     | 
| 728 | 
         
            +
            @register_model
         
     | 
| 729 | 
         
            +
            def vit_deit_small_patch16_224(pretrained=False, **kwargs):
         
     | 
| 730 | 
         
            +
                """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
         
     | 
| 731 | 
         
            +
                ImageNet-1k weights from https://github.com/facebookresearch/deit.
         
     | 
| 732 | 
         
            +
                """
         
     | 
| 733 | 
         
            +
                model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
         
     | 
| 734 | 
         
            +
                model = _create_vision_transformer('vit_deit_small_patch16_224', pretrained=pretrained, **model_kwargs)
         
     | 
| 735 | 
         
            +
                return model
         
     | 
| 736 | 
         
            +
             
     | 
| 737 | 
         
            +
             
     | 
| 738 | 
         
            +
            @register_model
         
     | 
| 739 | 
         
            +
            def vit_deit_base_patch16_224(pretrained=False, **kwargs):
         
     | 
| 740 | 
         
            +
                """ DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
         
     | 
| 741 | 
         
            +
                ImageNet-1k weights from https://github.com/facebookresearch/deit.
         
     | 
| 742 | 
         
            +
                """
         
     | 
| 743 | 
         
            +
                model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
         
     | 
| 744 | 
         
            +
                model = _create_vision_transformer('vit_deit_base_patch16_224', pretrained=pretrained, **model_kwargs)
         
     | 
| 745 | 
         
            +
                return model
         
     | 
| 746 | 
         
            +
             
     | 
| 747 | 
         
            +
             
     | 
| 748 | 
         
            +
            @register_model
         
     | 
| 749 | 
         
            +
            def vit_deit_base_patch16_384(pretrained=False, **kwargs):
         
     | 
| 750 | 
         
            +
                """ DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
         
     | 
| 751 | 
         
            +
                ImageNet-1k weights from https://github.com/facebookresearch/deit.
         
     | 
| 752 | 
         
            +
                """
         
     | 
| 753 | 
         
            +
                model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
         
     | 
| 754 | 
         
            +
                model = _create_vision_transformer('vit_deit_base_patch16_384', pretrained=pretrained, **model_kwargs)
         
     | 
| 755 | 
         
            +
                return model
         
     | 
| 756 | 
         
            +
             
     | 
| 757 | 
         
            +
             
     | 
| 758 | 
         
            +
            @register_model
         
     | 
| 759 | 
         
            +
            def vit_deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
         
     | 
| 760 | 
         
            +
                """ DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
         
     | 
| 761 | 
         
            +
                ImageNet-1k weights from https://github.com/facebookresearch/deit.
         
     | 
| 762 | 
         
            +
                """
         
     | 
| 763 | 
         
            +
                model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
         
     | 
| 764 | 
         
            +
                model = _create_vision_transformer(
         
     | 
| 765 | 
         
            +
                    'vit_deit_tiny_distilled_patch16_224', pretrained=pretrained,  distilled=True, **model_kwargs)
         
     | 
| 766 | 
         
            +
                return model
         
     | 
| 767 | 
         
            +
             
     | 
| 768 | 
         
            +
             
     | 
| 769 | 
         
            +
            @register_model
         
     | 
| 770 | 
         
            +
            def vit_deit_small_distilled_patch16_224(pretrained=False, **kwargs):
         
     | 
| 771 | 
         
            +
                """ DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
         
     | 
| 772 | 
         
            +
                ImageNet-1k weights from https://github.com/facebookresearch/deit.
         
     | 
| 773 | 
         
            +
                """
         
     | 
| 774 | 
         
            +
                model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
         
     | 
| 775 | 
         
            +
                model = _create_vision_transformer(
         
     | 
| 776 | 
         
            +
                    'vit_deit_small_distilled_patch16_224', pretrained=pretrained,  distilled=True, **model_kwargs)
         
     | 
| 777 | 
         
            +
                return model
         
     | 
| 778 | 
         
            +
             
     | 
| 779 | 
         
            +
             
     | 
| 780 | 
         
            +
            @register_model
         
     | 
| 781 | 
         
            +
            def vit_deit_base_distilled_patch16_224(pretrained=False, **kwargs):
         
     | 
| 782 | 
         
            +
                """ DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
         
     | 
| 783 | 
         
            +
                ImageNet-1k weights from https://github.com/facebookresearch/deit.
         
     | 
| 784 | 
         
            +
                """
         
     | 
| 785 | 
         
            +
                model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
         
     | 
| 786 | 
         
            +
                model = _create_vision_transformer(
         
     | 
| 787 | 
         
            +
                    'vit_deit_base_distilled_patch16_224', pretrained=pretrained,  distilled=True, **model_kwargs)
         
     | 
| 788 | 
         
            +
                return model
         
     | 
| 789 | 
         
            +
             
     | 
| 790 | 
         
            +
             
     | 
| 791 | 
         
            +
            @register_model
         
     | 
| 792 | 
         
            +
            def vit_deit_base_distilled_patch16_384(pretrained=False, **kwargs):
         
     | 
| 793 | 
         
            +
                """ DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
         
     | 
| 794 | 
         
            +
                ImageNet-1k weights from https://github.com/facebookresearch/deit.
         
     | 
| 795 | 
         
            +
                """
         
     | 
| 796 | 
         
            +
                model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
         
     | 
| 797 | 
         
            +
                model = _create_vision_transformer(
         
     | 
| 798 | 
         
            +
                    'vit_deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs)
         
     | 
| 799 | 
         
            +
                return model
         
     | 
| 800 | 
         
            +
             
     | 
| 801 | 
         
            +
             
     | 
| 802 | 
         
            +
            @register_model
         
     | 
| 803 | 
         
            +
            def vit_base_patch16_224_miil_in21k(pretrained=False, **kwargs):
         
     | 
| 804 | 
         
            +
                """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
         
     | 
| 805 | 
         
            +
                Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
         
     | 
| 806 | 
         
            +
                """
         
     | 
| 807 | 
         
            +
                model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs)
         
     | 
| 808 | 
         
            +
                model = _create_vision_transformer('vit_base_patch16_224_miil_in21k', pretrained=pretrained, **model_kwargs)
         
     | 
| 809 | 
         
            +
                return model
         
     | 
| 810 | 
         
            +
             
     | 
| 811 | 
         
            +
             
     | 
| 812 | 
         
            +
            @register_model
         
     | 
| 813 | 
         
            +
            def vit_base_patch16_224_miil(pretrained=False, **kwargs):
         
     | 
| 814 | 
         
            +
                """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
         
     | 
| 815 | 
         
            +
                Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
         
     | 
| 816 | 
         
            +
                """
         
     | 
| 817 | 
         
            +
                model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs)
         
     | 
| 818 | 
         
            +
                model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs)
         
     | 
| 819 | 
         
            +
                return model
         
     | 
    	
        networks/vision_transformer.py
    ADDED
    
    | 
         @@ -0,0 +1,569 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
         
     | 
| 2 | 
         
            +
            """
         
     | 
| 3 | 
         
            +
            Mostly copy-paste from timm library.
         
     | 
| 4 | 
         
            +
            https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
         
     | 
| 5 | 
         
            +
            """
         
     | 
| 6 | 
         
            +
            from typing import Optional
         
     | 
| 7 | 
         
            +
            import math
         
     | 
| 8 | 
         
            +
            from functools import partial
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            import torch
         
     | 
| 11 | 
         
            +
            import torch.nn as nn
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            def _no_grad_trunc_normal_(tensor, mean, std, a, b):
         
     | 
| 15 | 
         
            +
                # Cut & paste from PyTorch official master until it's in a few official releases - RW
         
     | 
| 16 | 
         
            +
                # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
         
     | 
| 17 | 
         
            +
                def norm_cdf(x):
         
     | 
| 18 | 
         
            +
                    # Computes standard normal cumulative distribution function
         
     | 
| 19 | 
         
            +
                    return (1. + math.erf(x / math.sqrt(2.))) / 2.
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
                if (mean < a - 2 * std) or (mean > b + 2 * std):
         
     | 
| 22 | 
         
            +
                    warnings.warn(
         
     | 
| 23 | 
         
            +
                        "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. The distribution of values may be incorrect.",
         
     | 
| 24 | 
         
            +
                        stacklevel=2
         
     | 
| 25 | 
         
            +
                    )
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                with torch.no_grad():
         
     | 
| 28 | 
         
            +
                    # Values are generated by using a truncated uniform distribution and
         
     | 
| 29 | 
         
            +
                    # then using the inverse CDF for the normal distribution.
         
     | 
| 30 | 
         
            +
                    # Get upper and lower cdf values
         
     | 
| 31 | 
         
            +
                    l = norm_cdf((a - mean) / std)
         
     | 
| 32 | 
         
            +
                    u = norm_cdf((b - mean) / std)
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                    # Uniformly fill tensor with values from [l, u], then translate to
         
     | 
| 35 | 
         
            +
                    # [2l-1, 2u-1].
         
     | 
| 36 | 
         
            +
                    tensor.uniform_(2 * l - 1, 2 * u - 1)
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                    # Use inverse cdf transform for normal distribution to get truncated
         
     | 
| 39 | 
         
            +
                    # standard normal
         
     | 
| 40 | 
         
            +
                    tensor.erfinv_()
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                    # Transform to proper mean, std
         
     | 
| 43 | 
         
            +
                    tensor.mul_(std * math.sqrt(2.))
         
     | 
| 44 | 
         
            +
                    tensor.add_(mean)
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                    # Clamp to ensure it's in the proper range
         
     | 
| 47 | 
         
            +
                    tensor.clamp_(min=a, max=b)
         
     | 
| 48 | 
         
            +
                    return tensor
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
            def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
         
     | 
| 52 | 
         
            +
                # type: (Tensor, float, float, float, float) -> Tensor
         
     | 
| 53 | 
         
            +
                return _no_grad_trunc_normal_(tensor, mean, std, a, b)
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
            def drop_path(x, drop_prob: float = 0., training: bool = False):
         
     | 
| 57 | 
         
            +
                if drop_prob == 0. or not training:
         
     | 
| 58 | 
         
            +
                    return x
         
     | 
| 59 | 
         
            +
                keep_prob = 1 - drop_prob
         
     | 
| 60 | 
         
            +
                shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
         
     | 
| 61 | 
         
            +
                random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
         
     | 
| 62 | 
         
            +
                random_tensor.floor_()  # binarize
         
     | 
| 63 | 
         
            +
                output = x.div(keep_prob) * random_tensor
         
     | 
| 64 | 
         
            +
                return output
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
            class DropPath(nn.Module):
         
     | 
| 68 | 
         
            +
                """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
         
     | 
| 69 | 
         
            +
                """
         
     | 
| 70 | 
         
            +
                def __init__(self, drop_prob=None):
         
     | 
| 71 | 
         
            +
                    super(DropPath, self).__init__()
         
     | 
| 72 | 
         
            +
                    self.drop_prob = drop_prob
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                def forward(self, x):
         
     | 
| 75 | 
         
            +
                    return drop_path(x, self.drop_prob, self.training)
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
            class Mlp(nn.Module):
         
     | 
| 79 | 
         
            +
                def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
         
     | 
| 80 | 
         
            +
                    super().__init__()
         
     | 
| 81 | 
         
            +
                    out_features = out_features or in_features
         
     | 
| 82 | 
         
            +
                    hidden_features = hidden_features or in_features
         
     | 
| 83 | 
         
            +
                    self.fc1 = nn.Linear(in_features, hidden_features)
         
     | 
| 84 | 
         
            +
                    self.act = act_layer()
         
     | 
| 85 | 
         
            +
                    self.fc2 = nn.Linear(hidden_features, out_features)
         
     | 
| 86 | 
         
            +
                    self.drop = nn.Dropout(drop)
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                def forward(self, x):
         
     | 
| 89 | 
         
            +
                    x = self.fc1(x)
         
     | 
| 90 | 
         
            +
                    x = self.act(x)
         
     | 
| 91 | 
         
            +
                    x = self.drop(x)
         
     | 
| 92 | 
         
            +
                    x = self.fc2(x)
         
     | 
| 93 | 
         
            +
                    x = self.drop(x)
         
     | 
| 94 | 
         
            +
                    return x
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
            class Attention(nn.Module):
         
     | 
| 98 | 
         
            +
                def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
         
     | 
| 99 | 
         
            +
                    super().__init__()
         
     | 
| 100 | 
         
            +
                    self.num_heads = num_heads
         
     | 
| 101 | 
         
            +
                    head_dim = dim // num_heads
         
     | 
| 102 | 
         
            +
                    self.scale = qk_scale or head_dim ** -0.5  # square root of dimension for normalisation
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                    self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
         
     | 
| 105 | 
         
            +
                    self.attn_drop = nn.Dropout(attn_drop)
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                    self.proj = nn.Linear(dim, dim)
         
     | 
| 108 | 
         
            +
                    self.proj_drop = nn.Dropout(proj_drop)
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
                def forward(self, x):
         
     | 
| 111 | 
         
            +
                    B, N, C = x.shape  # B x (cls token + # patch tokens) x dim
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
                    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
         
     | 
| 114 | 
         
            +
                    # qkv: 3 x B x Nh x (cls token + # patch tokens) x (dim // Nh)
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                    q, k, v = qkv[0], qkv[1], qkv[2]
         
     | 
| 117 | 
         
            +
                    # q, k, v: B x Nh x (cls token + # patch tokens) x (dim // Nh)
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                    # q: B x Nh x (cls token + # patch tokens) x (dim // Nh)
         
     | 
| 120 | 
         
            +
                    # k.transpose(-2, -1) = B x Nh x (dim // Nh) x (cls token + # patch tokens)
         
     | 
| 121 | 
         
            +
                    # attn: B x Nh x (cls token + # patch tokens) x (cls token + # patch tokens)
         
     | 
| 122 | 
         
            +
                    attn = (q @ k.transpose(-2, -1)) * self.scale  # @ operator is for matrix multiplication
         
     | 
| 123 | 
         
            +
                    attn = attn.softmax(dim=-1)  # B x Nh x (cls token + # patch tokens) x (cls token + # patch tokens)
         
     | 
| 124 | 
         
            +
                    attn = self.attn_drop(attn)
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
                    # attn = B x Nh x (cls token + # patch tokens) x (cls token + # patch tokens)
         
     | 
| 127 | 
         
            +
                    # v = B x Nh x (cls token + # patch tokens) x (dim // Nh)
         
     | 
| 128 | 
         
            +
                    # attn @ v = B x Nh x (cls token + # patch tokens) x (dim // Nh)
         
     | 
| 129 | 
         
            +
                    # (attn @ v).transpose(1, 2) = B x (cls token + # patch tokens) x Nh x (dim // Nh)
         
     | 
| 130 | 
         
            +
                    x = (attn @ v).transpose(1, 2).reshape(B, N, C)  # B x (cls token + # patch tokens) x dim
         
     | 
| 131 | 
         
            +
                    x = self.proj(x)  # B x (cls token + # patch tokens) x dim
         
     | 
| 132 | 
         
            +
                    x = self.proj_drop(x)
         
     | 
| 133 | 
         
            +
                    return x, attn
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
            class Block(nn.Module):
         
     | 
| 137 | 
         
            +
                def __init__(self,
         
     | 
| 138 | 
         
            +
                             dim, num_heads,
         
     | 
| 139 | 
         
            +
                             mlp_ratio=4.,
         
     | 
| 140 | 
         
            +
                             qkv_bias=False,
         
     | 
| 141 | 
         
            +
                             qk_scale=None,
         
     | 
| 142 | 
         
            +
                             drop=0.,
         
     | 
| 143 | 
         
            +
                             attn_drop=0.,
         
     | 
| 144 | 
         
            +
                             drop_path=0.,
         
     | 
| 145 | 
         
            +
                             act_layer=nn.GELU,
         
     | 
| 146 | 
         
            +
                             norm_layer=nn.LayerNorm):
         
     | 
| 147 | 
         
            +
                    super().__init__()
         
     | 
| 148 | 
         
            +
                    self.norm1 = norm_layer(dim)
         
     | 
| 149 | 
         
            +
                    self.attn = Attention(
         
     | 
| 150 | 
         
            +
                        dim,
         
     | 
| 151 | 
         
            +
                        num_heads=num_heads,
         
     | 
| 152 | 
         
            +
                        qkv_bias=qkv_bias,
         
     | 
| 153 | 
         
            +
                        qk_scale=qk_scale,
         
     | 
| 154 | 
         
            +
                        attn_drop=attn_drop,
         
     | 
| 155 | 
         
            +
                        proj_drop=drop
         
     | 
| 156 | 
         
            +
                    )
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
                    self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
                    self.norm2 = norm_layer(dim)
         
     | 
| 161 | 
         
            +
                    mlp_hidden_dim = int(dim * mlp_ratio)
         
     | 
| 162 | 
         
            +
                    self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                def forward(self, x, return_attention=False):
         
     | 
| 165 | 
         
            +
                    y, attn = self.attn(self.norm1(x))
         
     | 
| 166 | 
         
            +
                    if return_attention:
         
     | 
| 167 | 
         
            +
                        return attn
         
     | 
| 168 | 
         
            +
                    x = x + self.drop_path(y)
         
     | 
| 169 | 
         
            +
                    x = x + self.drop_path(self.mlp(self.norm2(x)))
         
     | 
| 170 | 
         
            +
                    return x
         
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
            class PatchEmbed(nn.Module):
         
     | 
| 174 | 
         
            +
                """ Image to Patch Embedding"""
         
     | 
| 175 | 
         
            +
                def __init__(self, img_size=(224, 224), patch_size=16, in_chans=3, embed_dim=768):
         
     | 
| 176 | 
         
            +
                    super().__init__()
         
     | 
| 177 | 
         
            +
                    num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size)
         
     | 
| 178 | 
         
            +
                    self.img_size = img_size
         
     | 
| 179 | 
         
            +
                    self.patch_size = patch_size
         
     | 
| 180 | 
         
            +
                    self.num_patches = num_patches
         
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
                    self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
         
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
                def forward(self, x):
         
     | 
| 185 | 
         
            +
                    B, C, H, W = x.shape
         
     | 
| 186 | 
         
            +
                    x = self.proj(x)
         
     | 
| 187 | 
         
            +
                    x = x.flatten(2).transpose(1, 2)  # B x (P_H * P_W) x C
         
     | 
| 188 | 
         
            +
                    return x
         
     | 
| 189 | 
         
            +
             
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
            class VisionTransformer(nn.Module):
         
     | 
| 192 | 
         
            +
                """ Vision Transformer """
         
     | 
| 193 | 
         
            +
                def __init__(self,
         
     | 
| 194 | 
         
            +
                             img_size=(224, 224),
         
     | 
| 195 | 
         
            +
                             patch_size=16,
         
     | 
| 196 | 
         
            +
                             in_chans=3,
         
     | 
| 197 | 
         
            +
                             num_classes=0,
         
     | 
| 198 | 
         
            +
                             embed_dim=768,
         
     | 
| 199 | 
         
            +
                             depth=12,
         
     | 
| 200 | 
         
            +
                             num_heads=12,
         
     | 
| 201 | 
         
            +
                             mlp_ratio=4.,
         
     | 
| 202 | 
         
            +
                             qkv_bias=False,
         
     | 
| 203 | 
         
            +
                             qk_scale=None,
         
     | 
| 204 | 
         
            +
                             drop_rate=0.,
         
     | 
| 205 | 
         
            +
                             attn_drop_rate=0.,
         
     | 
| 206 | 
         
            +
                             drop_path_rate=0.,
         
     | 
| 207 | 
         
            +
                             norm_layer=nn.LayerNorm):
         
     | 
| 208 | 
         
            +
                    super().__init__()
         
     | 
| 209 | 
         
            +
                    self.num_features = self.embed_dim = embed_dim
         
     | 
| 210 | 
         
            +
             
     | 
| 211 | 
         
            +
                    self.patch_embed = PatchEmbed(
         
     | 
| 212 | 
         
            +
                        img_size=(224, 224),  # noel: this is to load pretrained model.
         
     | 
| 213 | 
         
            +
                        patch_size=patch_size,
         
     | 
| 214 | 
         
            +
                        in_chans=in_chans,
         
     | 
| 215 | 
         
            +
                        embed_dim=embed_dim
         
     | 
| 216 | 
         
            +
                    )
         
     | 
| 217 | 
         
            +
                    num_patches = self.patch_embed.num_patches
         
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
            +
                    self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
         
     | 
| 220 | 
         
            +
                    self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
         
     | 
| 221 | 
         
            +
                    self.pos_drop = nn.Dropout(p=drop_rate)
         
     | 
| 222 | 
         
            +
             
     | 
| 223 | 
         
            +
                    dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
         
     | 
| 224 | 
         
            +
                    self.blocks = nn.ModuleList([
         
     | 
| 225 | 
         
            +
                        Block(
         
     | 
| 226 | 
         
            +
                            dim=embed_dim,
         
     | 
| 227 | 
         
            +
                            num_heads=num_heads,
         
     | 
| 228 | 
         
            +
                            mlp_ratio=mlp_ratio,
         
     | 
| 229 | 
         
            +
                            qkv_bias=qkv_bias,
         
     | 
| 230 | 
         
            +
                            qk_scale=qk_scale,
         
     | 
| 231 | 
         
            +
                            drop=drop_rate,
         
     | 
| 232 | 
         
            +
                            attn_drop=attn_drop_rate,
         
     | 
| 233 | 
         
            +
                            drop_path=dpr[i],
         
     | 
| 234 | 
         
            +
                            norm_layer=norm_layer
         
     | 
| 235 | 
         
            +
                        ) for i in range(depth)])
         
     | 
| 236 | 
         
            +
                    self.norm = norm_layer(embed_dim)
         
     | 
| 237 | 
         
            +
             
     | 
| 238 | 
         
            +
                    # Classifier head
         
     | 
| 239 | 
         
            +
                    self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
         
     | 
| 240 | 
         
            +
             
     | 
| 241 | 
         
            +
                    trunc_normal_(self.pos_embed, std=.02)
         
     | 
| 242 | 
         
            +
                    trunc_normal_(self.cls_token, std=.02)
         
     | 
| 243 | 
         
            +
                    self.apply(self._init_weights)
         
     | 
| 244 | 
         
            +
             
     | 
| 245 | 
         
            +
                    self.depth = depth
         
     | 
| 246 | 
         
            +
                    self.embed_dim = self.n_embs = embed_dim
         
     | 
| 247 | 
         
            +
                    self.mlp_ratio = mlp_ratio
         
     | 
| 248 | 
         
            +
                    self.n_heads = num_heads
         
     | 
| 249 | 
         
            +
                    self.patch_size = patch_size
         
     | 
| 250 | 
         
            +
             
     | 
| 251 | 
         
            +
                def _init_weights(self, m):
         
     | 
| 252 | 
         
            +
                    if isinstance(m, nn.Linear):
         
     | 
| 253 | 
         
            +
                        trunc_normal_(m.weight, std=.02)
         
     | 
| 254 | 
         
            +
                        if isinstance(m, nn.Linear) and m.bias is not None:
         
     | 
| 255 | 
         
            +
                            nn.init.constant_(m.bias, 0)
         
     | 
| 256 | 
         
            +
                    elif isinstance(m, nn.LayerNorm):
         
     | 
| 257 | 
         
            +
                        nn.init.constant_(m.bias, 0)
         
     | 
| 258 | 
         
            +
                        nn.init.constant_(m.weight, 1.0)
         
     | 
| 259 | 
         
            +
             
     | 
| 260 | 
         
            +
                def make_input_divisible(self, x: torch.Tensor) -> torch.Tensor:
         
     | 
| 261 | 
         
            +
                    """Pad some pixels to make the input size divisible by the patch size."""
         
     | 
| 262 | 
         
            +
                    B, _, H_0, W_0 = x.shape
         
     | 
| 263 | 
         
            +
                    pad_w = (self.patch_size - W_0 % self.patch_size) % self.patch_size
         
     | 
| 264 | 
         
            +
                    pad_h = (self.patch_size - H_0 % self.patch_size) % self.patch_size
         
     | 
| 265 | 
         
            +
             
     | 
| 266 | 
         
            +
                    x = nn.functional.pad(x, (0, pad_w, 0, pad_h), value=0)
         
     | 
| 267 | 
         
            +
                    return x
         
     | 
| 268 | 
         
            +
             
     | 
| 269 | 
         
            +
                def prepare_tokens(self, x):
         
     | 
| 270 | 
         
            +
                    B, nc, h, w = x.shape
         
     | 
| 271 | 
         
            +
                    x: torch.Tensor = self.make_input_divisible(x)
         
     | 
| 272 | 
         
            +
                    patch_embed_h, patch_embed_w = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size
         
     | 
| 273 | 
         
            +
             
     | 
| 274 | 
         
            +
                    x = self.patch_embed(x)  # patch linear embedding
         
     | 
| 275 | 
         
            +
             
     | 
| 276 | 
         
            +
                    # add positional encoding to each token
         
     | 
| 277 | 
         
            +
                    # add the [CLS] token to the embed patch tokens
         
     | 
| 278 | 
         
            +
                    cls_tokens = self.cls_token.expand(B, -1, -1)
         
     | 
| 279 | 
         
            +
                    x = torch.cat((cls_tokens, x), dim=1)
         
     | 
| 280 | 
         
            +
                    x = x + self.interpolate_pos_encoding(x, self.pos_embed, size=(patch_embed_h, patch_embed_w))
         
     | 
| 281 | 
         
            +
                    return self.pos_drop(x)
         
     | 
| 282 | 
         
            +
             
     | 
| 283 | 
         
            +
                @staticmethod
         
     | 
| 284 | 
         
            +
                def split_token(x, token_type: str):
         
     | 
| 285 | 
         
            +
                    if token_type == "cls":
         
     | 
| 286 | 
         
            +
                        return x[:, 0, :]
         
     | 
| 287 | 
         
            +
                    elif token_type == "patch":
         
     | 
| 288 | 
         
            +
                        return x[:, 1:, :]
         
     | 
| 289 | 
         
            +
                    else:
         
     | 
| 290 | 
         
            +
                        return x
         
     | 
| 291 | 
         
            +
             
     | 
| 292 | 
         
            +
                # noel
         
     | 
| 293 | 
         
            +
                def forward(self, x, layer: Optional[str] = None):
         
     | 
| 294 | 
         
            +
                    x: torch.Tensor = self.prepare_tokens(x)
         
     | 
| 295 | 
         
            +
             
     | 
| 296 | 
         
            +
                    features: dict = {}
         
     | 
| 297 | 
         
            +
                    for i, blk in enumerate(self.blocks):
         
     | 
| 298 | 
         
            +
                        x = blk(x)
         
     | 
| 299 | 
         
            +
                        features[f"layer{i + 1}"] = self.norm(x)
         
     | 
| 300 | 
         
            +
             
     | 
| 301 | 
         
            +
                    if layer is not None:
         
     | 
| 302 | 
         
            +
                        return features[layer]
         
     | 
| 303 | 
         
            +
                    else:
         
     | 
| 304 | 
         
            +
                        return features
         
     | 
| 305 | 
         
            +
             
     | 
| 306 | 
         
            +
                # noel - for DINO's visual
         
     | 
| 307 | 
         
            +
                def get_last_selfattention(self, x):
         
     | 
| 308 | 
         
            +
                    x = self.prepare_tokens(x)
         
     | 
| 309 | 
         
            +
                    for i, blk in enumerate(self.blocks):
         
     | 
| 310 | 
         
            +
                        if i < len(self.blocks) - 1:
         
     | 
| 311 | 
         
            +
                            x = blk(x)
         
     | 
| 312 | 
         
            +
                        else:
         
     | 
| 313 | 
         
            +
                            # return attention of the last block
         
     | 
| 314 | 
         
            +
                            return blk(x, return_attention=True)
         
     | 
| 315 | 
         
            +
             
     | 
| 316 | 
         
            +
                def get_tokens(
         
     | 
| 317 | 
         
            +
                        self,
         
     | 
| 318 | 
         
            +
                        x,
         
     | 
| 319 | 
         
            +
                        layers: list,
         
     | 
| 320 | 
         
            +
                        patch_tokens: bool = False,
         
     | 
| 321 | 
         
            +
                        norm: bool = True,
         
     | 
| 322 | 
         
            +
                        input_tokens: bool = False,
         
     | 
| 323 | 
         
            +
                        post_pe: bool = False
         
     | 
| 324 | 
         
            +
                ):
         
     | 
| 325 | 
         
            +
                    """Return intermediate tokens."""
         
     | 
| 326 | 
         
            +
                    list_tokens: list = []
         
     | 
| 327 | 
         
            +
             
     | 
| 328 | 
         
            +
                    B = x.shape[0]
         
     | 
| 329 | 
         
            +
                    x = self.patch_embed(x)
         
     | 
| 330 | 
         
            +
             
     | 
| 331 | 
         
            +
                    cls_tokens = self.cls_token.expand(B, -1, -1)
         
     | 
| 332 | 
         
            +
             
     | 
| 333 | 
         
            +
                    x = torch.cat((cls_tokens, x), dim=1)
         
     | 
| 334 | 
         
            +
             
     | 
| 335 | 
         
            +
                    if input_tokens:
         
     | 
| 336 | 
         
            +
                        list_tokens.append(x)
         
     | 
| 337 | 
         
            +
             
     | 
| 338 | 
         
            +
                    pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
         
     | 
| 339 | 
         
            +
                    x = x + pos_embed
         
     | 
| 340 | 
         
            +
             
     | 
| 341 | 
         
            +
                    if post_pe:
         
     | 
| 342 | 
         
            +
                        list_tokens.append(x)
         
     | 
| 343 | 
         
            +
             
     | 
| 344 | 
         
            +
                    x = self.pos_drop(x)
         
     | 
| 345 | 
         
            +
             
     | 
| 346 | 
         
            +
                    for i, blk in enumerate(self.blocks):
         
     | 
| 347 | 
         
            +
                        x = blk(x)  # B x # patches x dim
         
     | 
| 348 | 
         
            +
                        if layers is None or i in layers:
         
     | 
| 349 | 
         
            +
                            list_tokens.append(self.norm(x) if norm else x)
         
     | 
| 350 | 
         
            +
             
     | 
| 351 | 
         
            +
                    tokens = torch.stack(list_tokens, dim=1)  # B x n_layers x (1 + # patches) x dim
         
     | 
| 352 | 
         
            +
             
     | 
| 353 | 
         
            +
                    if not patch_tokens:
         
     | 
| 354 | 
         
            +
                        return tokens[:, :, 0, :]  # index [CLS] tokens only, B x n_layers x dim
         
     | 
| 355 | 
         
            +
             
     | 
| 356 | 
         
            +
                    else:
         
     | 
| 357 | 
         
            +
                        return tokens
         
     | 
| 358 | 
         
            +
             
     | 
| 359 | 
         
            +
                def forward_features(self, x):
         
     | 
| 360 | 
         
            +
                    B = x.shape[0]
         
     | 
| 361 | 
         
            +
                    x = self.patch_embed(x)
         
     | 
| 362 | 
         
            +
             
     | 
| 363 | 
         
            +
                    cls_tokens = self.cls_token.expand(B, -1, -1)
         
     | 
| 364 | 
         
            +
                    x = torch.cat((cls_tokens, x), dim=1)
         
     | 
| 365 | 
         
            +
                    pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
         
     | 
| 366 | 
         
            +
                    x = x + pos_embed
         
     | 
| 367 | 
         
            +
                    x = self.pos_drop(x)
         
     | 
| 368 | 
         
            +
             
     | 
| 369 | 
         
            +
                    for blk in self.blocks:
         
     | 
| 370 | 
         
            +
                        x = blk(x)
         
     | 
| 371 | 
         
            +
             
     | 
| 372 | 
         
            +
                    if self.norm is not None:
         
     | 
| 373 | 
         
            +
                        x = self.norm(x)
         
     | 
| 374 | 
         
            +
             
     | 
| 375 | 
         
            +
                    return x[:, 0]
         
     | 
| 376 | 
         
            +
             
     | 
| 377 | 
         
            +
                def interpolate_pos_encoding(self, x, pos_embed, size):
         
     | 
| 378 | 
         
            +
                    """Interpolate the learnable positional encoding to match the number of patches.
         
     | 
| 379 | 
         
            +
             
     | 
| 380 | 
         
            +
                    x: B x (1 + N patches) x dim_embedding
         
     | 
| 381 | 
         
            +
                    pos_embed: B x (1 + N patches) x dim_embedding
         
     | 
| 382 | 
         
            +
             
     | 
| 383 | 
         
            +
                    return interpolated positional embedding
         
     | 
| 384 | 
         
            +
                    """
         
     | 
| 385 | 
         
            +
                    npatch = x.shape[1] - 1  # (H // patch_size * W // patch_size)
         
     | 
| 386 | 
         
            +
                    N = pos_embed.shape[1] - 1  # 784 (= 28 x 28)
         
     | 
| 387 | 
         
            +
                    if npatch == N:
         
     | 
| 388 | 
         
            +
                        return pos_embed
         
     | 
| 389 | 
         
            +
                    class_emb, pos_embed = pos_embed[:, 0], pos_embed[:, 1:]  # a learnable CLS token, learnable position embeddings
         
     | 
| 390 | 
         
            +
             
     | 
| 391 | 
         
            +
                    dim = x.shape[-1]  # dimension of embeddings
         
     | 
| 392 | 
         
            +
                    pos_embed = nn.functional.interpolate(
         
     | 
| 393 | 
         
            +
                        pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),  # B x dim x 28 x 28
         
     | 
| 394 | 
         
            +
                        size=size,
         
     | 
| 395 | 
         
            +
                        mode='bicubic',
         
     | 
| 396 | 
         
            +
                        align_corners=False
         
     | 
| 397 | 
         
            +
                    )
         
     | 
| 398 | 
         
            +
             
     | 
| 399 | 
         
            +
                    pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
         
     | 
| 400 | 
         
            +
                    pos_embed = torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)
         
     | 
| 401 | 
         
            +
                    return pos_embed
         
     | 
| 402 | 
         
            +
             
     | 
| 403 | 
         
            +
                def forward_selfattention(self, x, return_interm_attn=False):
         
     | 
| 404 | 
         
            +
                    B, nc, w, h = x.shape
         
     | 
| 405 | 
         
            +
                    N = self.pos_embed.shape[1] - 1
         
     | 
| 406 | 
         
            +
                    x = self.patch_embed(x)
         
     | 
| 407 | 
         
            +
             
     | 
| 408 | 
         
            +
                    # interpolate patch embeddings
         
     | 
| 409 | 
         
            +
                    dim = x.shape[-1]
         
     | 
| 410 | 
         
            +
                    w0 = w // self.patch_embed.patch_size
         
     | 
| 411 | 
         
            +
                    h0 = h // self.patch_embed.patch_size
         
     | 
| 412 | 
         
            +
                    class_pos_embed = self.pos_embed[:, 0]
         
     | 
| 413 | 
         
            +
                    patch_pos_embed = self.pos_embed[:, 1:]
         
     | 
| 414 | 
         
            +
                    patch_pos_embed = nn.functional.interpolate(
         
     | 
| 415 | 
         
            +
                        patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
         
     | 
| 416 | 
         
            +
                        scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
         
     | 
| 417 | 
         
            +
                        mode='bicubic'
         
     | 
| 418 | 
         
            +
                    )
         
     | 
| 419 | 
         
            +
                    if w0 != patch_pos_embed.shape[-2]:
         
     | 
| 420 | 
         
            +
                        helper = torch.zeros(h0)[None, None, None, :].repeat(1, dim, w0 - patch_pos_embed.shape[-2], 1).to(x.device)
         
     | 
| 421 | 
         
            +
                        patch_pos_embed = torch.cat((patch_pos_embed, helper), dim=-2)
         
     | 
| 422 | 
         
            +
                    if h0 != patch_pos_embed.shape[-1]:
         
     | 
| 423 | 
         
            +
                        helper = torch.zeros(w0)[None, None, :, None].repeat(1, dim, 1, h0 - patch_pos_embed.shape[-1]).to(x.device)
         
     | 
| 424 | 
         
            +
                        pos_embed = torch.cat((patch_pos_embed, helper), dim=-1)
         
     | 
| 425 | 
         
            +
                    patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
         
     | 
| 426 | 
         
            +
                    pos_embed = torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
         
     | 
| 427 | 
         
            +
             
     | 
| 428 | 
         
            +
                    cls_tokens = self.cls_token.expand(B, -1, -1)  # self.cls_token: 1 x 1 x emb_dim -> ?
         
     | 
| 429 | 
         
            +
                    x = torch.cat((cls_tokens, x), dim=1)
         
     | 
| 430 | 
         
            +
                    x = x + pos_embed
         
     | 
| 431 | 
         
            +
                    x = self.pos_drop(x)
         
     | 
| 432 | 
         
            +
             
     | 
| 433 | 
         
            +
                    if return_interm_attn:
         
     | 
| 434 | 
         
            +
                        list_attn = []
         
     | 
| 435 | 
         
            +
                        for i, blk in enumerate(self.blocks):
         
     | 
| 436 | 
         
            +
                            attn = blk(x, return_attention=True)
         
     | 
| 437 | 
         
            +
                            x = blk(x)
         
     | 
| 438 | 
         
            +
                            list_attn.append(attn)
         
     | 
| 439 | 
         
            +
                        return torch.cat(list_attn, dim=0)
         
     | 
| 440 | 
         
            +
             
     | 
| 441 | 
         
            +
                    else:
         
     | 
| 442 | 
         
            +
                        for i, blk in enumerate(self.blocks):
         
     | 
| 443 | 
         
            +
                            if i < len(self.blocks) - 1:
         
     | 
| 444 | 
         
            +
                                x = blk(x)
         
     | 
| 445 | 
         
            +
                            else:
         
     | 
| 446 | 
         
            +
                                return blk(x, return_attention=True)
         
     | 
| 447 | 
         
            +
             
     | 
| 448 | 
         
            +
                def forward_return_n_last_blocks(self, x, n=1, return_patch_avgpool=False):
         
     | 
| 449 | 
         
            +
                    B = x.shape[0]
         
     | 
| 450 | 
         
            +
                    x = self.patch_embed(x)
         
     | 
| 451 | 
         
            +
             
     | 
| 452 | 
         
            +
                    cls_tokens = self.cls_token.expand(B, -1, -1)
         
     | 
| 453 | 
         
            +
             
     | 
| 454 | 
         
            +
                    x = torch.cat((cls_tokens, x), dim=1)
         
     | 
| 455 | 
         
            +
                    pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
         
     | 
| 456 | 
         
            +
                    x = x + pos_embed
         
     | 
| 457 | 
         
            +
                    x = self.pos_drop(x)
         
     | 
| 458 | 
         
            +
             
     | 
| 459 | 
         
            +
                    # we will return the [CLS] tokens from the `n` last blocks
         
     | 
| 460 | 
         
            +
                    output = []
         
     | 
| 461 | 
         
            +
                    for i, blk in enumerate(self.blocks):
         
     | 
| 462 | 
         
            +
                        x = blk(x)
         
     | 
| 463 | 
         
            +
                        if len(self.blocks) - i <= n:
         
     | 
| 464 | 
         
            +
                            # get only CLS token (B x dim)
         
     | 
| 465 | 
         
            +
                            output.append(self.norm(x)[:, 0])
         
     | 
| 466 | 
         
            +
                    if return_patch_avgpool:
         
     | 
| 467 | 
         
            +
                        x = self.norm(x)
         
     | 
| 468 | 
         
            +
                        # In addition to the [CLS] tokens from the `n` last blocks, we also return 
         
     | 
| 469 | 
         
            +
                        # the patch tokens from the last block. This is useful for linear eval.
         
     | 
| 470 | 
         
            +
                        output.append(torch.mean(x[:, 1:], dim=1))
         
     | 
| 471 | 
         
            +
                    return torch.cat(output, dim=-1)
         
     | 
| 472 | 
         
            +
             
     | 
| 473 | 
         
            +
                def return_patch_emb_from_n_last_blocks(self, x, n=1, return_patch_avgpool=False):
         
     | 
| 474 | 
         
            +
                    """Return intermediate patch embeddings, rather than CLS token, from the last n blocks."""
         
     | 
| 475 | 
         
            +
                    B = x.shape[0]
         
     | 
| 476 | 
         
            +
                    x = self.patch_embed(x)
         
     | 
| 477 | 
         
            +
             
     | 
| 478 | 
         
            +
                    cls_tokens = self.cls_token.expand(B, -1, -1)
         
     | 
| 479 | 
         
            +
             
     | 
| 480 | 
         
            +
                    x = torch.cat((cls_tokens, x), dim=1)
         
     | 
| 481 | 
         
            +
                    pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
         
     | 
| 482 | 
         
            +
                    x = x + pos_embed
         
     | 
| 483 | 
         
            +
                    x = self.pos_drop(x)
         
     | 
| 484 | 
         
            +
             
     | 
| 485 | 
         
            +
                    # we will return the [CLS] tokens from the `n` last blocks
         
     | 
| 486 | 
         
            +
                    output = []
         
     | 
| 487 | 
         
            +
                    for i, blk in enumerate(self.blocks):
         
     | 
| 488 | 
         
            +
                        x = blk(x)
         
     | 
| 489 | 
         
            +
                        if len(self.blocks) - i <= n:
         
     | 
| 490 | 
         
            +
                            output.append(self.norm(x)[:, 1:])  # get only CLS token (B x dim)
         
     | 
| 491 | 
         
            +
             
     | 
| 492 | 
         
            +
                    if return_patch_avgpool:
         
     | 
| 493 | 
         
            +
                        x = self.norm(x)
         
     | 
| 494 | 
         
            +
                        # In addition to the [CLS] tokens from the `n` last blocks, we also return
         
     | 
| 495 | 
         
            +
                        # the patch tokens from the last block. This is useful for linear eval.
         
     | 
| 496 | 
         
            +
                        output.append(torch.mean(x[:, 1:], dim=1))
         
     | 
| 497 | 
         
            +
                    return torch.stack(output, dim=-1)  # B x n_patches x dim x n
         
     | 
| 498 | 
         
            +
             
     | 
| 499 | 
         
            +
             
     | 
| 500 | 
         
            +
            def deit_tiny(patch_size=16, **kwargs):
         
     | 
| 501 | 
         
            +
                model = VisionTransformer(
         
     | 
| 502 | 
         
            +
                    patch_size=patch_size,
         
     | 
| 503 | 
         
            +
                    embed_dim=192,
         
     | 
| 504 | 
         
            +
                    depth=12,
         
     | 
| 505 | 
         
            +
                    num_heads=3,
         
     | 
| 506 | 
         
            +
                    mlp_ratio=4,
         
     | 
| 507 | 
         
            +
                    qkv_bias=True,
         
     | 
| 508 | 
         
            +
                    norm_layer=partial(nn.LayerNorm, eps=1e-6),
         
     | 
| 509 | 
         
            +
                    **kwargs)
         
     | 
| 510 | 
         
            +
                return model
         
     | 
| 511 | 
         
            +
             
     | 
| 512 | 
         
            +
             
     | 
| 513 | 
         
            +
            def deit_small(patch_size=16, **kwargs):
         
     | 
| 514 | 
         
            +
                depth = kwargs.pop("depth") if "depth" in kwargs else 12
         
     | 
| 515 | 
         
            +
                model = VisionTransformer(
         
     | 
| 516 | 
         
            +
                    patch_size=patch_size,
         
     | 
| 517 | 
         
            +
                    embed_dim=384,
         
     | 
| 518 | 
         
            +
                    depth=depth,
         
     | 
| 519 | 
         
            +
                    num_heads=6,
         
     | 
| 520 | 
         
            +
                    mlp_ratio=4,
         
     | 
| 521 | 
         
            +
                    qkv_bias=True,
         
     | 
| 522 | 
         
            +
                    norm_layer=partial(nn.LayerNorm, eps=1e-6),
         
     | 
| 523 | 
         
            +
                    **kwargs
         
     | 
| 524 | 
         
            +
                )
         
     | 
| 525 | 
         
            +
                return model
         
     | 
| 526 | 
         
            +
             
     | 
| 527 | 
         
            +
             
     | 
| 528 | 
         
            +
            def vit_base(patch_size=16, **kwargs):
         
     | 
| 529 | 
         
            +
                model = VisionTransformer(
         
     | 
| 530 | 
         
            +
                    patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
         
     | 
| 531 | 
         
            +
                    qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
         
     | 
| 532 | 
         
            +
                return model
         
     | 
| 533 | 
         
            +
             
     | 
| 534 | 
         
            +
             
     | 
| 535 | 
         
            +
            class DINOHead(nn.Module):
         
     | 
| 536 | 
         
            +
                def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256):
         
     | 
| 537 | 
         
            +
                    super().__init__()
         
     | 
| 538 | 
         
            +
                    nlayers = max(nlayers, 1)
         
     | 
| 539 | 
         
            +
                    if nlayers == 1:
         
     | 
| 540 | 
         
            +
                        self.mlp = nn.Linear(in_dim, bottleneck_dim)
         
     | 
| 541 | 
         
            +
                    else:
         
     | 
| 542 | 
         
            +
                        layers = [nn.Linear(in_dim, hidden_dim)]
         
     | 
| 543 | 
         
            +
                        if use_bn:
         
     | 
| 544 | 
         
            +
                            layers.append(nn.BatchNorm1d(hidden_dim))
         
     | 
| 545 | 
         
            +
                        layers.append(nn.GELU())
         
     | 
| 546 | 
         
            +
                        for _ in range(nlayers - 2):
         
     | 
| 547 | 
         
            +
                            layers.append(nn.Linear(hidden_dim, hidden_dim))
         
     | 
| 548 | 
         
            +
                            if use_bn:
         
     | 
| 549 | 
         
            +
                                layers.append(nn.BatchNorm1d(hidden_dim))
         
     | 
| 550 | 
         
            +
                            layers.append(nn.GELU())
         
     | 
| 551 | 
         
            +
                        layers.append(nn.Linear(hidden_dim, bottleneck_dim))
         
     | 
| 552 | 
         
            +
                        self.mlp = nn.Sequential(*layers)
         
     | 
| 553 | 
         
            +
                    self.apply(self._init_weights)
         
     | 
| 554 | 
         
            +
                    self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
         
     | 
| 555 | 
         
            +
                    self.last_layer.weight_g.data.fill_(1)
         
     | 
| 556 | 
         
            +
                    if norm_last_layer:
         
     | 
| 557 | 
         
            +
                        self.last_layer.weight_g.requires_grad = False
         
     | 
| 558 | 
         
            +
             
     | 
| 559 | 
         
            +
                def _init_weights(self, m):
         
     | 
| 560 | 
         
            +
                    if isinstance(m, nn.Linear):
         
     | 
| 561 | 
         
            +
                        trunc_normal_(m.weight, std=.02)
         
     | 
| 562 | 
         
            +
                        if isinstance(m, nn.Linear) and m.bias is not None:
         
     | 
| 563 | 
         
            +
                            nn.init.constant_(m.bias, 0)
         
     | 
| 564 | 
         
            +
             
     | 
| 565 | 
         
            +
                def forward(self, x):
         
     | 
| 566 | 
         
            +
                    x = self.mlp(x)
         
     | 
| 567 | 
         
            +
                    x = nn.functional.normalize(x, dim=-1, p=2)
         
     | 
| 568 | 
         
            +
                    x = self.last_layer(x)
         
     | 
| 569 | 
         
            +
                    return x
         
     | 
    	
        resources/.DS_Store
    ADDED
    
    | 
         Binary file (8.2 kB). View file 
     | 
| 
         | 
    	
        resources/0053.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        resources/0236.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        resources/0239.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        resources/0403.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        resources/0412.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        resources/ILSVRC2012_test_00005309.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        resources/ILSVRC2012_test_00012622.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        resources/ILSVRC2012_test_00022698.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        resources/ILSVRC2012_test_00040725.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        resources/ILSVRC2012_test_00075738.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        resources/ILSVRC2012_test_00080683.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        resources/ILSVRC2012_test_00085874.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        resources/im052.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        resources/sun_ainjbonxmervsvpv.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        resources/sun_alfntqzssslakmss.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        resources/sun_amnrcxhisjfrliwa.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        resources/sun_bvyxpvkouzlfwwod.jpg
    ADDED
    
    
											 
									 |