File size: 2,931 Bytes
6dc7e3e
 
21c8653
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6dc7e3e
21c8653
 
 
 
719f891
 
21c8653
 
 
 
719f891
 
21c8653
 
719f891
21c8653
719f891
21c8653
719f891
21c8653
719f891
 
83247a7
21c8653
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
---
license: apache-2.0
base_model: google/vit-base-patch16-224-in21k
tags:
- generated_from_trainer
datasets:
- imagefolder
metrics:
- f1
model-index:
- name: Pokemon-classification-1stGen
  results:
  - task:
      name: Image Classification
      type: image-classification
    dataset:
      name: imagefolder
      type: imagefolder
      config: default
      split: train
      args: default
    metrics:
    - name: F1
      type: f1
      value: 0.9272453917274858
---


# Pokemon-classification-1stGen

This model is a fine-tuned version of [google/vit-base-patch16-224-in21k](https://huggingface.co/google/vit-base-patch16-224-in21k) on the [Dusduo/1stGen-Pokemon-Images](https://huggingface.co/datasets/Dusduo/1stGen-Pokemon-Images) dataset.
It has been trained to discriminate between the pokemons from the [1st Generation](https://en.wikipedia.org/wiki/List_of_generation_I_Pok%C3%A9mon).
It achieves the following results on the evaluation set:
- Loss: 0.4182
- F1: 0.9272

A demonstration of the model application is [hosted on Spaces](https://huggingface.co/spaces/Dusduo/GottaClassifyEmAll).
Feel free to check it out!


## Model description

Transformer-based vision model for pokemon image classification.

## Intended uses & limitations

This model is intended to classify between pokemons from the 1st Generation. Therefore, when provided with images of pokemon from posterior generation, the model outputs won't be usable as such.
Moreover, the model was not designed to handle non pokemon images as well as images presenting several entities. 
However, an additional layer can help mitigate the risk of wrongly classifying non pokemon images by analyzing the spread of the output (the confusion of the model), such a layer can be found in my implementation, available [here](https://github.com/A-Duss/GottaClassifyEmAll).

## Training procedure

### Training hyperparameters

The following hyperparameters were used during training:
- learning_rate: 6.56462271373806e-05
- train_batch_size: 4
- eval_batch_size: 16
- seed: 42
- gradient_accumulation_steps: 4
- total_train_batch_size: 16
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
- lr_scheduler_type: linear
- lr_scheduler_warmup_ratio: 0.1
- num_epochs: 7

### Training results

| Training Loss | Epoch | Step | Validation Loss | F1     |
|:-------------:|:-----:|:----:|:---------------:|:------:|
| 4.3698        | 1.0   | 527  | 3.2781          | 0.5784 |
| 2.3225        | 2.0   | 1055 | 1.6644          | 0.7368 |
| 1.1907        | 3.0   | 1582 | 0.9749          | 0.8475 |
| 0.6947        | 4.0   | 2110 | 0.6765          | 0.8939 |
| 0.4827        | 5.0   | 2637 | 0.5290          | 0.9171 |
| 0.3515        | 6.0   | 3165 | 0.4530          | 0.9195 |
| 0.3074        | 6.99  | 3689 | 0.4182          | 0.9272 |


### Framework versions

- Transformers 4.35.2
- Pytorch 2.2.0.dev20231126+cu118
- Datasets 2.15.0
- Tokenizers 0.15.0