JJMack commited on
Commit
b8ebbd7
·
verified ·
1 Parent(s): bf60d5d

Update README.md

Browse files

Updating the model card

Files changed (1) hide show
  1. README.md +75 -1
README.md CHANGED
@@ -7,4 +7,78 @@ base_model:
7
  tags:
8
  - videogames
9
  - pokemon
10
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  tags:
8
  - videogames
9
  - pokemon
10
+ pipeline_tag: image-classification
11
+ ---
12
+
13
+ # Model Card: Pokemon Generation 1 through 9 Image Classifier
14
+
15
+ ## Model Description
16
+ The Fine-Tuned Vision Transformer (ViT) is a variant of the transformer encoder architecture, similar to BERT, that has been adapted for image classification tasks. This specific model, named "google/vit-base-patch16-224-in21k," is pre-trained on a substantial collection of images in a supervised manner, leveraging the ImageNet-21k dataset. The images in the pre-training dataset are resized to a resolution of 224x224 pixels, making it suitable for a wide range of image recognition tasks.
17
+
18
+ The model was trained using an augmented dataset of JJMack/pokemon-classification-gen1-9, with 5 additional augmentend version of each image.
19
+
20
+ This model was for me to learn how to fine tune a model and I am writing a LinkedIn Article series around the process. You can find the first link [Building a Real Pokédex - An AI Journey](https://www.linkedin.com/pulse/building-real-pok%C3%A9dex-ai-journey-jeremy-mack-jc3fc/?trackingId=zWK6TeRJ%2FXLAmv7BKZsQxA%3D%3D)
21
+
22
+ ### Intended Uses
23
+ - **Pokemon Classification**: The primary intended use of this model is for the classification of Pokemon images.
24
+
25
+ ### How to use
26
+ Here is how to use this model to classifiy an image based on 1 of 1025 pokemone:
27
+
28
+ ```python
29
+ # Use a pipeline as a high-level helper
30
+ from PIL import Image
31
+ from transformers import pipeline
32
+ img = Image.open("<path_to_image_file>")
33
+ classifier = pipeline("image-classification", model="JJMack/pokemon_gen1_9_classifier")
34
+ classifier(img)
35
+ ```
36
+
37
+ <hr>
38
+
39
+ ``` markdown
40
+ # Load model directly
41
+ import torch
42
+ from PIL import Image
43
+ from transformers import AutoModelForImageClassification, ViTImageProcessor
44
+ img = Image.open("<path_to_image_file>")
45
+ model = AutoModelForImageClassification.from_pretrained("JJMack/pokemon_gen1_9_classifier")
46
+ processor = ViTImageProcessor.from_pretrained('JJMack/pokemon_gen1_9_classifier')
47
+ with torch.no_grad():
48
+ inputs = processor(images=img, return_tensors="pt")
49
+ outputs = model(**inputs)
50
+ logits = outputs.logits
51
+ predicted_label = logits.argmax(-1).item()
52
+ model.config.id2label[predicted_label]
53
+
54
+ ```
55
+
56
+ ### Limitations
57
+ - **Specialized Task Fine-Tuning**: While the model is adept at NSFW image classification, its performance may vary when applied to other tasks.
58
+ - Users interested in employing this model for different tasks should explore fine-tuned versions available in the model hub for optimal results.
59
+
60
+ ## Training Data
61
+
62
+ The model's training data came from [Bulapedia](https://bulbapedia.bulbagarden.net/wiki/Main_Page). Each image of the training dataset was augmented 5 times with the following augments
63
+ ```
64
+ - RandomHorizontalFlip(p=0.5),
65
+ - RandomVerticalFlip(p=0.5),
66
+ - RandomRotation(degrees=30),
67
+ - ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2),
68
+ - GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)),
69
+ - RandomAffine(degrees=0, translate=(0.1, 0.1)),
70
+ - RandomPerspective(distortion_scale=0.5, p=0.5),
71
+ - RandomGrayscale(p=0.2),
72
+ ```
73
+
74
+ ### Training Stats
75
+ ```
76
+ - 'eval_loss': 0.7451944351196289,
77
+ - 'eval_accuracy': 0.9221343873517787,
78
+ - 'eval_runtime': 39.6834,
79
+ - 'eval_samples_per_second': 63.755,
80
+ - 'eval_steps_per_second': 7.988
81
+
82
+ ```
83
+
84
+ <hr>