CIFAR10 是另外一個 dataset, 和 mnist 一樣,有十種類別(飛機、汽車、鳥、貓、鹿、狗、青蛙、馬、船、卡車)

https://www.cs.toronto.edu/~kriz/cifar.html


In [ ]:
import keras
from keras.models import Sequential
from PIL import Image
import numpy as np
import tarfile

In [ ]:
# 下載 dataset
url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
import os
import urllib
from urllib.request import urlretrieve
def reporthook(a,b,c):
    print("\rdownloading: %5.1f%%"%(a*b*100.0/c), end="")
tar_gz = "cifar-10-python.tar.gz"
if not os.path.isfile(tar_gz):
        print('Downloading data from %s' % url)
        urlretrieve(url, tar_gz, reporthook=reporthook)

In [ ]:
# 讀取 dataset
# 只有 train 和 test 沒有 validation
import pickle
train_X=[]
train_y=[]
tar_gz = "cifar-10-python.tar.gz"
with tarfile.open(tar_gz) as tarf:
    for i in range(1, 6):
        dataset = "cifar-10-batches-py/data_batch_%d"%i
        print("load",dataset)
        with tarf.extractfile(dataset) as f:
            result = pickle.load(f, encoding='latin1')
        train_X.extend(result['data']/255)
        train_y.extend(result['labels'])
    train_X=np.float32(train_X)
    train_y=np.int32(train_y)
    dataset = "cifar-10-batches-py/test_batch"
    print("load",dataset)
    with tarf.extractfile(dataset) as f:
        result = pickle.load(f, encoding='latin1')
        test_X=np.float32(result['data']/255)
        test_y=np.int32(result['labels'])
train_Y = np.eye(10)[train_y]
test_Y = np.eye(10)[test_y]

In [ ]:
# or
# from keras.datasets import cifar10
# from keras.utils import np_utils
# (train_X, train_y), (test_X, test_y) = cifar10.load_data()
# train_Y = np_utils.to_categorical(train_y, 10)
# test_Y = np_utils.to_categorical(test_y, 10)

查看一下資料


In [ ]:
train_X.shape

In [ ]:
# channels x 高 x 寬 (顏色)
3*32*32

In [ ]:
from IPython.display import display
def showX(X):
    int_X = (X*255).clip(0,255).astype('uint8')
    # N*3072 -> N*3*32*32 -> 32 * 32N * 3
    int_X_reshape = np.moveaxis(int_X.reshape(-1,3,32,32), 1, 3)
    int_X_reshape = int_X_reshape.swapaxes(0,1).reshape(32,-1, 3)
    display(Image.fromarray(int_X_reshape))
# 訓練資料, X 的前 20 筆
showX(train_X[:20])
print(train_y[:20])
name_array = np.array("飛機、汽車、鳥、貓、鹿、狗、青蛙、馬、船、卡車".split('、'))
print(name_array[train_y[:20]])

Q


In [ ]:
# 參考答案
# %load q_cifar10_logistic.py

In [ ]:
# 參考答案
# %load q_cifar10_cnn.py

In [ ]: