Spaces:
Runtime error
Runtime error
| from transformers import ViTImageProcessor, ViTForImageClassification | |
| import gradio as gr | |
| from datasets import load_dataset | |
| import torch | |
| import random | |
| import numpy as np | |
| import pandas as pd | |
| def get_predictions(image): | |
| inputs = processor(image, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| # Get top n predictions | |
| top_indices = logits[0].argsort(dim=-1, descending=True) | |
| probabilities = torch.softmax(logits, dim=-1)[0, top_indices] | |
| labels = [model.config.id2label[idx.item()] for idx in top_indices] | |
| predictions = {} | |
| for i, label in enumerate(labels): | |
| predictions[label] = probabilities[i] | |
| return predictions | |
| data = load_dataset("marcelomoreno26/geoguessr",split="test") | |
| model_name = "marcelomoreno26/vit-base-patch-16-384-geoguessr" | |
| processor = ViTImageProcessor.from_pretrained(model_name) | |
| model = ViTForImageClassification.from_pretrained(model_name) | |
| length = len(data) | |
| countries = [] | |
| with open("countries.txt", "r") as file: | |
| for line in file: | |
| countries.append(line.strip()) | |
| def get_result(selection): | |
| global correct_country | |
| global model_prediction | |
| global filtered_predictions | |
| if selection == correct_country and correct_country == model_prediction: | |
| result = "It's a draw!" | |
| elif selection == correct_country: | |
| result = "Congratulations! You won!" | |
| elif correct_country == model_prediction: | |
| result = "Sorry, you lost. The AI guessed it right!" | |
| else: | |
| result = "Sorry, you both lost." | |
| total_prob = sum([(float(value)) for value in filtered_predictions.values()]) | |
| prob_per_country = [(key,np.round(float(value)/total_prob,3)*100) for key,value in filtered_predictions.items()] | |
| df = pd.DataFrame(prob_per_country,columns=["Country","Model Confidence (%)"]).sort_values(by="Model Confidence (%)",ascending=False) | |
| ai_confidence = f"The AI's guess was {model_prediction}\n\nAI's Results:\n"+ df.to_markdown(index=False) | |
| return f"The correct country was: {correct_country}\n{result}", ai_confidence | |
| def load(): | |
| global filtered_predictions | |
| # Randomly select an image | |
| i = random.randint(0, len(data) - 1) | |
| image = data[i]['image'] | |
| correct_country = data[i]['label'] | |
| # Randomly sample 4 countries as options | |
| options = [country for country in random.sample(countries, 4) if country != correct_country] | |
| options.append(correct_country) | |
| random.shuffle(options) | |
| # Get model predictions | |
| predictions = get_predictions(image) | |
| filtered_predictions = {country: predictions[country] for country in options} | |
| model_prediction = max(filtered_predictions, key=filtered_predictions.get) | |
| return image, options, correct_country, model_prediction | |
| def reload(): | |
| global correct_country | |
| global model_prediction | |
| global filtered_predictions | |
| # Randomly select an image | |
| i = random.randint(0, len(data) - 1) | |
| image = data[i]['image'] | |
| correct_country = data[i]['label'] | |
| # Randomly sample 4 countries as options | |
| options = [country for country in random.sample(countries, 4) if country != correct_country] | |
| options.append(correct_country) | |
| random.shuffle(options) | |
| # Get model predictions | |
| predictions = get_predictions(image) | |
| filtered_predictions = {country: predictions[country] for country in options} | |
| model_prediction = max(filtered_predictions, key=filtered_predictions.get) | |
| return gr.Image(image), gr.Radio(choices=options, label ="Select the country:"), "", "" | |
| with gr.Blocks() as demo: | |
| image, options, correct_country, model_prediction = load() | |
| gr.Markdown("# GeoGuessr - Can You Beat the AI?") | |
| gr.Markdown("Try to guess the country in the image. Can you beat the AI?") | |
| gr.Markdown("## Instructions:") | |
| gr.Markdown("\n1. First to Start playing press **Get New Image** at the bottom (the server needs to refresh from the cache and previous user)\n2. Select the country where you think the image was taken.\n3. Review the results.\n4. Play again by clicking **Get New Image**") | |
| img = gr.Image(image) | |
| radio = gr.Radio(choices=options, label ="Select the country:") | |
| ai_pred = gr.Markdown() | |
| text = gr.Text(label="Result") | |
| radio.select(fn=get_result, inputs=radio, outputs=[text,ai_pred]) | |
| btn = gr.Button(value="Get New Image") | |
| btn.click(reload, None,outputs=[img,radio,text,ai_pred]) | |
| demo.launch() |