FastAI Demonstration: LeafSnap

See the full explanation behind this demonstration notebook at FastAI quick test / LeafSnap.

In [1]:
from google.colab import drive
drive.mount('/content/gdrive')
Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/gdrive
In [2]:
from pathlib import Path
drive_path = Path('/content/gdrive/My Drive/leafsnap')
base_path = Path('/content/leafsnap-dataset')
In [3]:
!pip install "torch==1.4" "torchvision==0.5.0"
Collecting torch==1.4
  Downloading https://files.pythonhosted.org/packages/24/19/4804aea17cd136f1705a5e98a00618cb8f6ccc375ad8bfa437408e09d058/torch-1.4.0-cp36-cp36m-manylinux1_x86_64.whl (753.4MB)
     |████████████████████████████████| 753.4MB 17kB/s 
Collecting torchvision==0.5.0
  Downloading https://files.pythonhosted.org/packages/7e/90/6141bf41f5655c78e24f40f710fdd4f8a8aff6c8b7c6f0328240f649bdbe/torchvision-0.5.0-cp36-cp36m-manylinux1_x86_64.whl (4.0MB)
     |████████████████████████████████| 4.0MB 40.1MB/s 
Requirement already satisfied: pillow>=4.1.1 in /usr/local/lib/python3.6/dist-packages (from torchvision==0.5.0) (7.0.0)
Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from torchvision==0.5.0) (1.15.0)
Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torchvision==0.5.0) (1.18.5)
Installing collected packages: torch, torchvision
  Found existing installation: torch 1.6.0+cu101
    Uninstalling torch-1.6.0+cu101:
      Successfully uninstalled torch-1.6.0+cu101
  Found existing installation: torchvision 0.7.0+cu101
    Uninstalling torchvision-0.7.0+cu101:
      Successfully uninstalled torchvision-0.7.0+cu101
Successfully installed torch-1.4.0 torchvision-0.5.0
In [4]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
In [5]:
from fastai.vision import *
from fastai.metrics import error_rate
In [ ]:
# leafsnap_url = 'http://leafsnap.com/static/dataset/leafsnap-dataset'
# tar_path = download_data(leafsnap_url, ext='.tar')

leafsnap_url = 'http://leafsnap.com/static/dataset/leafsnap-dataset.tar'
!wget {leafsnap_url}

import tarfile
tar = tarfile.open('leafsnap-dataset.tar')
tar.extractall(path='leafsnap-dataset')
--2020-08-07 23:31:31--  http://leafsnap.com/static/dataset/leafsnap-dataset.tar
Resolving leafsnap.com (leafsnap.com)... 128.59.23.133
Connecting to leafsnap.com (leafsnap.com)|128.59.23.133|:80... ^C
In [6]:
import zipfile

with zipfile.ZipFile(drive_path/'leafsnap-kaggle.zip', 'r') as zip:
    zip.extractall(base_path)
In [7]:
import pandas as pd
dataset_df = pd.read_csv(base_path/'leafsnap-dataset-images.txt', sep='\t')
dataset_df.head()
Out[7]:
file_id image_path segmented_path species source
0 55497 dataset/images/lab/abies_concolor/ny1157-01-1.jpg dataset/segmented/lab/abies_concolor/ny1157-01... Abies concolor lab
1 55498 dataset/images/lab/abies_concolor/ny1157-01-2.jpg dataset/segmented/lab/abies_concolor/ny1157-01... Abies concolor lab
2 55499 dataset/images/lab/abies_concolor/ny1157-01-3.jpg dataset/segmented/lab/abies_concolor/ny1157-01... Abies concolor lab
3 55500 dataset/images/lab/abies_concolor/ny1157-01-4.jpg dataset/segmented/lab/abies_concolor/ny1157-01... Abies concolor lab
4 55501 dataset/images/lab/abies_concolor/ny1157-02-1.jpg dataset/segmented/lab/abies_concolor/ny1157-02... Abies concolor lab
In [8]:
dataset_df.columns
Out[8]:
Index(['file_id', 'image_path', 'segmented_path', 'species', 'source'], dtype='object')
In [9]:
# field-only images
# train_dataset_df = dataset_df[dataset_df.source == 'field'][['image_path', 'species']]

# all images
train_dataset_df = dataset_df[['image_path', 'species']]
In [10]:
# simplify available classes to more generic species
simplified_species = {
  'Acer campestre': 'Maple',
  'Acer ginnala': 'Maple',
  'Acer griseum': 'Maple',
  'Cedrus atlantica': 'Cedar',
  'Cedrus deodara': 'Cedar',
  'Cedrus libani': 'Cedar',
  'Broussonettia papyrifera': 'Mulberry',
  'Chamaecyparis pisifera': 'Cypress',
  'Cornus kousa': 'Dogwood',
  'Cryptomeria japonica': 'Cedar',
  'Fraxinus pennsylvanica': 'Ash',
  'Gleditsia triacanthos': 'Honeylocust',
  'Juniperus virginiana': 'Redcedar',
  'Malus hupehensis': 'Crabapple',
  'Morus rubra': 'Mulberry',
  'Picea abies': 'Spruce',
  'Picea orientalis': 'Spruce',
  'Picea pungens': 'Spruce',
  'Pinus bungeana': 'Pine',
  'Pinus cembra': 'Pine',
  'Pinus densiflora': 'Pine',
  'Pinus flexilis': 'Pine',
  'Pinus koraiensis': 'Pine',
  'Pinus parviflora': 'Pine',
  'Pinus peucea': 'Pine',
  'Pinus rigida': 'Pine',
  'Pinus strobus': 'Pine',
  'Pinus thunbergii': 'Pine',
  'Pinus wallichiana': 'Pine',
  'Platanus acerifolia': 'Plane Tree',
  'Quercus montana': 'Oak',
  'Quercus muehlenbergii': 'Oak',
  'Pseudolarix amabilis': 'Larch',
  'Quercus alba': 'Oak',
  'Ailanthus altissima': 'Tree of Heaven',
  'Taxodium distichum': 'Cypress',
  'Tsuga canadensis': 'Hemlock',
  'Ulmus hollandica': 'Elm',
  'Ulmus americana': 'Elm',
  'Euonymus europaeus': 'Splindetree',
  'Syringa vulgaris': 'Lilac',
  'Chamaecyparis thyoides': 'Cedar',
  'Acer platanoides': 'Maple',
  'Acer pseudoplatanus': 'Maple',
  'Acer palmatum': 'Maple',
  'Acer rubrum': 'Maple',
  'Acer saccharinum': 'Maple',
  'Betula lenta': 'Birch',
  'Aesculus pavi': 'Buckeye',
  'Betula jacqemontii': 'Birch',
  'Aesculus flava': 'Buckeye',
  'Amelanchier canadensis': 'Serviceberry',
  'Asimina triloba': 'Pawpaw',
  'Betula nigra': 'Birch',
  'Carpinus betulus': 'Hornbeam',
  'Carya tomentosa': 'Hickory',
  'Carya cordiformis': 'Hickory',
  'Castanea dentata': 'Chestnut',
  'Catalpa speciosa': 'Catalpa',
  'Celtis occidentalis': 'Hackberry',
  'Cercis canadensis': 'Redbud',
  'Cercidiphyllum japonicum': 'Katsura',
  'Chionanthus retusus': 'Fringetree',
  'Cladrastis lutea': 'Yellowwood',
  'Crataegus phaenopyrum': 'Hawthorn',
  'Evodia daniellii': 'Evodia',
  'Diospyros virginiana': 'Persimmon',
  'Eucommia ulmoides': 'Rubbertree',
  'Fagus grandifolia': 'Beech',
  'Fagus sylvatica': 'Beech',
  'Pinus pungens': 'Pine',
  'Pinus resinosa': 'Pine',
  'Pinus sylvestris': 'Pine',
  'Pinus taeda': 'Pine',
  'Pinus virginiana': 'Pine',
  'Fraxinus americana': 'Ash',
  'Fraxinus nigra': 'Ash',
  'Halesia tetraptera': 'Silverbell',
  'Tilia cordata': 'Linden',
  'Tilia europaea': 'Linden',
  'Tilia tomentosa': 'Linden',
  'Ilex opaca': 'Holly',
  'Juglans nigra': 'Walnut',
  'Koelreuteria paniculata': 'Goldenrain',
  'Liquidambar styraciflua': 'Sweetgum',
  'Maclura pomifera': 'Osage Orange',
  'Magnolia acuminata': 'Magnolia',
  'Magnolia denudata': 'Magnolia',
  'Magnolia soulangiana': 'Magnolia',
  'Magnolia stellata': 'Magnolia',
  'Malus baccata': 'Crabapple',
  'Morus alba': 'Mulberry',
  'Nyssa sylvatica': 'Tupelo',
  'Oxydendrum arboreum': 'Sourwood',
  'Paulownia tomentosa': 'Empress Tree',
  'Phellodendron amurense': 'Corktree',
  'Platanus occidentalis': 'Sycamore',
  'Populus deltoides': 'Cottonwood',
  'Prunus sargentii': 'Cherry',
  'Prunus serotina': 'Cherry',
  'Chionanthus virginicus': 'Fringetree',
  'Gymnocladus dioicus': 'Coffeetree',
  'Prunus serrulata': 'Cherry',
  'Prunus subhirtella': 'Cherry',
  'Prunus yedoensis': 'Cherry',
  'Ptelea trifoliata': 'Hoptree',
  'Quercus bicolor': 'Oak',
  'Quercus cerris': 'Oak',
  'Quercus coccinea': 'Oak',
  'Quercus imbricaria': 'Oak',
  'Quercus macrocarpa': 'Oak',
  'Quercus palustris': 'Oak',
  'Quercus phellos': 'Oak',
  'Quercus velutina': 'Oak',
  'Abies concolor': 'Fir',
  'Picea rubens': 'Spruce',
  'Magnolia grandiflora': 'Magnolia',
  'Syringa reticulata': 'Lilac',
  'Tilia americana': 'Linden',
  'Pseudotsuga menziesii': 'Fir',
  'Prunus cerasifera': 'Cherry',
  'Toona sinensis': 'Chinese Toon',
  'Ulmus glabra': 'Elm',
  'Rhus typhina': 'Sumac',
  'Larix decidua': 'Larch',
  'Amelanchier arborea': 'Serviceberry',
  'Amelanchier laevis': 'Serviceberry',
  'Betula alleghaniensis': 'Birch',
  'Betula populifolia': 'Birch',
  'Betula platyphylla': 'Birch',
  'Prunus spinosa': 'Cherry',
  'Acer saccharum': 'Maple',
  'Magnolia kobus': 'Magnolia',
  'Clerodendrum trichotomum': 'Glorybower',
  'Carpinus caroliniana': 'Hornbeam',
  'Carya glabra': 'Hickory',
  'Cornus drummondii': 'Dogwood',
  'Aesculus glabra': 'Buckeye',
  'Aesculus hippocastamon': 'Buckeye',
  'Catalpa bignonioides': 'Catalpa',
  'Crataegus monogyna': 'Hawthorn',
  'Photinia villosa': 'Photinia',
  'Sorbus aucuparia': 'Ash',
  'Quercus falcata': 'Oak',
  'Quercus marilandica': 'Oak',
  'Acer pensylvanicum': 'Maple',
  'Picea glauca': 'Spruce',
  'Pinus echinata': 'Pine',
  'Viburnum lantana': 'Wayfaringtree',
  'Crataegus punctata': 'Hawthorn',
  'Crataegus succulenta': 'Hawthorn',
  'Zelkova serrata': 'Zelkova',
  'Fraxinus excelsior': 'Ash',
  'Ficus carica': 'Fig',
  'Carya ovata': 'Hickory',
  'Ostrya virginiana': 'Hophornbeam',
  'Populus grandidentata': 'Aspen',
  'Populus tremuloides': 'Aspen',
  'Buxus sempervirens': 'Box',
  'Prunus pensylvanica': 'Cherry',
  'Prunus virginiana': 'Cherry',
  'Betula papyrifera': 'Birch',
  'Betula pendula': 'Birch',
  'Betula pubescens': 'Birch',
  'Acer nigrum': 'Maple',
  'Alnus incana': 'Alder',
  'Juglans regia': 'Walnut',
  'Salix caroliniana': 'Willow',
  'Ulmus parvifolia': 'Elm',
  'Ulmus rubra': 'Elm',
  'Betula uber': 'Birch',
  'Quercus shumardii': 'Oak',
  'Salix alba': 'Willow',
  'Crataegus laevigata': 'Hawthorn',
  'Cornus mas': 'Cherry',
  'Corylus colurna': 'Filbert',
  'Crataegus crus-galli': 'Hawthorn',
  'Quercus robur': 'Oak',
  'Quercus rubra': 'Oak',
  'Juglans cinerea': 'Walnut',
  'Salix babylonica': 'Willow',
  'Crataegus dilatata': 'Hawthorn',
  'Robinia pseudo-acacia': 'Locust',
  'Ulmus procera': 'Elm',
  'Ulmus pumila': 'Elm',
  'Acer negundo': 'Elder',
  'Albizia julibrissin': 'Mimosa',
  'Alnus glutinosa': 'Alder',
  'Magnolia macrophylla': 'Magnolia',
  'Magnolia tripetala': 'Magnolia',
  'Quercus michauxii': 'Oak',
  'Quercus nigra': 'Oak',
  'Abies nordmanniana': 'Fir',
  'Malus floribunda': 'Crabapple',
  'Pyrus calleryana': 'Callery Pear',
  'Cornus florida': 'Dogwood',
  'Metasequoia glyptostroboides': 'Metasequoia',
  'Magnolia virginiana': 'Magnolia',
  'Sassafras albidum': 'Sassafrass',
  'Quercus acutissima': 'Oak',
  'Celtis tenuifolia': 'Hackberry',
  'Pinus nigra': 'Pine',
  'Quercus stellata': 'Oak',
  'Quercus virginiana': 'Oak',
  'Crataegus pruinosa': 'Hawthorn',
  'Crataegus viridis': 'Hawthorn',
  'Liriodendron tulipifera': 'Tuliptree',
  'Malus angustifolia': 'Crabapple',
  'Malus coronaria': 'Crabapple',
  'Malus pumila': 'Crabapple',
  'Abies fraseri': 'Fir',
  'Salix matsudana': 'Willow',
  'Staphylea trifolia': 'Bladdernut',
  'Styrax japonica': 'Snowbell',
  'Juniperus communis': 'Juniper',
  'Salix nigra': 'Willow',
  'Stewartia pseudocamellia': 'Stewartia',
  'Styrax obassia': 'Snowbell',
  'Picea mariana': 'Spruce',
  'Ginkgo biloba': 'Ginkgo',
  'Fringretree': 'Fringetree',
}
In [11]:
train_dataset_df.replace(simplified_species, inplace=True)

simplified_species_list = list(train_dataset_df.species.unique())
simplified_species_list.sort()

print(len(simplified_species_list))
simplified_species_list
72
/usr/local/lib/python3.6/dist-packages/pandas/core/frame.py:4172: SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  method=method,
Out[11]:
['Ash',
 'Aspen',
 'Beech',
 'Birch',
 'Bladdernut',
 'Buckeye',
 'Callery Pear',
 'Catalpa',
 'Cedar',
 'Cherry',
 'Chestnut',
 'Chinese Toon',
 'Coffeetree',
 'Corktree',
 'Cottonwood',
 'Crabapple',
 'Cypress',
 'Dogwood',
 'Elder',
 'Elm',
 'Empress Tree',
 'Evodia',
 'Fig',
 'Filbert',
 'Fir',
 'Fringetree',
 'Ginkgo',
 'Goldenrain',
 'Hackberry',
 'Hawthorn',
 'Hemlock',
 'Hickory',
 'Holly',
 'Honeylocust',
 'Hophornbeam',
 'Hoptree',
 'Hornbeam',
 'Katsura',
 'Larch',
 'Lilac',
 'Linden',
 'Locust',
 'Magnolia',
 'Maple',
 'Metasequoia',
 'Mimosa',
 'Mulberry',
 'Oak',
 'Osage Orange',
 'Pawpaw',
 'Persimmon',
 'Pine',
 'Plane Tree',
 'Redbud',
 'Redcedar',
 'Rubbertree',
 'Sassafrass',
 'Serviceberry',
 'Silverbell',
 'Snowbell',
 'Sourwood',
 'Spruce',
 'Stewartia',
 'Sweetgum',
 'Sycamore',
 'Tree of Heaven',
 'Tuliptree',
 'Tupelo',
 'Walnut',
 'Willow',
 'Yellowwood',
 'Zelkova']
In [ ]:
data = ImageDataBunch.from_df(base_path/'leafsnap-dataset', train_dataset_df, size=224, bs=64).normalize()
In [ ]:
data.show_batch(rows=3, figsize=(10, 10))

Training: resnet34

In [ ]:
model_rn34 = cnn_learner(data, models.resnet34, metrics=error_rate)
Downloading: "https://download.pytorch.org/models/resnet34-333f7ec4.pth" to /root/.cache/torch/checkpoints/resnet34-333f7ec4.pth

In [ ]:
model_rn34.fit_one_cycle(4)
epoch train_loss valid_loss error_rate time
0 1.211646 0.665800 0.199255 11:05
1 0.627393 0.341471 0.108213 11:05
2 0.379345 0.234967 0.074842 11:14
3 0.260320 0.193733 0.063664 11:14
In [ ]:
model_rn34.save('leafsnap-rn34-4e')
In [ ]:
interp = ClassificationInterpretation.from_learner(model_rn34)
interp.plot_top_losses(9, figsize=(15,11))
In [ ]:
interp.plot_confusion_matrix(figsize=(12,12), dpi=60)
In [ ]:
interp.most_confused(min_val=5)
Out[ ]:
[('Oak', 'Cherry', 16),
 ('Pine', 'Cherry', 13),
 ('Cherry', 'Oak', 10),
 ('Cherry', 'Crabapple', 7),
 ('Crabapple', 'Cherry', 7),
 ('Oak', 'Magnolia', 7),
 ('Cherry', 'Magnolia', 6),
 ('Cherry', 'Pine', 6),
 ('Birch', 'Hawthorn', 5),
 ('Cedar', 'Larch', 5),
 ('Mulberry', 'Birch', 5)]
In [ ]:
model_rn34.unfreeze()
model_rn34.fit_one_cycle(1)
epoch train_loss valid_loss error_rate time
0 0.301033 0.151590 0.050867 11:22
In [ ]:
model_rn34.save('leafsnap-rn34-4e-ft1')
In [15]:
# from https://forums.fast.ai/t/automated-learning-rate-suggester/44199
def find_appropriate_lr(model:Learner, lr_diff:int = 15, loss_threshold:float = .05, adjust_value:float = 1, plot:bool = False) -> float:
    #Run the Learning Rate Finder
    model.lr_find()
    
    #Get loss values and their corresponding gradients, and get lr values
    losses = np.array(model.recorder.losses)
    assert(lr_diff < len(losses))
    loss_grad = np.gradient(losses)
    lrs = model.recorder.lrs
    
    #Search for index in gradients where loss is lowest before the loss spike
    #Initialize right and left idx using the lr_diff as a spacing unit
    #Set the local min lr as -1 to signify if threshold is too low
    r_idx = -1
    l_idx = r_idx - lr_diff
    while (l_idx >= -len(losses)) and (abs(loss_grad[r_idx] - loss_grad[l_idx]) > loss_threshold):
        local_min_lr = lrs[l_idx]
        r_idx -= 1
        l_idx -= 1

    lr_to_use = local_min_lr * adjust_value
    
    if plot:
        # plots the gradients of the losses in respect to the learning rate change
        plt.plot(loss_grad)
        plt.plot(len(losses)+l_idx, loss_grad[l_idx],markersize=10,marker='o',color='red')
        plt.ylabel("Loss")
        plt.xlabel("Index of LRs")
        plt.show()

        plt.plot(np.log10(lrs), losses)
        plt.ylabel("Loss")
        plt.xlabel("Log 10 Transform of Learning Rate")
        loss_coord = np.interp(np.log10(lr_to_use), np.log10(lrs), losses)
        plt.plot(np.log10(lr_to_use), loss_coord, markersize=10,marker='o',color='red')
        plt.show()
        
    return lr_to_use
In [ ]:
# model_rn34.lr_find()
# model_rn34.recorder.plot()
lr = find_appropriate_lr(model_rn34, plot=True)
0.00% [0/1 00:00<00:00]
epoch train_loss valid_loss error_rate time

15.58% [60/385 01:25<07:45 0.4707]
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
In [ ]:
model_rn34.unfreeze()
model_rn34.fit_one_cycle(8, max_lr=lr)
epoch train_loss valid_loss error_rate time
0 0.161282 0.143692 0.045197 11:23
1 0.311188 0.244588 0.075328 11:29
2 0.239615 0.196655 0.063340 11:34
3 0.135139 0.167998 0.054755 11:27
4 0.067463 0.088164 0.029807 11:29
5 0.029889 0.058575 0.018792 11:28
6 0.019870 0.047009 0.016686 11:31
7 0.012546 0.043165 0.015714 11:38
In [ ]:
model_rn34.export(file=drive_path/'leafsnap-rn34-4e-ft1-ft8.pkl', destroy=True)
this Learner object self-destroyed - it still exists, but no longer usable

Training: resnet50

In [12]:
data = ImageDataBunch.from_df(base_path/'leafsnap-dataset', train_dataset_df, size=299, bs=32, ds_tfms=get_transforms()).normalize()
In [13]:
model_rn50 = cnn_learner(data, models.resnet50, metrics=error_rate)
Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to /root/.cache/torch/checkpoints/resnet50-19c8e357.pth

In [16]:
# learn.lr_find()
# learn.recorder.plot()
lr = find_appropriate_lr(model_rn50, plot=True)
0.00% [0/1 00:00<00:00]
epoch train_loss valid_loss error_rate time

11.67% [90/771 02:19<17:36 17.2454]
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
In [17]:
model_rn50.fit_one_cycle(8, max_lr=lr)
epoch train_loss valid_loss error_rate time
0 1.015574 0.609384 0.184999 21:40
1 0.783965 0.592988 0.181597 21:12
2 0.598352 0.441541 0.140450 21:13
3 0.469173 0.243737 0.080998 22:00
4 0.286007 0.140592 0.048275 22:42
5 0.167208 0.079739 0.026729 22:11
6 0.119546 0.059385 0.021707 22:33
7 0.092457 0.053715 0.019925 22:32
In [18]:
model_rn50.save('leafsnap-rn50-8e')
In [19]:
interp = ClassificationInterpretation.from_learner(model_rn50)
interp.most_confused(min_val=5)
Out[19]:
[('Oak', 'Cherry', 20), ('Pine', 'Cherry', 6), ('Cherry', 'Oak', 5)]
In [20]:
interp.plot_top_losses(9, figsize=(15,11))
In [21]:
model_rn50.export(file=drive_path/'leafsnap-rn50-8e.pkl', destroy=True)
this Learner object self-destroyed - it still exists, but no longer usable