See the full explanation behind this demonstration notebook at FastAI quick test / LeafSnap.
from google.colab import drive
drive.mount('/content/gdrive')
from pathlib import Path
drive_path = Path('/content/gdrive/My Drive/leafsnap')
base_path = Path('/content/leafsnap-dataset')
!pip install "torch==1.4" "torchvision==0.5.0"
%reload_ext autoreload
%autoreload 2
%matplotlib inline
from fastai.vision import *
from fastai.metrics import error_rate
# leafsnap_url = 'http://leafsnap.com/static/dataset/leafsnap-dataset'
# tar_path = download_data(leafsnap_url, ext='.tar')
leafsnap_url = 'http://leafsnap.com/static/dataset/leafsnap-dataset.tar'
!wget {leafsnap_url}
import tarfile
tar = tarfile.open('leafsnap-dataset.tar')
tar.extractall(path='leafsnap-dataset')
import zipfile
with zipfile.ZipFile(drive_path/'leafsnap-kaggle.zip', 'r') as zip:
zip.extractall(base_path)
import pandas as pd
dataset_df = pd.read_csv(base_path/'leafsnap-dataset-images.txt', sep='\t')
dataset_df.head()
dataset_df.columns
# field-only images
# train_dataset_df = dataset_df[dataset_df.source == 'field'][['image_path', 'species']]
# all images
train_dataset_df = dataset_df[['image_path', 'species']]
# simplify available classes to more generic species
simplified_species = {
'Acer campestre': 'Maple',
'Acer ginnala': 'Maple',
'Acer griseum': 'Maple',
'Cedrus atlantica': 'Cedar',
'Cedrus deodara': 'Cedar',
'Cedrus libani': 'Cedar',
'Broussonettia papyrifera': 'Mulberry',
'Chamaecyparis pisifera': 'Cypress',
'Cornus kousa': 'Dogwood',
'Cryptomeria japonica': 'Cedar',
'Fraxinus pennsylvanica': 'Ash',
'Gleditsia triacanthos': 'Honeylocust',
'Juniperus virginiana': 'Redcedar',
'Malus hupehensis': 'Crabapple',
'Morus rubra': 'Mulberry',
'Picea abies': 'Spruce',
'Picea orientalis': 'Spruce',
'Picea pungens': 'Spruce',
'Pinus bungeana': 'Pine',
'Pinus cembra': 'Pine',
'Pinus densiflora': 'Pine',
'Pinus flexilis': 'Pine',
'Pinus koraiensis': 'Pine',
'Pinus parviflora': 'Pine',
'Pinus peucea': 'Pine',
'Pinus rigida': 'Pine',
'Pinus strobus': 'Pine',
'Pinus thunbergii': 'Pine',
'Pinus wallichiana': 'Pine',
'Platanus acerifolia': 'Plane Tree',
'Quercus montana': 'Oak',
'Quercus muehlenbergii': 'Oak',
'Pseudolarix amabilis': 'Larch',
'Quercus alba': 'Oak',
'Ailanthus altissima': 'Tree of Heaven',
'Taxodium distichum': 'Cypress',
'Tsuga canadensis': 'Hemlock',
'Ulmus hollandica': 'Elm',
'Ulmus americana': 'Elm',
'Euonymus europaeus': 'Splindetree',
'Syringa vulgaris': 'Lilac',
'Chamaecyparis thyoides': 'Cedar',
'Acer platanoides': 'Maple',
'Acer pseudoplatanus': 'Maple',
'Acer palmatum': 'Maple',
'Acer rubrum': 'Maple',
'Acer saccharinum': 'Maple',
'Betula lenta': 'Birch',
'Aesculus pavi': 'Buckeye',
'Betula jacqemontii': 'Birch',
'Aesculus flava': 'Buckeye',
'Amelanchier canadensis': 'Serviceberry',
'Asimina triloba': 'Pawpaw',
'Betula nigra': 'Birch',
'Carpinus betulus': 'Hornbeam',
'Carya tomentosa': 'Hickory',
'Carya cordiformis': 'Hickory',
'Castanea dentata': 'Chestnut',
'Catalpa speciosa': 'Catalpa',
'Celtis occidentalis': 'Hackberry',
'Cercis canadensis': 'Redbud',
'Cercidiphyllum japonicum': 'Katsura',
'Chionanthus retusus': 'Fringetree',
'Cladrastis lutea': 'Yellowwood',
'Crataegus phaenopyrum': 'Hawthorn',
'Evodia daniellii': 'Evodia',
'Diospyros virginiana': 'Persimmon',
'Eucommia ulmoides': 'Rubbertree',
'Fagus grandifolia': 'Beech',
'Fagus sylvatica': 'Beech',
'Pinus pungens': 'Pine',
'Pinus resinosa': 'Pine',
'Pinus sylvestris': 'Pine',
'Pinus taeda': 'Pine',
'Pinus virginiana': 'Pine',
'Fraxinus americana': 'Ash',
'Fraxinus nigra': 'Ash',
'Halesia tetraptera': 'Silverbell',
'Tilia cordata': 'Linden',
'Tilia europaea': 'Linden',
'Tilia tomentosa': 'Linden',
'Ilex opaca': 'Holly',
'Juglans nigra': 'Walnut',
'Koelreuteria paniculata': 'Goldenrain',
'Liquidambar styraciflua': 'Sweetgum',
'Maclura pomifera': 'Osage Orange',
'Magnolia acuminata': 'Magnolia',
'Magnolia denudata': 'Magnolia',
'Magnolia soulangiana': 'Magnolia',
'Magnolia stellata': 'Magnolia',
'Malus baccata': 'Crabapple',
'Morus alba': 'Mulberry',
'Nyssa sylvatica': 'Tupelo',
'Oxydendrum arboreum': 'Sourwood',
'Paulownia tomentosa': 'Empress Tree',
'Phellodendron amurense': 'Corktree',
'Platanus occidentalis': 'Sycamore',
'Populus deltoides': 'Cottonwood',
'Prunus sargentii': 'Cherry',
'Prunus serotina': 'Cherry',
'Chionanthus virginicus': 'Fringetree',
'Gymnocladus dioicus': 'Coffeetree',
'Prunus serrulata': 'Cherry',
'Prunus subhirtella': 'Cherry',
'Prunus yedoensis': 'Cherry',
'Ptelea trifoliata': 'Hoptree',
'Quercus bicolor': 'Oak',
'Quercus cerris': 'Oak',
'Quercus coccinea': 'Oak',
'Quercus imbricaria': 'Oak',
'Quercus macrocarpa': 'Oak',
'Quercus palustris': 'Oak',
'Quercus phellos': 'Oak',
'Quercus velutina': 'Oak',
'Abies concolor': 'Fir',
'Picea rubens': 'Spruce',
'Magnolia grandiflora': 'Magnolia',
'Syringa reticulata': 'Lilac',
'Tilia americana': 'Linden',
'Pseudotsuga menziesii': 'Fir',
'Prunus cerasifera': 'Cherry',
'Toona sinensis': 'Chinese Toon',
'Ulmus glabra': 'Elm',
'Rhus typhina': 'Sumac',
'Larix decidua': 'Larch',
'Amelanchier arborea': 'Serviceberry',
'Amelanchier laevis': 'Serviceberry',
'Betula alleghaniensis': 'Birch',
'Betula populifolia': 'Birch',
'Betula platyphylla': 'Birch',
'Prunus spinosa': 'Cherry',
'Acer saccharum': 'Maple',
'Magnolia kobus': 'Magnolia',
'Clerodendrum trichotomum': 'Glorybower',
'Carpinus caroliniana': 'Hornbeam',
'Carya glabra': 'Hickory',
'Cornus drummondii': 'Dogwood',
'Aesculus glabra': 'Buckeye',
'Aesculus hippocastamon': 'Buckeye',
'Catalpa bignonioides': 'Catalpa',
'Crataegus monogyna': 'Hawthorn',
'Photinia villosa': 'Photinia',
'Sorbus aucuparia': 'Ash',
'Quercus falcata': 'Oak',
'Quercus marilandica': 'Oak',
'Acer pensylvanicum': 'Maple',
'Picea glauca': 'Spruce',
'Pinus echinata': 'Pine',
'Viburnum lantana': 'Wayfaringtree',
'Crataegus punctata': 'Hawthorn',
'Crataegus succulenta': 'Hawthorn',
'Zelkova serrata': 'Zelkova',
'Fraxinus excelsior': 'Ash',
'Ficus carica': 'Fig',
'Carya ovata': 'Hickory',
'Ostrya virginiana': 'Hophornbeam',
'Populus grandidentata': 'Aspen',
'Populus tremuloides': 'Aspen',
'Buxus sempervirens': 'Box',
'Prunus pensylvanica': 'Cherry',
'Prunus virginiana': 'Cherry',
'Betula papyrifera': 'Birch',
'Betula pendula': 'Birch',
'Betula pubescens': 'Birch',
'Acer nigrum': 'Maple',
'Alnus incana': 'Alder',
'Juglans regia': 'Walnut',
'Salix caroliniana': 'Willow',
'Ulmus parvifolia': 'Elm',
'Ulmus rubra': 'Elm',
'Betula uber': 'Birch',
'Quercus shumardii': 'Oak',
'Salix alba': 'Willow',
'Crataegus laevigata': 'Hawthorn',
'Cornus mas': 'Cherry',
'Corylus colurna': 'Filbert',
'Crataegus crus-galli': 'Hawthorn',
'Quercus robur': 'Oak',
'Quercus rubra': 'Oak',
'Juglans cinerea': 'Walnut',
'Salix babylonica': 'Willow',
'Crataegus dilatata': 'Hawthorn',
'Robinia pseudo-acacia': 'Locust',
'Ulmus procera': 'Elm',
'Ulmus pumila': 'Elm',
'Acer negundo': 'Elder',
'Albizia julibrissin': 'Mimosa',
'Alnus glutinosa': 'Alder',
'Magnolia macrophylla': 'Magnolia',
'Magnolia tripetala': 'Magnolia',
'Quercus michauxii': 'Oak',
'Quercus nigra': 'Oak',
'Abies nordmanniana': 'Fir',
'Malus floribunda': 'Crabapple',
'Pyrus calleryana': 'Callery Pear',
'Cornus florida': 'Dogwood',
'Metasequoia glyptostroboides': 'Metasequoia',
'Magnolia virginiana': 'Magnolia',
'Sassafras albidum': 'Sassafrass',
'Quercus acutissima': 'Oak',
'Celtis tenuifolia': 'Hackberry',
'Pinus nigra': 'Pine',
'Quercus stellata': 'Oak',
'Quercus virginiana': 'Oak',
'Crataegus pruinosa': 'Hawthorn',
'Crataegus viridis': 'Hawthorn',
'Liriodendron tulipifera': 'Tuliptree',
'Malus angustifolia': 'Crabapple',
'Malus coronaria': 'Crabapple',
'Malus pumila': 'Crabapple',
'Abies fraseri': 'Fir',
'Salix matsudana': 'Willow',
'Staphylea trifolia': 'Bladdernut',
'Styrax japonica': 'Snowbell',
'Juniperus communis': 'Juniper',
'Salix nigra': 'Willow',
'Stewartia pseudocamellia': 'Stewartia',
'Styrax obassia': 'Snowbell',
'Picea mariana': 'Spruce',
'Ginkgo biloba': 'Ginkgo',
'Fringretree': 'Fringetree',
}
train_dataset_df.replace(simplified_species, inplace=True)
simplified_species_list = list(train_dataset_df.species.unique())
simplified_species_list.sort()
print(len(simplified_species_list))
simplified_species_list
data = ImageDataBunch.from_df(base_path/'leafsnap-dataset', train_dataset_df, size=224, bs=64).normalize()
data.show_batch(rows=3, figsize=(10, 10))
model_rn34 = cnn_learner(data, models.resnet34, metrics=error_rate)
model_rn34.fit_one_cycle(4)
model_rn34.save('leafsnap-rn34-4e')
interp = ClassificationInterpretation.from_learner(model_rn34)
interp.plot_top_losses(9, figsize=(15,11))
interp.plot_confusion_matrix(figsize=(12,12), dpi=60)
interp.most_confused(min_val=5)
model_rn34.unfreeze()
model_rn34.fit_one_cycle(1)
model_rn34.save('leafsnap-rn34-4e-ft1')
# from https://forums.fast.ai/t/automated-learning-rate-suggester/44199
def find_appropriate_lr(model:Learner, lr_diff:int = 15, loss_threshold:float = .05, adjust_value:float = 1, plot:bool = False) -> float:
#Run the Learning Rate Finder
model.lr_find()
#Get loss values and their corresponding gradients, and get lr values
losses = np.array(model.recorder.losses)
assert(lr_diff < len(losses))
loss_grad = np.gradient(losses)
lrs = model.recorder.lrs
#Search for index in gradients where loss is lowest before the loss spike
#Initialize right and left idx using the lr_diff as a spacing unit
#Set the local min lr as -1 to signify if threshold is too low
r_idx = -1
l_idx = r_idx - lr_diff
while (l_idx >= -len(losses)) and (abs(loss_grad[r_idx] - loss_grad[l_idx]) > loss_threshold):
local_min_lr = lrs[l_idx]
r_idx -= 1
l_idx -= 1
lr_to_use = local_min_lr * adjust_value
if plot:
# plots the gradients of the losses in respect to the learning rate change
plt.plot(loss_grad)
plt.plot(len(losses)+l_idx, loss_grad[l_idx],markersize=10,marker='o',color='red')
plt.ylabel("Loss")
plt.xlabel("Index of LRs")
plt.show()
plt.plot(np.log10(lrs), losses)
plt.ylabel("Loss")
plt.xlabel("Log 10 Transform of Learning Rate")
loss_coord = np.interp(np.log10(lr_to_use), np.log10(lrs), losses)
plt.plot(np.log10(lr_to_use), loss_coord, markersize=10,marker='o',color='red')
plt.show()
return lr_to_use
# model_rn34.lr_find()
# model_rn34.recorder.plot()
lr = find_appropriate_lr(model_rn34, plot=True)
model_rn34.unfreeze()
model_rn34.fit_one_cycle(8, max_lr=lr)
model_rn34.export(file=drive_path/'leafsnap-rn34-4e-ft1-ft8.pkl', destroy=True)
data = ImageDataBunch.from_df(base_path/'leafsnap-dataset', train_dataset_df, size=299, bs=32, ds_tfms=get_transforms()).normalize()
model_rn50 = cnn_learner(data, models.resnet50, metrics=error_rate)
# learn.lr_find()
# learn.recorder.plot()
lr = find_appropriate_lr(model_rn50, plot=True)
model_rn50.fit_one_cycle(8, max_lr=lr)
model_rn50.save('leafsnap-rn50-8e')
interp = ClassificationInterpretation.from_learner(model_rn50)
interp.most_confused(min_val=5)
interp.plot_top_losses(9, figsize=(15,11))
model_rn50.export(file=drive_path/'leafsnap-rn50-8e.pkl', destroy=True)