Musheff: Mushroom Classification Model

Accuracy Framework

An EfficientNet-b3 model fine-tuned for classifying mushroom images into the 12 popular Russian mushroom classes corresponding to species and edibility. Fine-tuned on the 12_popular_russia_mushrooms_edible_poisonous dataset.

Model Details

Architecture

  • Base Model: EfficientNet-b3 (pretrained on ImageNet)
  • Fine-tuning Approach:
    • Frozen base layers
    • Modified classification head
      • dropout_rate: 0.3
      • num_classes: 12
  • Loss Function: CrossEntropyLoss
  • Optimizer: Adam (lr=0.001)

Performance

Dataset Accuracy
Validation 96.88%
Test 95.68%

Training Strategy

  • Custom Split Ratio:
    • 80% Training
    • 10% Validation
    • 10% Test
  • Training Duration:
    • Maximum 15 epochs with early stopping
    • Best model checkpoint tracking
  • Overfitting Prevention:
    • Custom data splits for improved variation
    • Early stopping mechanism (patience=3 epochs, min_delta=0.001)
    • Validation performance monitoring

How to use

Here is how to use Musheff model to classify a mushroom image into one of the 12 Russian classes (species and edibility):

Option 1: Using the HuggingFace Transformers Model

import random

import torch

from datasets import load_dataset
from transformers import (
    AutoImageProcessor,
    AutoModel,
)

test_dataset = load_dataset(
    "SoFa325/12_popular_russia_mushrooms_edible_poisonous", split="test"
)

test_len = len(test_dataset)

# Pick a random image from test set
random_index = random.randint(0, test_len)

image = test_dataset["image"][random_index]

preprocessor = AutoImageProcessor.from_pretrained(
    "blasisd/musheff",
    trust_remote_code=True,
    use_fast=True,
)


model = AutoModel.from_pretrained(
    "blasisd/musheff",
    trust_remote_code=True,
    low_cpu_mem_usage=True,  # Activates memory-efficient loading
    device_map="auto",  # Distributes layers across devices
)

inputs = preprocessor(image, return_tensors="pt").to(model.device)

model.eval()
with torch.inference_mode():
    logits = model(inputs["pixel_values"])

# model predicts one of the 12 potential mushroom classes
predicted_label = logits.argmax(dim=1).item()

print(f"True label: {test_dataset['label'][random_index]}")
print(f"Predicted label: {model.config.id2label[predicted_label]}"),

NOTE: You may optionally replace AutoModel with AutoModelForImageClassification in the code above. The Musheff model has been registered for both classes, which can be used with the same parameters.

Option 2: Using the PyTorch Model

Alternatively, you can download the files from the repository locally and follow the steps below:

import json
import random

import torch

from datasets import load_dataset
from torchvision import models

from model import Musheff


# Device agnostic
device = "cuda" if torch.cuda.is_available() else "cpu"

test_dataset = load_dataset(
    "SoFa325/12_popular_russia_mushrooms_edible_poisonous", split="test"
)

test_len = len(test_dataset)

# Pick a random image from test set
random_index = random.randint(0, test_len)

image = test_dataset["image"][random_index]

with open("config.json", "r") as json_fp:
    config = json.load(json_fp)

model = Musheff(config)
model.model.load_state_dict(torch.load("musheff.pth"))
model.to(device=device)

transform = models.EfficientNet_B3_Weights.DEFAULT.transforms()

img = transform(image)

# Expecting 4D shape i.e. (batch_size, channels, height, width)
img = img.unsqueeze(0)

model.eval()

with torch.inference_mode():
    logits = model(img.to(device))

# model predicts one of the 12 potential mushroom classes
predicted_label = logits.argmax(dim=1).item()

with open("config.json", "r") as json_fp:
    id2label = json.load(json_fp).get("id2label")

print(f"True label: {test_dataset['label'][random_index]}")
print(f"Predicted label: {id2label[str(predicted_label)]}")

Dependencies

Install required Python packages using either method.

Option 1: Direct installation

Run this command in your terminal (recommended inside a virtual environment):

pip install accelerate==1.8.1 datasets==3.6.0 pillow==11.1.0 torch==2.6.0 torchvision==0.21.0 transformers==4.53.1

Option 2: Requirements file

  1. Download requirements.txt from the repository
  2. Run in terminal (recommended inside a virtual environment):
pip install -r requirements.txt

Intended Uses & Limitations

Recommended Use Cases

  • Educational mushroom identification apps
  • Preliminary screening of common Russian mushrooms
  • Integration with foraging guide applications

Limitations

  • Geographic Specificity: Only recognizes 12 mushroom species common in Russia
  • Safety Critical: NOT SUITABLE for consumption decisions - always consult human experts

Ethical Considerations

🚨 Critical Warning:

  • Misclassification could lead to fatal poisoning
  • Always verify predictions with certified mycologists
  • Intended for educational purposes only, not consumption safety

Model Hub

Explore fine-tuned variants on the Hugging Face Hub

Downloads last month
12
Safetensors
Model size
10.8M params
Tensor type
F32
Β·
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Dataset used to train blasisd/musheff

Spaces using blasisd/musheff 2