OffWorldTensor's picture
Update README.md
a83532d verified
---
license: mit
language: en
library_name: pytorch
tags:
- pytorch
- tabular-classification
- pokemon
- finance
- scikit-learn
- shap
---
# Pokémon TCG Price Predictor
This repository contains a PyTorch model trained to analyze Pokemon card features to identify cards with potential for significant price increases.
This model is the backend for the **[PokePrice Gradio Demo](https://huggingface.co/spaces/OffWorldTensor/PokePrice)**.
## Model Description
The model is a simple Multi-Layer Perceptron (MLP) implemented in PyTorch. It takes various features of a Pokémon card as input—such as its rarity, type, and historical price data—and outputs a single logit. A sigmoid function can be applied to this logit to get a probability score for the price rising.
- **Model type:** Tabular Binary Classification
- **Architecture:** `PricePredictor` (MLP)
- **Framework:** PyTorch
- **Training Data:** A custom dataset derived from the PokemonTCG/pokemon-tcg-data repository, augmented with pricing history.
## How to Use
To use this model, you will need `torch`, `scikit-learn`, `pandas`, and `huggingface_hub`. You can download the model artifacts directly from the Hub.
First, ensure you have `network.py` (which defines the model class) in your working directory.
```python
import torch
import joblib
import json
import pandas as pd
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
# Make sure you have network.py in the same directory
from network import PricePredictor
REPO_ID = "your-username/pokemon-price-predictor"
MODEL_FILENAME = "model.safetensors"
CONFIG_FILENAME = "config.json"
SCALER_FILENAME = "scaler.pkl"
print("Downloading model files from the Hub...")
model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME)
config_path = hf_hub_download(repo_id=REPO_ID, filename=CONFIG_FILENAME)
scaler_path = hf_hub_download(repo_id=REPO_ID, filename=SCALER_FILENAME)
print("Downloads complete.")
with open(config_path, "r") as f:
config = json.load(f)
feature_columns = config["feature_columns"]
input_size = config["input_size"]
model = PricePredictor(input_size=input_size)
model.load_state_dict(load_file(model_path))
model.eval()
scaler = joblib.load(scaler_path)
data_to_predict = {
'rawPrice': [10.0], 'gradedPriceTen': [100.0], 'gradedPriceNine': [50.0],
}
input_df = pd.DataFrame(data_to_predict)
missing_cols = set(feature_columns) - set(input_df.columns)
for c in missing_cols:
input_df[c] = 0.0
input_df = input_df[feature_columns]
input_scaled = scaler.transform(input_df.values)
input_tensor = torch.tensor(input_scaled, dtype=torch.float32)
with torch.no_grad():
logits = model(input_tensor)
probability = torch.sigmoid(logits).item()
print(f"\nPrediction for the input card:")
print(f" - Probability of 30% price rise in 6 months: {probability:.4f}")
if probability > 0.5:
print(" - Prediction: Price WILL LIKELY rise.")
else:
print(" - Prediction: Price WILL LIKELY NOT rise.")
```
## Model Explainability
To understand the model's decisions, SHAP (SHapley Additive exPlanations) values were computed.
### Global Feature Importance
This plot shows the average impact of each feature on the model's output magnitude. Features at the top are most influential.
![image/png](https://cdn-uploads.huggingface.co/production/uploads/68b20687b24f311b7de2242d/WMXEn5Ond1zo4B6hvsqLN.png)
## Limitations and Bias
- The model is trained on historical data and may not predict future trends accurately, especially in a volatile market.
- The definition of "price rise" is fixed at 30% over 6 months. The model is not trained for other thresholds or timeframes.
- The dataset may have inherent biases related to card popularity, set releases, or data collection artifacts.
## Author
Callum Anderson