ensemble / README.md
yuchen4's picture
Update README.md
1c17aef verified
metadata
language:
  - en
tags:
  - xception
  - gps-prediction
  - mean
  - standard deviation
metrics:
  - mae
  - rmse

Custom Xception Model

This is a fine-tuned Xception model's ensemble for predicting latitude and longitude from images.

Model Metadata

  • Latitude Mean: 39.95165153939056
  • Latitude Std: 0.0007248140892687559
  • Longitude Mean: -75.19139496469714
  • Longitude Std: 0.0007013685468922234

Error Metrics

  • Mean Absolute Error (MAE): 0.00020775681686579616
  • Root Mean Squared Error (RMSE): 0.0003053099508331751

Model Evaluation

import torch
from transformers import AutoImageProcessor, AutoModelForImageClassification
from huggingface_hub import hf_hub_download, login
import torch.nn as nn
from datasets import load_dataset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
from geopy.distance import geodesic
from timm import create_model


login(token="")


class CustomXceptionModel(nn.Module):
    def __init__(self, model_name="legacy_xception", num_classes=2, metadata=None):
        super().__init__()

        self.metadata = metadata if metadata is not None else {}
        self.xception = create_model(model_name, pretrained=False)
        in_features = self.xception.fc.in_features

        self.xception.fc = nn.Linear(in_features, 2)

    def forward(self, x):
        return self.xception(x)


model_path = hf_hub_download(repo_id="aaaimg2gps/ensemble", filename="best_bagging_models.pth")
model_list = torch.load(model_path)
for i in range(len(model_list)):
    model_list[i] = model_list[i].to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))


lat_mean = 39.95165153939056
lat_std = 0.0007248140892687559
lon_mean = -75.19139496469714
lon_std = 0.0007013685468922234

#test dataset
dataset_test = load_dataset("gydou/released_img", split="train")

inference_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

class GPSImageDataset(Dataset):
    def __init__(self, hf_dataset, transform=None, lat_mean=None, lat_std=None, lon_mean=None, lon_std=None):
        self.hf_dataset = hf_dataset
        self.transform = transform
        self.latitude_mean = lat_mean
        self.latitude_std = lat_std
        self.longitude_mean = lon_mean
        self.longitude_std = lon_std

    def __len__(self):
        return len(self.hf_dataset)

    def __getitem__(self, idx):
        example = self.hf_dataset[idx]
        image = example['image']
        latitude = example['Latitude']
        longitude = example['Longitude']

        if self.transform:
            image = self.transform(image)

        latitude = (latitude - self.latitude_mean) / self.latitude_std
        longitude = (longitude - self.longitude_mean) / self.longitude_std
        gps_coords = torch.tensor([latitude, longitude], dtype=torch.float32)

        return image, gps_coords

test_dataset = GPSImageDataset(
    hf_dataset=dataset_test,
    transform=inference_transform,
    lat_mean=lat_mean,
    lat_std=lat_std,
    lon_mean=lon_mean,
    lon_std=lon_std
)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

def weighted_mean(outputs_list):
    outputs_stack = torch.stack(outputs_list, dim=0)  # (num_models, batch_size, 2)
    weighted_outputs = []
    epsilon = 1e-6  

    for i in range(outputs_stack.size(1)):  
        sample_preds = outputs_stack[:, i, :]  
        distances = torch.cdist(sample_preds, sample_preds, p=2)  
        avg_distances = distances.mean(dim=1)  
        weights = 1 / (avg_distances + epsilon)  
        normalized_weights = weights / weights.sum()  
        weighted_output = (sample_preds * normalized_weights[:, None]).sum(dim=0)
        weighted_outputs.append(weighted_output)

    return torch.stack(weighted_outputs, dim=0)

def evaluate_model(model_list, dataloader, device):
    distances = []

    with torch.no_grad():
        for images, gps_coords in dataloader:
            images = images.to(device)
            gps_coords = gps_coords.to(device)

            # prediction of each model
            outputs_list = [model(images) for model in model_list]

            # weighted mean of predictions
            outputs_mean = weighted_mean(outputs_list)

            # denormalize predictions and actuals
            preds = outputs_mean.cpu() * torch.tensor([lat_std, lon_std]) + torch.tensor([lat_mean, lon_mean])
            actuals = gps_coords.cpu() * torch.tensor([lat_std, lon_std]) + torch.tensor([lat_mean, lon_mean])

            for pred, actual in zip(preds.numpy(), actuals.numpy()):
                distance = geodesic((actual[0], actual[1]), (pred[0], pred[1])).meters
                distances.append(distance)
                print(f"Predicted coordinates: ({pred[0]:.6f}, {pred[1]:.6f})")
                print(f"Actual coordinates: ({actual[0]:.6f}, {actual[1]:.6f})")
                print(f"Distance error: {distance:.2f} meters")
                print("---")

    mean_dist = np.mean(distances)
    median_dist = np.median(distances)
    dist_std = np.std(distances)
    dist_95 = np.percentile(distances, 95)

    print("\n=== Overall Performance Evaluation ===")
    print(f'Mean distance error: {mean_dist:.2f} meters')
    print(f'Median distance error: {median_dist:.2f} meters')
    print(f'Distance standard deviation: {dist_std:.2f} meters')
    print(f'95th percentile distance error: {dist_95:.2f} meters')

# evaluation
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Starting model evaluation...")
evaluate_model(model_list, test_dataloader, device)