jlincar commited on
Commit
8eba60b
ยท
1 Parent(s): 5f35f7d
Files changed (1) hide show
  1. app.py +226 -0
app.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
3
+ from datasets import load_dataset
4
+ import random
5
+ import torch
6
+
7
+ # Load model from Hugging Face
8
+ model_name = "Jordiett/convnextv2-geoguessr"
9
+ processor = AutoImageProcessor.from_pretrained(model_name)
10
+ model = AutoModelForImageClassification.from_pretrained(model_name)
11
+
12
+ # Load dataset
13
+ print("Loading GeoGuessr dataset...")
14
+ dataset = load_dataset("marcelomoreno26/geoguessr", split="test")
15
+ print(f"Loaded {len(dataset)} test images")
16
+
17
+ # List of countries
18
+ countries = list(model.config.id2label.values())
19
+
20
+ # Game state
21
+ class GameState:
22
+ def __init__(self):
23
+ self.player_score = 0
24
+ self.ai_score = 0
25
+ self.rounds = 0
26
+ self.current_image = None
27
+ self.correct_country = None
28
+ self.ai_prediction = None
29
+ self.options = []
30
+ self.used_indices = []
31
+
32
+ game = GameState()
33
+
34
+ def get_ai_prediction(image):
35
+ """Get AI prediction"""
36
+ inputs = processor(images=image, return_tensors="pt")
37
+ with torch.no_grad():
38
+ outputs = model(**inputs)
39
+ logits = outputs.logits
40
+ predicted_id = logits.argmax(-1).item()
41
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)[0]
42
+
43
+ # Top 3 predictions
44
+ top3_prob, top3_idx = torch.topk(probabilities, 3)
45
+ top3_countries = [(model.config.id2label[idx.item()], prob.item())
46
+ for idx, prob in zip(top3_idx, top3_prob)]
47
+
48
+ return model.config.id2label[predicted_id], top3_countries
49
+
50
+ def generate_options(correct_country, ai_prediction):
51
+ """Generate 4 options (1 correct + 3 incorrect)"""
52
+ options = [correct_country]
53
+
54
+ # Add AI prediction if it's wrong (makes it more interesting)
55
+ if ai_prediction != correct_country and ai_prediction not in options:
56
+ options.append(ai_prediction)
57
+
58
+ # Fill remaining slots with random countries
59
+ other_countries = [c for c in countries if c not in options]
60
+ needed = 4 - len(options)
61
+ options.extend(random.sample(other_countries, needed))
62
+
63
+ random.shuffle(options)
64
+ return options
65
+
66
+ def new_round():
67
+ """Start a new round with a random image from the dataset"""
68
+ # Select a random image that hasn't been used yet
69
+ available_indices = [i for i in range(len(dataset)) if i not in game.used_indices]
70
+
71
+ if len(available_indices) == 0:
72
+ # Reset if all images have been used
73
+ game.used_indices = []
74
+ available_indices = list(range(len(dataset)))
75
+
76
+ idx = random.choice(available_indices)
77
+ game.used_indices.append(idx)
78
+
79
+ # Get image and label from dataset
80
+ sample = dataset[idx]
81
+ image = sample["image"]
82
+ game.correct_country = sample["label"]
83
+ game.current_image = image
84
+
85
+ # Get AI prediction
86
+ ai_pred, top3 = get_ai_prediction(image)
87
+ game.ai_prediction = ai_pred
88
+
89
+ # Generate options
90
+ game.options = generate_options(game.correct_country, ai_pred)
91
+
92
+ # AI confidence message
93
+ ai_confidence = f"The AI predicts: **{ai_pred}**\n\n"
94
+ ai_confidence += "Top 3 predictions:\n"
95
+ for country, prob in top3:
96
+ ai_confidence += f"- {country}: {prob*100:.1f}%\n"
97
+
98
+ return (
99
+ image,
100
+ "๐ŸŒ **Where do you think this image is from?**\n\n" + ai_confidence,
101
+ gr.update(choices=game.options, value=None, visible=True),
102
+ gr.update(visible=True),
103
+ f"๐ŸŽฎ Player: {game.player_score} | ๐Ÿค– AI: {game.ai_score} | ๐ŸŽฏ Rounds: {game.rounds}"
104
+ )
105
+
106
+ def check_answer(player_choice):
107
+ """Check player's answer"""
108
+ if player_choice is None:
109
+ return "โš ๏ธ Please select an option!", gr.update(visible=True)
110
+
111
+ game.rounds += 1
112
+
113
+ # Check if player is correct
114
+ player_correct = (player_choice == game.correct_country)
115
+ if player_correct:
116
+ game.player_score += 1
117
+
118
+ # Check if AI is correct
119
+ ai_correct = (game.ai_prediction == game.correct_country)
120
+ if ai_correct:
121
+ game.ai_score += 1
122
+
123
+ # Result message
124
+ result = f"## ๐ŸŽฏ Round {game.rounds} Result\n\n"
125
+ result += f"**Correct country:** {game.correct_country}\n\n"
126
+
127
+ if player_correct and ai_correct:
128
+ result += "๐ŸŽ‰ **It's a tie!** Both you and the AI got it right!\n"
129
+ elif player_correct:
130
+ result += "๐Ÿ† **You win!** The AI was wrong.\n"
131
+ elif ai_correct:
132
+ result += "๐Ÿค– **AI wins!** You were wrong.\n"
133
+ else:
134
+ result += "โŒ **Both failed!**\n"
135
+
136
+ result += f"\n**Your answer:** {player_choice} {'โœ…' if player_correct else 'โŒ'}\n"
137
+ result += f"**AI prediction:** {game.ai_prediction} {'โœ…' if ai_correct else 'โŒ'}\n"
138
+
139
+ # Calculate win rate
140
+ if game.rounds > 0:
141
+ player_rate = (game.player_score / game.rounds) * 100
142
+ ai_rate = (game.ai_score / game.rounds) * 100
143
+ result += f"\n**Your accuracy:** {player_rate:.1f}% ({game.player_score}/{game.rounds})\n"
144
+ result += f"**AI accuracy:** {ai_rate:.1f}% ({game.ai_score}/{game.rounds})\n"
145
+
146
+ return result, gr.update(visible=True)
147
+
148
+ def reset_game():
149
+ """Reset the game"""
150
+ game.player_score = 0
151
+ game.ai_score = 0
152
+ game.rounds = 0
153
+ game.used_indices = []
154
+ return (
155
+ None,
156
+ "๐ŸŽฎ **Game reset!** Click 'New Round' to start playing.",
157
+ gr.update(choices=[], value=None, visible=False),
158
+ gr.update(visible=False),
159
+ "๐ŸŽฎ Player: 0 | ๐Ÿค– AI: 0 | ๐ŸŽฏ Rounds: 0",
160
+ ""
161
+ )
162
+
163
+ # Gradio Interface
164
+ with gr.Blocks(theme=gr.themes.Soft(), title="GeoGuessr: Player vs AI") as demo:
165
+ gr.Markdown("""
166
+ # ๐ŸŒ GeoGuessr: Player vs AI
167
+
168
+ Compete against an AI trained with ConvNeXt V2 to guess countries from Google Street View images!
169
+
170
+ **How to play:**
171
+ 1. Click "๐ŸŽฎ New Round" to load a random image from the GeoGuessr dataset
172
+ 2. Choose one of the 4 proposed countries
173
+ 3. Click "โœ… Check Answer" to see if you beat the AI!
174
+
175
+ **Model:** ConvNeXt V2 Base (61% accuracy, 51.77% F1-macro)
176
+ """)
177
+
178
+ with gr.Row():
179
+ with gr.Column(scale=2):
180
+ image_display = gr.Image(type="pil", label="๐Ÿ“ธ Street View Image", interactive=False)
181
+ start_btn = gr.Button("๐ŸŽฎ New Round", variant="primary", size="lg")
182
+
183
+ with gr.Column(scale=1):
184
+ scoreboard = gr.Markdown("๐ŸŽฎ Player: 0 | ๐Ÿค– AI: 0 | ๐ŸŽฏ Rounds: 0")
185
+ reset_btn = gr.Button("๐Ÿ”„ Reset Game", variant="secondary")
186
+
187
+ question = gr.Markdown("โฌ‡๏ธ Click 'New Round' to start!")
188
+
189
+ options = gr.Radio(
190
+ choices=[],
191
+ label="๐ŸŒ Select the country:",
192
+ visible=False
193
+ )
194
+
195
+ submit_btn = gr.Button("โœ… Check Answer", variant="primary", visible=False)
196
+
197
+ result = gr.Markdown("")
198
+
199
+ # Events
200
+ start_btn.click(
201
+ fn=new_round,
202
+ inputs=[],
203
+ outputs=[image_display, question, options, submit_btn, scoreboard]
204
+ )
205
+
206
+ submit_btn.click(
207
+ fn=check_answer,
208
+ inputs=[options],
209
+ outputs=[result, start_btn]
210
+ )
211
+
212
+ reset_btn.click(
213
+ fn=reset_game,
214
+ outputs=[image_display, question, options, submit_btn, scoreboard, result]
215
+ )
216
+
217
+ gr.Markdown("""
218
+ ---
219
+ **Dataset:** [GeoGuessr by marcelomoreno26](https://huggingface.co/datasets/marcelomoreno26/geoguessr)
220
+ **Model:** [ConvNeXt V2 GeoGuessr by Jordiett](https://huggingface.co/Jordiett/convnextv2-geoguessr)
221
+
222
+ Images are randomly selected from the test set of the GeoGuessr dataset.
223
+ """)
224
+
225
+ if __name__ == "__main__":
226
+ demo.launch()