AI用のライブラリ「AI4R」

Ruby用の人工知能ライブラリ「AI4R」というのがあるのでサンプルを動かしてみた。

AI4Rはニューラルネットワーク遺伝的アルゴリズム、決定木などのいくつかのAIアルゴリズムが実装されている人工知能用ライブラリです。
gemで提供されているので

gem install ai4r

とかでインストールできます。
サンプルとしてニューラルネットワークバックプロパゲーションというアルゴリズムで簡易OCRのサンプルを動かしてみました。

認識する文字

三角と四角、十字という3つのパターンを学習させ、ノイズがのった三角や四画、十字のパターンでも認識できるかどうかというものです。
図形はとりあえず

TRIANGLE = [
  [ 0,  0,  0,  0,  0,  0,  0,  5,  5,  0,  0,  0,  0,  0,  0,  0],
  [ 0,  0,  0,  0,  0,  0,  1,  9,  9,  1,  0,  0,  0,  0,  0,  0],
  [ 0,  0,  0,  0,  0,  0,  5,  5,  5,  5,  0,  0,  0,  0,  0,  0],
  [ 0,  0,  0,  0,  0,  1,  9,  1,  1,  9,  1,  0,  0,  0,  0,  0],
  [ 0,  0,  0,  0,  0,  5,  5,  0,  0,  5,  5,  0,  0,  0,  0,  0],
  [ 0,  0,  0,  0,  1,  9,  1,  0,  0,  1,  9,  1,  0,  0,  0,  0],
  [ 0,  0,  0,  0,  5,  5,  0,  0,  0,  0,  5,  5,  0,  0,  0,  0],
  [ 0,  0,  0,  1,  9,  1,  0,  0,  0,  0,  1,  9,  1,  0,  0,  0],
  [ 0,  0,  0,  5,  5,  0,  0,  0,  0,  0,  0,  5,  5,  0,  0,  0],
  [ 0,  0,  1,  9,  1,  0,  0,  0,  0,  0,  0,  1,  9,  1,  0,  0],
  [ 0,  0,  5,  5,  0,  0,  0,  0,  0,  0,  0,  0,  5,  5,  0,  0],
  [ 0,  1,  9,  1,  0,  0,  0,  0,  0,  0,  0,  0,  1,  9,  1,  0],
  [ 0,  5,  5,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  5,  5,  0],
  [ 1,  9,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  9,  1],
  [ 5,  5,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  5,  5],
  [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10]
]

というような配列で渡すことにします。数字が大きいところが色が濃いということになっています。(10=>黒、0=>白)
これのノイズがのったパターンは

TRIANGLE_WITH_NOISE = [
  [ 1,  0,  0,  0,  0,  0,  0,  1,  5,  0,  0,  1,  0,  0,  0,  0],
  [ 0,  0,  0,  0,  3,  0,  1,  9,  9,  1,  0,  0,  0,  0,  3,  0],
  [ 0,  3,  0,  0,  0,  0,  5,  1,  5,  3,  0,  0,  0,  0,  0,  7],
  [ 0,  0,  0,  7,  0,  1,  9,  1,  1,  9,  1,  0,  0,  0,  3,  0],
  [ 0,  0,  0,  0,  0,  3,  5,  0,  3,  5,  5,  0,  0,  0,  0,  0],
  [ 0,  1,  0,  0,  1,  9,  1,  0,  1,  1,  9,  1,  0,  0,  0,  0],
  [ 1,  0,  0,  0,  5,  5,  0,  0,  0,  0,  5,  5,  7,  0,  0,  3],
  [ 0,  0,  3,  3,  9,  1,  0,  0,  1,  0,  1,  9,  1,  0,  0,  0],
  [ 0,  0,  0,  5,  5,  0,  3,  7,  0,  0,  0,  5,  5,  0,  0,  0],
  [ 0,  0,  1,  9,  1,  0,  0,  0,  0,  0,  0,  1,  9,  1,  0,  0],
  [ 0,  0,  5,  5,  0,  0,  0,  0,  3,  0,  0,  0,  5,  5,  0,  0],
  [ 0,  1,  9,  1,  0,  0,  0,  0,  0,  0,  0,  0,  1,  9,  1,  0],
  [ 0,  5,  5,  0,  3,  0,  0,  3,  0,  0,  0,  0,  0,  5,  5,  0],
  [ 1,  9,  1,  0,  0,  3,  0,  0,  0,  1,  0,  0,  0,  1,  9,  1],
  [ 5,  5,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  5,  5],
  [10, 10, 10, 10,  1, 10, 10, 10, 10, 10,  1, 10, 10, 10, 10, 10]
]

こんな感じで表現されます

サンプルコード

require 'rubygems'
require 'ai4r'

#画像を数値で表したもの
require File.dirname(__FILE__) + '/training_patterns'
require File.dirname(__FILE__) + '/patterns_with_noise'
require File.dirname(__FILE__) + '/patterns_with_base_noise'


#ニューラルネットを
# 256の入力層
# 3の出力層
#でつくる
net = Ai4r::NeuralNetwork::Backpropagation.new([256, 3])

#トレーニングデータ(正解)の読み込み
tr_input = TRIANGLE.flatten.collect { |input| input.to_f / 10}
sq_input = SQUARE.flatten.collect { |input| input.to_f / 10}
cr_input = CROSS.flatten.collect { |input| input.to_f / 10}

#ニューラルネットに100回学習させる
100.times do
  net.train(tr_input, [1,0,0]) #三角なら[1,0,0]を出力しろ
  net.train(sq_input, [0,1,0]) #四角なら[0,1,0]を出力しろ
  net.train(cr_input, [0,0,1]) #十字なら[0,0,1]を出力しろ
end

# 2種類のノイズの乗ったデータを読み込む
tr_with_noise = TRIANGLE_WITH_NOISE.flatten.collect { |input| input.to_f / 10}
sq_with_noise = SQUARE_WITH_NOISE.flatten.collect { |input| input.to_f / 10}
cr_with_noise = CROSS_WITH_NOISE.flatten.collect { |input| input.to_f / 10}

tr_with_base_noise = TRIANGLE_WITH_BASE_NOISE.flatten.collect { |input| input.to_f / 10}
sq_with_base_noise = SQUARE_WITH_BASE_NOISE.flatten.collect { |input| input.to_f / 10}
cr_with_base_noise = CROSS_WITH_BASE_NOISE.flatten.collect { |input| input.to_f / 10}

# 結果を出力してみる

def result_label(result)
  if result[0] > result[1] && result[0] > result[2]
    "TRIANGLE"
  elsif result[1] > result[2] 
    "SQUARE"
  else    
    "CROSS"
  end
end

puts "学習データの場合"
puts "#{net.eval(tr_input).inspect} => #{result_label(net.eval(tr_input))}"
puts "#{net.eval(sq_input).inspect} => #{result_label(net.eval(sq_input))}"
puts "#{net.eval(cr_input).inspect} => #{result_label(net.eval(cr_input))}"
puts "ノイズ画像の場合"
puts "#{net.eval(tr_with_noise).inspect} => #{result_label(net.eval(tr_with_noise))}"
puts "#{net.eval(sq_with_noise).inspect} => #{result_label(net.eval(sq_with_noise))}"
puts "#{net.eval(cr_with_noise).inspect} => #{result_label(net.eval(cr_with_noise))}"
puts "ベースノイズの乗った画像の場合"
puts "#{net.eval(tr_with_base_noise).inspect} => #{result_label(net.eval(tr_with_base_noise))}"
puts "#{net.eval(sq_with_base_noise).inspect} => #{result_label(net.eval(sq_with_base_noise))}"
puts "#{net.eval(cr_with_base_noise).inspect} => #{result_label(net.eval(cr_with_base_noise))}"

出力

>ruby main.rb
学習データの場合
[0.974141461875309, 0.0149817330952856, 0.0400017794171882] => TRIANGLE
[0.0105839896514653, 0.980318764428595, 0.00299470632051346] => SQUARE
[0.0349560810207491, 0.0328454552188737, 0.944498535154918] => CROSS
ノイズ画像の場合
[0.832589730599948, 0.00674211260374761, 0.14708474657236] => TRIANGLE
[0.0104213152893759, 0.984316789275014, 0.0677999122461796] => SQUARE
[0.00861554380329126, 0.0438377286006615, 0.995384132322082] => CROSS
ベースノイズの乗った画像の場合
[0.75036683307126, 0.0205163231033386, 0.0744688176270517] => TRIANGLE
[0.0113511683682741, 0.918351820692977, 0.0120036385463944] => SQUARE
[0.00833986810952316, 0.021950349429871, 0.929720482647546] => CROSS

見事に学習し、ノイズが乗った画像でも三角なのか、四角なのか、十字なのか判別できるようになりましたね!
このサンプルはgemに同梱されていますので興味のある人は是非ためしてみてください。