rubyで単純パーセプトロン
目的
機械学習を勉強するにあたって単純パーセプトロンから入るのは通説らしいので、僕も単純パーセプトロンから機械学習を入門します。
理論を勉強しつつ、実際に実装してみることで理解を深めることが目的です。
注意
学習中であるため、この記事に書かれている内容の正しさは保証しません。
間違っている箇所を見つけたら指摘していただけると幸いです。
識別関数
識別関数とは何らかのデータを特定のクラスタに識別する関数です。
例えば、次のようなデータを考えます。
青色のデータと赤色のデータの間に直線を引くことでデータを識別できそうです。
実際に直線を引いてみます。
直線より上側にあるデータが青色のデータ、下側にあるデータが赤色のデータと識別することができました。
これを式で表してみます。
直線を式で表すと
になります。
直線より上側は青色のデータ、下側は赤色のデータなので
青色のデータは
となり、赤色のデータは
となります。
直線の傾きや切片を決定する
が決まれば直線が一意に決まります。
今回はこの直線が識別関数となります。
単純パーセプトロンではこの
を学習によって求めます。
単純パーセプトロンによる学習
単純パーセプトロンについてのおさらいです。
単純パーセプトロンは複数の入力を受け取り、値が何らかの値に達すると発火する識別器です。
雑に図で表すとこんな感じです。
この図に先ほどの識別関数を当てはめてみます。
この図を式で表します。
を重みと呼びます。
次に実際に学習方法を考えます。
学習の方法は
- 重み()にランダムな値を設定する
- 学習データをパーセプトロンに投げ、出力が正しいか比較する
- 出力が間違っていれば、重みを更新する
- 重みの更新が収束するまで2〜3を繰り返す
重みの更新
学習データを読み込み出力結果が間違っていれば重みの更新を行います。
重みの更新は、正しい結果を出しそうな方向に少しづつずらしていきます。
この「重みを正しそうな方向に少しづつずらす」ために誤差関数というものを定義します。
誤差関数とは、「予測した結果が実際の結果とどのくらいずれているか」を出力する関数です。
誤差関数はずれが大きいほど大きな値を返すのであればどのような関数でも大丈夫ですが、
単純パーセプトロンの場合、誤差関数は
と定義する場合が多いようです。
※はデータの正解ラベルで正解の場合を不正解の場合となります。
この誤差関数をグラフで表してみます。誤差関数は「正解とのズレが大きいほど大きな値を返す」という特性から、
2次関数のようなグラフを描きます。
重みの更新は最終的にこのグラフでいう極小値を目指します。
では、どうすれば極小値に近づくように重みを更新でききるのでしょうか。
傾きを見て、傾きが負の場合正の方向へ、傾きが正の場合負の方向へ重みをずらしていけば、
重みを極小値へ近づけていけそうです。
これを式で表してみます。
このままだと、の値が大きくなりすぎるので、調整するために
学習率をかけてあげます。
この式を展開していきます。
ここで、の場合は更新自体を行わないので、についてのみ考えます。
よって更新式は
となります。
rubyで実装
理解を深めるためにrubyで実装しました。
グラフの描画はgnuplotを使っています。
以下コードと実行結果となります。
GitHub - ogidow/Simple-perceptron-ruby
require "gnuplot" # w1 + w2x + w3y = 0 class Perceptron def initialize(learning_rate, num_data) @w_vec = {} @learning_rate = learning_rate @num_data = num_data end def init_w_vec @num_data.times do @w_vec = {w1: rand(-10..10), w2: rand(-10..10), w3: rand(-10..10)} end end def predict (data) #p w_vec @w_vec[:w1] * 1 + @w_vec[:w2] * data[:x1] + @w_vec[:w3] * data[:x2] end def update (data) @w_vec = {w1: @w_vec[:w1] + @learning_rate * 1 * data[:label], w2: @w_vec[:w2] + @learning_rate * data[:x2] * data[:label], w3: @w_vec[:w3] + @learning_rate * data[:x1] * data[:label]} end def train (datas) update_count = 0 datas.each do |data| result = predict(data) if result * data[:label] < 0 update(data) update_count += 1 end end update_count end def draw_graph(x1, y1, x2, y2, title) Gnuplot.open do |gp| Gnuplot::Plot.new(gp) do |plot| plot.xlabel "x" plot.ylabel "y" plot.title title plot.data << Gnuplot::DataSet.new([x1, y1]) do |ds| ds.with = "points" ds.notitle end plot.data << Gnuplot::DataSet.new([x2, y2]) do |ds| ds.with = "lines" ds.notitle end end end end def run #学習用データ作成 init_w_vec datas = [] @num_data.times do if rand > 0.5 datas.push({x1: rand(1..100) * -1, x2: rand(1..100) * 1, label: 1}) else datas.push({x1: rand(1..100) * 1, x2: rand(1..100) * -1, label: -1}) end end ##学習開始 ##収束条件:重みの更新がなくなったら update_count = 0 5000.times do update_count = train datas break if update_count == 0 end # w1 + x * w2 + y * w3 = 0 # 分離直線の傾きと切片を求める slope = -1 * @w_vec[:w2] / @w_vec[:w3] interecept = -1 * @w_vec[:w1] / @w_vec[:w3] #学習した分離直線を出力 puts "y = #{slope}x + #{-1 * interecept}" # 学習した分離直線を描画 x =[] y =[] (-100..100).each do |i| x.push i y.push i * slope + interecept end draw_graph(datas.map{|v| v[:x1] }, datas.map{|v| v[:x2]}, x, y, "perceptron") end end perceptron = Perceptron.new(0.2, 100) perceptron.run
見事に分離直線を引くことができました!!