File size: 6,438 Bytes
26a51d6
2e23403
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eaee3a9
26a51d6
2e23403
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159

---
license: mit
datasets:
- PedroSampaio/fruits-360
language:
- en
base_model:
- google/efficientnet-b0
pipeline_tag: image-classification
tags:
- pytorch
- torchvision
- efficientnet
- image-classification
- fruits
- fruits-360
- transfer-learning
- neptune-ai
widget:
  # Example image URLs from the web - replace if you have better ones
  - src: https://images.unsplash.com/photo-1573246123790-a64e870b8b1a?ixlib=rb-1.2.1&auto=format&fit=crop&w=640 # Example Apple
    example_title: Apple Example
  - src: https://images.unsplash.com/photo-1528825871115-3581a5377919?ixlib=rb-1.2.1&auto=format&fit=crop&w=640 # Example Banana
    example_title: Banana Example
---

[DEMO APP](https://huggingface.co/spaces/bhumong/fruit-classifier-app)

# Fruit Classifier - EfficientNet-B0 (Fruits-360 Merged)

This repository contains a fruit image classification model based on a fine-tuned **EfficientNet-B0** architecture using PyTorch and torchvision. The model was trained on the **Fruits-360 dataset**, with a modification where specific fruit variants were merged into broader categories (e.g., "Apple Red 1", "Apple 6" merged into "Apple"), resulting in **[76]** distinct classes. <-- Make sure this matches your actual class count

Training progress and metrics were tracked using **Neptune.ai**.

## Model Description

*   **Architecture:** EfficientNet-B0 (pre-trained on ImageNet)
*   **Fine-tuning Strategy:** Transfer learning. The pre-trained base model's weights were frozen, and only the final classifier layer was replaced and trained on the target dataset.
*   **Framework:** PyTorch / torchvision
*   **Task:** Image Classification
*   **Dataset:** Fruits-360 (Merged Classes)
*   **Number of Classes:** [76] <-- Make sure this matches your actual class count

## Intended Uses & Limitations

*   **Intended Use:** Classifying images of fruits belonging to one of the [76] merged categories derived from the Fruits-360 dataset. Suitable for educational purposes, demonstrations, or as a baseline for further development.
*   **Limitations:**
    *   Trained *only* on the Fruits-360 dataset. Performance on images significantly different from this dataset (e.g., different lighting, backgrounds, occlusions, fruit varieties not present) is not guaranteed.
    *   Only recognizes the specific [76] merged classes it was trained on.
    *   Performance may vary depending on input image quality.
    *   Not intended for safety-critical applications without rigorous testing and validation.

## How to Use

You can load the model and its configuration directly from the Hugging Face Hub using `torch`, `torchvision`, and `huggingface_hub`.

```python
import torch
import torchvision.models as models
from torchvision.models import EfficientNet_B0_Weights # Or the specific version used
from PIL import Image
from torchvision import transforms
import json
import requests
from huggingface_hub import hf_hub_download
import os

# --- 1. Define Model Loading Function ---
def load_model_from_hf(repo_id, model_filename="pytorch_model.bin", config_filename="config.json"):
    """Loads model state_dict and config from Hugging Face Hub."""

    # Download config file
    config_path = hf_hub_download(repo_id=repo_id, filename=config_filename)
    with open(config_path, 'r') as f:
        config = json.load(f)

    num_labels = config['num_labels']
    id2label = config['id2label'] # Load label mapping

    # Instantiate the correct architecture (EfficientNet-B0)
    # Load architecture without pre-trained weights, as we'll load our fine-tuned ones
    model = models.efficientnet_b0(weights=None)

    # Modify the classifier head to match the number of classes used during training
    num_ftrs = model.classifier[1].in_features
    model.classifier[1] = torch.nn.Linear(num_ftrs, num_labels)

    # Download model weights
    model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)

    # Load the state dict
    # Ensure map_location handles CPU/GPU as needed
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    state_dict = torch.load(model_path, map_location=device)
    model.load_state_dict(state_dict)

    model.eval() # Set to evaluation mode
    print(f"Model loaded successfully from {repo_id} and set to evaluation mode.")
    return model, config, id2label

# --- 2. Define Preprocessing ---
# Use the same transformations as validation during training
IMG_SIZE = (224, 224) # Standard EfficientNet input size
# ImageNet stats often used with EfficientNet pre-training
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]

preprocess = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
])

# --- 3. Load Model ---
repo_id_to_load = "Bhumong/fruit-classifier-efficientnet-b0" # Your repo ID
model, config, id2label = load_model_from_hf(repo_id_to_load)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)


# --- 4. Prepare Input Image ---
# Example: Load an image file (replace with your image path)
image_path = "path/to/your/fruit_image.jpg" # <-- REPLACE WITH YOUR IMAGE PATH

if not os.path.exists(image_path):
    print(f"Warning: Image path not found: {image_path}")
    print("Skipping prediction. Please provide a valid image path.")
    input_batch = None
else:
    try:
        img = Image.open(image_path).convert("RGB")
        input_tensor = preprocess(img)
        # Add batch dimension (model expects batches)
        input_batch = input_tensor.unsqueeze(0)
        input_batch = input_batch.to(device)
    except Exception as e:
        print(f"Error processing image {image_path}: {e}")
        input_batch = None

# --- 5. Make Prediction ---
if input_batch is not None:
    with torch.no_grad(): # Disable gradient calculations for inference
        output = model(input_batch)
        probabilities = torch.nn.functional.softmax(output[0], dim=0)
        top_prob, top_catid = torch.max(probabilities, dim=0)

    predicted_label_index = top_catid.item()
    # Use the id2label mapping loaded from config
    predicted_label = id2label.get(str(predicted_label_index), "Unknown Label")
    confidence = top_prob.item()

    print(f"\nPrediction for: {os.path.basename(image_path)}")
    print(f"Predicted Label Index: {predicted_label_index}")
    print(f"Predicted Label: {predicted_label}")
    print(f"Confidence: {confidence:.4f}")