Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	
		Jonas Rheiner
		
	commited on
		
		
					Commit 
							
							·
						
						4b77aea
	
1
								Parent(s):
							
							710c658
								
update
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- app.py +206 -64
- dataset_examples/africa/3962011747224020_1024.jpg +0 -0
- dataset_examples/africa/3973126792769679_1024.jpg +0 -0
- dataset_examples/africa/4471109009586514_1024.jpg +0 -0
- dataset_examples/asia/106261888221766_1024.jpg +0 -0
- dataset_examples/asia/138321512570044_1024.jpg +0 -0
- dataset_examples/asia/147206360658971_1024.jpg +0 -0
- dataset_examples/europe/1423684677989158_1024.jpg +0 -0
- dataset_examples/europe/1483506425323136_1024.jpg +0 -0
- dataset_examples/europe/150428493699453_1024.jpg +0 -0
- dataset_examples/north america/1000276020376482_1024.jpg +0 -0
- dataset_examples/north america/757125001639938_1024.jpg +0 -0
- dataset_examples/north america/843371313196684_1024.jpg +0 -0
- dataset_examples/oceania/100141636075718_1024.jpg +0 -0
- dataset_examples/oceania/1899604010244250_1024.jpg +0 -0
- dataset_examples/oceania/821397982117703_1024.jpg +0 -0
- dataset_examples/south america/103386512798674_1024.jpg +0 -0
- dataset_examples/south america/1677652415776082_1024.jpg +0 -0
- dataset_examples/south america/327973242016483_1024.jpg +0 -0
- examples/1000276020376482_1024.jpg +0 -0
- examples/100141636075718_1024.jpg +0 -0
- examples/103386512798674_1024.jpg +0 -0
- examples/106261888221766_1024.jpg +0 -0
- examples/138321512570044_1024.jpg +0 -0
- examples/1423684677989158_1024.jpg +0 -0
- examples/147206360658971_1024.jpg +0 -0
- examples/1483506425323136_1024.jpg +0 -0
- examples/150428493699453_1024.jpg +0 -0
- examples/1677652415776082_1024.jpg +0 -0
- examples/1899604010244250_1024.jpg +0 -0
- examples/327973242016483_1024.jpg +0 -0
- examples/3962011747224020_1024.jpg +0 -0
- examples/3973126792769679_1024.jpg +0 -0
- examples/4471109009586514_1024.jpg +0 -0
- examples/757125001639938_1024.jpg +0 -0
- examples/821397982117703_1024.jpg +0 -0
- examples/843371313196684_1024.jpg +0 -0
- kerger-test-images/Africa_Botswana_-24.358520377382_23.5184910801.jpg +0 -0
- kerger-test-images/Africa_Kenya_-0.21870999999999_37.023791.jpg +0 -0
- kerger-test-images/Africa_Madagascar_-16.078452454738_46.73369803641.jpg +0 -0
- kerger-test-images/Africa_South Africa_-23.590135077274_28.785944164821.jpg +0 -0
- kerger-test-images/Africa_Tanzania_-3.3676537025657_36.716512872377.jpg +0 -0
- kerger-test-images/Africa_Uganda_1.1212866787272_33.915204986261.jpg +0 -0
- kerger-test-images/Asia_Israel_31.708865303742_34.94966916063.jpg +0 -0
- kerger-test-images/Asia_Japan_35.381304970616_134.65860211972.jpg +0 -0
- kerger-test-images/Asia_Pakistan_24.910493840503_69.506229024537.jpg +0 -0
- kerger-test-images/Asia_Russia_54.597757883015_48.163689656865.jpg +0 -0
- kerger-test-images/Asia_Russia_56.018311493214_38.359778952407.jpg +0 -0
- kerger-test-images/Asia_Russia_60.27835356798_29.754665851696.jpg +0 -0
- kerger-test-images/Asia_Thailand_19.824843951089_99.694080339609.jpg +0 -0
    	
        app.py
    CHANGED
    
    | @@ -4,92 +4,208 @@ from transformers import CLIPProcessor, CLIPModel | |
| 4 | 
             
            import torch
         | 
| 5 | 
             
            import itertools
         | 
| 6 | 
             
            import os
         | 
|  | |
| 7 |  | 
| 8 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 9 | 
             
            processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14-336")
         | 
|  | |
|  | |
|  | |
| 10 |  | 
| 11 | 
            -
            continents = ["Africa", "Asia", "Europe", | 
|  | |
| 12 | 
             
            countries_per_continent = {
         | 
| 13 | 
             
                "Africa": [
         | 
| 14 | 
            -
                    "Algeria", "Angola", "Benin", "Botswana", "Burkina Faso", "Burundi", "Cabo Verde", "Cameroon", | 
| 15 | 
            -
                    "Central African Republic", " | 
| 16 | 
            -
                    "Djibouti", "Egypt", "Equatorial Guinea", "Eritrea", "Eswatini", "Ethiopia", "Gabon", | 
| 17 | 
            -
                    "Gambia", "Ghana", "Guinea", "Guinea-Bissau", "Ivory Coast", "Kenya", "Lesotho", "Liberia", | 
| 18 | 
            -
                    "Libya", "Madagascar", "Malawi", "Mali", "Mauritania", "Mauritius", "Morocco", "Mozambique", | 
| 19 | 
            -
                    "Namibia", "Niger", "Nigeria", "Rwanda", "Sao Tome and Principe", "Senegal", "Seychelles", | 
| 20 | 
            -
                    "Sierra Leone", "Somalia", "South Africa", " | 
| 21 | 
             
                    "Tunisia", "Uganda", "Zambia", "Zimbabwe"
         | 
| 22 | 
             
                ],
         | 
| 23 | 
             
                "Asia": [
         | 
| 24 | 
            -
                    "Afghanistan", "Armenia", "Azerbaijan", "Bahrain", "Bangladesh", "Bhutan", "Brunei", | 
| 25 | 
            -
                    "Cambodia", "China", "Cyprus", "Georgia", "India", "Indonesia", "Iran", "Iraq", | 
| 26 | 
            -
                    "Israel", "Japan", "Jordan", "Kazakhstan", "Kuwait", "Kyrgyzstan", "Laos", "Lebanon", | 
| 27 | 
            -
                    "Malaysia", "Maldives", "Mongolia", "Myanmar", "Nepal", "North Korea", "Oman", "Pakistan", | 
| 28 | 
            -
                    "Palestine", "Philippines", "Qatar", "Russia", "Saudi Arabia", "Singapore", "South Korea", | 
| 29 | 
            -
                    "Sri Lanka", "Syria", "Taiwan", "Tajikistan", "Thailand", "Timor-Leste", "Turkey", | 
| 30 | 
             
                    "Turkmenistan", "United Arab Emirates", "Uzbekistan", "Vietnam", "Yemen"
         | 
| 31 | 
             
                ],
         | 
| 32 | 
             
                "Europe": [
         | 
| 33 | 
            -
                    "Albania", " | 
| 34 | 
            -
                    "Bulgaria", "Croatia", "Cyprus", "Czech Republic", "Denmark", "Estonia", "Finland", "France", | 
| 35 | 
            -
                    "Georgia", "Germany", "Greece", "Hungary", "Iceland", "Ireland", "Italy", "Kazakhstan", | 
| 36 | 
            -
                    "Kosovo", "Latvia", "Liechtenstein", "Lithuania", "Luxembourg", "Malta", "Moldova", "Monaco", | 
| 37 | 
            -
                    "Montenegro", "Netherlands", "North Macedonia", "Norway", "Poland", "Portugal", "Romania", | 
| 38 | 
            -
                    "Russia", "San Marino", "Serbia", "Slovakia", "Slovenia", "Spain", "Sweden", "Switzerland", | 
| 39 | 
            -
                    "Turkey", "Ukraine", "United Kingdom" | 
| 40 | 
             
                ],
         | 
| 41 | 
             
                "North America": [
         | 
| 42 | 
            -
                    "Antigua and Barbuda", "Bahamas", "Barbados", "Belize", "Canada", "Costa Rica", "Cuba", | 
| 43 | 
            -
                    "Dominica", "Dominican Republic", "El Salvador", "Grenada", "Guatemala", "Haiti", "Honduras", | 
| 44 | 
            -
                    "Jamaica", "Mexico", "Nicaragua", "Panama", "Saint Kitts and Nevis", "Saint Lucia", | 
| 45 | 
             
                    "Saint Vincent and the Grenadines", "Trinidad and Tobago", "United States"
         | 
| 46 | 
             
                ],
         | 
| 47 | 
             
                "Oceania": [
         | 
| 48 | 
            -
                    "Australia", "Fiji", "Kiribati", "Marshall Islands", "Micronesia", "Nauru", "New Zealand", | 
| 49 | 
             
                    "Palau", "Papua New Guinea", "Samoa", "Solomon Islands", "Tonga", "Tuvalu", "Vanuatu"
         | 
| 50 | 
             
                ],
         | 
| 51 | 
             
                "South America": [
         | 
| 52 | 
            -
                    "Argentina", "Bolivia", "Brazil", "Chile", "Colombia", "Ecuador", "Guyana", "Paraguay", | 
| 53 | 
             
                    "Peru", "Suriname", "Uruguay", "Venezuela"
         | 
| 54 | 
             
                ]
         | 
| 55 | 
             
            }
         | 
| 56 | 
            -
            countries = list(set(itertools.chain.from_iterable( | 
|  | |
| 57 |  | 
| 58 | 
            -
             | 
| 59 | 
            -
             | 
| 60 | 
            -
                " | 
| 61 | 
            -
                " | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 62 | 
             
            }
         | 
| 63 |  | 
|  | |
| 64 | 
             
            def predict(input_img):
         | 
| 65 | 
            -
                inputs = processor(text=[f"A photo from { | 
|  | |
|  | |
| 66 | 
             
                with torch.no_grad():
         | 
| 67 | 
            -
                    outputs =  | 
| 68 | 
             
                    logits_per_image = outputs.logits_per_image
         | 
| 69 | 
             
                    probs = logits_per_image.softmax(dim=-1)
         | 
| 70 | 
             
                    pred_id = probs.argmax().cpu().item()
         | 
| 71 | 
            -
                continent_probs = {label: prob for label, | 
| 72 | 
            -
             | 
| 73 | 
            -
             | 
| 74 | 
             
                predicted_continent_countries = countries_per_continent[continents[pred_id]]
         | 
| 75 | 
            -
                inputs = processor(text=[f"A photo from { | 
|  | |
|  | |
| 76 | 
             
                with torch.no_grad():
         | 
| 77 | 
            -
                    outputs =  | 
| 78 | 
             
                    logits_per_image = outputs.logits_per_image
         | 
| 79 | 
             
                    probs = logits_per_image.softmax(dim=-1)
         | 
| 80 | 
            -
                country_probs = {label: prob for label, prob in zip( | 
| 81 | 
            -
             | 
| 82 | 
             
                return continent_probs, country_probs
         | 
| 83 |  | 
| 84 | 
            -
             | 
| 85 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 86 | 
             
                continent_probs, country_probs = predict(input_img)
         | 
| 87 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 88 |  | 
| 89 | 
            -
            def next_versus_image():
         | 
| 90 | 
            -
                # VERSUS_GT["continent"] = "test"
         | 
| 91 | 
            -
                # VERSUS_GT["country"] = "test"
         | 
| 92 | 
            -
                return "versus_images/[email protected]"
         | 
| 93 |  | 
| 94 | 
             
            def get_example_images(dir):
         | 
| 95 | 
             
                image_extensions = (".jpg", ".jpeg", ".png")
         | 
| @@ -100,39 +216,65 @@ def get_example_images(dir): | |
| 100 | 
             
                            image_files.append(os.path.join(root, file))
         | 
| 101 | 
             
                return image_files
         | 
| 102 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 103 | 
             
            demo = gr.Blocks()
         | 
| 104 | 
             
            with demo:
         | 
| 105 | 
             
                with gr.Tab("Image Geolocation Demo"):
         | 
| 106 | 
             
                    with gr.Row():
         | 
| 107 | 
             
                        with gr.Column():
         | 
| 108 | 
            -
                            image = gr.Image(label="Image", type="pil", | 
|  | |
| 109 | 
             
                            predict_btn = gr.Button("Predict")
         | 
| 110 | 
            -
                            example_images = get_example_images(" | 
| 111 | 
            -
                            example_images.extend(get_example_images("versus_images"))
         | 
| 112 | 
            -
                            gr.Examples(examples=example_images, | 
|  | |
| 113 | 
             
                        with gr.Column():
         | 
| 114 | 
             
                            continents_label = gr.Label(label="Continents")
         | 
| 115 | 
            -
                            country_label = gr.Label( | 
|  | |
| 116 | 
             
                            # continents_label.select(predict_country, inputs=[image, continents_label], outputs=country_label)
         | 
| 117 | 
            -
                predict_btn.click(predict, inputs=image, outputs=[ | 
|  | |
| 118 |  | 
| 119 | 
             
                with gr.Tab("Versus Mode"):
         | 
|  | |
| 120 | 
             
                    with gr.Row():
         | 
| 121 | 
             
                        with gr.Column():
         | 
| 122 | 
            -
                            versus_image = gr.Image( | 
| 123 | 
            -
             | 
| 124 | 
            -
                             | 
| 125 | 
            -
             | 
|  | |
|  | |
| 126 | 
             
                            with gr.Row():
         | 
| 127 | 
             
                                next_img_btn = gr.Button("Try new image")
         | 
| 128 | 
             
                                versus_btn = gr.Button("Submit guess")
         | 
| 129 | 
             
                        with gr.Column():
         | 
| 130 | 
            -
                            versus_output = gr. | 
| 131 | 
            -
                             | 
| 132 | 
            -
                             | 
| 133 | 
            -
             | 
| 134 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 135 |  | 
| 136 |  | 
| 137 | 
             
            if __name__ == "__main__":
         | 
| 138 | 
            -
                demo.launch()
         | 
|  | |
| 4 | 
             
            import torch
         | 
| 5 | 
             
            import itertools
         | 
| 6 | 
             
            import os
         | 
| 7 | 
            +
            import plotly.graph_objects as go
         | 
| 8 |  | 
| 9 | 
            +
             | 
| 10 | 
            +
            CUDA_AVAILABLE = torch.cuda.is_available()
         | 
| 11 | 
            +
            print(f"CUDA={CUDA_AVAILABLE}")
         | 
| 12 | 
            +
            device = "cuda" if CUDA_AVAILABLE else "cpu"
         | 
| 13 | 
            +
            print(f"count={torch.cuda.device_count()}")
         | 
| 14 | 
            +
            print(f"current={torch.cuda.get_device_name(torch.cuda.current_device())}")
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            continent_model = CLIPModel.from_pretrained("model-checkpoints/continent")
         | 
| 17 | 
            +
            country_model = CLIPModel.from_pretrained("model-checkpoints/country")
         | 
| 18 | 
             
            processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14-336")
         | 
| 19 | 
            +
            continent_model = continent_model.to(device)
         | 
| 20 | 
            +
            country_model = country_model.to(device)
         | 
| 21 | 
            +
             | 
| 22 |  | 
| 23 | 
            +
            continents = ["Africa", "Asia", "Europe",
         | 
| 24 | 
            +
                          "North America", "Oceania", "South America"]
         | 
| 25 | 
             
            countries_per_continent = {
         | 
| 26 | 
             
                "Africa": [
         | 
| 27 | 
            +
                    "Algeria", "Angola", "Benin", "Botswana", "Burkina Faso", "Burundi", "Cabo Verde", "Cameroon",
         | 
| 28 | 
            +
                    "Central African Republic", "Congo", "Democratic Republic of the Congo",
         | 
| 29 | 
            +
                    "Djibouti", "Egypt", "Equatorial Guinea", "Eritrea", "Eswatini", "Ethiopia", "Gabon",
         | 
| 30 | 
            +
                    "Gambia", "Ghana", "Guinea", "Guinea-Bissau", "Ivory Coast", "Kenya", "Lesotho", "Liberia",
         | 
| 31 | 
            +
                    "Libya", "Madagascar", "Malawi", "Mali", "Mauritania", "Mauritius", "Morocco", "Mozambique",
         | 
| 32 | 
            +
                    "Namibia", "Niger", "Nigeria", "Rwanda", "Sao Tome and Principe", "Senegal", "Seychelles",
         | 
| 33 | 
            +
                    "Sierra Leone", "Somalia", "South Africa", "Sudan", "Tanzania", "Togo",
         | 
| 34 | 
             
                    "Tunisia", "Uganda", "Zambia", "Zimbabwe"
         | 
| 35 | 
             
                ],
         | 
| 36 | 
             
                "Asia": [
         | 
| 37 | 
            +
                    "Afghanistan", "Armenia", "Azerbaijan", "Bahrain", "Bangladesh", "Bhutan", "Brunei",
         | 
| 38 | 
            +
                    "Cambodia", "China", "Cyprus", "Georgia", "India", "Indonesia", "Iran", "Iraq",
         | 
| 39 | 
            +
                    "Israel", "Japan", "Jordan", "Kazakhstan", "Kuwait", "Kyrgyzstan", "Laos", "Lebanon",
         | 
| 40 | 
            +
                    "Malaysia", "Maldives", "Mongolia", "Myanmar", "Nepal", "North Korea", "Oman", "Pakistan",
         | 
| 41 | 
            +
                    "Palestine", "Philippines", "Qatar", "Russia", "Saudi Arabia", "Singapore", "South Korea",
         | 
| 42 | 
            +
                    "Sri Lanka", "Syria", "Taiwan", "Tajikistan", "Thailand", "Timor-Leste", "Turkey",
         | 
| 43 | 
             
                    "Turkmenistan", "United Arab Emirates", "Uzbekistan", "Vietnam", "Yemen"
         | 
| 44 | 
             
                ],
         | 
| 45 | 
             
                "Europe": [
         | 
| 46 | 
            +
                    "Albania", "Armenia", "Austria", "Azerbaijan", "Belarus", "Belgium", "Bosnia and Herzegovina",
         | 
| 47 | 
            +
                    "Bulgaria", "Croatia", "Cyprus", "Czech Republic", "Denmark", "Estonia", "Finland", "France",
         | 
| 48 | 
            +
                    "Georgia", "Germany", "Greece", "Hungary", "Iceland", "Ireland", "Italy", "Kazakhstan",
         | 
| 49 | 
            +
                    "Kosovo", "Latvia", "Liechtenstein", "Lithuania", "Luxembourg", "Malta", "Moldova", "Monaco",
         | 
| 50 | 
            +
                    "Montenegro", "Netherlands", "North Macedonia", "Norway", "Poland", "Portugal", "Romania",
         | 
| 51 | 
            +
                    "Russia", "San Marino", "Serbia", "Slovakia", "Slovenia", "Spain", "Sweden", "Switzerland",
         | 
| 52 | 
            +
                    "Turkey", "Ukraine", "United Kingdom"
         | 
| 53 | 
             
                ],
         | 
| 54 | 
             
                "North America": [
         | 
| 55 | 
            +
                    "Antigua and Barbuda", "Bahamas", "Barbados", "Belize", "Canada", "Costa Rica", "Cuba",
         | 
| 56 | 
            +
                    "Dominica", "Dominican Republic", "El Salvador", "Grenada", "Guatemala", "Haiti", "Honduras",
         | 
| 57 | 
            +
                    "Jamaica", "Mexico", "Nicaragua", "Panama", "Saint Kitts and Nevis", "Saint Lucia",
         | 
| 58 | 
             
                    "Saint Vincent and the Grenadines", "Trinidad and Tobago", "United States"
         | 
| 59 | 
             
                ],
         | 
| 60 | 
             
                "Oceania": [
         | 
| 61 | 
            +
                    "Australia", "Fiji", "Kiribati", "Marshall Islands", "Micronesia", "Nauru", "New Zealand",
         | 
| 62 | 
             
                    "Palau", "Papua New Guinea", "Samoa", "Solomon Islands", "Tonga", "Tuvalu", "Vanuatu"
         | 
| 63 | 
             
                ],
         | 
| 64 | 
             
                "South America": [
         | 
| 65 | 
            +
                    "Argentina", "Bolivia", "Brazil", "Chile", "Colombia", "Ecuador", "Guyana", "Paraguay",
         | 
| 66 | 
             
                    "Peru", "Suriname", "Uruguay", "Venezuela"
         | 
| 67 | 
             
                ]
         | 
| 68 | 
             
            }
         | 
| 69 | 
            +
            countries = list(set(itertools.chain.from_iterable(
         | 
| 70 | 
            +
                countries_per_continent.values())))
         | 
| 71 |  | 
| 72 | 
            +
            INTIAL_VERSUS_IMAGE = "versus_images/Europe_Germany_49.069183_10.319444_im2gps3k.jpg"
         | 
| 73 | 
            +
            INITAL_VERSUS_STATE = {
         | 
| 74 | 
            +
                "image": INTIAL_VERSUS_IMAGE,
         | 
| 75 | 
            +
                "continent": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[0],
         | 
| 76 | 
            +
                "country": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[1],
         | 
| 77 | 
            +
                "lat": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[2],
         | 
| 78 | 
            +
                "lon": INTIAL_VERSUS_IMAGE.split("/")[-1].split("_")[3],
         | 
| 79 | 
            +
                "score": {
         | 
| 80 | 
            +
                    "HUMAN": 0,
         | 
| 81 | 
            +
                    "AI": 0
         | 
| 82 | 
            +
                },
         | 
| 83 | 
            +
                "idx": 0
         | 
| 84 | 
             
            }
         | 
| 85 |  | 
| 86 | 
            +
             | 
| 87 | 
             
            def predict(input_img):
         | 
| 88 | 
            +
                inputs = processor(text=[f"A photo from {
         | 
| 89 | 
            +
                                   geo}." for geo in continents], images=input_img, return_tensors="pt", padding=True)
         | 
| 90 | 
            +
                inputs = inputs.to(device)
         | 
| 91 | 
             
                with torch.no_grad():
         | 
| 92 | 
            +
                    outputs = continent_model(**inputs)
         | 
| 93 | 
             
                    logits_per_image = outputs.logits_per_image
         | 
| 94 | 
             
                    probs = logits_per_image.softmax(dim=-1)
         | 
| 95 | 
             
                    pred_id = probs.argmax().cpu().item()
         | 
| 96 | 
            +
                continent_probs = {label: prob for label,
         | 
| 97 | 
            +
                                   prob in zip(continents, probs.tolist()[0])}
         | 
| 98 | 
            +
             | 
| 99 | 
             
                predicted_continent_countries = countries_per_continent[continents[pred_id]]
         | 
| 100 | 
            +
                inputs = processor(text=[f"A photo from {
         | 
| 101 | 
            +
                                   geo}." for geo in predicted_continent_countries], images=input_img, return_tensors="pt", padding=True)
         | 
| 102 | 
            +
                inputs = inputs.to(device)
         | 
| 103 | 
             
                with torch.no_grad():
         | 
| 104 | 
            +
                    outputs = country_model(**inputs)
         | 
| 105 | 
             
                    logits_per_image = outputs.logits_per_image
         | 
| 106 | 
             
                    probs = logits_per_image.softmax(dim=-1)
         | 
| 107 | 
            +
                country_probs = {label: prob for label, prob in zip(
         | 
| 108 | 
            +
                    predicted_continent_countries, probs.tolist()[0])}
         | 
| 109 | 
             
                return continent_probs, country_probs
         | 
| 110 |  | 
| 111 | 
            +
             | 
| 112 | 
            +
            def make_versus_map(human_country, model_country, versus_state):
         | 
| 113 | 
            +
                fig = go.Figure()
         | 
| 114 | 
            +
                fig.add_trace(go.Scattergeo(
         | 
| 115 | 
            +
                    lon=[versus_state["lon"]],
         | 
| 116 | 
            +
                    lat=[versus_state["lat"]],
         | 
| 117 | 
            +
                    text=["📷"],
         | 
| 118 | 
            +
                    mode='text+markers',
         | 
| 119 | 
            +
                    hoverinfo='text',
         | 
| 120 | 
            +
                    hovertext=f"Photo taken in {versus_state['country']}, {
         | 
| 121 | 
            +
                        versus_state['continent']}",
         | 
| 122 | 
            +
                    marker=dict(size=14, color='#00B945'),
         | 
| 123 | 
            +
                    showlegend=False
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                ))
         | 
| 126 | 
            +
                if human_country == model_country:
         | 
| 127 | 
            +
                    fig.add_trace(go.Scattergeo(
         | 
| 128 | 
            +
                        locations=[human_country],
         | 
| 129 | 
            +
                        locationmode='country names',
         | 
| 130 | 
            +
                        text=["🧑🤖"],
         | 
| 131 | 
            +
                        mode='text',
         | 
| 132 | 
            +
                        hoverinfo='location',
         | 
| 133 | 
            +
                        showlegend=False
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    ))
         | 
| 136 | 
            +
                else:
         | 
| 137 | 
            +
                    fig.add_trace(go.Scattergeo(
         | 
| 138 | 
            +
                        locations=[human_country],
         | 
| 139 | 
            +
                        locationmode='country names',
         | 
| 140 | 
            +
                        text=["🧑"],
         | 
| 141 | 
            +
                        mode='text',
         | 
| 142 | 
            +
                        hoverinfo='location',
         | 
| 143 | 
            +
                        showlegend=False
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    ))
         | 
| 146 | 
            +
                    fig.add_trace(go.Scattergeo(
         | 
| 147 | 
            +
                        locations=[model_country],
         | 
| 148 | 
            +
                        locationmode='country names',
         | 
| 149 | 
            +
                        text=["🤖"],
         | 
| 150 | 
            +
                        mode='text',
         | 
| 151 | 
            +
                        hoverinfo='location',
         | 
| 152 | 
            +
                        showlegend=False
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                    ))
         | 
| 155 | 
            +
                fig.update_geos(
         | 
| 156 | 
            +
                    visible=True, resolution=110,
         | 
| 157 | 
            +
                    showcountries=True, countrycolor="grey", fitbounds="locations", projection_type="natural earth",
         | 
| 158 | 
            +
                )
         | 
| 159 | 
            +
                return fig
         | 
| 160 | 
            +
             | 
| 161 | 
            +
             | 
| 162 | 
            +
            def versus_mode_inputs(input_img, human_continent, human_country, versus_state):
         | 
| 163 | 
            +
                human_points = 0
         | 
| 164 | 
            +
                model_points = 0
         | 
| 165 | 
            +
                if human_country == versus_state["country"]:
         | 
| 166 | 
            +
                    country_result = "✅"
         | 
| 167 | 
            +
                    human_points += 2
         | 
| 168 | 
            +
                else:
         | 
| 169 | 
            +
                    country_result = "❌"
         | 
| 170 | 
            +
                if human_continent == versus_state["continent"]:
         | 
| 171 | 
            +
                    continent_result = "✅"
         | 
| 172 | 
            +
                    human_points += 1
         | 
| 173 | 
            +
                else:
         | 
| 174 | 
            +
                    continent_result = "❌"
         | 
| 175 | 
            +
                human_result = f"The photo is from **{versus_state['country']}** {
         | 
| 176 | 
            +
                    country_result} in **{versus_state['continent']}** {continent_result}"
         | 
| 177 | 
            +
                human_score_update = f"+{
         | 
| 178 | 
            +
                    human_points} points" if human_points > 0 else "Wrong guess, try a new image."
         | 
| 179 | 
            +
                versus_state['score']['HUMAN'] += human_points
         | 
| 180 | 
            +
             | 
| 181 | 
             
                continent_probs, country_probs = predict(input_img)
         | 
| 182 | 
            +
                model_country = max(country_probs, key=country_probs.get)
         | 
| 183 | 
            +
                model_continent = max(continent_probs, key=continent_probs.get)
         | 
| 184 | 
            +
                if model_country == versus_state["country"]:
         | 
| 185 | 
            +
                    model_country_result = "✅"
         | 
| 186 | 
            +
                    model_points += 2
         | 
| 187 | 
            +
                else:
         | 
| 188 | 
            +
                    model_country_result = "❌"
         | 
| 189 | 
            +
                if model_continent == versus_state["continent"]:
         | 
| 190 | 
            +
                    model_continent_result = "✅"
         | 
| 191 | 
            +
                    model_points += 1
         | 
| 192 | 
            +
                else:
         | 
| 193 | 
            +
                    model_continent_result = "❌"
         | 
| 194 | 
            +
                model_score_update = f"+{
         | 
| 195 | 
            +
                    model_points} points" if model_points > 0 else "The model was wrong, seems the world is not yet doomed."
         | 
| 196 | 
            +
                versus_state['score']['AI'] += model_points
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                map = make_versus_map(human_country, model_country, versus_state)
         | 
| 199 | 
            +
                return f"""
         | 
| 200 | 
            +
            ## {human_result}
         | 
| 201 | 
            +
            ### The AI 🤖 thinks this photo is from **{model_country}** {model_country_result} in **{model_continent}** {model_continent_result}
         | 
| 202 | 
            +
             | 
| 203 | 
            +
            🧑 {human_score_update}  
         | 
| 204 | 
            +
            🤖 {model_score_update}
         | 
| 205 | 
            +
             | 
| 206 | 
            +
            ### Score     🧑 {versus_state['score']['HUMAN']} : {versus_state['score']['AI']} 🤖
         | 
| 207 | 
            +
            """, continent_probs, country_probs, map, versus_state
         | 
| 208 |  | 
|  | |
|  | |
|  | |
|  | |
| 209 |  | 
| 210 | 
             
            def get_example_images(dir):
         | 
| 211 | 
             
                image_extensions = (".jpg", ".jpeg", ".png")
         | 
|  | |
| 216 | 
             
                            image_files.append(os.path.join(root, file))
         | 
| 217 | 
             
                return image_files
         | 
| 218 |  | 
| 219 | 
            +
             | 
| 220 | 
            +
            def next_versus_image(versus_state):
         | 
| 221 | 
            +
                images = get_example_images("versus_images")
         | 
| 222 | 
            +
                versus_state["idx"] += 1
         | 
| 223 | 
            +
                if versus_state["idx"] > len(images):
         | 
| 224 | 
            +
                    versus_state["idx"] = 0
         | 
| 225 | 
            +
                versus_image = images[versus_state["idx"]]
         | 
| 226 | 
            +
                versus_state["continent"] = versus_image.split("/")[-1].split("_")[0]
         | 
| 227 | 
            +
                versus_state["country"] = versus_image.split("/")[-1].split("_")[1]
         | 
| 228 | 
            +
                versus_state["lat"] = versus_image.split("/")[-1].split("_")[2]
         | 
| 229 | 
            +
                versus_state["lon"] = versus_image.split("/")[-1].split("_")[3]
         | 
| 230 | 
            +
                versus_state["image"] = versus_image
         | 
| 231 | 
            +
                return versus_image, versus_state, None, None
         | 
| 232 | 
            +
             | 
| 233 | 
             
            demo = gr.Blocks()
         | 
| 234 | 
             
            with demo:
         | 
| 235 | 
             
                with gr.Tab("Image Geolocation Demo"):
         | 
| 236 | 
             
                    with gr.Row():
         | 
| 237 | 
             
                        with gr.Column():
         | 
| 238 | 
            +
                            image = gr.Image(label="Image", type="pil",
         | 
| 239 | 
            +
                                             sources=["upload", "clipboard"])
         | 
| 240 | 
             
                            predict_btn = gr.Button("Predict")
         | 
| 241 | 
            +
                            example_images = get_example_images("kerger-test-images")
         | 
| 242 | 
            +
                            # example_images.extend(get_example_images("versus_images"))
         | 
| 243 | 
            +
                            gr.Examples(examples=example_images,
         | 
| 244 | 
            +
                                        inputs=image, examples_per_page=24)
         | 
| 245 | 
             
                        with gr.Column():
         | 
| 246 | 
             
                            continents_label = gr.Label(label="Continents")
         | 
| 247 | 
            +
                            country_label = gr.Label(
         | 
| 248 | 
            +
                                num_top_classes=5, label="Top countries")
         | 
| 249 | 
             
                            # continents_label.select(predict_country, inputs=[image, continents_label], outputs=country_label)
         | 
| 250 | 
            +
                predict_btn.click(predict, inputs=image, outputs=[
         | 
| 251 | 
            +
                                  continents_label, country_label])
         | 
| 252 |  | 
| 253 | 
             
                with gr.Tab("Versus Mode"):
         | 
| 254 | 
            +
                    versus_state = gr.State(value=INITAL_VERSUS_STATE)
         | 
| 255 | 
             
                    with gr.Row():
         | 
| 256 | 
             
                        with gr.Column():
         | 
| 257 | 
            +
                            versus_image = gr.Image(
         | 
| 258 | 
            +
                                INITAL_VERSUS_STATE["image"], interactive=False)
         | 
| 259 | 
            +
                            continent_selection = gr.Radio(
         | 
| 260 | 
            +
                                continents, label="Continents", info="Where was this image taken? (1 Point)")
         | 
| 261 | 
            +
                            country_selection = gr.Dropdown(countries, label="Countries", info="Can you guess the exact country? (2 Points)"
         | 
| 262 | 
            +
                                                            ),
         | 
| 263 | 
             
                            with gr.Row():
         | 
| 264 | 
             
                                next_img_btn = gr.Button("Try new image")
         | 
| 265 | 
             
                                versus_btn = gr.Button("Submit guess")
         | 
| 266 | 
             
                        with gr.Column():
         | 
| 267 | 
            +
                            versus_output = gr.Markdown()
         | 
| 268 | 
            +
                            # with gr.Accordion("View Map", open=False):
         | 
| 269 | 
            +
                            map = gr.Plot(label="Locations")
         | 
| 270 | 
            +
                            with gr.Accordion("Full Model Output", open=False):
         | 
| 271 | 
            +
                                continents_label = gr.Label(label="Continents")
         | 
| 272 | 
            +
                                country_label = gr.Label(
         | 
| 273 | 
            +
                                    num_top_classes=5, label="Top countries")
         | 
| 274 | 
            +
                    next_img_btn.click(next_versus_image, inputs=[versus_state], outputs=[versus_image, versus_state, continent_selection, country_selection[0]])
         | 
| 275 | 
            +
                    versus_btn.click(versus_mode_inputs, inputs=[versus_image, continent_selection, country_selection[0], versus_state], outputs=[
         | 
| 276 | 
            +
                                     versus_output, continents_label, country_label, map, versus_state])
         | 
| 277 |  | 
| 278 |  | 
| 279 | 
             
            if __name__ == "__main__":
         | 
| 280 | 
            +
                demo.launch()
         | 
    	
        dataset_examples/africa/3962011747224020_1024.jpg
    DELETED
    
    | Binary file (145 kB) | 
|  | 
    	
        dataset_examples/africa/3973126792769679_1024.jpg
    DELETED
    
    | Binary file (69.5 kB) | 
|  | 
    	
        dataset_examples/africa/4471109009586514_1024.jpg
    DELETED
    
    | Binary file (147 kB) | 
|  | 
    	
        dataset_examples/asia/106261888221766_1024.jpg
    DELETED
    
    | Binary file (131 kB) | 
|  | 
    	
        dataset_examples/asia/138321512570044_1024.jpg
    DELETED
    
    | Binary file (78.1 kB) | 
|  | 
    	
        dataset_examples/asia/147206360658971_1024.jpg
    DELETED
    
    | Binary file (96.2 kB) | 
|  | 
    	
        dataset_examples/europe/1423684677989158_1024.jpg
    DELETED
    
    | Binary file (84.3 kB) | 
|  | 
    	
        dataset_examples/europe/1483506425323136_1024.jpg
    DELETED
    
    | Binary file (46.8 kB) | 
|  | 
    	
        dataset_examples/europe/150428493699453_1024.jpg
    DELETED
    
    | Binary file (140 kB) | 
|  | 
    	
        dataset_examples/north america/1000276020376482_1024.jpg
    DELETED
    
    | Binary file (66 kB) | 
|  | 
    	
        dataset_examples/north america/757125001639938_1024.jpg
    DELETED
    
    | Binary file (162 kB) | 
|  | 
    	
        dataset_examples/north america/843371313196684_1024.jpg
    DELETED
    
    | Binary file (257 kB) | 
|  | 
    	
        dataset_examples/oceania/100141636075718_1024.jpg
    DELETED
    
    | Binary file (84.2 kB) | 
|  | 
    	
        dataset_examples/oceania/1899604010244250_1024.jpg
    DELETED
    
    | Binary file (46.2 kB) | 
|  | 
    	
        dataset_examples/oceania/821397982117703_1024.jpg
    DELETED
    
    | Binary file (75.7 kB) | 
|  | 
    	
        dataset_examples/south america/103386512798674_1024.jpg
    DELETED
    
    | Binary file (135 kB) | 
|  | 
    	
        dataset_examples/south america/1677652415776082_1024.jpg
    DELETED
    
    | Binary file (94.5 kB) | 
|  | 
    	
        dataset_examples/south america/327973242016483_1024.jpg
    DELETED
    
    | Binary file (92.1 kB) | 
|  | 
    	
        examples/1000276020376482_1024.jpg
    DELETED
    
    | Binary file (66 kB) | 
|  | 
    	
        examples/100141636075718_1024.jpg
    DELETED
    
    | Binary file (84.2 kB) | 
|  | 
    	
        examples/103386512798674_1024.jpg
    DELETED
    
    | Binary file (135 kB) | 
|  | 
    	
        examples/106261888221766_1024.jpg
    DELETED
    
    | Binary file (131 kB) | 
|  | 
    	
        examples/138321512570044_1024.jpg
    DELETED
    
    | Binary file (78.1 kB) | 
|  | 
    	
        examples/1423684677989158_1024.jpg
    DELETED
    
    | Binary file (84.3 kB) | 
|  | 
    	
        examples/147206360658971_1024.jpg
    DELETED
    
    | Binary file (96.2 kB) | 
|  | 
    	
        examples/1483506425323136_1024.jpg
    DELETED
    
    | Binary file (46.8 kB) | 
|  | 
    	
        examples/150428493699453_1024.jpg
    DELETED
    
    | Binary file (140 kB) | 
|  | 
    	
        examples/1677652415776082_1024.jpg
    DELETED
    
    | Binary file (94.5 kB) | 
|  | 
    	
        examples/1899604010244250_1024.jpg
    DELETED
    
    | Binary file (46.2 kB) | 
|  | 
    	
        examples/327973242016483_1024.jpg
    DELETED
    
    | Binary file (92.1 kB) | 
|  | 
    	
        examples/3962011747224020_1024.jpg
    DELETED
    
    | Binary file (145 kB) | 
|  | 
    	
        examples/3973126792769679_1024.jpg
    DELETED
    
    | Binary file (69.5 kB) | 
|  | 
    	
        examples/4471109009586514_1024.jpg
    DELETED
    
    | Binary file (147 kB) | 
|  | 
    	
        examples/757125001639938_1024.jpg
    DELETED
    
    | Binary file (162 kB) | 
|  | 
    	
        examples/821397982117703_1024.jpg
    DELETED
    
    | Binary file (75.7 kB) | 
|  | 
    	
        examples/843371313196684_1024.jpg
    DELETED
    
    | Binary file (257 kB) | 
|  | 
    	
        kerger-test-images/Africa_Botswana_-24.358520377382_23.5184910801.jpg
    ADDED
    
    |   | 
    	
        kerger-test-images/Africa_Kenya_-0.21870999999999_37.023791.jpg
    ADDED
    
    |   | 
    	
        kerger-test-images/Africa_Madagascar_-16.078452454738_46.73369803641.jpg
    ADDED
    
    |   | 
    	
        kerger-test-images/Africa_South Africa_-23.590135077274_28.785944164821.jpg
    ADDED
    
    |   | 
    	
        kerger-test-images/Africa_Tanzania_-3.3676537025657_36.716512872377.jpg
    ADDED
    
    |   | 
    	
        kerger-test-images/Africa_Uganda_1.1212866787272_33.915204986261.jpg
    ADDED
    
    |   | 
    	
        kerger-test-images/Asia_Israel_31.708865303742_34.94966916063.jpg
    ADDED
    
    |   | 
    	
        kerger-test-images/Asia_Japan_35.381304970616_134.65860211972.jpg
    ADDED
    
    |   | 
    	
        kerger-test-images/Asia_Pakistan_24.910493840503_69.506229024537.jpg
    ADDED
    
    |   | 
    	
        kerger-test-images/Asia_Russia_54.597757883015_48.163689656865.jpg
    ADDED
    
    |   | 
    	
        kerger-test-images/Asia_Russia_56.018311493214_38.359778952407.jpg
    ADDED
    
    |   | 
    	
        kerger-test-images/Asia_Russia_60.27835356798_29.754665851696.jpg
    ADDED
    
    |   | 
    	
        kerger-test-images/Asia_Thailand_19.824843951089_99.694080339609.jpg
    ADDED
    
    |   |