In [2]:
from keras.datasets import imdb
from keras.preprocessing import sequence
from keras.models import Sequential
import keras.layers as kl
from keras.optimizers import Adam
from hyperopt import fmin, tpe, hp, STATUS_OK, Trials
from kopt import CompileFN, KMongoTrials, test_fn
# 1. define the data function returning training, (validation, test) data
def data(max_features=5000, maxlen=80):
(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)
x_train = sequence.pad_sequences(x_train, maxlen=maxlen)
x_test = sequence.pad_sequences(x_test, maxlen=maxlen)
return (x_train[:100], y_train[:100], max_features), (x_test, y_test)
In [3]:
# 2. Define the model function returning a compiled Keras model
def model(train_data, lr=0.001,
embedding_dims=128, rnn_units=64,
dropout=0.2):
# extract data dimensions
max_features = train_data[2]
model = Sequential()
model.add(kl.Embedding(max_features, embedding_dims))
model.add(kl.LSTM(rnn_units, dropout=dropout, recurrent_dropout=dropout))
model.add(kl.Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy',
optimizer=Adam(lr=lr),
metrics=['accuracy'])
return model
In [4]:
# Specify the objective function
db_name = "imdb"
exp_name = "myexp1"
objective = CompileFN(db_name, exp_name,
data_fn=data,
model_fn=model,
loss_metric="acc", # which metric to optimize for
loss_metric_mode="max", # try to maximize the metric
valid_split=.2, # use 20% of the training data for the validation set
save_model='best', # checkpoint the best model
save_results=True, # save the results as .json (in addition to mongoDB)
save_dir="./saved_models/") # place to store the models
In [5]:
# define the hyper-parameter ranges
# see https://github.com/hyperopt/hyperopt/wiki/FMin for more info
hyper_params = {
"data": {
"max_features": 100,
"maxlen": 80,
},
"model": {
"lr": hp.loguniform("m_lr", np.log(1e-4), np.log(1e-2)), # 0.0001 - 0.01
"embedding_dims": hp.choice("m_emb", (64, 128)),
"rnn_units": 64,
"dropout": hp.uniform("m_do", 0, 0.5),
},
"fit": {
"epochs": 20
}
}
In [6]:
# test model training, on a small subset for one epoch
test_fn(objective, hyper_params)
In [7]:
# run hyper-parameter optimization sequentially (without any database)
trials = Trials()
best = fmin(objective, hyper_params, trials=trials, algo=tpe.suggest, max_evals=2)
In [73]:
import tempfile
import os
import subprocess
# Start the mongodb database and a worker
mongodb_path = tempfile.mkdtemp()
results_path = tempfile.mkdtemp()
proc_args = ["mongod",
"--dbpath=%s" % mongodb_path,
"--noprealloc",
"--port=22334"]
print("starting mongod", proc_args)
mongodb_proc = subprocess.Popen(
proc_args,
cwd=mongodb_path,
)
In [74]:
# Start the worker
from kopt.utils import merge_dicts
proc_args_worker = ["hyperopt-mongo-worker",
"--mongo=localhost:22334/imdb",
"--poll-interval=0.1"]
mongo_worker_proc = subprocess.Popen(
proc_args_worker,
env=merge_dicts(os.environ, {"PYTHONPATH": os.getcwd()}),
)
In [75]:
## In Order for pickling of the functions to work,
## we need to import the functions from a module different
## than __main___
## I've implemented them in model.py and data.py
import model
import data
objective.data_fn = data.data
objective.model_fn = model.model
objective.save_dir = results_path
In [76]:
# run hyper-parameter optimization in parallel (saving the results to MonogoDB)
# Follow the hyperopt guide:
# https://github.com/hyperopt/hyperopt/wiki/Parallelizing-Evaluations-During-Search-via-MongoDB
# KMongoTrials extends hyperopt.MongoTrials with convenience methods
trials = KMongoTrials(db_name, exp_name,
ip="localhost",
port=22334)
best = fmin(objective, hyper_params, trials=trials, algo=tpe.suggest, max_evals=2)
In [91]:
# Number of submitted trials
len(trials)
Out[91]:
In [78]:
# ALl the traial information in one tidy pd.DataFrame
trials.as_df()
Out[78]:
In [79]:
# load the best model
model = trials.load_model(trials.best_trial_tid())
model
Out[79]:
In [83]:
# see the training history of the best model
train_hist = trials.train_history(trials.best_trial_tid())
train_hist
Out[83]:
In [89]:
# close the processes (mongodb & hyperopt worker)
mongo_worker_proc.kill()
mongodb_proc.kill()