diff --git a/app.py b/app.py new file mode 100755 index 0000000000000000000000000000000000000000..5136ada9d4867a360c1e05eaf561f5efd0c28eb6 --- /dev/null +++ b/app.py @@ -0,0 +1,76 @@ +from argparse import ArgumentParser +from typing import Dict +import torch +from PIL import Image +import modules.transforms as transforms +from modules.primaps import PriMaPs +from modules.backbone.dino.dinovit import DinoFeaturizerv2 +from modules.visualization import visualize_demo +import gradio as gr +# set seeds +torch.manual_seed(0) +torch.cuda.manual_seed(0) + + + +def gradio_primaps(image_path, threshold, architecture): + ''' + Gradio demo to visualize PriMaPs for a single image. + ''' + + device='cuda:0' + resize_to = 320 if 'v2' not in architecture else 322 + patch_size = 8 if 'v2' not in architecture else 14 + + # get SLL image encoder and primaps module + net = DinoFeaturizerv2(architecture, patch_size) + net.to(device) + primaps_module = PriMaPs(threshold=threshold, + ignore_id=255) + + # get transforms + demo_transforms = transforms.Compose([transforms.ToTensor(), + transforms.Resize(resize_to), + transforms.CenterCrop([resize_to, resize_to]), + transforms.Normalize()]) + + # load image and apply transforms + image = Image.open(image_path) + image, _ = demo_transforms(image, torch.zeros(image.size)) + image.to(device) + # get SSL features + feats = net(image.unsqueeze(0).to(device), n=1).squeeze() + # get primaps pseudo labels + primaps = primaps_module._get_pseudo(image, feats, torch.zeros(image.shape[1:])) + # visualize overlay + return visualize_demo(image, primaps) + + +if __name__ == '__main__': + # Example image paths + example_images = [ + "assets/demo_examples/cityscapes_example.png", + "assets/demo_examples/coco_example.jpg", + "assets/demo_examples/potsdam_example.png" + ] + + # Gradio interface + interface = gr.Interface( + fn=gradio_primaps, + inputs=[ + gr.Image(type="filepath", label="Image"), + gr.Slider(0.0, 1.0, step=0.05, value=0.35, label="Threshold"), + gr.Dropdown(choices=['dino_vits', 'dino_vitb', 'dinov2_vits', 'dinov2_vitb'], value='dino_vitb', label="SSL Features"), + ], + outputs=gr.Image(label="PriMaPs"), + title="PriMaPs Demo", + description="Upload an image and adjust the threshold to visualize PriMaPs.", + examples=[ + [example_images[0], 0.35, 'dino_vitb'], + [example_images[1], 0.35, 'dino_vitb'], + [example_images[2], 0.35, 'dino_vitb'] + ] + ) + + # Launch the app + interface.launch(debug=True) diff --git a/modules/.DS_Store b/modules/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 Binary files /dev/null and b/modules/.DS_Store differ diff --git a/modules/__pycache__/clustering.cpython-310.pyc b/modules/__pycache__/clustering.cpython-310.pyc new file mode 100755 index 0000000000000000000000000000000000000000..63df81314c3256ea566e29c6b232acb2e181bf0e Binary files /dev/null and b/modules/__pycache__/clustering.cpython-310.pyc differ diff --git a/modules/__pycache__/clustering.cpython-311.pyc b/modules/__pycache__/clustering.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..96e4cb6dcaa9bc969457c2743e242d69486c26cd Binary files /dev/null and b/modules/__pycache__/clustering.cpython-311.pyc differ diff --git a/modules/__pycache__/clustering.cpython-36.pyc b/modules/__pycache__/clustering.cpython-36.pyc new file mode 100755 index 0000000000000000000000000000000000000000..66f0bf0bac956a1ec59507a0d3be07c662039eb5 Binary files /dev/null and b/modules/__pycache__/clustering.cpython-36.pyc differ diff --git a/modules/__pycache__/clustering.cpython-38.pyc b/modules/__pycache__/clustering.cpython-38.pyc new file mode 100755 index 0000000000000000000000000000000000000000..7947859475ca24a34a8fca6e5f5e7fd3b77aac78 Binary files /dev/null and b/modules/__pycache__/clustering.cpython-38.pyc differ diff --git a/modules/__pycache__/clustering.cpython-39.pyc b/modules/__pycache__/clustering.cpython-39.pyc new file mode 100755 index 0000000000000000000000000000000000000000..3bad67c394b8f858031acd07cdd8e4b18b8dba00 Binary files /dev/null and b/modules/__pycache__/clustering.cpython-39.pyc differ diff --git a/modules/__pycache__/crf.cpython-310.pyc b/modules/__pycache__/crf.cpython-310.pyc new file mode 100755 index 0000000000000000000000000000000000000000..2e74804fcbd28df4d036318e42ee6f94e6d7e435 Binary files /dev/null and b/modules/__pycache__/crf.cpython-310.pyc differ diff --git a/modules/__pycache__/crf.cpython-311.pyc b/modules/__pycache__/crf.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..456bab185637c017c60c11f129f136e6abb743cb Binary files /dev/null and b/modules/__pycache__/crf.cpython-311.pyc differ diff --git a/modules/__pycache__/crf.cpython-36.pyc b/modules/__pycache__/crf.cpython-36.pyc new file mode 100755 index 0000000000000000000000000000000000000000..ca0301e6d9151ca2c2725f87c1a528d40e94af19 Binary files /dev/null and b/modules/__pycache__/crf.cpython-36.pyc differ diff --git a/modules/__pycache__/crf.cpython-37.pyc b/modules/__pycache__/crf.cpython-37.pyc new file mode 100755 index 0000000000000000000000000000000000000000..e9505988f4a44d05de22b89cde3997d772669689 Binary files /dev/null and b/modules/__pycache__/crf.cpython-37.pyc differ diff --git a/modules/__pycache__/crf.cpython-38.pyc b/modules/__pycache__/crf.cpython-38.pyc new file mode 100755 index 0000000000000000000000000000000000000000..f3b48cdeae2b34f2060bb77dbf3680c94f30a82b Binary files /dev/null and b/modules/__pycache__/crf.cpython-38.pyc differ diff --git a/modules/__pycache__/crf.cpython-39.pyc b/modules/__pycache__/crf.cpython-39.pyc new file mode 100755 index 0000000000000000000000000000000000000000..a0a9dfd8e9e2dfea5b32e11e3f58f3b36296d438 Binary files /dev/null and b/modules/__pycache__/crf.cpython-39.pyc differ diff --git a/modules/__pycache__/ema.cpython-311.pyc b/modules/__pycache__/ema.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..4c81c18dcaa1f78ae79dfdd75ab00b6600b34831 Binary files /dev/null and b/modules/__pycache__/ema.cpython-311.pyc differ diff --git a/modules/__pycache__/ema.cpython-36.pyc b/modules/__pycache__/ema.cpython-36.pyc new file mode 100755 index 0000000000000000000000000000000000000000..957fb43e7a82a9576088b137b78df675270ad6ea Binary files /dev/null and b/modules/__pycache__/ema.cpython-36.pyc differ diff --git a/modules/__pycache__/ema.cpython-38.pyc b/modules/__pycache__/ema.cpython-38.pyc new file mode 100755 index 0000000000000000000000000000000000000000..1008d122f734cfdc5553430ea7eb02828f57bcc8 Binary files /dev/null and b/modules/__pycache__/ema.cpython-38.pyc differ diff --git a/modules/__pycache__/ema.cpython-39.pyc b/modules/__pycache__/ema.cpython-39.pyc new file mode 100755 index 0000000000000000000000000000000000000000..5a9a0b00aa75b37bc78d8fb3a9e82ce18f4fef2f Binary files /dev/null and b/modules/__pycache__/ema.cpython-39.pyc differ diff --git a/modules/__pycache__/gansbeke_batched_crf.cpython-36.pyc b/modules/__pycache__/gansbeke_batched_crf.cpython-36.pyc new file mode 100755 index 0000000000000000000000000000000000000000..d8b26f14cf3f0bd803b939497b976e2826717d7f Binary files /dev/null and b/modules/__pycache__/gansbeke_batched_crf.cpython-36.pyc differ diff --git a/modules/__pycache__/maskprop.cpython-310.pyc b/modules/__pycache__/maskprop.cpython-310.pyc new file mode 100755 index 0000000000000000000000000000000000000000..c6dd706ee1c5fa71ac6f98123e03cfff43270796 Binary files /dev/null and b/modules/__pycache__/maskprop.cpython-310.pyc differ diff --git a/modules/__pycache__/maskprop.cpython-36.pyc b/modules/__pycache__/maskprop.cpython-36.pyc new file mode 100755 index 0000000000000000000000000000000000000000..24af50becce51b7af5a53d2fc70e10c7f8a9d146 Binary files /dev/null and b/modules/__pycache__/maskprop.cpython-36.pyc differ diff --git a/modules/__pycache__/maskprop.cpython-38.pyc b/modules/__pycache__/maskprop.cpython-38.pyc new file mode 100755 index 0000000000000000000000000000000000000000..c98984ea82e7abc77e8e16570956ba3905c3411d Binary files /dev/null and b/modules/__pycache__/maskprop.cpython-38.pyc differ diff --git a/modules/__pycache__/maskprop.cpython-39.pyc b/modules/__pycache__/maskprop.cpython-39.pyc new file mode 100755 index 0000000000000000000000000000000000000000..1f842cd80d8a5821e8ecd8574cdc732750ee0ed0 Binary files /dev/null and b/modules/__pycache__/maskprop.cpython-39.pyc differ diff --git a/modules/__pycache__/median_pool.cpython-310.pyc b/modules/__pycache__/median_pool.cpython-310.pyc new file mode 100755 index 0000000000000000000000000000000000000000..f7fe6328c6da0d5b80c6a24bc84d49e46b16f874 Binary files /dev/null and b/modules/__pycache__/median_pool.cpython-310.pyc differ diff --git a/modules/__pycache__/median_pool.cpython-311.pyc b/modules/__pycache__/median_pool.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..65356ed81501bae76d2caa32b78d86b31f0da4fb Binary files /dev/null and b/modules/__pycache__/median_pool.cpython-311.pyc differ diff --git a/modules/__pycache__/median_pool.cpython-36.pyc b/modules/__pycache__/median_pool.cpython-36.pyc new file mode 100755 index 0000000000000000000000000000000000000000..761d139842b92679f08ce303e463a8ee841eaeeb Binary files /dev/null and b/modules/__pycache__/median_pool.cpython-36.pyc differ diff --git a/modules/__pycache__/median_pool.cpython-37.pyc b/modules/__pycache__/median_pool.cpython-37.pyc new file mode 100755 index 0000000000000000000000000000000000000000..159e6c3f363098d51d2b6985e21fd06c5ae79f66 Binary files /dev/null and b/modules/__pycache__/median_pool.cpython-37.pyc differ diff --git a/modules/__pycache__/median_pool.cpython-39.pyc b/modules/__pycache__/median_pool.cpython-39.pyc new file mode 100755 index 0000000000000000000000000000000000000000..4918e2f35ea55bd98122785af290a9e65ccca7b8 Binary files /dev/null and b/modules/__pycache__/median_pool.cpython-39.pyc differ diff --git a/modules/__pycache__/metrics.cpython-311.pyc b/modules/__pycache__/metrics.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..6599442f613bf90981043dc7955c9dc9d5fb0815 Binary files /dev/null and b/modules/__pycache__/metrics.cpython-311.pyc differ diff --git a/modules/__pycache__/metrics.cpython-36.pyc b/modules/__pycache__/metrics.cpython-36.pyc new file mode 100755 index 0000000000000000000000000000000000000000..638423007b1f8eba0343c118309fd475a397a165 Binary files /dev/null and b/modules/__pycache__/metrics.cpython-36.pyc differ diff --git a/modules/__pycache__/pamr.cpython-36.pyc b/modules/__pycache__/pamr.cpython-36.pyc new file mode 100755 index 0000000000000000000000000000000000000000..063b762516ab81251411387136246f2b9aef3732 Binary files /dev/null and b/modules/__pycache__/pamr.cpython-36.pyc differ diff --git a/modules/__pycache__/parser.cpython-311.pyc b/modules/__pycache__/parser.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..fe16753689c3ef419ec70d5c777639ebcc22ce53 Binary files /dev/null and b/modules/__pycache__/parser.cpython-311.pyc differ diff --git a/modules/__pycache__/parser.cpython-36.pyc b/modules/__pycache__/parser.cpython-36.pyc new file mode 100755 index 0000000000000000000000000000000000000000..f6d7cb8d5c454013513af8daa81cb7ede74de477 Binary files /dev/null and b/modules/__pycache__/parser.cpython-36.pyc differ diff --git a/modules/__pycache__/primaps.cpython-310.pyc b/modules/__pycache__/primaps.cpython-310.pyc new file mode 100755 index 0000000000000000000000000000000000000000..36317fafec67ce9ff26102f05561cbbd5db0e43a Binary files /dev/null and b/modules/__pycache__/primaps.cpython-310.pyc differ diff --git a/modules/__pycache__/primaps.cpython-311.pyc b/modules/__pycache__/primaps.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..4e03e857e3c59da6ac0c62c3df8eb1899cabfc79 Binary files /dev/null and b/modules/__pycache__/primaps.cpython-311.pyc differ diff --git a/modules/__pycache__/primaps.cpython-36.pyc b/modules/__pycache__/primaps.cpython-36.pyc new file mode 100755 index 0000000000000000000000000000000000000000..fc53624ac64dc52f6fbe47300b1fd8c53f54b3a9 Binary files /dev/null and b/modules/__pycache__/primaps.cpython-36.pyc differ diff --git a/modules/__pycache__/primaps.cpython-37.pyc b/modules/__pycache__/primaps.cpython-37.pyc new file mode 100755 index 0000000000000000000000000000000000000000..1675ae53b33a45fba8a77b356d86e5602e218010 Binary files /dev/null and b/modules/__pycache__/primaps.cpython-37.pyc differ diff --git a/modules/__pycache__/transforms.cpython-310.pyc b/modules/__pycache__/transforms.cpython-310.pyc new file mode 100755 index 0000000000000000000000000000000000000000..978f0eb71fe5cd1a1a5177086147165b0aec085d Binary files /dev/null and b/modules/__pycache__/transforms.cpython-310.pyc differ diff --git a/modules/__pycache__/transforms.cpython-311.pyc b/modules/__pycache__/transforms.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..7387f714d68a9c58eb593a2baa5ac90d0fed11dd Binary files /dev/null and b/modules/__pycache__/transforms.cpython-311.pyc differ diff --git a/modules/__pycache__/transforms.cpython-36.pyc b/modules/__pycache__/transforms.cpython-36.pyc new file mode 100755 index 0000000000000000000000000000000000000000..2950e67bd0c3d94e9182bbe6225307ba86bea8bc Binary files /dev/null and b/modules/__pycache__/transforms.cpython-36.pyc differ diff --git a/modules/__pycache__/transforms.cpython-37.pyc b/modules/__pycache__/transforms.cpython-37.pyc new file mode 100755 index 0000000000000000000000000000000000000000..4996faa0921b3626c4f54f30290899ae414a8fbe Binary files /dev/null and b/modules/__pycache__/transforms.cpython-37.pyc differ diff --git a/modules/__pycache__/visualization.cpython-310.pyc b/modules/__pycache__/visualization.cpython-310.pyc new file mode 100755 index 0000000000000000000000000000000000000000..886f888379effc6fe96d89d936ff24c327c616a8 Binary files /dev/null and b/modules/__pycache__/visualization.cpython-310.pyc differ diff --git a/modules/__pycache__/visualization.cpython-311.pyc b/modules/__pycache__/visualization.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..013624f5b157f536985284f66b324a66192d3849 Binary files /dev/null and b/modules/__pycache__/visualization.cpython-311.pyc differ diff --git a/modules/__pycache__/visualization.cpython-36.pyc b/modules/__pycache__/visualization.cpython-36.pyc new file mode 100755 index 0000000000000000000000000000000000000000..be15c455890589e6ca590d7611ac5164f68196e6 Binary files /dev/null and b/modules/__pycache__/visualization.cpython-36.pyc differ diff --git a/modules/__pycache__/visualization.cpython-37.pyc b/modules/__pycache__/visualization.cpython-37.pyc new file mode 100755 index 0000000000000000000000000000000000000000..7af5f6daf23522a372c4900961a35272a1773334 Binary files /dev/null and b/modules/__pycache__/visualization.cpython-37.pyc differ diff --git a/modules/backbone/__pycache__/__init__.cpython-36.pyc b/modules/backbone/__pycache__/__init__.cpython-36.pyc new file mode 100755 index 0000000000000000000000000000000000000000..6f82c082a349f94d49191dff4875e894745e2715 Binary files /dev/null and b/modules/backbone/__pycache__/__init__.cpython-36.pyc differ diff --git a/modules/backbone/__pycache__/dinovit.cpython-36.pyc b/modules/backbone/__pycache__/dinovit.cpython-36.pyc new file mode 100755 index 0000000000000000000000000000000000000000..2d7b9c9a212fa41871428769be12f75dc86e3c4f Binary files /dev/null and b/modules/backbone/__pycache__/dinovit.cpython-36.pyc differ diff --git a/modules/backbone/__pycache__/resnet.cpython-36.pyc b/modules/backbone/__pycache__/resnet.cpython-36.pyc new file mode 100755 index 0000000000000000000000000000000000000000..6a99d50c2f62d8e4f632b995f27ad86c360da509 Binary files /dev/null and b/modules/backbone/__pycache__/resnet.cpython-36.pyc differ diff --git a/modules/backbone/dino/__pycache__/dinovit.cpython-310.pyc b/modules/backbone/dino/__pycache__/dinovit.cpython-310.pyc new file mode 100755 index 0000000000000000000000000000000000000000..5a03c6e4e3a27edac09e0c7fa566829f51d5507a Binary files /dev/null and b/modules/backbone/dino/__pycache__/dinovit.cpython-310.pyc differ diff --git a/modules/backbone/dino/__pycache__/dinovit.cpython-311.pyc b/modules/backbone/dino/__pycache__/dinovit.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..92ccd6a4895bf794adadba9bcbaf2ce1c7663d4d Binary files /dev/null and b/modules/backbone/dino/__pycache__/dinovit.cpython-311.pyc differ diff --git a/modules/backbone/dino/__pycache__/dinovit.cpython-36.pyc b/modules/backbone/dino/__pycache__/dinovit.cpython-36.pyc new file mode 100755 index 0000000000000000000000000000000000000000..be7e65c991502feed8fd8a2f7324b6b34d31fd96 Binary files /dev/null and b/modules/backbone/dino/__pycache__/dinovit.cpython-36.pyc differ diff --git a/modules/backbone/dino/__pycache__/dinovit.cpython-37.pyc b/modules/backbone/dino/__pycache__/dinovit.cpython-37.pyc new file mode 100755 index 0000000000000000000000000000000000000000..659b33091962d4940486a2212897c90ce39fd803 Binary files /dev/null and b/modules/backbone/dino/__pycache__/dinovit.cpython-37.pyc differ diff --git a/modules/backbone/dino/__pycache__/dinovit.cpython-38.pyc b/modules/backbone/dino/__pycache__/dinovit.cpython-38.pyc new file mode 100755 index 0000000000000000000000000000000000000000..aa56be18cb056867662f05b2269f62de1fe276d9 Binary files /dev/null and b/modules/backbone/dino/__pycache__/dinovit.cpython-38.pyc differ diff --git a/modules/backbone/dino/__pycache__/dinovit.cpython-39.pyc b/modules/backbone/dino/__pycache__/dinovit.cpython-39.pyc new file mode 100755 index 0000000000000000000000000000000000000000..df616905d46220bea1cbebcc11492357b931507d Binary files /dev/null and b/modules/backbone/dino/__pycache__/dinovit.cpython-39.pyc differ diff --git a/modules/backbone/dino/__pycache__/utils.cpython-310.pyc b/modules/backbone/dino/__pycache__/utils.cpython-310.pyc new file mode 100755 index 0000000000000000000000000000000000000000..402ade112972c8b9131d540dc793709172435371 Binary files /dev/null and b/modules/backbone/dino/__pycache__/utils.cpython-310.pyc differ diff --git a/modules/backbone/dino/__pycache__/utils.cpython-311.pyc b/modules/backbone/dino/__pycache__/utils.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..313a08a773fbb75fea1caaba7f6aa1eca37227e5 Binary files /dev/null and b/modules/backbone/dino/__pycache__/utils.cpython-311.pyc differ diff --git a/modules/backbone/dino/__pycache__/utils.cpython-36.pyc b/modules/backbone/dino/__pycache__/utils.cpython-36.pyc new file mode 100755 index 0000000000000000000000000000000000000000..c3520099f8f40abd239ffdcc429a90dbe69a4bf0 Binary files /dev/null and b/modules/backbone/dino/__pycache__/utils.cpython-36.pyc differ diff --git a/modules/backbone/dino/__pycache__/utils.cpython-37.pyc b/modules/backbone/dino/__pycache__/utils.cpython-37.pyc new file mode 100755 index 0000000000000000000000000000000000000000..140a23bf9b2a4eb131cebb19f1d009f6fde14e03 Binary files /dev/null and b/modules/backbone/dino/__pycache__/utils.cpython-37.pyc differ diff --git a/modules/backbone/dino/__pycache__/utils.cpython-38.pyc b/modules/backbone/dino/__pycache__/utils.cpython-38.pyc new file mode 100755 index 0000000000000000000000000000000000000000..9b0b66dd10a09a9725997e6636989116faf3d552 Binary files /dev/null and b/modules/backbone/dino/__pycache__/utils.cpython-38.pyc differ diff --git a/modules/backbone/dino/__pycache__/utils.cpython-39.pyc b/modules/backbone/dino/__pycache__/utils.cpython-39.pyc new file mode 100755 index 0000000000000000000000000000000000000000..b56278ed055738c5211746b92ee478b019a70d08 Binary files /dev/null and b/modules/backbone/dino/__pycache__/utils.cpython-39.pyc differ diff --git a/modules/backbone/dino/__pycache__/vision_transformer.cpython-310.pyc b/modules/backbone/dino/__pycache__/vision_transformer.cpython-310.pyc new file mode 100755 index 0000000000000000000000000000000000000000..7e5cccaa8144930e349c13bda9b746451fa703b5 Binary files /dev/null and b/modules/backbone/dino/__pycache__/vision_transformer.cpython-310.pyc differ diff --git a/modules/backbone/dino/__pycache__/vision_transformer.cpython-311.pyc b/modules/backbone/dino/__pycache__/vision_transformer.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..7fddb78e44acf4f3fc9a5f228ae76b9e4c6146de Binary files /dev/null and b/modules/backbone/dino/__pycache__/vision_transformer.cpython-311.pyc differ diff --git a/modules/backbone/dino/__pycache__/vision_transformer.cpython-36.pyc b/modules/backbone/dino/__pycache__/vision_transformer.cpython-36.pyc new file mode 100755 index 0000000000000000000000000000000000000000..f6951299e5df718eaecacab0dc8bd4940b93f437 Binary files /dev/null and b/modules/backbone/dino/__pycache__/vision_transformer.cpython-36.pyc differ diff --git a/modules/backbone/dino/__pycache__/vision_transformer.cpython-37.pyc b/modules/backbone/dino/__pycache__/vision_transformer.cpython-37.pyc new file mode 100755 index 0000000000000000000000000000000000000000..45797ed360fcc8bc63cce4e9c53b620b1908c9d6 Binary files /dev/null and b/modules/backbone/dino/__pycache__/vision_transformer.cpython-37.pyc differ diff --git a/modules/backbone/dino/__pycache__/vision_transformer.cpython-38.pyc b/modules/backbone/dino/__pycache__/vision_transformer.cpython-38.pyc new file mode 100755 index 0000000000000000000000000000000000000000..7a89176728ba78c3cc50994b800496a7323e9889 Binary files /dev/null and b/modules/backbone/dino/__pycache__/vision_transformer.cpython-38.pyc differ diff --git a/modules/backbone/dino/__pycache__/vision_transformer.cpython-39.pyc b/modules/backbone/dino/__pycache__/vision_transformer.cpython-39.pyc new file mode 100755 index 0000000000000000000000000000000000000000..3b86eea9a6e5f7a5a6e70db7c7850d680feef7a3 Binary files /dev/null and b/modules/backbone/dino/__pycache__/vision_transformer.cpython-39.pyc differ diff --git a/modules/backbone/dino/dinovit.py b/modules/backbone/dino/dinovit.py new file mode 100755 index 0000000000000000000000000000000000000000..f966dab3866655df2dbea18e45d5ee4c3a1dd15b --- /dev/null +++ b/modules/backbone/dino/dinovit.py @@ -0,0 +1,151 @@ +import torch +import torch.nn as nn +# import modules.backbone.dino.vision_transformer as vits + + +# class DinoFeaturizer(nn.Module): + +# def __init__(self, arch, patch_size, totrain): +# super().__init__() +# self.patch_size = patch_size +# self.feat_type = "feat" + +# self.model = vits.__dict__[arch]( +# patch_size=patch_size, +# num_classes=0) +# for p in self.model.parameters(): +# p.requires_grad = False +# self.model.eval() #.cuda() +# if totrain: +# for p in self.model.parameters(): +# p.requires_grad = True +# self.model.train() +# self.dropout = torch.nn.Dropout2d(p=.1) + +# if arch == "vit_small" and patch_size == 16: +# url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth" +# elif arch == "vit_small" and patch_size == 8: +# url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth" +# elif arch == "vit_base" and patch_size == 16: +# url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth" +# elif arch == "vit_base" and patch_size == 8: +# url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth" +# else: +# raise ValueError("Unknown arch and patch size") + +# # if pretrained_weights is not None: +# # state_dict = torch.load(cfg.pretrained_weights, map_location="cpu") +# # state_dict = state_dict["teacher"] +# # # remove `module.` prefix +# # state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} +# # # remove `backbone.` prefix induced by multicrop wrapper +# # state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} + +# # # state_dict = {k.replace("projection_head", "mlp"): v for k, v in state_dict.items()} +# # # state_dict = {k.replace("prototypes", "last_layer"): v for k, v in state_dict.items()} + +# # msg = self.model.load_state_dict(state_dict, strict=False) +# # print('Pretrained weights found at {} and loaded with msg: {}'.format(cfg.pretrained_weights, msg)) +# # else: +# print("Since no pretrained weights have been provided, we load the reference pretrained DINO weights.") +# state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url) +# self.model.load_state_dict(state_dict, strict=True) + +# # if arch == "vit_small": +# # self.n_feats = 384 +# # else: +# # self.n_feats = 768 +# # self.cluster1 = self.make_clusterer(self.n_feats) +# # self.proj_type = cfg.projection_type +# # if self.proj_type == "nonlinear": +# # self.cluster2 = self.make_nonlinear_clusterer(self.n_feats) + +# # def make_clusterer(self, in_channels): +# # return torch.nn.Sequential( +# # torch.nn.Conv2d(in_channels, self.dim, (1, 1))) # , + +# # def make_nonlinear_clusterer(self, in_channels): +# # return torch.nn.Sequential( +# # torch.nn.Conv2d(in_channels, in_channels, (1, 1)), +# # torch.nn.ReLU(), +# # torch.nn.Conv2d(in_channels, self.dim, (1, 1))) + +# def forward(self, img, n=1, return_class_feat=False): +# # self.model.eval() +# with torch.no_grad(): +# assert (img.shape[2] % self.patch_size == 0) +# assert (img.shape[3] % self.patch_size == 0) + +# # get selected layer activations +# feat, attn, qkv = self.model.get_intermediate_feat(img, n=n) +# if n == 1: +# feat, attn, qkv = feat[0], attn[0], qkv[0] +# else: +# feat, attn, qkv = feat[-n], attn[-n], qkv[-n] + + + +# feat_h = img.shape[2] // self.patch_size +# feat_w = img.shape[3] // self.patch_size + +# if self.feat_type == "feat": +# image_feat = feat[:, 1:, :].reshape(feat.shape[0], feat_h, feat_w, -1).permute(0, 3, 1, 2) +# elif self.feat_type == "KK": +# image_k = qkv[1, :, :, 1:, :].reshape(feat.shape[0], 6, feat_h, feat_w, -1) +# B, H, I, J, D = image_k.shape +# image_feat = image_k.permute(0, 1, 4, 2, 3).reshape(B, H * D, I, J) +# else: +# raise ValueError("Unknown feat type:{}".format(self.feat_type)) + +# if return_class_feat: +# return image_feat, feat[:, :1, :].reshape(feat.shape[0], 1, 1, -1).permute(0, 3, 1, 2) +# else: +# return image_feat + +# # if self.proj_type is not None: +# # code = self.cluster1(self.dropout(image_feat)) +# # if self.proj_type == "nonlinear": +# # code += self.cluster2(self.dropout(image_feat)) +# # else: +# # code = image_feat + +# # if self.cfg.dropout: +# # return self.dropout(image_feat), code +# # else: +# # return image_feat, code + +class DinoFeaturizerv2(nn.Module): + + def __init__(self, arch, patch_size): + super().__init__() + self.patch_size = patch_size + self.arch = arch + if 'v2' in arch: + self.model = torch.hub.load('facebookresearch/dinov2', arch+str(patch_size)) + elif 'resnet' in arch: + rn_dino = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50') + from torchvision.models.feature_extraction import create_feature_extractor + return_nodes = {'layer4.2.relu_2': 'out'} + self.model = create_feature_extractor(rn_dino, return_nodes=return_nodes) + else: + self.model = torch.hub.load('facebookresearch/dino:main', arch+str(patch_size)) + for p in self.model.parameters(): + p.requires_grad = False + self.model.eval() + + + def forward(self, img, n=1): + with torch.no_grad(): + assert (img.shape[2] % self.patch_size == 0) + assert (img.shape[3] % self.patch_size == 0) + + if 'v2' in self.arch: + image_feat = self.model.get_intermediate_layers(img, n, reshape=True)[n-1] + elif 'resnet' in self.arch: + image_feat = self.model(img)['out'] + else: + image_feat = self.model.get_intermediate_layers(img, n)[-n][:, 1:, :].transpose(1, 2).contiguous() + image_feat = image_feat.view(image_feat.size(0), image_feat.size(1), img.size(-1)//self.patch_size, img.size(-1)//self.patch_size) + + return image_feat + \ No newline at end of file diff --git a/modules/backbone/dino/utils.py b/modules/backbone/dino/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..f2396a152a286572feb0a4f5633ce65a11eb1680 --- /dev/null +++ b/modules/backbone/dino/utils.py @@ -0,0 +1,619 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Misc functions. + +Mostly copy-paste from torchvision references or other public repos like DETR: +https://github.com/facebookresearch/detr/blob/master/util/misc.py +""" +import os +import sys +import time +import math +import random +import datetime +import subprocess +from collections import defaultdict, deque + +import numpy as np +import torch +from torch import nn +import torch.distributed as dist +from PIL import ImageFilter, ImageOps + + +class GaussianBlur(object): + """ + Apply Gaussian Blur to the PIL image. + """ + def __init__(self, p=0.5, radius_min=0.1, radius_max=2.): + self.prob = p + self.radius_min = radius_min + self.radius_max = radius_max + + def __call__(self, img): + do_it = random.random() <= self.prob + if not do_it: + return img + + return img.filter( + ImageFilter.GaussianBlur( + radius=random.uniform(self.radius_min, self.radius_max) + ) + ) + + +class Solarization(object): + """ + Apply Solarization to the PIL image. + """ + def __init__(self, p): + self.p = p + + def __call__(self, img): + if random.random() < self.p: + return ImageOps.solarize(img) + else: + return img + + +def load_pretrained_weights(model, pretrained_weights, checkpoint_key, model_name, patch_size): + if os.path.isfile(pretrained_weights): + state_dict = torch.load(pretrained_weights, map_location="cpu") + if checkpoint_key is not None and checkpoint_key in state_dict: + print(f"Take key {checkpoint_key} in provided checkpoint dict") + state_dict = state_dict[checkpoint_key] + # remove `module.` prefix + state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} + # remove `backbone.` prefix induced by multicrop wrapper + state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} + msg = model.load_state_dict(state_dict, strict=False) + print('Pretrained weights found at {} and loaded with msg: {}'.format(pretrained_weights, msg)) + else: + print("Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate.") + url = None + if model_name == "vit_small" and patch_size == 16: + url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth" + elif model_name == "vit_small" and patch_size == 8: + url = "dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth" + elif model_name == "vit_base" and patch_size == 16: + url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth" + elif model_name == "vit_base" and patch_size == 8: + url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth" + if url is not None: + print("Since no pretrained weights have been provided, we load the reference pretrained DINO weights.") + state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url) + model.load_state_dict(state_dict, strict=True) + else: + print("There is no reference weights available for this model => We use random weights.") + + +def clip_gradients(model, clip): + norms = [] + for name, p in model.named_parameters(): + if p.grad is not None: + param_norm = p.grad.data.norm(2) + norms.append(param_norm.item()) + clip_coef = clip / (param_norm + 1e-6) + if clip_coef < 1: + p.grad.data.mul_(clip_coef) + return norms + + +def cancel_gradients_last_layer(epoch, model, freeze_last_layer): + if epoch >= freeze_last_layer: + return + for n, p in model.named_parameters(): + if "last_layer" in n: + p.grad = None + + +def restart_from_checkpoint(ckp_path, run_variables=None, **kwargs): + """ + Re-start from checkpoint + """ + if not os.path.isfile(ckp_path): + return + print("Found checkpoint at {}".format(ckp_path)) + + # open checkpoint file + checkpoint = torch.load(ckp_path, map_location="cpu") + + # key is what to look for in the checkpoint file + # value is the object to load + # example: {'state_dict': model} + for key, value in kwargs.items(): + if key in checkpoint and value is not None: + try: + msg = value.load_state_dict(checkpoint[key], strict=False) + print("=> loaded {} from checkpoint '{}' with msg {}".format(key, ckp_path, msg)) + except TypeError: + try: + msg = value.load_state_dict(checkpoint[key]) + print("=> loaded {} from checkpoint '{}'".format(key, ckp_path)) + except ValueError: + print("=> failed to load {} from checkpoint '{}'".format(key, ckp_path)) + else: + print("=> failed to load {} from checkpoint '{}'".format(key, ckp_path)) + + # re load variable important for the run + if run_variables is not None: + for var_name in run_variables: + if var_name in checkpoint: + run_variables[var_name] = checkpoint[var_name] + + +def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0): + warmup_schedule = np.array([]) + warmup_iters = warmup_epochs * niter_per_ep + if warmup_epochs > 0: + warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) + + iters = np.arange(epochs * niter_per_ep - warmup_iters) + schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) + + schedule = np.concatenate((warmup_schedule, schedule)) + assert len(schedule) == epochs * niter_per_ep + return schedule + + +def bool_flag(s): + """ + Parse boolean arguments from the command line. + """ + FALSY_STRINGS = {"off", "false", "0"} + TRUTHY_STRINGS = {"on", "true", "1"} + if s.lower() in FALSY_STRINGS: + return False + elif s.lower() in TRUTHY_STRINGS: + return True + else: + raise argparse.ArgumentTypeError("invalid value for a boolean flag") + + +def fix_random_seeds(seed=31): + """ + Fix random seeds. + """ + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.6f} ({global_avg:.6f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +def reduce_dict(input_dict, average=True): + """ + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + Reduce the values in the dictionary from all processes so that all processes + have the averaged results. Returns a dict with the same fields as + input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.all_reduce(values) + if average: + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.6f}') + data_time = SmoothedValue(fmt='{avg:.6f}') + space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + if torch.cuda.is_available(): + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}', + 'max mem: {memory:.0f}' + ]) + else: + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}' + ]) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + else: + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('{} Total time: {} ({:.6f} s / it)'.format( + header, total_time_str, total_time / len(iterable))) + + +def get_sha(): + cwd = os.path.dirname(os.path.abspath(__file__)) + + def _run(command): + return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() + sha = 'N/A' + diff = "clean" + branch = 'N/A' + try: + sha = _run(['git', 'rev-parse', 'HEAD']) + subprocess.check_output(['git', 'diff'], cwd=cwd) + diff = _run(['git', 'diff-index', 'HEAD']) + diff = "has uncommited changes" if diff else "clean" + branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) + except Exception: + pass + message = f"sha: {sha}, status: {diff}, branch: {branch}" + return message + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def init_distributed_mode(args): + # launched with torch.distributed.launch + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + # launched with submitit on a slurm cluster + elif 'SLURM_PROCID' in os.environ: + args.rank = int(os.environ['SLURM_PROCID']) + args.gpu = args.rank % torch.cuda.device_count() + # launched naively with `python main_dino.py` + # we manually add MASTER_ADDR and MASTER_PORT to env variables + elif torch.cuda.is_available(): + print('Will run the code on one GPU.') + args.rank, args.gpu, args.world_size = 0, 0, 1 + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = '29500' + else: + print('Does not support training without GPU.') + sys.exit(1) + + dist.init_process_group( + backend="nccl", + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, + ) + + torch.cuda.set_device(args.gpu) + print('| distributed init (rank {}): {}'.format( + args.rank, args.dist_url), flush=True) + dist.barrier() + setup_for_distributed(args.rank == 0) + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + maxk = max(topk) + batch_size = target.size(0) + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.reshape(1, -1).expand_as(pred)) + return [correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk] + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +class LARS(torch.optim.Optimizer): + """ + Almost copy-paste from https://github.com/facebookresearch/barlowtwins/blob/main/main.py + """ + def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, eta=0.001, + weight_decay_filter=None, lars_adaptation_filter=None): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, + eta=eta, weight_decay_filter=weight_decay_filter, + lars_adaptation_filter=lars_adaptation_filter) + super().__init__(params, defaults) + + @torch.no_grad() + def step(self): + for g in self.param_groups: + for p in g['params']: + dp = p.grad + + if dp is None: + continue + + if p.ndim != 1: + dp = dp.add(p, alpha=g['weight_decay']) + + if p.ndim != 1: + param_norm = torch.norm(p) + update_norm = torch.norm(dp) + one = torch.ones_like(param_norm) + q = torch.where(param_norm > 0., + torch.where(update_norm > 0, + (g['eta'] * param_norm / update_norm), one), one) + dp = dp.mul(q) + + param_state = self.state[p] + if 'mu' not in param_state: + param_state['mu'] = torch.zeros_like(p) + mu = param_state['mu'] + mu.mul_(g['momentum']).add_(dp) + + p.add_(mu, alpha=-g['lr']) + + +class MultiCropWrapper(nn.Module): + """ + Perform forward pass separately on each resolution input. + The inputs corresponding to a single resolution are clubbed and single + forward is run on the same resolution inputs. Hence we do several + forward passes = number of different resolutions used. We then + concatenate all the output features and run the head forward on these + concatenated features. + """ + def __init__(self, backbone, head): + super(MultiCropWrapper, self).__init__() + # disable layers dedicated to ImageNet labels classification + backbone.fc, backbone.head = nn.Identity(), nn.Identity() + self.backbone = backbone + self.head = head + + def forward(self, x): + # convert to list + if not isinstance(x, list): + x = [x] + idx_crops = torch.cumsum(torch.unique_consecutive( + torch.tensor([inp.shape[-1] for inp in x]), + return_counts=True, + )[1], 0) + start_idx = 0 + for end_idx in idx_crops: + _out = self.backbone(torch.cat(x[start_idx: end_idx])) + if start_idx == 0: + output = _out + else: + output = torch.cat((output, _out)) + start_idx = end_idx + # Run the head forward on the concatenated features. + return self.head(output) + + +def get_params_groups(model): + regularized = [] + not_regularized = [] + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + # we do not regularize biases nor Norm parameters + if name.endswith(".bias") or len(param.shape) == 1: + not_regularized.append(param) + else: + regularized.append(param) + return [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.}] + + +def has_batchnorms(model): + bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) + for name, module in model.named_modules(): + if isinstance(module, bn_types): + return True + return False diff --git a/modules/backbone/dino/vision_transformer.py b/modules/backbone/dino/vision_transformer.py new file mode 100755 index 0000000000000000000000000000000000000000..29419e88f6ba2cffe708931ea51cf78ed7f971cb --- /dev/null +++ b/modules/backbone/dino/vision_transformer.py @@ -0,0 +1,314 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Mostly copy-paste from timm library. +https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +""" +import math +from functools import partial + +import torch +import torch.nn as nn +from modules.backbone.dino.utils import trunc_normal_ + +def drop_path(x, drop_prob: float = 0., training: bool = False): + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, return_qkv=False): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x,attn, qkv + + + +class Block(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, return_attention=False, return_qkv = False): + y, attn, qkv = self.attn(self.norm1(x)) + if return_attention: + return attn + x = x + self.drop_path(y) + x = x + self.drop_path(self.mlp(self.norm2(x))) + if return_qkv: + return x,attn, qkv + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + num_patches = (img_size // patch_size) * (img_size // patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + B, C, H, W = x.shape + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class VisionTransformer(nn.Module): + """ Vision Transformer """ + def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs): + super().__init__() + + self.num_features = self.embed_dim = embed_dim + + self.patch_embed = PatchEmbed( + img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + + # Classifier head + self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def interpolate_pos_encoding(self, x, w, h): + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + class_pos_embed = self.pos_embed[:, 0] + patch_pos_embed = self.pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_embed.patch_size + h0 = h // self.patch_embed.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + w0, h0 = w0 + 0.1, h0 + 0.1 + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), + mode='bicubic', + ) + assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def prepare_tokens(self, x): + B, nc, w, h = x.shape + x = self.patch_embed(x) # patch linear embedding + + # add the [CLS] token to the embed patch tokens + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + # add positional encoding to each token + x = x + self.interpolate_pos_encoding(x, w, h) + + return self.pos_drop(x) + + def forward(self, x): + x = self.prepare_tokens(x) + for blk in self.blocks: + x = blk(x) + x = self.norm(x) + return x[:, 0] + + def forward_feats(self, x): + x = self.prepare_tokens(x) + for blk in self.blocks: + x = blk(x) + x = self.norm(x) + return x + + def get_intermediate_feat(self, x, n=1): + x = self.prepare_tokens(x) + # we return the output tokens from the `n` last blocks + feat = [] + attns = [] + qkvs = [] + for i, blk in enumerate(self.blocks): + x,attn,qkv = blk(x, return_qkv=True) + if len(self.blocks) - i <= n: + feat.append(self.norm(x)) + qkvs.append(qkv) + attns.append(attn) + return feat, attns, qkvs + + def get_last_selfattention(self, x): + x = self.prepare_tokens(x) + for i, blk in enumerate(self.blocks): + if i < len(self.blocks) - 1: + x = blk(x) + else: + # return attention of the last block + return blk(x, return_attention=True) + + def get_intermediate_layers(self, x, n=1): + x = self.prepare_tokens(x) + # we return the output tokens from the `n` last blocks + output = [] + for i, blk in enumerate(self.blocks): + x = blk(x) + if len(self.blocks) - i <= n: + output.append(self.norm(x)) + return output + + +def vit_tiny(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_small(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_base(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +class DINOHead(nn.Module): + def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256): + super().__init__() + nlayers = max(nlayers, 1) + if nlayers == 1: + self.mlp = nn.Linear(in_dim, bottleneck_dim) + else: + layers = [nn.Linear(in_dim, hidden_dim)] + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + for _ in range(nlayers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim)) + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, bottleneck_dim)) + self.mlp = nn.Sequential(*layers) + self.apply(self._init_weights) + self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + if norm_last_layer: + self.last_layer.weight_g.requires_grad = False + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.mlp(x) + x = nn.functional.normalize(x, dim=-1, p=2) + x = self.last_layer(x) + return x diff --git a/modules/crf.py b/modules/crf.py new file mode 100755 index 0000000000000000000000000000000000000000..88ac4d43bb5eb8a241d6612abcc5e34315f75fef --- /dev/null +++ b/modules/crf.py @@ -0,0 +1,60 @@ +# +# Authors: Wouter Van Gansbeke & Simon Vandenhende +# Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/) + +import sys +import os +import numpy as np +import pydensecrf.densecrf as dcrf +import pydensecrf.utils as utils +import torch +import torch.nn.functional as F +import torchvision.transforms.functional as VF +sys.path.append(os.getcwd()) +from modules.transforms import UnNormalize as unnorm + + +MAX_ITER = 10 +POS_W = 3 +POS_XY_STD = 1 +Bi_W = 4 +Bi_XY_STD = 67 +Bi_RGB_STD = 3 +BGR_MEAN = np.array([104.008, 116.669, 122.675]) + + +def dense_crf(image_tensor: torch.FloatTensor, output_logits: torch.FloatTensor): + image = np.array(VF.to_pil_image(unnorm()(image_tensor)))[:, :, ::-1] + H, W = image.shape[:2] + image = np.ascontiguousarray(image) + + output_logits = F.interpolate(output_logits.unsqueeze(0), size=(H, W), mode="bilinear", + align_corners=False).squeeze() + output_probs = F.softmax(output_logits, dim=0).cpu().numpy() + + c = output_probs.shape[0] + h = output_probs.shape[1] + w = output_probs.shape[2] + + U = utils.unary_from_softmax(output_probs) + U = np.ascontiguousarray(U) + + d = dcrf.DenseCRF2D(w, h, c) + d.setUnaryEnergy(U) + d.addPairwiseGaussian(sxy=POS_XY_STD, compat=POS_W) + d.addPairwiseBilateral(sxy=Bi_XY_STD, srgb=Bi_RGB_STD, rgbim=image, compat=Bi_W) + + Q = d.inference(MAX_ITER) + Q = np.array(Q).reshape((c, h, w)) + return Q + + + +def _apply_crf(tup): + return dense_crf(tup[0], tup[1]) + + +def batched_crf(pool, img_tensor, prob_tensor): + outputs = pool.map(_apply_crf, zip(img_tensor.detach().cpu(), prob_tensor.detach().cpu())) + return torch.cat([torch.from_numpy(arr).unsqueeze(0) for arr in outputs], dim=0) + diff --git a/modules/median_pool.py b/modules/median_pool.py new file mode 100755 index 0000000000000000000000000000000000000000..d922044a29d34b5243f567595b88f9e350cbea98 --- /dev/null +++ b/modules/median_pool.py @@ -0,0 +1,52 @@ +# Original implementation: https://gist.github.com/rwightman/f2d3849281624be7c0f11c85c87c1598 + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.modules.utils import _pair, _quadruple + + +class MedianPool2d(nn.Module): + """ Median pool (usable as median filter when stride=1) module. + + Args: + kernel_size: size of pooling kernel, int or 2-tuple + stride: pool stride, int or 2-tuple + padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad + same: override padding and enforce same padding, boolean + """ + def __init__(self, kernel_size=3, stride=1, padding=0, same=False): + super(MedianPool2d, self).__init__() + self.k = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _quadruple(padding) # convert to l, r, t, b + self.same = same + + def _padding(self, x): + if self.same: + ih, iw = x.size()[2:] + if ih % self.stride[0] == 0: + ph = max(self.k[0] - self.stride[0], 0) + else: + ph = max(self.k[0] - (ih % self.stride[0]), 0) + if iw % self.stride[1] == 0: + pw = max(self.k[1] - self.stride[1], 0) + else: + pw = max(self.k[1] - (iw % self.stride[1]), 0) + pl = pw // 2 + pr = pw - pl + pt = ph // 2 + pb = ph - pt + padding = (pl, pr, pt, pb) + else: + padding = self.padding + return padding + + def forward(self, x): + # using existing pytorch functions and tensor ops so that we get autograd, + # would likely be more efficient to implement from scratch at C/Cuda level + x = F.pad(x, self._padding(x), mode='reflect') + x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1]) + x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0] + return x \ No newline at end of file diff --git a/modules/primaps.py b/modules/primaps.py new file mode 100755 index 0000000000000000000000000000000000000000..a1377d373a69844e44097a8e0b961f2f6a374215 --- /dev/null +++ b/modules/primaps.py @@ -0,0 +1,75 @@ +import torch +import sys +import os +import torch.nn.functional as F + +sys.path.append(os.getcwd()) +from modules.crf import dense_crf +from modules.median_pool import MedianPool2d + + +class PriMaPs(): + def __init__(self, + threshold=0.4, + ignore_id=27): + super(PriMaPs, self).__init__() + self.threshold = threshold + self.ignore_id = ignore_id + self.medianfilter = MedianPool2d(kernel_size=3, stride=1, padding=1) + + def _get_pseudo(self, img, feat, cls_prior): + # initialize used pixel mask + mask = torch.ones(feat.shape[-2:]).bool().to(feat.device) + mask_memory = [] + pseudo_masks = [] + # get masks until 95% of features are masked or mask does not change + while ((mask!=1).sum()/mask.numel() < 0.95): + _, _, v = torch.pca_lowrank(feat[:, mask].permute(1, 0), q=3, niter=100) + # cos similarity to to c + sim = torch.einsum("c,cij->ij", v[:, 0], F.normalize(feat, dim=0)) + # refine direction with NN + sim[~mask] = 0 + v = F.normalize(feat, dim=0)[:, sim==sim.max()][:, 0] + sim = torch.einsum("c,cij->ij", v, F.normalize(feat, dim=0)) + sim[~mask] = 0 + # apply threshhold and norm + sim[sim0]=0 + mask_memory.insert(0, mask.clone()) + if mask_memory.__len__() > 3: + mask_memory.pop() + if torch.Tensor([(mask_memory[0]==i).all() for i in mask_memory]).all(): + break + # insert bg mask and stack + pseudo_masks = (self.medianfilter(torch.stack(pseudo_masks, dim=0).unsqueeze(0)).squeeze()*10).clamp(0, 1) + bg = (torch.mean(pseudo_masks[pseudo_masks!=0])*torch.ones(feat.shape[-2:], device=feat.device)-pseudo_masks.sum(dim=0)).unsqueeze(0).clamp(0, 1) + + if (pseudo_masks.shape).__len__() == 2: + pseudo_masks = pseudo_masks.unsqueeze(0) + pseudo_masks = torch.cat([bg, pseudo_masks], dim=0) + pseudo_masks = F.log_softmax(pseudo_masks, dim=0) + # apply crf to refine masks + pseudo_masks = dense_crf(img.squeeze(), pseudo_masks).argmax(0) + pseudo_masks = torch.Tensor(pseudo_masks).to(feat.device) + + if (cls_prior == 0).all(): + pseudolabel = pseudo_masks + pseudolabel[pseudolabel==0] = self.ignore_id + else: + pseudolabel = torch.ones(img.shape[-2:]).to(feat.device)*self.ignore_id + for i in pseudo_masks.unique()[pseudo_masks.unique()!=0]: + # only look at not assigned and attended pixels + mask = (pseudolabel==self.ignore_id)*(pseudo_masks==i) + pseudolabel[mask] = int(torch.mode(cls_prior[mask])[0]) + return pseudolabel + + # multiprocessing wrapper + def _apply_batched_decompose(self, tup): + return self._get_pseudo(tup[0], tup[1], tup[2]) + + def __call__(self, pool, imgs, features, cls_prior): + outs = pool.map(self._apply_batched_decompose, zip(imgs, features, cls_prior)) + return torch.stack(outs, dim=0) \ No newline at end of file diff --git a/modules/transforms.py b/modules/transforms.py new file mode 100755 index 0000000000000000000000000000000000000000..533e2fdfc2ec8a3a011f040d357f071092781d8e --- /dev/null +++ b/modules/transforms.py @@ -0,0 +1,333 @@ +import torch, random +import torchvision.transforms.functional as F +import torchvision.transforms as tf +import numpy as np +from PIL import Image +from typing import Tuple, List, Callable + + +class Compose: + + def __init__(self, + transforms: List[Callable], + student_augs: bool = False): + self.transforms = transforms + self.student_augs = student_augs + + def __call__(self, + img: Image.Image, + gt: Image.Image, + pseudo = None) -> Tuple[torch.Tensor, torch.Tensor]: + + for transform in self.transforms: + if pseudo is None: + img, gt = transform(img, gt) + else: + img, gt, pseudo = transform(img, gt, pseudo) + + if self.student_augs: + aimg = img.clone() + aimg, _ = RandGaussianBlur()(aimg, gt) + if 0.5 > random.random(): + aimg, _ = ColorJitter()(aimg, gt) + else: + aimg, _ = MaskGrayscale()(aimg, gt) + + + if pseudo is None and not self.student_augs: + return img, gt + elif pseudo is None and self.student_augs: + return img, gt, aimg + elif pseudo is not None and not self.student_augs: + return img, gt, pseudo + else: + return img, gt, aimg, pseudo + +class ToTensor: + + def __call__(self, + img: Image.Image, + gt: Image.Image, + pseudo = None) -> Tuple[torch.Tensor, torch.Tensor]: + + img = F.to_tensor(np.array(img)) + gt = torch.from_numpy(np.array(gt)).unsqueeze(0) + if pseudo is not None: + pseudo = torch.from_numpy(np.array(pseudo)).unsqueeze(0) + + if pseudo is None: + return img, gt + else: + return img, gt, pseudo + +class Resize: + + def __init__(self, + resize: Tuple[int]): + + self.img_resize = tf.Resize(size=resize, + interpolation=tf.InterpolationMode.BILINEAR) + self.gt_resize = tf.Resize(size=resize, + interpolation=tf.InterpolationMode.NEAREST) + + def __call__(self, + img: Image.Image, + gt: Image.Image, + pseudo = None) -> Tuple[Image.Image, Image.Image]: + + img = self.img_resize(img) + gt = self.gt_resize(gt) + + if pseudo is None: + return img, gt + else: + return img, gt, self.gt_resize(pseudo) + +class ImgResize: + + def __init__(self, + resize: Tuple[int, int]): + self.resize = resize + self.num_pixels = self.resize[0]*self.resize[1] + + def __call__(self, + img: torch.Tensor, + gt: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if torch.prod(torch.tensor(img.shape[-2:])) > self.num_pixels: + img = torch.nn.functional.interpolate(img.unsqueeze(0), size=self.resize, mode='bilinear').squeeze(0) + return img, gt + +class ImgResizePIL: + + def __init__(self, + resize: Tuple[int]): + self.resize = resize + self.num_pixels = self.resize[0]*self.resize[1] + + def __call__(self, + img: Image) -> Image: + if img.height*img.width > self.num_pixels: + img = img.resize((self.resize[1], self.resize[0]), tf.InterpolationMode.BILINEAR) + return img + +class Normalize: + + def __init__(self, + mean: List[float] = [0.485, 0.456, 0.406], + std: List[float] = [0.229, 0.224, 0.225]): + + self.norm = tf.Normalize(mean=mean, + std=std) + + def __call__(self, + img: torch.Tensor, + gt: torch.Tensor, + pseudo = None) -> Tuple[torch.Tensor, torch.Tensor]: + + img = self.norm(img) + + if pseudo is None: + return img, gt + else: + return img, gt, pseudo + +class UnNormalize(object): + def __init__(self, + mean: List[float] = [0.485, 0.456, 0.406], + std: List[float] = [0.229, 0.224, 0.225]): + self.mean = mean + self.std = std + + def __call__(self, image): + image2 = torch.clone(image) + for t, m, s in zip(image2, self.mean, self.std): + t.mul_(s).add_(m) + return image2 + + + +class RandomHFlip: + + def __init__(self, + percentage: float = 0.5): + + self.percentage = percentage + + def __call__(self, + img: Image.Image, + gt: Image.Image, + pseudo = None) -> Tuple[Image.Image, Image.Image]: + + if random.random() < self.percentage: + img = F.hflip(img) + gt = F.hflip(gt) + if pseudo is not None: + pseudo = F.hflip(pseudo) + + if pseudo is None: + return img, gt + else: + return img, gt, pseudo + + +class RandomResizedCrop: + + def __init__(self, + crop_size: List[int], + crop_scale: List[float], + crop_ratio: List[float]): + print('RandomResizedCrop ratio modified!!!') + self.crop_scale = tuple(crop_scale) + self.crop_ratio = tuple(crop_ratio) + self.crop = tf.RandomResizedCrop(size=tuple(crop_size), + scale=self.crop_scale, + ratio=self.crop_ratio,) + + def __call__(self, + img: Image.Image, + gt: Image.Image, + pseudo = None) -> Tuple[Image.Image, Image.Image]: + + i, j, h, w = self.crop.get_params(img=img, + scale=self.crop.scale, + ratio=self.crop.ratio) + img = F.resized_crop(img, i, j, h, w, self.crop.size, tf.InterpolationMode.BILINEAR) + gt = F.resized_crop(gt, i, j, h, w, self.crop.size, tf.InterpolationMode.NEAREST) + if pseudo is not None: + pseudo = F.resized_crop(pseudo, i, j, h, w, self.crop.size, tf.InterpolationMode.NEAREST) + + if pseudo is None: + return img, gt + else: + return img, gt, pseudo + +class CenterCrop: + + def __init__(self, + crop_size: int): + + self.crop = tf.CenterCrop(size=crop_size) + + def __call__(self, + img: Image.Image, + gt: Image.Image, + pseudo = None) -> Tuple[Image.Image, Image.Image]: + + img = self.crop(img) + gt = self.crop(gt) + + if pseudo is None: + return img, gt + else: + return img, gt, self.crop(pseudo) + +class PyramidCenterCrop: + + def __init__(self, + crop_size: List[int], + scales: List[float]): + + self.crop_size = crop_size + self.scales = scales + self.crop = tf.CenterCrop(size=crop_size) + + + def __call__(self, + img: Image.Image, + gt: Image.Image) -> Tuple[Image.Image, Image.Image]: + + imgs = [] + gts = [] + for s in self.scales: + new_size = (int(self.crop_size*1/s), int(self.crop_size*1/s*(img.shape[2]/img.shape[1]))) + img = tf.Resize(size=new_size, interpolation=tf.InterpolationMode.BILINEAR)(img) + gt = tf.Resize(size=new_size, interpolation=tf.InterpolationMode.NEAREST)(gt) + imgs.append(self.crop(img)) + gts.append(self.crop(gt)) + + return torch.stack(imgs), torch.stack(gts) + + + + + + +class IdsToTrainIds: + + def __init__(self, + source: str): + + self.source = source + self.first_nonvoid = 7 + + + def __call__(self, + img: torch.Tensor, + gt: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + + if self.source == 'cityscapes': + gt = gt.to(dtype=torch.int64) - self.first_nonvoid + gt[gt>26] = 255 + gt[gt<0] = 255 + elif self.source == 'cocostuff': + gt = gt.to(dtype=torch.int64) + elif self.source == 'potsdam': + gt = gt.to(dtype=torch.int64) + return img, gt + + +class ColorJitter: + def __init__(self, percentage: float = 0.3, brightness: float = 0.1, + contrast: float = 0.1, saturation: float = 0.1, hue: float = 0.1): + + self.percentage = percentage + self.jitter = tf.ColorJitter(brightness=brightness, + contrast=contrast, + saturation=saturation, + hue=hue) + + def __call__(self, + img: Image.Image, + gt: Image.Image, + pseudo = None) -> Tuple[Image.Image, Image.Image]: + if random.random() < self.percentage: + img = self.jitter(img) + + if pseudo is None: + return img, gt + else: + return img, gt, pseudo + +class MaskGrayscale: + + def __init__(self, percentage: float = 0.1): + self.percentage = percentage + + def __call__(self, + img: Image.Image, + gt: Image.Image, + pseudo = None) -> Tuple[Image.Image, Image.Image]: + if self.percentage > random.random(): + img = tf.Grayscale(num_output_channels=3)(img) + if pseudo is None: + return img, gt + else: + return img, gt, pseudo + +class RandGaussianBlur: + + def __init__(self, radius: List[float] = [.1, 2.]): + self.radius = radius + + def __call__(self, + img: Image.Image, + gt: Image.Image, + pseudo = None) -> Tuple[Image.Image, Image.Image]: + + radius = random.uniform(self.radius[0], self.radius[1]) + img = tf.GaussianBlur(kernel_size=21, sigma=radius)(img) + + if pseudo is None: + return img, gt + else: + return img, gt, pseudo diff --git a/modules/visualization.py b/modules/visualization.py new file mode 100755 index 0000000000000000000000000000000000000000..c659a7145d9ca5e2f87986912fcaa5e182d40b36 --- /dev/null +++ b/modules/visualization.py @@ -0,0 +1,473 @@ +import os +import sys +import numpy as np +import torch +import matplotlib.pyplot as plt +from cityscapesscripts.helpers.labels import labels as cs_labels +from datasets.cityscapes import get_cs_labeldata +from datasets.cocostuff import get_coco_labeldata +from datasets.potsdam import get_pd_labeldata + +sys.path.append(os.getcwd()) +import modules.transforms as transforms + + +def visualize_segmentation(img = None, + label = None, + linear = None, + mlp = None, + cluster = None, + dataset_name = None, + additional = None, + additional_name = None, + additional2 = None, + additional_name2 = None, + legend = None, + name = None): + + + if dataset_name == "cityscapes": + colormap = np.array([ + [128, 64, 128], + [244, 35, 232], + [250, 170, 160], + [230, 150, 140], + [70, 70, 70], + [102, 102, 156], + [190, 153, 153], + [180, 165, 180], + [150, 100, 100], + [150, 120, 90], + [153, 153, 153], + [153, 153, 153], + [250, 170, 30], + [220, 220, 0], + [107, 142, 35], + [152, 251, 152], + [70, 130, 180], + [220, 20, 60], + [255, 0, 0], + [0, 0, 142], + [0, 0, 70], + [0, 60, 100], + [0, 0, 90], + [0, 0, 110], + [0, 80, 100], + [0, 0, 230], + [119, 11, 32], + [0, 0, 0], + [220, 220, 220]]) + elif dataset_name == "cocostuff": + colormap = get_coco_labeldata()[-1] + + + orig_h, orig_w = label.cpu().shape[-2:] + img = img.cpu().squeeze(0).numpy().transpose(1, 2, 0) + img = (img-img.min())/(img-img.min()).max() + label = label.cpu().squeeze(0).numpy().transpose(1, 2, 0) + #transforms.labelIdsToTrainIds(source="cityscapes", target="cityscapes") + + label[label == 255] = 27 + colored_label = colormap[label.flatten()] + colored_label = colored_label.reshape(orig_h, orig_w, 3) + + num_subplots = 3 + if linear != None: num_subplots += 1 + if mlp != None: num_subplots += 1 + if additional != None: num_subplots += 1 + if additional2 != None: num_subplots += 1 + + + fig = plt.figure(figsize=(8, 2), dpi=200) + fig.tight_layout() + plt.axis('off') + plt.subplot(1, num_subplots, 1) + plt.gca().set_title('Image') + plt.imshow(img) + plt.axis("off") + plt.subplot(1, num_subplots, 2) + plt.gca().set_title('Ground Truth') + plt.imshow(colored_label) + plt.axis("off") + i = 3 + if linear != None: + linear = linear.cpu().numpy().transpose(1, 2, 0).astype('uint8') + linear = colormap[linear.flatten()].reshape(linear.shape[0], linear.shape[1], 3) + plt.axis("off") + plt.subplot(1, num_subplots, i) + plt.gca().set_title('Linear') + plt.imshow(linear) + i+=1 + + if mlp != None: + mlp = mlp.cpu().numpy().transpose(1, 2, 0).astype('uint8') + mlp = colormap[mlp.flatten()].reshape(mlp.shape[0], mlp.shape[1], 3) + plt.axis("off") + plt.subplot(1, num_subplots, i) + plt.gca().set_title('MLP') + plt.imshow(mlp) + plt.axis("off") + i+=1 + + if cluster != None: + cluster = cluster.cpu().numpy().transpose(1, 2, 0).astype('uint8') + cluster = colormap[cluster.flatten()].reshape(cluster.shape[0], cluster.shape[1], 3) + plt.axis("off") + plt.subplot(1, num_subplots, i) + plt.gca().set_title('Cluster') + plt.imshow(cluster) + plt.axis("off") + i+=1 + + if additional != None: + #additional = additional.cpu().numpy() + additional = additional.cpu().numpy().transpose(1, 2, 0).astype('uint8') + additional = colormap[additional.flatten()].reshape(additional.shape[0], additional.shape[1], 3) + plt.axis("off") + plt.subplot(1, num_subplots, i) + plt.gca().set_title(additional_name) + plt.imshow(additional) + plt.axis("off") + i+=1 + + if additional2 != None: + additional2 = additional2.cpu().numpy() + plt.axis("off") + plt.subplot(1, num_subplots, i) + plt.gca().set_title(additional_name2) + plt.imshow(additional2) + plt.axis("off") + i+=1 + + + # if legend != None: + # from matplotlib.lines import Line2D + + # legend_elements = [Line2D([0], [0], color=np.array(cls[7])/255, lw=4, label=cls[0]) for cls in cs_labels[7:-1]] + + # # Create the figure + # #fig, ax = plt.subplots() + # plt.legend(handles=legend_elements, loc='right') + + + + if name != None: plt.savefig(name) + fig.canvas.draw() + # Now we can save it to a numpy array. + data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close('all') + + return data + + + +def visualize_confusion_matrix(cls_names, meter, name=None): + # plot of confusion matrix + conf_matrix = (meter.histogram/meter.histogram.sum(dim=0)) + conf_matrix = np.array(conf_matrix.cpu(), dtype=np.float16) + fig, ax = plt.subplots(figsize=(15, 15)) + ax.matshow(torch.Tensor(conf_matrix).fill_diagonal_(0), cmap=plt.cm.Blues, alpha=0.8) + for i in range(conf_matrix.shape[0]): + for j in range(conf_matrix.shape[1]): + ax.text(x=j, y=i,s=(conf_matrix[i, j]*100).round(1), va='center', ha='center', size='large') + ax.set_xticks(list(range(cls_names.__len__()))) + ax.set_xticklabels(cls_names, rotation=90, ha='center', fontsize=12) + ax.set_yticks(list(range(cls_names.__len__()))) + ax.set_yticklabels(cls_names, fontsize=12) + plt.xlabel('Predictions', fontsize=18) + plt.ylabel('Actuals', fontsize=18) + plt.title('Confusion Matrix', fontsize=18) + + if name != None: plt.savefig(name) + fig.canvas.draw() + # Now we can save it to a numpy array. + data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close('all') + return data + + + + + +def batch_visualize_segmentation(img = None, + label = None, + in1 = None, + in2 = None, + in3 = None, + in4 = None, + dataset_name = None): + + + if dataset_name == "cityscapes": + colormap = get_cs_labeldata()[-1] + elif dataset_name == "cocostuff": + colormap = get_coco_labeldata()[-1] + elif dataset_name == "potsdam": + colormap = get_pd_labeldata()[-1] + + def _vis_one_img(idx, img, label, ins): + + orig_h, orig_w = label.cpu().shape[-2:] + img = img.cpu().numpy().transpose(1, 2, 0) + img = (img-img.min())/(img-img.min()).max() + label = label.cpu().numpy().transpose(1, 2, 0) + label[label > 27] = 27 + colored_label = colormap[label.flatten()].reshape(orig_h, orig_w, 3) + + num_subplots = sum([1 for x in [in1, in2, in3, in4] if x != None]) + 2 + + fig = plt.figure(figsize=(10, 2), dpi=150) + fig.tight_layout() + plt.axis('off') + plt.subplot(1, num_subplots, 1) + if idx == 0: plt.gca().set_title('Image') + plt.imshow(img) + plt.axis("off") + plt.subplot(1, num_subplots, 2) + if idx == 0: plt.gca().set_title('Ground Truth') + plt.imshow(colored_label) + plt.axis("off") + if ins != None: + i = 3 + for input in ins: + vis = input[1].cpu().numpy().transpose(1, 2, 0).astype('uint8') + vis = colormap[vis.flatten()].reshape(vis.shape[0], vis.shape[1], 3) + plt.axis("off") + plt.subplot(1, num_subplots, i) + if idx == 0: plt.gca().set_title(input[0]) + plt.imshow(vis) + plt.axis("off") + i+=1 + + fig.canvas.draw() + plt.close('all') + one_vis = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) + one_vis = one_vis.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close('all') + return one_vis + + imgs = [] + for idx, (data) in enumerate(zip(img, label)): + imgs.append(_vis_one_img(idx, data[0], data[1], [[i[0], i[1][idx].unsqueeze(0)] for i in [in1, in2, in3, in4] if i!=None])) + + return np.vstack(imgs) + + + +def visualize_single_masks(img, + label, + data, + dataset_name = None): + + + if dataset_name == "cityscapes": + colormap = get_cs_labeldata()[-1] + elif dataset_name == "cocostuff": + colormap = get_coco_labeldata()[-1] + elif dataset_name == "potsdam": + colormap = get_pd_labeldata()[-1] + + + fig = plt.figure(figsize=(data['sim'].__len__()*2, 7*2), dpi=150) + fig.tight_layout() + for indx, (sim, nnsim, nnsim_thresh, crf, pamr, mask) in enumerate(zip(data['sim'], data['nnsim'], data['nnsim_tresh'], data['crf'], data['pamr'], data['outmask'])): + rows = data['sim'].__len__() + cols = 8 + plotlabel=colormap[label.squeeze(0).squeeze(0).int().cpu()] + plt.subplot(rows, cols, 1+(indx*cols)) + img = (img-img.min())/(img.max()-img.min()) + if indx == 0: plt.title('Image') + plt.imshow(img.squeeze(0).permute(1, 2, 0).cpu()) + plt.axis('off') + plt.subplot(rows, cols, 2+(indx*cols)) + if indx == 0: plt.title('GT') + plt.imshow(plotlabel) + plt.axis('off') + plt.subplot(rows, cols, 3+(indx*cols)) + if indx == 0: plt.title('1.Eig') + plt.imshow(sim.cpu().numpy()) + plt.axis('off') + plt.subplot(rows, cols, 4+(indx*cols)) + if indx == 0: plt.title('1.EigNN') + plt.imshow(nnsim.cpu().numpy()) + plt.axis('off') + plt.subplot(rows, cols, 5+(indx*cols)) + if indx == 0: plt.title('+Thresh') + plt.imshow(nnsim_thresh) + plt.axis('off') + plt.subplot(rows, cols, 6+(indx*cols)) + if indx == 0: plt.title('+CRF') + plt.imshow(crf) + plt.axis('off') + plt.subplot(rows, cols, 7+(indx*cols)) + if indx == 0: plt.title('PAMR') + plt.imshow(pamr.squeeze().cpu().numpy()) + plt.axis('off') + plt.subplot(rows, cols, 8+(indx*cols)) + if indx == 0: plt.title('Mask') + mask[0, 0] = 0 + plt.imshow(mask.numpy(), cmap='Greys') + plt.axis('off') + # plt.savefig(str(idx)+'.png', tight_layout=True) + + + fig.canvas.draw() + plt.close('all') + one_vis = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) + one_vis = one_vis.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close('all') + return one_vis + + + + + +def visualize_pseudo_paper(img, + label, + pseudo_gt, + pseudo_plain, + dataset_name = None, + save_name = None): + + + if dataset_name == "cityscapes": + colormap = get_cs_labeldata()[-1] + elif dataset_name == "cocostuff": + colormap = get_coco_labeldata()[-1] + elif dataset_name == "potsdam": + colormap = get_pd_labeldata()[-1] + + + np.random.seed(0) + cb_colomap = np.array([list(np.random.randint(0, 255, size=(1,3))[0]) for _ in range(400)]+[[0, 0, 0]]) + pseudo_plain = pseudo_plain.int().cpu() + pseudo_plain[pseudo_plain==255] = 400 + pseudo_plain = cb_colomap[pseudo_plain.int().cpu()].squeeze() + + + + + fig = plt.figure(figsize=(8, 2), dpi=150) + fig.subplots_adjust(left=0.1, + bottom=0.1, + right=0.5, + top=0.5, + wspace=0.05, + hspace=0.0) + + plt.subplot(1, 4, 1) + img = (img-img.min())/(img.max()-img.min()) + img = img.squeeze(0).permute(1, 2, 0).cpu() + plt.imshow(img) + plt.axis('off') + + plt.subplot(1, 4, 2) + plotlabel=colormap[label.squeeze(0).squeeze(0).int().cpu()] + plt.imshow(plotlabel) + plt.axis('off') + + plt.subplot(1, 4, 3) + plotpseudo=colormap[pseudo_gt.squeeze(0).squeeze(0).int().cpu()] + # pseudo_plain = np.array(pseudo_plain.cpu(), dtype=np.int16).squeeze() + # plotpseudo = mark_boundaries(plotlabel/255, pseudo_plain, color=(1, 1, 1)) + plt.imshow(plotpseudo) + plt.axis('off') + + plt.subplot(1, 4, 4) + plt.imshow(pseudo_plain) + plt.axis('off') + plt.savefig(save_name+'.pdf', bbox_inches='tight', pad_inches=0.0) + + + save_name_single = os.path.join(os.path.dirname(save_name), 'singleimgs/') + os.makedirs(os.path.dirname(save_name_single), exist_ok=True) + for i, n in zip([img, plotlabel, plotpseudo, pseudo_plain], ['img', 'gt', 'pseudo', 'pseudoc']): + fig = plt.figure(figsize=(2, 2), dpi=300) + plt.imshow(i) + plt.axis('off') + plt.savefig(os.path.join(save_name_single, os.path.split(save_name)[-1]+'_'+n+'.png'), bbox_inches='tight', pad_inches=0.0) + + + + + +def logits_to_image(logits = None, + img = None, + label = None, + dataset_name = None, + save_path = None, + save_imggt = False): + + + if dataset_name == "cityscapes": + colormap = get_cs_labeldata()[-1] + elif dataset_name == "cocostuff": + colormap = get_coco_labeldata()[-1] + elif dataset_name == "potsdam": + colormap = get_pd_labeldata()[-1] + + vis = logits.cpu().numpy().transpose(1, 2, 0).astype('uint8') + vis = colormap[vis.flatten()].reshape(vis.shape[0], vis.shape[1], 3) + + fig = plt.figure(figsize=(2, 2), dpi=400) + fig.tight_layout() + plt.subplot(1, 1, 1) + plt.imshow(vis) + plt.axis("off") + plt.savefig(save_path+'_pred.png', bbox_inches='tight', pad_inches=0.0) + plt.close('all') + + if save_imggt: + orig_h, orig_w = label.cpu().shape[-2:] + img = img.cpu().numpy().transpose(1, 2, 0) + img = (img-img.min())/(img-img.min()).max() + label = label.cpu().numpy().transpose(1, 2, 0) + label[label > 27] = 27 + colored_label = colormap[label.flatten()].reshape(orig_h, orig_w, 3) + + fig = plt.figure(figsize=(2, 2), dpi=400) + fig.tight_layout() + plt.subplot(1, 1, 1) + plt.imshow(img) + plt.axis("off") + plt.savefig(save_path+'_img.png', bbox_inches='tight', pad_inches=0.0) + plt.close('all') + + fig = plt.figure(figsize=(2, 2), dpi=400) + fig.tight_layout() + plt.subplot(1, 1, 1) + plt.imshow(colored_label) + plt.axis("off") + plt.savefig(save_path+'_gt.png', bbox_inches='tight', pad_inches=0.0) + plt.close('all') + + + + +class Vis_Demo(): + def __init__(self): + super(Vis_Demo, self).__init__() + self.colormap = get_coco_labeldata()[-1] + + def apply_colors(self, logits): + vis = logits.cpu().numpy().transpose(1, 2, 0).astype('uint8') + vis = self.colormap[vis.flatten()].reshape(vis.shape[0], vis.shape[1], 3) + return vis + + + +def visualize_demo(img, pseudo, alpha = 0.5): + np.random.seed(0) + cb_colomap = np.array([list(np.random.randint(0, 255, size=(1,3))[0]) for _ in range(400)]+[[0, 0, 0]]) + pseudo_plain = pseudo.long().cpu().numpy() + pseudo_plain[pseudo_plain==255] = 400 + pseudo_plain = cb_colomap[pseudo_plain].squeeze() + + img = transforms.UnNormalize()(img)*255 + img = img.permute(1, 2, 0).long().cpu().numpy() + out = alpha*img + (1-alpha)*pseudo_plain + + return np.array(out, dtype=np.uint8) + diff --git a/requirements.txt b/requirements.txt new file mode 100755 index 0000000000000000000000000000000000000000..5317a1114a6aaf87a749b8bb0aabfe42e43286ca --- /dev/null +++ b/requirements.txt @@ -0,0 +1,7 @@ +python==3.10 +pytorch +torchvision +gradio +cityscapesScripts +pydensecrf +scipy \ No newline at end of file