require mnist relevant module. It helps you use mnist dataset.


In [1]:
require 'nyaplot'
require 'ruby_brain'
require 'ruby_brain/dataset/mnist/data'


Out[1]:
true

First, add argmax function into Array class. This method finds the index of the array position the max value exists and helpful for following steps


In [2]:
class Array
  def argmax
    max_i, max_val = 0, self.first
    self.each_with_index do |v, i|
      max_val, max_i = v, i if v > max_val
    end
    max_i
  end
end


Out[2]:
:argmax

Get MNIST dataset from THE MNIST DATABASE of handwritten digits if the dataset files don't exist in the working directory. And load them into Ruby dictionary dataset


In [3]:
dataset = RubyBrain::DataSet::Mnist::data
training_dataset = dataset.first
test_dataset = dataset.last

# dataset has :input and :output dataset
training_dataset.keys # => [:input, :output]
test_dataset.keys # => [:input, :output]

# :input of training_dataset has 60000(samples) x 784(28 * 28 input pixcels)
training_dataset[:input].size       # => 60000
training_dataset[:input].first.size # => 784

# :output of training_dataset has 60000(samples) x 10(classes 0~9)
training_dataset[:output].size       # => 60000
training_dataset[:output].first.size # => 10

# :input of test_dataset has 10000(samples) x 784(28 * 28 input pixcels)
test_dataset[:input].size       # => 10000
test_dataset[:input].first.size # => 784

# :output of test_dataset has 10000(samples) x 10(classes 0~9)
test_dataset[:output].size       # => 10000
test_dataset[:output].first.size # => 10


Out[3]:
10

In this example, We use only first 5000 samples of training_dataset because RubyBrain is slow and it takes very long time to learn full training_dataset. NUM_TRAIN_DATA means how many first images are used as training data. Here it is set as 5000.


In [4]:
# use only first 5000 samples for training
NUM_TRAIN_DATA = 5000
training_input = training_dataset[:input][0..(NUM_TRAIN_DATA-1)]
training_supervisor = training_dataset[:output][0..(NUM_TRAIN_DATA-1)]
# use full test dataset
test_input = test_dataset[:input]
test_supervisor = test_dataset[:output]
nil

You can see some training_input and training_supervisor with following code


In [5]:
5.times do |s|
  x = []
  y = []
  z = []
  28.times do |i|
    28.times do |j|
      x.push(j)
      y.push(-i)
      z.push(training_input[s][i*28+j])
    end
  end
  puts "training_input[#{s}] : "
  plot1 = Nyaplot::Plot.new
  mnistm = plot1.add(:heatmap, x, y, z)
  mnistm.stroke_width("0")
  mnistm.height(1.0)
  mnistm.width(1.0)
  plot1.legend(true)
  plot1.show
  puts "training_supervisor[#{s}] : #{training_supervisor[s].argmax}\n"
  puts "-------------------------------------------------------------------------------------\n"
end


training_input[0] : 
training_supervisor[0] : 5

-------------------------------------------------------------------------------------

training_input[1] : 
training_supervisor[1] : 0

-------------------------------------------------------------------------------------

training_input[2] : 
training_supervisor[2] : 4

-------------------------------------------------------------------------------------

training_input[3] : 
training_supervisor[3] : 1

-------------------------------------------------------------------------------------

training_input[4] : 
training_supervisor[4] : 9

-------------------------------------------------------------------------------------

Out[5]:
5

Then construct the network and initialize. In this case, an image has 784(28x28) pixcels and 10 classes(0..9). So, the network structure should be [784, 50, 10] with 1 hidden layer which has 50 units. You can construct the structure with following code.


In [6]:
# network structure [784, 50, 10]
network = RubyBrain::Network.new([training_input.first.size, 50, training_supervisor.first.size])
# learning rate is 0.7
network.learning_rate = 0.7
# initialize network
network.init_network

Run training


In [7]:
org_stdout = $stdout
$stdout = File.open(File::NULL, "w")
network.learn(training_input, training_supervisor, max_training_count=100, tolerance=0.0004, monitoring_channels=[:best_params_training])
$stdout = org_stdout


Out[7]:
#<IRuby::OStream:0x005650cabe9620 @session=#<IRuby::Session:0x005650cabea5c0 @sockets={:publish=>#<ZMQ::Socket::Pub:0x005650cabea160>, :reply=>#<ZMQ::Socket::Router:0x005650cabea390>, :stdin=>#<ZMQ::Socket::Router:0x005650cabe9fd0>}, @session="212cd2dd-89a1-4b65-8b62-f948ffaf7920", @hmac=c8c2f7df1301a9acb8b6525521d8b7d2dc895622a8fea3776aa3c8ab12a1a171, @last_recvd_msg={:idents=>["A27A3F615B364779AE90CA4F32C28EFA"], :header=>{"version"=>"5.0", "msg_id"=>"3F707483601B4A35AF6F9DE453F23281", "session"=>"A27A3F615B364779AE90CA4F32C28EFA", "username"=>"username", "msg_type"=>"execute_request"}, :parent_header=>{}, :metadata=>{}, :content=>{"user_expressions"=>{}, "code"=>"org_stdout = $stdout\n$stdout = File.open(File::NULL, \"w\")\nnetwork.learn(training_input, training_supervisor, max_training_count=100, tolerance=0.0004, monitoring_channels=[:best_params_training])\n$stdout = org_stdout", "allow_stdin"=>true, "silent"=>false, "stop_on_error"=>true, "store_history"=>true}, :buffers=>nil}>, @name=:stdout>

Now, An optimized network was completed. You can check it.

You can get the accuracy with following code


In [8]:
results = []
test_input.each_with_index do |input, i|
  supervisor_label = test_supervisor[i].argmax
  predicated_label = network.get_forward_outputs(input).argmax
  results << (supervisor_label == predicated_label)
end

puts "accuracy: #{results.count(true).to_f/results.size}"


accuracy: 0.9271

Review actual images and predicated classes


In [9]:
test_input[0..20].each_with_index do |input, s|
  x = []
  y = []
  z = []
  28.times do |i|
    28.times do |j|
      x.push(j)
      y.push(-i)
      z.push(test_input[s][i*28+j])
    end
  end
  puts "test_input[#{s}] : "
  plot1 = Nyaplot::Plot.new
  mnistm = plot1.add(:heatmap, x, y, z)
  mnistm.stroke_width("0")
  mnistm.height(1.0)
  mnistm.width(1.0)
  plot1.legend(true)
  plot1.show
  puts "test_supervisor[#{s}] : #{test_supervisor[s].argmax}\n"
  puts "predicated_class : #{network.get_forward_outputs(input).argmax}"
  puts "-------------------------------------------------------------------------------------\n"
end
nil


test_input[0] : 
test_supervisor[0] : 7

predicated_class : 7
-------------------------------------------------------------------------------------

test_input[1] : 
test_supervisor[1] : 2

predicated_class : 2
-------------------------------------------------------------------------------------

test_input[2] : 
test_supervisor[2] : 1

predicated_class : 1
-------------------------------------------------------------------------------------

test_input[3] : 
test_supervisor[3] : 0

predicated_class : 0
-------------------------------------------------------------------------------------

test_input[4] : 
test_supervisor[4] : 4

predicated_class : 4
-------------------------------------------------------------------------------------

test_input[5] : 
test_supervisor[5] : 1

predicated_class : 1
-------------------------------------------------------------------------------------

test_input[6] : 
test_supervisor[6] : 4

predicated_class : 4
-------------------------------------------------------------------------------------

test_input[7] : 
test_supervisor[7] : 9

predicated_class : 9
-------------------------------------------------------------------------------------

test_input[8] : 
test_supervisor[8] : 5

predicated_class : 2
-------------------------------------------------------------------------------------

test_input[9] : 
test_supervisor[9] : 9

predicated_class : 9
-------------------------------------------------------------------------------------

test_input[10] : 
test_supervisor[10] : 0

predicated_class : 0
-------------------------------------------------------------------------------------

test_input[11] : 
test_supervisor[11] : 6

predicated_class : 6
-------------------------------------------------------------------------------------

test_input[12] : 
test_supervisor[12] : 9

predicated_class : 9
-------------------------------------------------------------------------------------

test_input[13] : 
test_supervisor[13] : 0

predicated_class : 0
-------------------------------------------------------------------------------------

test_input[14] : 
test_supervisor[14] : 1

predicated_class : 1
-------------------------------------------------------------------------------------

test_input[15] : 
test_supervisor[15] : 5

predicated_class : 5
-------------------------------------------------------------------------------------

test_input[16] : 
test_supervisor[16] : 9

predicated_class : 9
-------------------------------------------------------------------------------------

test_input[17] : 
test_supervisor[17] : 7

predicated_class : 7
-------------------------------------------------------------------------------------

test_input[18] : 
test_supervisor[18] : 3

predicated_class : 3
-------------------------------------------------------------------------------------

test_input[19] : 
test_supervisor[19] : 4

predicated_class : 4
-------------------------------------------------------------------------------------

test_input[20] : 
test_supervisor[20] : 9

predicated_class : 9
-------------------------------------------------------------------------------------