!pip install -Uqq fastai
!pip install -Uqq datasets
In this tutorial, we will show you how to use a HuggingFace dataset with fastai to train a model for image classification. We will use the Beans dataset, which consists of images of beans with three different types of diseases.
Step 1: Install the required libraries
Before starting, we need to install the required libraries. Run the following commands to install fastai and HuggingFace’s datasets:
Login to HuggingFace Hub to download the dataset.
!huggingface-cli login
Step 2: Import the required modules
import torch
from fastai.data.all import *
from fastai.vision.all import *
from datasets import load_dataset
Step 3: Load the dataset
Let’s load Beans dataset, which is a dataset of images of beans taken in the field using smartphone cameras. It consists of 3 classes: 2 disease classes and the healthy class.
= load_dataset("beans") raw_ds
raw_ds
DatasetDict({
train: Dataset({
features: ['image_file_path', 'image', 'labels'],
num_rows: 1034
})
validation: Dataset({
features: ['image_file_path', 'image', 'labels'],
num_rows: 133
})
test: Dataset({
features: ['image_file_path', 'image', 'labels'],
num_rows: 128
})
})
The dataset is splitted into train, validation, and test sets.
Let’s see label names.
= raw_ds['train'].features['labels'].names
class_names class_names
['angular_leaf_spot', 'bean_rust', 'healthy']
Step 4: Preprocess the dataset
Often, we need preprocessing raw data before training a model with it. HuggingFace datasets
library provides two methods for preprocessing:
map
: This method is used to apply a function to each example in the dataset, possibly in a batched manner. The function can be applied to one or more columns of the dataset, and the result can be stored in a new column or overwrite the existing one. The map function also allows you to remove some columns from the dataset, if needed. This method is useful for preprocessing the dataset, such as resizing images, tokenizing text, or encoding categorical features. It caches outputs so that they’re not computed again.set_transforms
: This method is used to set a transform function that is applied on-the-fly when accessing examples from the dataset. This means that the dataset is not modified in-place, and the transform function is applied only when the examples are accessed. This method is useful for applying data augmentation techniques or normalization that should be applied dynamically during training without modifying the dataset beforehand.
Let’s resize each image to 224x244.
def preprocess(records):
"image"] = [image.convert("RGB").resize((224, 224)) for image in records["image"]]
records[return records
Before batching samples, we need to remove unnecessary columns in the dataset such as image_file_path
.
= raw_ds.map(preprocess, remove_columns=["image_file_path"], batched=True) ds
Loading cached processed dataset at /Users/bdsaglam/.cache/huggingface/datasets/beans/default/0.0.0/90c755fb6db1c0ccdad02e897a37969dbf070bed3755d4391e269ff70642d791/cache-d335def00fc26298.arrow
Loading cached processed dataset at /Users/bdsaglam/.cache/huggingface/datasets/beans/default/0.0.0/90c755fb6db1c0ccdad02e897a37969dbf070bed3755d4391e269ff70642d791/cache-9b27f1e864a73628.arrow
Loading cached processed dataset at /Users/bdsaglam/.cache/huggingface/datasets/beans/default/0.0.0/90c755fb6db1c0ccdad02e897a37969dbf070bed3755d4391e269ff70642d791/cache-417ffd63aef737e6.arrow
'train'][0] ds[
{'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=224x224>,
'labels': 0}
We won’t use set_transform
as fastai’s DataBlock can apply item-level and batch-level transforms, e.g. data augmentations, normalization. When we use a pretrained model, it already applies necessary transforms such as normalization.
Step 5: Create the DataBlock
Now, we can create dataloaders for the dataset using DataBlock from fastai. As the dataset is already splitted into train, validation, and test sets, we don’t need to split it further. Hence, we will use nosplit
function.
def nosplit(items):
return list(range(len(items))), []
= DataBlock(
dblock =(ImageBlock, CategoryBlock),
blocks=lambda record: record['image'],
get_x=lambda record: class_names[record['labels']],
get_y= nosplit,
splitter
)
= dblock.dataloaders(ds['train']).train
train_dl = dblock.dataloaders(ds['validation']).train
valid_dl = DataLoaders(train_dl, valid_dl)
dls
dls.show_batch()
len(dls.train.dataset),len(dls.valid.dataset)
(1034, 133)
Step 6: Training
Let’s fine-tune a pretrained ResNet model on our dataset using Learner.fine_tune
method.
= vision_learner(
learn
dls,
resnet34, =CrossEntropyLossFlat(),
loss_func=accuracy,
metrics )
# Find a good learning rate
learn.lr_find()
SuggestedLRs(valley=0.0004786300996784121)
# Fine-tune the model
1, 1e-3, freeze_epochs=2) learn.fine_tune(
Step 7: Evaluation
Only after 3 epochs, fine-tuned model achieves 90.6% accuracy on validation set, not too bad!
= learn.validate(dl=dls.valid)
loss, accuracy print(f"Loss {loss:.6f}\nAccuracy: {accuracy:.2%}")
Loss 0.306838
Accuracy: 90.62%
Let’s check predictions visually.
=valid_dl) learn.show_results(dl
Let’s predict and evaluate on test set.
= dls.test_dl(ds['test'], with_labels=True)
tst_dl = learn.get_preds(dl=tst_dl, with_decoded=True) probs, targets, preds
= learn.validate(dl=tst_dl)
loss, accuracy print(f"Loss {loss:.6f}\nAccuracy: {accuracy:.2%}")
Loss 0.288472
Accuracy: 89.06%