In [ ]:
# Uncomment line below when using Colab (this installs OpenCV4)
# %system SwiftCV/install/install_colab.sh
%install-location $cwd/swift-install
%install '.package(path: "$cwd/FastaiNotebook_07_batchnorm")' FastaiNotebook_07_batchnorm
%install '.package(path: "$cwd/SwiftCV")' SwiftCV
In [ ]:
//export
import Path
import TensorFlow
import Python
In [ ]:
import FastaiNotebook_07_batchnorm
In [ ]:
%include "EnableIPythonDisplay.swift"
IPythonDisplay.shell.enable_matplotlib("inline")
Out[ ]:
The DataBlock API in Python is designed to help with the routine data manipulations involved in modelling: downloading data, loading it given an understanding of its layout on the filesystem, processing it, and feeding it into an ML framework like fastai. This is a data pipeline. How do we do this in Swift?
One approach is to build a set of types (structs, protocols, etc.) which represent various stages of this pipeline. By making the types generic, we could build a library that handled data for many kinds of models. However, it is sometimes a good rule of thumb, before writing generic types, to start by writing concrete types and then to notice what to abstract into a generic later. And another good rule of thumb, before writing concrete types, is to write no types at all, and to see how far you can get with a more primitive tool for composition: functions.
This notebook shows how to perform DataBlock-like operations using a lightweight functional style. This means, first, to rely as much as possible on pure functions -- that is, functions which do nothing but return outputs based on their inputs, and which don't mutate values anywhere. Second, in particular, it means to use Swift's support for higher-order functions (functions which take functions, like map
, filter
, reduce
, and compose
). Finally, this example relies on tuples. Like structs, tuples can have named, typed properties. Unlike structs, you don't need to name them. They can be a fast, ad-hoc way to explore the data types that you actually need, without being distracted by considering what's a method, an initializer, etc.,
Swift has excellent, understated support for a such a style.
In [ ]:
//export
public let dataPath = Path.home/".fastai"/"data"
In [ ]:
//export
public func downloadImagenette(path: Path = dataPath, sz:String="-320") -> Path {
let url = "https://s3.amazonaws.com/fast-ai-imageclas/imagenette\(sz).tgz"
let fname = "imagenette\(sz)"
let file = path/fname
try! path.mkdir(.p)
if !file.exists {
downloadFile(url, dest:(path/"\(fname).tgz").string)
_ = "/bin/tar".shell("-xzf", (path/"\(fname).tgz").string, "-C", path.string)
}
return file
}
Then we write a function to collect all the files in a directory, recursively.
In [ ]:
//export
public func collectFiles(under path: Path, recurse: Bool = false, filtering extensions: [String]? = nil) -> [Path] {
var res: [Path] = []
for p in try! path.ls(){
if p.kind == .directory && recurse {
res += collectFiles(under: p.path, recurse: recurse, filtering: extensions)
} else if extensions == nil || extensions!.contains(p.path.extension.lowercased()) {
res.append(p.path)
}
}
return res
}
To build our dataset, we need, at the core, only four functions that tells us:
We put those four requirements in a DatasetConfig
protocol.
In [ ]:
//export
public protocol DatasetConfig {
associatedtype Item
associatedtype Label
static func download() -> Path
static func getItems(_ path: Path) -> [Item]
static func isTraining(_ item: Item) -> Bool
static func labelOf(_ item: Item) -> Label
}
Here is what we know ahead of time about how imagenette data is laid out on disk:
.
└── data # <-- this is the fastai data root path
├── imagenette-160 # <-- this is the imagenette dataset path
│ ├── train # <-- the train/ and val/ dirs are our two segments
│ │ ├── n01440764 # <-- this is an image category _label_
│ │ │ ├── n01440764_10026.JPEG # <-- this is an image (a _sample_) with that label
│ │ │ ├── n01440764_10027.JPEG
│ │ │ ├── n01440764_10042.JPEG
...
│ ├── val
│ └── n03888257
│ ├── ILSVRC2012_val_00001440.JPEG
│ ├── ILSVRC2012_val_00002508.JPEG
...
We will define one type, an enum
, to capture this information.
This "empty" enum
will serve only as a namespace, a grouping, for pure functions representing this information. By putting this information into one type, our code is more modular: it more clearly distinguishes facts about this dataset, from general purpose data manipulators, from computations for this analysis.
Here's our Imagenette configuration type:
In [ ]:
//export
public enum ImageNette: DatasetConfig {
public static func download() -> Path { return downloadImagenette() }
public static func getItems(_ path: Path) -> [Path] {
return collectFiles(under: path, recurse: true, filtering: ["jpeg", "jpg"])
}
public static func isTraining(_ p:Path) -> Bool {
return p.parent.parent.basename() == "train"
}
public static func labelOf(_ p:Path) -> String { return p.parent.basename() }
}
From this configuration, we can get values by calling the download
and getItems
function. This step would be exactly the same for all datasets following the DatasetConfig
protocol:
In [ ]:
let path = ImageNette.download()
let allFnames = ImageNette.getItems(path)
This function will use our dataset configuration to describe a given item:
In [ ]:
//export
public func describeSample<C>(_ item: C.Item, config: C.Type) where C: DatasetConfig {
let isTraining = C.isTraining(item)
let label = C.labelOf(item)
print("""
item: \(item)
training?: \(isTraining)
label: \(label)
""")
}
In [ ]:
describeSample(allFnames[0], config: ImageNette.self)
We can see that our functions for path->isTraining and path->label are working as expected.
Now we want to split our samples into a training and validation sets. Since this is so routine we define a standard function that does so.
It is enough to take an array and returns a named tuple of two arrays, one for training and one for validation.
In [ ]:
//export
public func partitionIntoTrainVal<T>(_ items:[T],isTrain:((T)->Bool)) -> (train:[T],valid:[T]){
return (train: items.filter(isTrain), valid: items.filter { !isTrain($0) })
}
In [ ]:
var samples = partitionIntoTrainVal(allFnames, isTrain:ImageNette.isTraining)
And verify that it works as expected:
In [ ]:
describeSample(samples.valid.randomElement()!, config: ImageNette.self)
In [ ]:
describeSample(samples.train.randomElement()!, config: ImageNette.self)
We process the data by taking all training labels, uniquing them, sorting them, and then defining an integer to represent the label.
Those numerical labels let us define two functions, a function for label->number and the inverse function number->label.
But notable point is that the process that produces those functions is also a function: the input is a list of training labels, and the output is the label<->number bidirectional mappings.
That function which creates the bidirectional mapping is called initState
below. Those steps are generic and might be applied for other tasks, so we define another protocol for them.
In [ ]:
//export
public protocol Processor {
associatedtype Input
associatedtype Output
mutating func initState(_ items: [Input])
func process (_ item: Input) -> Output
func deprocess(_ item: Output) -> Input
}
And the specific CategoryProcessor
we need in this case.
In [ ]:
//export
public struct CategoryProcessor: Processor {
private(set) public var intToLabel: [String] = []
private(set) public var labelToInt: [String:Int] = [:]
public init() {}
public mutating func initState(_ items: [String]) {
intToLabel = Array(Set(items)).sorted()
labelToInt = Dictionary(uniqueKeysWithValues:
intToLabel.enumerated().map{ ($0.element, $0.offset) })
}
public func process(_ item: String) -> Int { return labelToInt[item]! }
public func deprocess(_ item: Int) -> String { return intToLabel[item] }
}
Let us create a labelNumber mapper from the training data. First we use the function labelOf
to get all the training labels, then we can initialize a CategoryProcessor
.
In [ ]:
var trainLabels = samples.train.map(ImageNette.labelOf)
var labelMapper = CategoryProcessor()
labelMapper.initState(trainLabels)
The labelMapper now supplies the two bidirectional functions. We can verify they have the required inverse relationship:
In [ ]:
var randomLabel = labelMapper.intToLabel.randomElement()!
print("label = \(randomLabel)")
var numericalizedLabel = labelMapper.process(randomLabel)
print("number = \(numericalizedLabel)")
var labelFromNumber = labelMapper.deprocess(numericalizedLabel)
print("label = \(labelFromNumber)")
Now we are in a position to give the data numerical labels.
Now in order to map from a sample item (a Path
), to a numerical label (an Int
), we just compose our Path->label function with a label->int function. Curiously, Swift does not define its own compose function, so we defined a compose
operator >|
ourselves. We can use it to create our new function as a composition explicitly:
In [ ]:
// export
public func >| <A, B, C>(_ f: @escaping (A) -> B,
_ g: @escaping (B) -> C) -> (A) -> C {
return { g(f($0)) }
}
The we define a function which map a raw sample (Path
) to a numericalized label (Int
)
In [ ]:
var pathToNumericalizedLabel = ImageNette.labelOf >| labelMapper.process
Now we can, if we wish, compute numericalized labels over all the training and validation items:
In [ ]:
var trainNumLabels = samples.train.map(pathToNumericalizedLabel)
var validNumLabels = samples.valid.map(pathToNumericalizedLabel)
We've gotten pretty far just using mostly just variables, functions, and function composition. But one downside is that our results are now scattered over a few different variables, samples
, trainNumLabels
, valNumLabels
. We collect these values into one structure for convenience:
In [ ]:
//export
public struct SplitLabeledData<Item,Label> {
public var train: [(x: Item, y: Label)]
public var valid: [(x: Item, y: Label)]
public init(train: [(x: Item, y: Label)], valid: [(x: Item, y: Label)]) {
(self.train,self.valid) = (train,valid)
}
}
And we can define a convenience function to build it directly from our config and a processor.
In [ ]:
//export
public func makeSLD<C, P>(config: C.Type, procL: inout P) -> SplitLabeledData<C.Item, P.Output>
where C: DatasetConfig, P: Processor, P.Input == C.Label{
let path = C.download()
let items = C.getItems(path)
let samples = partitionIntoTrainVal(items, isTrain:C.isTraining)
let trainLabels = samples.train.map(C.labelOf)
procL.initState(trainLabels)
let itemToProcessedLabel = C.labelOf >| procL.process
return SplitLabeledData(train: samples.train.map { ($0, itemToProcessedLabel($0)) },
valid: samples.valid.map { ($0, itemToProcessedLabel($0)) })
}
In [ ]:
var procL = CategoryProcessor()
let sld = makeSLD(config: ImageNette.self, procL: &procL)
We can use the same compose approach to convert our images from Path
filenames to resized images, or add all the data augmentation we want.
In [ ]:
//export
import Foundation
import SwiftCV
First let's open those images with openCV:
In [ ]:
//export
public func openImage(_ fn: Path) -> Mat {
return imdecode(try! Data(contentsOf: fn.url))
}
And add a convenience function to have a look.
In [ ]:
//export
public func showCVImage(_ img: Mat) {
let tensImg = Tensor<UInt8>(cvMat: img)!
let numpyImg = tensImg.makeNumpyArray()
plt.imshow(numpyImg)
plt.axis("off")
plt.show()
}
In [ ]:
showCVImage(openImage(sld.train.randomElement()!.x))
The channels are in BGR instead of RGB so we first switch them with openCV
In [ ]:
//export
public func BGRToRGB(_ img: Mat) -> Mat {
return cvtColor(img, nil, ColorConversionCode.COLOR_BGR2RGB)
}
Then we can resize them
In [ ]:
//export
public func resize(_ img: Mat, size: Int) -> Mat {
return resize(img, nil, Size(size, size), 0, 0, InterpolationFlag.INTER_LINEAR)
}
With our compose operator, the succession of transforms can be written in this pretty way:
In [ ]:
let transforms = openImage >| BGRToRGB >| { resize($0, size: 224) }
And we can have a look at one of our elements:
In [ ]:
showCVImage(transforms(sld.train.randomElement()!.x))
Now we will need tensors to train our model, so we need to convert our images and ints to tensors. Images are naturally converted to tensor of bytes.
In [ ]:
//export
public func cvImgToTensor(_ img: Mat) -> Tensor<UInt8> {
return Tensor<UInt8>(cvMat: img)!
}
We compose our transforms with that last function to get tensors.
In [ ]:
let pathToTF = transforms >| cvImgToTensor
In [ ]:
//export
public func intTOTI(_ i: Int) -> TI { return TI(Int32(i)) }
Now we define a Batcher
that will be responsible for creating minibatches as an iterator. It has the properties you know from PyTorch (batch size, num workers, shuffle) and will use multiprocessing to gather the images in parallel.
To be able to write for batch in Batcher(...)
, Batcher
needs to conform to Sequence
, which means it needs to have a makeIterator
function. That function has to return another struct that conforms to IteratorProtocol
. The only thing required there is a next
property that returns the next batch (or nil
if we are finished).
The code is pretty straightforward: we shuffle the dataset at each beginning of iteration if we want, then we apply the transforms in parallel with the use of concurrentMap
, that works just like map but with numWorkers
processes.
In [ ]:
//export
public struct Batcher<Item,Label,ScalarI: TensorFlowScalar,ScalarL: TensorFlowScalar>: Sequence {
public let dataset: [(Item, Label)]
public let xToTensor: (Item) -> Tensor<ScalarI>
public let yToTensor: (Label) -> Tensor<ScalarL>
public let collateFunc: (Tensor<ScalarI>, Tensor<ScalarL>) -> DataBatch<TF, TI>
public var bs: Int = 64
public var numWorkers: Int = 4
public var shuffle: Bool = false
public init(_ ds: [(Item, Label)],
xToTensor: @escaping (Item) -> Tensor<ScalarI>,
yToTensor: @escaping (Label) -> Tensor<ScalarL>,
collateFunc: @escaping (Tensor<ScalarI>, Tensor<ScalarL>) -> DataBatch<TF, TI>,
bs: Int = 64, numWorkers: Int = 4, shuffle: Bool = false) {
(dataset,self.xToTensor,self.yToTensor,self.collateFunc) = (ds,xToTensor,yToTensor,collateFunc)
(self.bs,self.numWorkers,self.shuffle) = (bs,numWorkers,shuffle)
}
public func makeIterator() -> BatchIterator<Item,Label,ScalarI,ScalarL> {
return BatchIterator(self, numWorkers: numWorkers, shuffle: shuffle)
}
}
public struct BatchIterator<Item,Label,ScalarI: TensorFlowScalar,ScalarL: TensorFlowScalar>: IteratorProtocol {
public let b: Batcher<Item,Label,ScalarI,ScalarL>
public var numWorkers: Int = 4
private var idx: Int = 0
private var ds: [(Item, Label)]
public init(_ batcher: Batcher<Item,Label,ScalarI,ScalarL>, numWorkers: Int = 4, shuffle: Bool = false){
(b,self.numWorkers,idx) = (batcher,numWorkers,0)
self.ds = shuffle ? b.dataset.shuffled() : b.dataset
}
public mutating func next() -> DataBatch<TF,TI>? {
guard idx < b.dataset.count else { return nil }
let end = idx + b.bs < b.dataset.count ? idx + b.bs : b.dataset.count
let samples = Array(ds[idx..<end])
idx += b.bs
return b.collateFunc(Tensor<ScalarI>(concatenating: samples.concurrentMap(nthreads: numWorkers) {
self.b.xToTensor($0.0).expandingShape(at: 0) }),
Tensor<ScalarL>(concatenating: samples.concurrentMap(nthreads: numWorkers) {
self.b.yToTensor($0.1).expandingShape(at: 0) }))
}
}
In [ ]:
SetNumThreads(0)
In [ ]:
//export
public func collateFunc(_ xb: Tensor<UInt8>, _ yb: TI) -> DataBatch<TF, TI> {
return DataBatch(xb: TF(xb)/255.0, yb: yb)
}
In [ ]:
let batcher = Batcher(sld.train, xToTensor: pathToTF, yToTensor: intTOTI, collateFunc: collateFunc, bs:256, shuffle:true)
In [ ]:
time {var c = 0
for batch in batcher { c += 1 }
}
In [ ]:
let firstBatch = batcher.first(where: {_ in true})!
In [ ]:
//export
func showTensorImage(_ img: TF) {
let numpyImg = img.makeNumpyArray()
plt.imshow(numpyImg)
plt.axis("off")
plt.show()
}
In [ ]:
showTensorImage(firstBatch.xb[0])
In [ ]:
import NotebookExport
let exporter = NotebookExport(Path.cwd/"08c_data_block_generic.ipynb")
print(exporter.export(usingPrefix: "FastaiNotebook_"))
In [ ]: