File size: 6,438 Bytes
26a51d6 2e23403 eaee3a9 26a51d6 2e23403 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
---
license: mit
datasets:
- PedroSampaio/fruits-360
language:
- en
base_model:
- google/efficientnet-b0
pipeline_tag: image-classification
tags:
- pytorch
- torchvision
- efficientnet
- image-classification
- fruits
- fruits-360
- transfer-learning
- neptune-ai
widget:
# Example image URLs from the web - replace if you have better ones
- src: https://images.unsplash.com/photo-1573246123790-a64e870b8b1a?ixlib=rb-1.2.1&auto=format&fit=crop&w=640 # Example Apple
example_title: Apple Example
- src: https://images.unsplash.com/photo-1528825871115-3581a5377919?ixlib=rb-1.2.1&auto=format&fit=crop&w=640 # Example Banana
example_title: Banana Example
---
[DEMO APP](https://huggingface.co/spaces/bhumong/fruit-classifier-app)
# Fruit Classifier - EfficientNet-B0 (Fruits-360 Merged)
This repository contains a fruit image classification model based on a fine-tuned **EfficientNet-B0** architecture using PyTorch and torchvision. The model was trained on the **Fruits-360 dataset**, with a modification where specific fruit variants were merged into broader categories (e.g., "Apple Red 1", "Apple 6" merged into "Apple"), resulting in **[76]** distinct classes. <-- Make sure this matches your actual class count
Training progress and metrics were tracked using **Neptune.ai**.
## Model Description
* **Architecture:** EfficientNet-B0 (pre-trained on ImageNet)
* **Fine-tuning Strategy:** Transfer learning. The pre-trained base model's weights were frozen, and only the final classifier layer was replaced and trained on the target dataset.
* **Framework:** PyTorch / torchvision
* **Task:** Image Classification
* **Dataset:** Fruits-360 (Merged Classes)
* **Number of Classes:** [76] <-- Make sure this matches your actual class count
## Intended Uses & Limitations
* **Intended Use:** Classifying images of fruits belonging to one of the [76] merged categories derived from the Fruits-360 dataset. Suitable for educational purposes, demonstrations, or as a baseline for further development.
* **Limitations:**
* Trained *only* on the Fruits-360 dataset. Performance on images significantly different from this dataset (e.g., different lighting, backgrounds, occlusions, fruit varieties not present) is not guaranteed.
* Only recognizes the specific [76] merged classes it was trained on.
* Performance may vary depending on input image quality.
* Not intended for safety-critical applications without rigorous testing and validation.
## How to Use
You can load the model and its configuration directly from the Hugging Face Hub using `torch`, `torchvision`, and `huggingface_hub`.
```python
import torch
import torchvision.models as models
from torchvision.models import EfficientNet_B0_Weights # Or the specific version used
from PIL import Image
from torchvision import transforms
import json
import requests
from huggingface_hub import hf_hub_download
import os
# --- 1. Define Model Loading Function ---
def load_model_from_hf(repo_id, model_filename="pytorch_model.bin", config_filename="config.json"):
"""Loads model state_dict and config from Hugging Face Hub."""
# Download config file
config_path = hf_hub_download(repo_id=repo_id, filename=config_filename)
with open(config_path, 'r') as f:
config = json.load(f)
num_labels = config['num_labels']
id2label = config['id2label'] # Load label mapping
# Instantiate the correct architecture (EfficientNet-B0)
# Load architecture without pre-trained weights, as we'll load our fine-tuned ones
model = models.efficientnet_b0(weights=None)
# Modify the classifier head to match the number of classes used during training
num_ftrs = model.classifier[1].in_features
model.classifier[1] = torch.nn.Linear(num_ftrs, num_labels)
# Download model weights
model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)
# Load the state dict
# Ensure map_location handles CPU/GPU as needed
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)
model.eval() # Set to evaluation mode
print(f"Model loaded successfully from {repo_id} and set to evaluation mode.")
return model, config, id2label
# --- 2. Define Preprocessing ---
# Use the same transformations as validation during training
IMG_SIZE = (224, 224) # Standard EfficientNet input size
# ImageNet stats often used with EfficientNet pre-training
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
preprocess = transforms.Compose([
transforms.Resize(IMG_SIZE),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
# --- 3. Load Model ---
repo_id_to_load = "Bhumong/fruit-classifier-efficientnet-b0" # Your repo ID
model, config, id2label = load_model_from_hf(repo_id_to_load)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
# --- 4. Prepare Input Image ---
# Example: Load an image file (replace with your image path)
image_path = "path/to/your/fruit_image.jpg" # <-- REPLACE WITH YOUR IMAGE PATH
if not os.path.exists(image_path):
print(f"Warning: Image path not found: {image_path}")
print("Skipping prediction. Please provide a valid image path.")
input_batch = None
else:
try:
img = Image.open(image_path).convert("RGB")
input_tensor = preprocess(img)
# Add batch dimension (model expects batches)
input_batch = input_tensor.unsqueeze(0)
input_batch = input_batch.to(device)
except Exception as e:
print(f"Error processing image {image_path}: {e}")
input_batch = None
# --- 5. Make Prediction ---
if input_batch is not None:
with torch.no_grad(): # Disable gradient calculations for inference
output = model(input_batch)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
top_prob, top_catid = torch.max(probabilities, dim=0)
predicted_label_index = top_catid.item()
# Use the id2label mapping loaded from config
predicted_label = id2label.get(str(predicted_label_index), "Unknown Label")
confidence = top_prob.item()
print(f"\nPrediction for: {os.path.basename(image_path)}")
print(f"Predicted Label Index: {predicted_label_index}")
print(f"Predicted Label: {predicted_label}")
print(f"Confidence: {confidence:.4f}")
|