In [ ]:
# windows only hack for graphviz path
import os
for path in os.environ['PATH'].split(os.pathsep):
if path.endswith("Library\\bin"):
os.environ['PATH']+=os.pathsep+os.path.join(path, 'graphviz')
In [ ]:
import keras
from keras.models import Sequential
from PIL import Image
import numpy as np
In [ ]:
import keras.backend as K
# 設定 channels first
K.set_image_data_format('channels_last')
In [ ]:
# 第一次使用時,系統會下載權重,會需要一點時間
pretrained = keras.applications.vgg16.VGG16()
In [ ]:
pretrained
In [ ]:
# 看一下網路的樣子
from IPython.display import SVG, display
from keras.utils.vis_utils import model_to_dot
SVG(model_to_dot(pretrained, show_shapes=True).create(prog='dot', format='svg'))
In [ ]:
from keras.applications import imagenet_utils
In [ ]:
imagenet_utils.CLASS_INDEX_PATH
In [ ]:
from urllib.request import urlopen
import json
with urlopen(imagenet_utils.CLASS_INDEX_PATH) as jsonf:
data = jsonf.read()
In [ ]:
class_dict = json.loads(data.decode())
[class_dict[str(i)][1] for i in range(1000)]
In [ ]:
# 下載 圖片
import os
import urllib
from urllib.request import urlretrieve
dataset = 'ILSVRC2012_val_1000.tar'
def reporthook(a,b,c):
print("\rdownloading: %5.1f%%"%(a*b*100.0/c), end="")
if not os.path.isfile(dataset):
origin = "https://www.dropbox.com/s/vippynksgd8c6qt/ILSVRC2012_val_1000.tar?dl=1"
print('Downloading data from %s' % origin)
urlretrieve(origin, dataset, reporthook=reporthook)
In [ ]:
# 解開圖片
from tarfile import TarFile
tar = TarFile(dataset)
tar.extractall()
In [ ]:
# 讀取圖片
from PIL import Image as pimage
from glob import glob
imgs = []
files = list(glob('ILSVRC2012_img_val/ILSVRC2012_val_*.JPEG'))
for fn in files:
img = pimage.open(fn)
if img.mode != 'RGB':
img = img.convert('RGB')
img = np.array(img.resize((224,224)))
imgs.append(img)
imgs = np.array(imgs)
In [ ]:
# 準備資料,轉成通用的格式(扣掉顏色的中間值)
p_imgs = imagenet_utils.preprocess_input(np.float32(imgs))
del imgs
In [ ]:
# 實際
predictions = pretrained.predict(p_imgs)
In [ ]:
# 對應編碼
results = imagenet_utils.decode_predictions(predictions)
In [ ]:
from IPython.display import Image, HTML, display
for fn, res in zip(files[:100], results[:100]):
res_text = "".join("<li>{:05.2f}% : {}</li>".format(x[2]*100, x[1]) for x in res)
display(HTML("""
<table><tr>
<td><img width=200 src="{}" /></td>
<td><ul>{}</ul></td>
</tr>
</table>
""".format(fn, res_text)))