Image Augmentation augments datasets (especially small datasets) to train model. The way to do image augmentation is to transform images by different ways. In this notebook we demonstrate how to do image augmentation using Analytics ZOO APIs.
In [1]:
from zoo.common.nncontext import init_nncontext
from zoo.feature.image import *
import cv2
import numpy as np
from IPython.display import Image, display
sc = init_nncontext("Image Augmentation Example")
In [2]:
# create LocalImageSet from an image
local_image_set = ImageSet.read(os.getenv("ANALYTICS_ZOO_HOME")+"/apps/image-augmentation/image/test.jpg")
# create LocalImageSet from an image folder
local_image_set = ImageSet.read(os.getenv("ANALYTICS_ZOO_HOME")+"/apps/image-augmentation/image/")
# create LocalImageSet from list of images
image = cv2.imread(os.getenv("ANALYTICS_ZOO_HOME")+"/apps/image-augmentation/image/test.jpg")
local_image_set = LocalImageSet([image])
print(local_image_set.get_image())
print('isDistributed: ', local_image_set.is_distributed(), ', isLocal: ', local_image_set.is_local())
In [3]:
# create DistributedImageSet from an image
distributed_image_set = ImageSet.read(os.getenv("ANALYTICS_ZOO_HOME")+"/apps/image-augmentation/image/test.jpg", sc, 2)
# create DistributedImageSet from an image folder
distributed_image_set = ImageSet.read(os.getenv("ANALYTICS_ZOO_HOME")+"/apps/image-augmentation/image/", sc, 2)
# create LocalImageSet from image rdd
image = cv2.imread(os.getenv("ANALYTICS_ZOO_HOME")+"/apps/image-augmentation/image/test.jpg")
image_rdd = sc.parallelize([image], 2)
label_rdd = sc.parallelize([np.array([1.0])], 2)
distributed_image_set = DistributedImageSet(image_rdd, label_rdd)
images_rdd = distributed_image_set.get_image()
label_rdd = distributed_image_set.get_label()
print(images_rdd)
print(label_rdd)
print('isDistributed: ', distributed_image_set.is_distributed(), ', isLocal: ', distributed_image_set.is_local())
print('total images:', images_rdd.count())
In [4]:
path = os.getenv("ANALYTICS_ZOO_HOME")+"/apps/image-augmentation/image/test.jpg"
def transform_display(transformer, image_set):
out = transformer(image_set)
cv2.imwrite('/tmp/tmp.jpg', out.get_image(to_chw=False)[0])
display(Image(filename='/tmp/tmp.jpg'))
In [5]:
brightness = ImageBrightness(0.0, 32.0)
image_set = ImageSet.read(path)
transform_display(brightness, image_set)
In [6]:
transformer = ImageHue(-18.0, 18.0)
image_set = ImageSet.read(path)
transform_display(transformer, image_set)
In [7]:
transformer = ImageSaturation(10.0, 20.0)
image_set= ImageSet.read(path)
transform_display(transformer, image_set)
In [8]:
transformer = ImageChannelOrder()
image_set = ImageSet.read(path)
transform_display(transformer, image_set)
In [9]:
transformer = ImageColorJitter()
image_set = ImageSet.read(path)
transform_display(transformer, image_set)
In [10]:
transformer = ImageResize(300, 300)
image_set = ImageSet.read(path)
transform_display(transformer, image_set)
In [11]:
transformer = ImageAspectScale(200, max_size = 3000)
image_set = ImageSet.read(path)
transform_display(transformer, image_set)
In [12]:
transformer = ImageRandomAspectScale([100, 300], max_size = 3000)
image_set = ImageSet.read(path)
transform_display(transformer, image_set)
In [13]:
transformer = ImageChannelNormalize(20.0, 30.0, 40.0, 2.0, 3.0, 4.0)
image_set = ImageSet.read(path)
transform_display(transformer, image_set)
In [14]:
%%time
print("PixelNormalize takes nearly one and a half minutes. Please wait a moment.")
means = [2.0] * 3 * 500 * 375
transformer = ImagePixelNormalize(means)
image_set = ImageSet.read(path)
transform_display(transformer, image_set)
In [15]:
transformer = ImageCenterCrop(200, 200)
image_set = ImageSet.read(path)
transform_display(transformer, image_set)
In [16]:
transformer = ImageRandomCrop(200, 200)
image_set = ImageSet.read(path)
transform_display(transformer, image_set)
In [17]:
transformer = ImageFixedCrop(0.0, 0.0, 200.0, 200.0, False)
image_set = ImageSet.read(path)
transform_display(transformer, image_set)
In [18]:
transformer = ImageFiller(0.0, 0.0, 0.5, 0.5, 255)
image_set = ImageSet.read(path)
transform_display(transformer, image_set)
In [19]:
transformer = ImageExpand(means_r=123, means_g=117, means_b=104,
max_expand_ratio=2.0)
image_set = ImageSet.read(path)
transform_display(transformer, image_set)
In [20]:
transformer = ImageHFlip()
image_set = ImageSet.read(path)
transform_display(transformer, image_set)
In [ ]: