Using the high level transfer learning APIs, you can easily customize pretrained models for feature extraction or fine-tuning.
In this notebook, we will use a pre-trained Inception_V1 model. But we will operate on the pre-trained model to freeze first few layers, replace the classifier on the top, then fine tune the whole model. And we use the fine-tuned model to solve the dogs-vs-cats classification problem,
Download the training dataset from and extract it.
The following commands copy about 1100 images of cats and dogs into demo/cats and demo/dogs separately.
mkdir -p demo/dogs
mkdir -p demo/cats
cp train/cat.7* demo/cats
cp train/dog.7* demo/dogs
Download the pre-trained Inception-V1 model from Zoo Alternatively, user may also download pre-trained caffe/Tensorflow/keras model.
In [1]:
import re
from bigdl.nn.criterion import CrossEntropyCriterion
from import Pipeline
from pyspark.sql.functions import col, udf
from pyspark.sql.types import DoubleType, StringType
from zoo.common.nncontext import *
from zoo.feature.image import *
from zoo.pipeline.api.keras.layers import Dense, Input, Flatten
from zoo.pipeline.api.keras.models import *
from import *
from zoo.pipeline.nnframes import *
In [2]:
sc = init_nncontext("ImageTransferLearningExample")
manually set model_path and image_path for training
model_path = path to the pre-trained models. (E.g. path/to/model/bigdl_inception-v1_imagenet_0.4.0.model)
image_path = path to the folder of the training images. (E.g. path/to/data/dogs-vs-cats/demo/*/*)
In [3]:
model_path = "path/to/model/bigdl_inception-v1_imagenet_0.4.0.model"
image_path = "file://path/to/data/dogs-vs-cats/demo/*/*"
imageDF = NNImageReader.readImages(image_path, sc)
In [4]:
getName = udf(lambda row:'(cat|dog)\.([\d]*)\.jpg', row[0], re.IGNORECASE).group(0),
getLabel = udf(lambda name: 1.0 if name.startswith('cat') else 2.0, DoubleType())
labelDF = imageDF.withColumn("name", getName(col("image"))) \
.withColumn("label", getLabel(col('name')))
(trainingDF, validationDF) = labelDF.randomSplit([0.9, 0.1])"name","label").show(10)
We fine-tune a pre-trained model by removing the last few layers, freezing the first few layers, and adding some new layers.
In [5]:
transformer = ChainedPreprocessing(
[RowToImageFeature(), ImageResize(256, 256), ImageCenterCrop(224, 224),
ImageChannelNormalize(123.0, 117.0, 104.0), ImageMatToTensor(), ImageFeatureToTensor()])
We use the Net API to load a pre-trained model, including models saved by Analytics Zoo, BigDL, Torch, Caffe and Tensorflow. Please refer to Net API Guide.
In [6]:
full_model = Net.load_bigdl(model_path)
Here we print all the model layers and you can choose which layer(s) to remove.
When a model is loaded using Net, we can use the newGraph(output) api to define a Model with the output specified by the parameter.
In [7]:
for layer in full_model.layers:
print (
model = full_model.new_graph(["pool5/drop_7x7_s1"])
The returning model's output layer is "pool5/drop_7x7_s1".
We freeze layers from input to pool4/3x3_s2 inclusive.
In [8]:
In [9]:
inputNode = Input(name="input", shape=(3, 224, 224))
inception = model.to_keras()(inputNode)
flatten = Flatten()(inception)
logits = Dense(2)(flatten)
lrModel = Model(inputNode, logits)
classifier = NNClassifier(lrModel, CrossEntropyCriterion(), transformer) \
.setLearningRate(0.003).setBatchSize(40).setMaxEpoch(1).setFeaturesCol("image") \
pipeline = Pipeline(stages=[classifier])
The transfer learning can finish in a few minutes.
In [10]:
catdogModel =
predictionDF = catdogModel.transform(validationDF).cache()
In [11]:"name","label","prediction").sort("label", ascending=False).show(10)"name","label","prediction").show(10)
correct = predictionDF.filter("label=prediction").count()
overall = predictionDF.count()
accuracy = correct * 1.0 / overall
print("Test Error = %g " % (1.0 - accuracy))
As we can see, the model from transfer learning can achieve over 95% accuracy on the validation set.
We randomly select some images to show, and print the prediction results here.
cat: prediction = 1.0 dog: prediction = 2.0
In [12]:
sampledog=predictionDF.filter(predictionDF.prediction==2.0).sort("label", ascending=False).limit(3).collect()
In [13]:
from IPython.display import Image, display
for cat in samplecat:
print ("prediction:"), cat.prediction
In [14]:
for dog in sampledog:
print ("prediction:"), dog.prediction