require mnist relevant module. It helps you use mnist dataset.
In [1]:
require 'nyaplot'
require 'ruby_brain'
require 'ruby_brain/dataset/mnist/data'
Out[1]:
First, add argmax function into Array class. This method finds the index of the array position the max value exists and helpful for following steps
In [2]:
class Array
def argmax
max_i, max_val = 0, self.first
self.each_with_index do |v, i|
max_val, max_i = v, i if v > max_val
end
max_i
end
end
Out[2]:
Get MNIST dataset from THE MNIST DATABASE of handwritten digits if the dataset files don't exist in the working directory. And load them into Ruby dictionary dataset
In [3]:
dataset = RubyBrain::DataSet::Mnist::data
training_dataset = dataset.first
test_dataset = dataset.last
# dataset has :input and :output dataset
training_dataset.keys # => [:input, :output]
test_dataset.keys # => [:input, :output]
# :input of training_dataset has 60000(samples) x 784(28 * 28 input pixcels)
training_dataset[:input].size # => 60000
training_dataset[:input].first.size # => 784
# :output of training_dataset has 60000(samples) x 10(classes 0~9)
training_dataset[:output].size # => 60000
training_dataset[:output].first.size # => 10
# :input of test_dataset has 10000(samples) x 784(28 * 28 input pixcels)
test_dataset[:input].size # => 10000
test_dataset[:input].first.size # => 784
# :output of test_dataset has 10000(samples) x 10(classes 0~9)
test_dataset[:output].size # => 10000
test_dataset[:output].first.size # => 10
Out[3]:
In this example, We use only first 5000 samples of training_dataset
because RubyBrain is slow and it takes very long time to learn full training_dataset.
NUM_TRAIN_DATA
means how many first images are used as training data.
Here it is set as 5000.
In [4]:
# use only first 5000 samples for training
NUM_TRAIN_DATA = 5000
training_input = training_dataset[:input][0..(NUM_TRAIN_DATA-1)]
training_supervisor = training_dataset[:output][0..(NUM_TRAIN_DATA-1)]
# use full test dataset
test_input = test_dataset[:input]
test_supervisor = test_dataset[:output]
nil
You can see some training_input
and training_supervisor
with following code
In [5]:
5.times do |s|
x = []
y = []
z = []
28.times do |i|
28.times do |j|
x.push(j)
y.push(-i)
z.push(training_input[s][i*28+j])
end
end
puts "training_input[#{s}] : "
plot1 = Nyaplot::Plot.new
mnistm = plot1.add(:heatmap, x, y, z)
mnistm.stroke_width("0")
mnistm.height(1.0)
mnistm.width(1.0)
plot1.legend(true)
plot1.show
puts "training_supervisor[#{s}] : #{training_supervisor[s].argmax}\n"
puts "-------------------------------------------------------------------------------------\n"
end
Out[5]:
Then construct the network and initialize.
In this case, an image has 784(28x28) pixcels and 10 classes(0..9).
So, the network structure should be [784, 50, 10]
with 1 hidden layer which has 50 units.
You can construct the structure with following code.
In [6]:
# network structure [784, 50, 10]
network = RubyBrain::Network.new([training_input.first.size, 50, training_supervisor.first.size])
# learning rate is 0.7
network.learning_rate = 0.7
# initialize network
network.init_network
Run training
In [7]:
org_stdout = $stdout
$stdout = File.open(File::NULL, "w")
network.learn(training_input, training_supervisor, max_training_count=100, tolerance=0.0004, monitoring_channels=[:best_params_training])
$stdout = org_stdout
Out[7]:
Now, An optimized network was completed. You can check it.
You can get the accuracy with following code
In [8]:
results = []
test_input.each_with_index do |input, i|
supervisor_label = test_supervisor[i].argmax
predicated_label = network.get_forward_outputs(input).argmax
results << (supervisor_label == predicated_label)
end
puts "accuracy: #{results.count(true).to_f/results.size}"
Review actual images and predicated classes
In [9]:
test_input[0..20].each_with_index do |input, s|
x = []
y = []
z = []
28.times do |i|
28.times do |j|
x.push(j)
y.push(-i)
z.push(test_input[s][i*28+j])
end
end
puts "test_input[#{s}] : "
plot1 = Nyaplot::Plot.new
mnistm = plot1.add(:heatmap, x, y, z)
mnistm.stroke_width("0")
mnistm.height(1.0)
mnistm.width(1.0)
plot1.legend(true)
plot1.show
puts "test_supervisor[#{s}] : #{test_supervisor[s].argmax}\n"
puts "predicated_class : #{network.get_forward_outputs(input).argmax}"
puts "-------------------------------------------------------------------------------------\n"
end
nil