--- license: cc-by-nc-sa-4.0 datasets: - JJMack/pokemon-classification-gen1-9 base_model: - google/vit-base-patch16-224-in21k tags: - videogames - pokemon pipeline_tag: image-classification --- # Model Card: Pokemon Generation 1 through 9 Image Classifier ## Model Description The Fine-Tuned Vision Transformer (ViT) is a variant of the transformer encoder architecture, similar to BERT, that has been adapted for image classification tasks. This specific model, named "google/vit-base-patch16-224-in21k," is pre-trained on a substantial collection of images in a supervised manner, leveraging the ImageNet-21k dataset. The images in the pre-training dataset are resized to a resolution of 224x224 pixels, making it suitable for a wide range of image recognition tasks. The model was trained using an augmented dataset of JJMack/pokemon-classification-gen1-9, with 5 additional augmentend version of each image. This model was for me to learn how to fine tune a model and I am writing a LinkedIn Article series around the process. You can find the first link [Building a Real Pokédex - An AI Journey](https://www.linkedin.com/pulse/building-real-pok%C3%A9dex-ai-journey-jeremy-mack-jc3fc/?trackingId=zWK6TeRJ%2FXLAmv7BKZsQxA%3D%3D) ### Intended Uses - **Pokemon Classification**: The primary intended use of this model is for the classification of Pokemon images. ### How to use Here is how to use this model to classifiy an image based on 1 of 1025 pokemone: ```python # Use a pipeline as a high-level helper from PIL import Image from transformers import pipeline img = Image.open("") classifier = pipeline("image-classification", model="JJMack/pokemon_gen1_9_classifier") classifier(img) ```
``` markdown # Load model directly import torch from PIL import Image from transformers import AutoModelForImageClassification, ViTImageProcessor img = Image.open("") model = AutoModelForImageClassification.from_pretrained("JJMack/pokemon_gen1_9_classifier") processor = ViTImageProcessor.from_pretrained('JJMack/pokemon_gen1_9_classifier') with torch.no_grad(): inputs = processor(images=img, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits predicted_label = logits.argmax(-1).item() model.config.id2label[predicted_label] ``` ### Limitations - **Specialized Task Fine-Tuning**: While the model is adept at NSFW image classification, its performance may vary when applied to other tasks. - Users interested in employing this model for different tasks should explore fine-tuned versions available in the model hub for optimal results. ## Training Data The model's training data came from [Bulapedia](https://bulbapedia.bulbagarden.net/wiki/Main_Page). Each image of the training dataset was augmented 5 times with the following augments ``` - RandomHorizontalFlip(p=0.5), - RandomVerticalFlip(p=0.5), - RandomRotation(degrees=30), - ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2), - GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)), - RandomAffine(degrees=0, translate=(0.1, 0.1)), - RandomPerspective(distortion_scale=0.5, p=0.5), - RandomGrayscale(p=0.2), ``` ### Training Stats ``` - 'eval_loss': 0.7451944351196289, - 'eval_accuracy': 0.9221343873517787, - 'eval_runtime': 39.6834, - 'eval_samples_per_second': 63.755, - 'eval_steps_per_second': 7.988 ```