最近、機械学習がブームでPythonを使えばだれでも簡単に学習器を作れるようになってきましたね。Pythonはライブラリが充実しているのでモデルについて何も知らなくでも機械学習をできます。私も別に仕事で機械学習を使っているわけではないのでそのような状態に甘んじていたのですが、最近になってさすがに気持ち悪さを感じていたので勉強を始めました。
今回は簡単なモデルとして、分類問題、特に線形の二値分類問題を考えます。
対象とする問題
いくつかの特徴量が与えられていて、それに対して正解ラベルが+1, -1で与えられているような状況を考えます。考える仮説空間はhalf-spacesクラス(日本語訳がわからない)です。
仮説空間という言葉が出てくると難しく感じますが、予測に使うのは下のような関数です。
は重みとバイアスです。基本的には線形関数で、最終的にsignで符号だけ取り出すことで予測値
を算出しています。
学習アルゴリズム
与えられている特徴量を、正解ラベルを
とします。この時、予測が当たっている状況を数式であらわすと、
符号関数は、少し抽象的です。もう少し扱いやすい形に書き換えましょう。要は両辺で符号さえ合っていればよいので、
としてもよいでしょう。訓練データが全部で個あるような状況を考えると、重みパラメータとバイアスが満たすべき条件は、
となります。さらに条件を詰めます。まず、バイアスを重みパラメータの中に取り込んでしまいます。
次に重みパラメータをの最小値
で規格化(?)します。
これをつかうことで、条件はもう少し厳しくすることができて、
これが求めるべきパラメータの満たすべき条件です。
どのように解くか
問題はこれをどう解くかです。これはよく見ると、次のような問題と同じであることがわかります。
maximize 0
s. t.
線形計画問題です。今回は目的関数はどうでもよくて、制約条件のみが問題になります。この問題の解き方は探せば見つかって、シンプレックス法(単体法)というものが有名です。
Pythonで実装する
では、Pythonで実装します。線形計画法のところはめんどくさくなってライブラリ使いました。ごめんなさい。
import numpy as np from sklearn.datasets import make_classification from sklearn.model_selection import train_test_split if __name__ == '__main__': #サンプル生成 X, y = make_classification( n_samples=500, n_classes=2, n_features=2, n_redundant=0, class_sep=3, ) y = list(map(lambda x: -1 if x==0 else 1, y)) X_train, X_test, y_train, y_test = train_test_split(X, y)
分類サンプルは、Scikit-learnのmake_classification
を使って生成しました。今回は簡単のためにサンプル数は500個、クラス数は2(二値分類)、特徴量は2種類で重要性が少ないものは0個としました。
学習アルゴリズムの本体は下のように定義します。
import pulp def solve(X, y): lp = pulp.LpProblem('lp', pulp.LpMaximize) # 解く問題を作る w1 = pulp.LpVariable('w1', 0) #重みパラメータ1 w2 = pulp.LpVariable('w2', 0) #重みパラメータ2 w3 = pulp.LpVariable('b', 0) #バイアス lp += 0 #目的関数 for i in range(len(X)): lp += y[i] * (X[i, 0] * w1 + X[i, 1] * w2 + w3) >= 1 #制約条件 lp.solve() #解く return w1.value(), w2.value(), w3.value() #各パラメータを返す
pulp
という線形計画問題を解くためのライブラリがありますのでありがたく使わせてもらいます。
学習がうまくいっているか確かめるために、テストデータを使って精度の確認を行います。
w1, w2, b = solve(X, y) acc = 0 for i in range(len(X_test)): pred = X_test[i].dot(np.array([[w1], [w2]])) + b if np.sign(pred) == np.sign(y_test[i]): acc += 1 print(acc/len(X_test))
乱数のシードを指定していないので指向のたびに値は変わりますが、おおむね9割程度の精度は出ています。
今回は、非常に単純なモデルについて書きました。その結果、構造は簡単でも9割程度の精度が出るということが確認できました(class_sep
大きくしているのでもう少し精度出てもよいと思うが...)。