pyhaya’s diary

機械学習系の記事をメインで書きます

Kaggleのタイタニックデータの解析

Kaggleの定番データセットといえば「タイタニックの生存者予測」です。今回は生存者の予測を目指して解析を行っていきたいと思います。データの可視化について詳しい説明は前回記事で書いているのでそちらを参照してください。
Titanic: Machine Learning from Disaster | Kaggle

解析環境

データの全体像の把握

データ解析をする際に最も重要なことはもちろんデータを理解することです。まずは大雑把にデータの全体像を見てみましょう。

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import os
import warnings
warnings.filterwarnings('ignore')
print(os.listdir("../input"))

# -> ['train.csv', 'gender_submission.csv', 'test.csv']

   

train = pd.read_csv('../input/train.csv')
test = pd.read_csv('../input/test.csv')
train.head()

f:id:pyhaya:20181118104522p:plain

いくつかの変数があることがわかります。

  • PassengerId : 乗客ID
  • Survived : 生存ラベル(1: 生存 0:死亡)
  • Pclass : 身分(1が最高)
  • Name : 乗客氏名
  • Sex : 性別
  • Age : 年齢
  • SibSp : 乗客中の親戚の人数
  • Parch : 乗客中の親もしくは子供の人数
  • Ticket : チケット番号
  • Fare : 乗船料
  • Cabin : 客室番号
  • Embarked : どこの港で乗船したか

変数の吟味

これらの変数の中で、中心となるのはもちろん'Survived'です。これにほかの変数が影響を与えうるか考えてみます。

'PassengerId', 'Ticket'

'PassengerId'はただの通し番号なので、生存には関係ない気がします。ただ、これが乗客の身分や部屋の位置に関係する可能性は一度調べる必要がありそうです。これは'Ticket'にも言えます。

'Pclass', 'Cabin', 'Fare'

'Pclass'に関しては、漠然と、身分の高い人は安全なところに客室があるのかなと思うので、関係ありそうです。これも後で解析しましょう。同様に'Cabin'と'Fare'も調べてみます。

'Sex', 'Age'

'Sex'や'Age'に関しては、欧米ではレディーファーストの精神が強いので女性や子供が優先的に救助されていた可能性があります。

'Name'

'Name'は生存とは全く関係なさそうに見えます。ただ、上の表を見てみると、乗客の名前には最初に'Mr'や'Mrs'などがついていることがわかります。これはその人の社会的な地位や状態を表しているといえるので、ここの情報だけは使える気がします。

'SibSp', 'Parch'

これらは親戚関係のパラメータですが、正直、これらのパラメータが'Survived'に効いてくるストーリーは思いつきませんでした。相関係数だけ調べて、早々に落としてもいいかもしれません。

'Embarked'

乗船地です。これはその土地の発展状況によって乗客層の経済力を示している可能性があります。

カテゴリ変数を数値へ変換していく

'Name'を変換する

'Name'の変換は少し手間です。まずは正規表現を使って敬称を取り出し、'Title'という名前の新しいカラムを作ります。

train['Title'] = train['Name'].str.extract(r'([A-Za-z]+)\.', expand=False)
test['Title'] = test['Name'].str.extract(r'([A-Za-z]+)\.', expand=False)
数値に置き換える

では、'Sex', 'Embarked', 'Title'を数値に変換していきます。これにはsklearn.preprocessing.LabelEncoderを使います。

from sklearn.preprocessing import LabelEncoder

for col in ['Sex', 'Title']:
    le = LabelEncoder()
    le.fit(train[col])
    train[col] = le.transform(train[col])
    
    le.fit(test[col])
    test[col] = le.transform(test[col])


train['Embarked'] = train['Embarked'].map({'S' : 1, 'C' : 2, 'Q' : 3})
test['Embarked'] = train['Embarked'].map({'S' : 1, 'C' : 2, 'Q' : 3})

多分'Embarked'はnull valueがあってうまくいかなかったので普通にmapを使って置き換えました。

'Name'を落とす

もう名前情報はいらなくなったので落としてしまいましょう。

train.drop('Name', axis=1, inplace=True)
test.drop('Name', axis=1, inplace=True)


ここまでやってきて、データの状況は下のようになっています。
f:id:pyhaya:20181118181123p:plain

Ticketを処理する

現時点ではまだ、'Ticket'が数値量になっていません。しかし、どのように数値に変換してよいのかはっきりしないので、もう少し詳しくこの変数について知ることにします。

いくつかTicketを眺めてみると、数字だけからなるチケット番号と、アルファベットが混ざるものの2つがあることがわかります。まずはこれから分離してみます。

number_ticket = train[train['Ticket'].str.match('\d+')]
num_alpha_ticket = train[train['Ticket'].str.match('[A-Z]+.+')]

数字だけのチケット

数字だけのチケットnumber_ticketについてまず見てみましょう。数字がどんなふうに分布しているか見るために、簡単なプロットをしてみます。

number_ticket['Ticket'] = number_ticket['Ticket'].apply(lambda x: int(x))
number_ticket.sort_values('Ticket', inplace=True)

plt.figure()
plt.plot(number_ticket['Ticket'], '-o')
plt.show()

f:id:pyhaya:20181118192749p:plain

大きく分けて2つのグループに分かれている様子が見て取れます。ylimをいじって拡大してみると、チケット番号は以下の5通りに分けられることがわかります。

  • 10000以下
  • i*1e+4より大きく(i+1)*1e+4以下 (i=1,2,3)
  • 300000以上

このグループ分けをしたときに、生存率に差が出るか見てみましょう。

x = [1, 2, 3, 4, 5]

lowest_num_ticket = number_ticket[number_ticket['Ticket'] <= 100000]
sec_lowest_num_ticket = number_ticket[(number_ticket['Ticket'] > 100000) & (number_ticket['Ticket'] < 200000)]
thir_lowest_num_ticket = number_ticket[(number_ticket['Ticket'] > 200000) & (number_ticket['Ticket'] < 300000)]
four_lowest_num_ticket = number_ticket[(number_ticket['Ticket'] > 300000) & (number_ticket['Ticket'] < 400000)]
high_num_ticket = number_ticket[number_ticket['Ticket'] > 3000000]
y = [lowest_num_ticket['Survived'].mean(), sec_lowest_num_ticket['Survived'].mean(),
     thir_lowest_num_ticket['Survived'].mean(), four_lowest_num_ticket['Survived'].mean(),
     high_num_ticket['Survived'].mean()
    ]

plt.figure()
plt.bar(x, y)
plt.xlabel('ticket number')
plt.ylabel('Survived')
plt.show()

f:id:pyhaya:20181118194155p:plain

数字が大きいほど生存率が小さくなるという傾向がありそうです。

アルファベットが入ったチケット

こちらは上に比べて分類が面倒です。正規表現を使ってうまくやっていきます。

A_ticket = num_alpha_ticket[num_alpha_ticket['Ticket'].str.match('A.+')]
CA_ticket = num_alpha_ticket[num_alpha_ticket['Ticket'].str.match('C\.*A\.*.+')]
PC_ticket = num_alpha_ticket[num_alpha_ticket['Ticket'].str.match('PC.+')]
PP_ticket = num_alpha_ticket[num_alpha_ticket['Ticket'].str.match('PP.+')]
SOTON_ticket = num_alpha_ticket[num_alpha_ticket['Ticket'].str.match('SOTON.+')]
STON_ticket = num_alpha_ticket[num_alpha_ticket['Ticket'].str.match('STON.+')]
LINE_ticket = num_alpha_ticket[num_alpha_ticket['Ticket'].str.match('LINE.*')]
FC_ticket = num_alpha_ticket[num_alpha_ticket['Ticket'].str.match('F\.C\.(C\.)*.+')]
W_ticket = num_alpha_ticket[num_alpha_ticket['Ticket'].str.match('W.+')]
C_ticket = num_alpha_ticket[num_alpha_ticket['Ticket'].str.match('C.+')]
SC_ticket = num_alpha_ticket[num_alpha_ticket['Ticket'].str.match('S(\.)*C.+')]
SO_ticket = num_alpha_ticket[num_alpha_ticket['Ticket'].str.match('S(\.)*O.+')]
other_ticket = num_alpha_ticket[
    num_alpha_ticket['Ticket'].str.match(
        '(Fa)*(P/PP)*(S\.P)*(S\.*W)*.+'
    )
]

x = [i for i in range(1, 14)]
y = [A_ticket['Survived'].mean(), CA_ticket['Survived'].mean(), PC_ticket['Survived'].mean()
    ,PP_ticket['Survived'].mean(), SOTON_ticket['Survived'].mean(), STON_ticket['Survived'].mean()
    ,LINE_ticket['Survived'].mean(), FC_ticket['Survived'].mean(), W_ticket['Survived'].mean()
    ,C_ticket['Survived'].mean(), SC_ticket['Survived'].mean(), SO_ticket['Survived'].mean()
    ,other_ticket['Survived'].mean()
    ]

plt.figure()
plt.bar(x, y)
plt.ylabel('survived')
plt.show()

f:id:pyhaya:20181118194601p:plain

このように場合分けが多岐にわたるときには、漏れがないか確認する手段を持っておくことが重要です。今回の場合は、セット型を使ってチェックが可能です。

new_set = set(A_ticket['Ticket']) | set(CA_ticket['Ticket']) |\
        set(PC_ticket['Ticket']) | set(PP_ticket['Ticket']) |\
        set(SOTON_ticket['Ticket']) | set(STON_ticket['Ticket']) |\
        set(LINE_ticket['Ticket']) | set(FC_ticket['Ticket']) |\
        set(W_ticket['Ticket']) | set(C_ticket['Ticket']) |\
        set(SC_ticket['Ticket']) | set(SO_ticket['Ticket']) |\
        set(other_ticket['Ticket'])
set(num_alpha_ticket['Ticket']) - new_set

これで計算結果が空のセットになることを確認します。

グラフを見ると、チケットによって生存率には大きな差が出ていることがわかります。

チケットのラベリング

この結果をもとにして、チケットのラベリングを行います。

number_ticket.loc[number_ticket['Ticket'] <= 100000, 'Ticket'] = 14
number_ticket.loc[(number_ticket['Ticket'] > 100000) & (number_ticket['Ticket'] <= 200000), 'Ticket'] = 15
number_ticket.loc[(number_ticket['Ticket'] > 200000) & (number_ticket['Ticket'] <= 300000), 'Ticket'] = 13
number_ticket.loc[(number_ticket['Ticket'] > 300000) & (number_ticket['Ticket'] <= 400000), 'Ticket'] = 5
number_ticket.loc[number_ticket['Ticket'] > 3000000, 'Ticket'] = 6
num_alpha_ticket.loc[num_alpha_ticket['Ticket'].str.match('A.+'), 'Ticket'] = "1"
num_alpha_ticket.loc[num_alpha_ticket['Ticket'].str.match('C\.*A\.*.+'), 'Ticket'] = "8"
num_alpha_ticket.loc[num_alpha_ticket['Ticket'].str.match('PC.+'), 'Ticket'] = "16"
num_alpha_ticket.loc[num_alpha_ticket['Ticket'].str.match('PP.+'), 'Ticket'] = "18"
num_alpha_ticket.loc[num_alpha_ticket['Ticket'].str.match('SOTON.+'), 'Ticket'] = "3"
num_alpha_ticket.loc[num_alpha_ticket['Ticket'].str.match('STON.+'), 'Ticket'] = "11"
num_alpha_ticket.loc[num_alpha_ticket['Ticket'].str.match('LINE.*'), 'Ticket'] = "7"
num_alpha_ticket.loc[num_alpha_ticket['Ticket'].str.match('F\.C\.(C\.)*.+'), 'Ticket'] = "17"
num_alpha_ticket.loc[num_alpha_ticket['Ticket'].str.match('W.+'), 'Ticket'] = "4"
num_alpha_ticket.loc[num_alpha_ticket['Ticket'].str.match('C.+'), 'Ticket'] = "9"
num_alpha_ticket.loc[num_alpha_ticket['Ticket'].str.match('S(\.)*C.+'), 'Ticket'] = "12"
num_alpha_ticket.loc[num_alpha_ticket['Ticket'].str.match('S(\.)*O.+'), 'Ticket'] = "2"
num_alpha_ticket.loc[num_alpha_ticket['Ticket'].str.match('[^\d](Fa)*(P/PP)*(S\.P)*(S\.*W)*.+'), 'Ticket'] = "10"

num_alpha_ticket['Ticket'] = num_alpha_ticket['Ticket'].apply(lambda x: int(x))
train = pd.concat([number_ticket, num_alpha_ticket])

ここまでの処理の結果

ここまでの処理で、各変数の相関係数がどのように変化したか見てみます。

plt.figure(figsize=(10, 8))
sns.heatmap(train.corr(), annot=True, cmap='Reds')
plt.show()

f:id:pyhaya:20181122001403p:plain

TicketとSurvivedに比較的強い相関が出ました。

次回は、これらのデータを使って実際に機械学習をやってみたいと思います。

Kaggleで勝つデータ分析の技術

Kaggleで勝つデータ分析の技術

機械学習を原理から理解する 線形分類

最近、機械学習がブームでPythonを使えばだれでも簡単に学習器を作れるようになってきましたね。Pythonはライブラリが充実しているのでモデルについて何も知らなくでも機械学習をできます。私も別に仕事で機械学習を使っているわけではないのでそのような状態に甘んじていたのですが、最近になってさすがに気持ち悪さを感じていたので勉強を始めました。

今回は簡単なモデルとして、分類問題、特に線形の二値分類問題を考えます。

対象とする問題

いくつかの特徴量が与えられていて、それに対して正解ラベルが+1, -1で与えられているような状況を考えます。考える仮説空間はhalf-spacesクラス(日本語訳がわからない)です。

仮説空間という言葉が出てくると難しく感じますが、予測に使うのは下のような関数です。

 y = sign(\mathbf{w}\cdot\mathbf{x} + b)

 \mathbf{w}, bは重みとバイアスです。基本的には線形関数で、最終的にsignで符号だけ取り出すことで予測値 yを算出しています。

学習アルゴリズム

与えられている特徴量を \mathbf{x}、正解ラベルを tとします。この時、予測が当たっている状況を数式であらわすと、

 t = sign(\mathbf{w}\cdot\mathbf{x} + b)

符号関数は、少し抽象的です。もう少し扱いやすい形に書き換えましょう。要は両辺で符号さえ合っていればよいので、

 t(\mathbf{w}\cdot\mathbf{x} + b) > 0

としてもよいでしょう。訓練データが全部で n個あるような状況を考えると、重みパラメータとバイアスが満たすべき条件は、

 t_i(\mathbf{w}\cdot\mathbf{x}_i + b_i) > 0\ \ \ \ (i=1,2,\cdots n)

となります。さらに条件を詰めます。まず、バイアスを重みパラメータの中に取り込んでしまいます。

  t_i(\mathbf{w}\cdot\mathbf{x}_i) > 0\ \ \ \ (\mathbf{w}=(w_1, w_2, \cdots, b), \mathbf{x}_i=(x_{i1}, x_{i2}, \cdots, 1))

次に重みパラメータを t(\mathbf{w}\cdot\mathbf{x})の最小値\gammaで規格化(?)します。

 \displaystyle \bar{w}=\frac{w}{\gamma}

これをつかうことで、条件はもう少し厳しくすることができて、

 t_i(\mathbf{\bar{w}}\cdot\mathbf{x}_i) = t_i(\mathbf{w}\cdot\mathbf{x}_i)\times\frac{1}{\gamma} \ge 1

これが求めるべきパラメータの満たすべき条件です。

どのように解くか

問題はこれをどう解くかです。これはよく見ると、次のような問題と同じであることがわかります。

maximize 0
s. t.  t_i(\mathbf{w}\cdot\mathbf{x}_i)\ge 1

線形計画問題です。今回は目的関数はどうでもよくて、制約条件のみが問題になります。この問題の解き方は探せば見つかって、シンプレックス法(単体法)というものが有名です。

Pythonで実装する

では、Pythonで実装します。線形計画法のところはめんどくさくなってライブラリ使いました。ごめんなさい。

import numpy as np 
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split

if __name__ == '__main__':
    #サンプル生成
    X, y = make_classification(
            n_samples=500,
            n_classes=2,
            n_features=2,
            n_redundant=0,
            class_sep=3,
            )

    y = list(map(lambda x: -1 if x==0 else 1, y))
    X_train, X_test, y_train, y_test = train_test_split(X, y)

分類サンプルは、Scikit-learnのmake_classificationを使って生成しました。今回は簡単のためにサンプル数は500個、クラス数は2(二値分類)、特徴量は2種類で重要性が少ないものは0個としました。

学習アルゴリズムの本体は下のように定義します。

import pulp

def solve(X, y):
    lp = pulp.LpProblem('lp', pulp.LpMaximize)    # 解く問題を作る
    w1 = pulp.LpVariable('w1', 0)    #重みパラメータ1
    w2 = pulp.LpVariable('w2', 0)    #重みパラメータ2
    w3 = pulp.LpVariable('b', 0)    #バイアス

    lp += 0    #目的関数
    for i in range(len(X)):
        lp += y[i] * (X[i, 0] * w1 + X[i, 1] * w2 + w3) >= 1    #制約条件

    lp.solve()    #解く
    return w1.value(), w2.value(), w3.value()    #各パラメータを返す

pulpという線形計画問題を解くためのライブラリがありますのでありがたく使わせてもらいます。

学習がうまくいっているか確かめるために、テストデータを使って精度の確認を行います。

w1, w2, b = solve(X, y)   
acc = 0
for i in range(len(X_test)):
    pred = X_test[i].dot(np.array([[w1], [w2]])) + b
    if np.sign(pred) == np.sign(y_test[i]):
        acc += 1

print(acc/len(X_test))

乱数のシードを指定していないので指向のたびに値は変わりますが、おおむね9割程度の精度は出ています。


今回は、非常に単純なモデルについて書きました。その結果、構造は簡単でも9割程度の精度が出るということが確認できました(class_sep大きくしているのでもう少し精度出てもよいと思うが...)。

ABC012 B 入浴時間 を解いた

今回は、AtCoder Beginners Contest 012 のB問題を解きましたのでまとめておきます。この問題は特別アルゴリズム等が必要なわけではなく、難しくもないのですが、ゼロ埋めが必要で、私がこれをよく忘れるので、備忘録的な感じで書きました。
beta.atcoder.jp

問題文

問題は下のようになっています。

高橋君は、お風呂で湯船に浸かった秒数を数える習慣があります。
今日は、高橋君は湯船でN秒まで数えました。
しかし、秒だと解りにくいので、何時間何分何秒、という形に直したいです。
秒数Nが与えられるので、hh:mm:ss の形式に変換しなさい。

入力は以下の形式で標準入力から与えられる。

N

1行目には、高橋君が湯船に浸かった秒数を表す整数N(0≦N≦86399)が与えられる。

出力
高橋君が湯船に浸かっていた時間を、hh:mm:ssの形式で、1行で出力せよ。出力の末尾には改行をいれること。

解答

算数的な感じです。まず時間から求めて、次に分、秒と求めていけばよいです。解答はゼロ埋め出ないとだめなので、iosiomanipをインクルードしておきます。

#include <iostream>
#include <ios>
#include <iomanip>
using namespace std;

int main() {
	int n; cin >> n;
	int hour, min, sec;

	hour = n / 3600;
	n -= hour * 3600;
	min = n / 60;
	n -= min * 60;
	sec = n;

	cout << setfill('0') << right << setw(2) << hour << ":";
	cout << setfill('0') << right << setw(2) << min << ":";
	cout << setfill('0') << right << setw(2) << sec << endl;
}

一応解説をしておくと、setfillで何で埋めるか(今回は0)を指定していて、表示するものを左右どちらに寄せるかを指定しているのが次のrightの部分、そして埋めた後に最終的に何文字にするのかを指定しているのがsetwになります。なのでsetw(4)としたときには、hour=3の時、表示は「0003」と4桁になります。

Kaggleのデータセットで遊んでみた 2 

前回の続きで、Titanicのデータセットで分析の基礎を学びます。

年齢と生存の関係を見てみる

前回の記事の最後に見たヒストグラムでは、Ageのデータがきれいな正規分布のような形をしているのが印象的でした。果たしてこの変数がSurviveに影響してくるのか、Survivedで分けてプロットしなおしてみます。

f,ax=plt.subplots(1,2,figsize=(20,10))
train[train['Survived']==0].Age.plot.hist(ax=ax[0],bins=20,edgecolor='black',color='red')
ax[0].set_title('Survived= 0')
x1=list(range(0,85,5))
ax[0].set_xticks(x1)
train[train['Survived']==1].Age.plot.hist(ax=ax[1],color='green',bins=20,edgecolor='black')
ax[1].set_title('Survived= 1')
x2=list(range(0,85,5))
ax[1].set_xticks(x2)
plt.show()

f:id:pyhaya:20181116000547p:plain
分布の形状に大きな差はありません。しかし、低年齢に注目してみると死者よりも明らかに生存者が多いことがわかります。

コードの説明

コード 動作
plt.subplots(1,2,figsize=(20,10)) プロット領域の用意。1行2列で、大きさは(20, 10)
train[train['Survived']==0].Age trainからSurvivedが0のものを抽出し、そのAgeカラムを取り出す
plot.hist(ax=ax[0], bins=20,...) 最初のプロット領域にヒストグラムをプロット。ビンは20個用意

性別と生存率の関係

性別と生存率の関係はどのようになっているでしょうか?

f,ax=plt.subplots(1,2,figsize=(18,8))
train[['Sex','Survived']].groupby(['Sex']).mean().plot.bar(ax=ax[0])
ax[0].set_title('Survived vs Sex')
sns.countplot('Sex',hue='Survived',data=train,ax=ax[1])
ax[1].set_title('Sex:Survived vs Dead')
plt.show()

f:id:pyhaya:20181116072422p:plain
今度は明確な差が出てきました。男性に比べると女性のほうが生存率が明らかに高いことがわかります。

コードの説明

コード 動作
.groupby(['Sex']).mean() 'Sex'でデータをまとめる。'Survived'の行の値は平均で置き換える
sns.countplot('Sex',hue='Survived',data=train,ax=ax[1]) trainデータで、横軸'Sex'縦軸'Survive'でプロット

複数の指標を同時に比較してみる

バイオリンプロット

複数の指標を一つのグラフ上にプロットするのはバイオリンプロットが便利です。sns.violinplotで利用できます。

f,ax=plt.subplots(1,2,figsize=(18,8))
sns.violinplot("Pclass","Age", hue="Survived", data=train,split=True,ax=ax[0],
              palette='pastel')
ax[0].set_title('Pclass and Age vs Survived')
ax[0].set_yticks(range(0,110,10))
sns.violinplot("Sex","Age", hue="Survived", data=train,split=True,ax=ax[1],
              palette='pastel')
ax[1].set_title('Sex and Age vs Survived')
ax[1].set_yticks(range(0,110,10))
plt.show()

f:id:pyhaya:20181116073730p:plain
ここから読み取れるのは、生存者数のピークは死者数のものとほぼ一致しているが、生存者数に関しては低年齢領域にもう一つ小さなピークがあるものが多いということです。また、Pclass=1では生存者数は1ピークですが、そのピーク位置は死者数と比較して低年齢側にシフトしています。

ペアプロット

sns.jointplot(x="Age", y="Survived", data=train, size=5,ratio=5, kind='kde', color='green')
plt.show()

f:id:pyhaya:20181116075802p:plain

ヒートマップ

ヒートマップで各変数の相関の強さをざっくりと見ます。

plt.figure(figsize=(7,4)) 
sns.heatmap(train.corr(),annot=True,cmap='Reds')
plt.show()

f:id:pyhaya:20181116080951p:plain
annot=Trueとすることで、相関係数がグラフ中に表示されるようになります。

Djangoで家計簿のWebアプリケーションを作る 7 ビューをクラスを使って整理する

Djangoで家計簿のWebアプリケーションを作っています。

ビューが汚い

ここまで様々な機能を実装してきました。その結果、views.pyの中身がだいぶ見づらくなってしまっています。

money/views.py

import calendar
import datetime
from django.shortcuts import render, redirect
from django.utils import timezone
import matplotlib.pyplot as plt
import pytz
from .models import Money
from .forms import SpendingForm
plt.rcParams['font.family'] = 'IPAPGothic'    #日本語の文字化け防止
# Create your views here.
TODAY = str(timezone.now()).split('-')

def index(request, year=TODAY[0], month=TODAY[1]):
    money = Money.objects.filter(use_date__year=year,                                                                                                                               use_date__month=month).order_by('use_date')

    total = index_utils.calc_month_pay(money)
    index_utils.format_date(money)
    form = SpendingForm()
    next_year, next_month = get_next(year, month)
    prev_year, prev_month = get_prev(year, month)

    context = {'year' : year,
            'month' : month,
            'prev_year' : prev_year,
            'prev_month' : prev_month,
            'next_year' : next_year,
            'next_month' : next_month,
            'money' : money,
            'total' : total,
            'form' : form
            }

    draw_graph(year, month)

    if request.method == 'POST':
        data = request.POST
        use_date = data['use_date']
        cost = data['cost']
        detail = data['detail']
        category = data['category']
        use_date = timezone.datetime.strptime(use_date, "%Y/%m/%d")
        tokyo_timezone = pytz.timezone('Asia/Tokyo')
        use_date = tokyo_timezone.localize(use_date)
        use_date += datetime.timedelta(hours=9)

        Money.objects.create(
                use_date = use_date,
                detail = detail,
                cost = int(cost),
                category = category,
                )
        return redirect(to='/money/{}/{}'.format(year, month))

    return render(request, 'money/index.html', context)

#...

どうにかしましょう。

リファクタリング

ビューの機能とは無関係の部分を抽出する

まずはビュー本来の機能である、コンテクストをHTMLに送るということ以外のことをしている部分を探し出して抽出していきましょう。まずはその月の支出合計を計算する部分は抽出できそうです。同様にして、データベースから日付をとってきて表示用に整形する部分も抽出できそうです。

money/utils/index_utils.py

def calc_month_pay(money):
    total = 0
    for m in money:
        total += m.cost

    return total

def format_date(money):
    for m in money:
        date = str(m.use_date).split(' ')[0]
        m.use_date = '/'.join(date.split('-')[1:3])

    return None

ついでに前月と次月を計算する関数もこちらに移してしまいましょう。

money/utils/index_utils.py

def calc_month_pay(money):
    total = 0
    for m in money:
        total += m.cost

    return total

def format_date(money):
    for m in money:
        date = str(m.use_date).split(' ')[0]
        m.use_date = '/'.join(date.split('-')[1:3])

    return None


def get_next(year, month):
    year = int(year)
    month = int(month)

    if month == 12:
        return str(year + 1), '1'
    else:
        return str(year), str(month + 1)

def get_prev(year, month):
    year = int(year)
    month = int(month)
    if month == 1:
        return str(year - 1), '12'
    else:
        return str(year), str(month - 1)

money/views.py

def index(request, year=TODAY[0], month=TODAY[1]):
    money = Money.objects.filter(use_date__year=year,                                                                                                                               use_date__month=month).order_by('use_date')

    total = index_utils.calc_month_pay(money)
    index_utils.format_date(money)
    form = SpendingForm()
    next_year, next_month = index_utils.get_next(year, month)
    prev_year, prev_month = index_utils.get_prev(year, month)

    context = {'year' : year,
            'month' : month,
            'prev_year' : prev_year,
            'prev_month' : prev_month,
            'next_year' : next_year,
            'next_month' : next_month,
            'money' : money,
            'total' : total,
            'form' : form
            }

    draw_graph(year, month)

    if request.method == 'POST':
        data = request.POST
        use_date = data['use_date']
        cost = data['cost']
        detail = data['detail']
        category = data['category']
        use_date = timezone.datetime.strptime(use_date, "%Y/%m/%d")
        tokyo_timezone = pytz.timezone('Asia/Tokyo')
        use_date = tokyo_timezone.localize(use_date)
        use_date += datetime.timedelta(hours=9)

        Money.objects.create(
                use_date = use_date,
                detail = detail,
                cost = int(cost),
                category = category,
                )
        return redirect(to='/money/{}/{}'.format(year, month))

    return render(request, 'money/index.html', context)

ビュークラスを使う

これで少しはすっきりしましたが、通常表示されるときに実行される部分と、postを受け取ったときに実行される部分が混ざってしまっています。これを解決するには、ビューをクラスとして書きます。

money/views.py

import calendar    
import datetime
from django.shortcuts import render, redirect
from django.utils import timezone
from django.views import View
import matplotlib.pyplot as plt
import pytz

from .models import Money
from .forms import SpendingForm
from .utils import index_utils

plt.rcParams['font.family'] = 'IPAPGothic'

TODAY = str(timezone.now()).split('-')

class MainView(View):
    def get(self, request, year=TODAY[0], month=TODAY[1]):
        money = Money.objects.filter(use_date__year=year,    
                            use_date__month=month).order_by('use_date')

        total = index_utils.calc_month_pay(money)
        index_utils.format_date(money)
        form = SpendingForm()
        next_year, next_month = index_utils.get_next(year, month)
        prev_year, prev_month = index_utils.get_prev(year, month)

        context = {'year' : year,
                'month' : month,
                'prev_year' : prev_year,
                'prev_month' : prev_month,
                'next_year' : next_year,
                'next_month' : next_month,
                'money' : money,
                'total' : total,
                'form' : form
                }

        draw_graph(year, month)

        return render(request, 'money/index.html', context)

    def post(self, request, year=TODAY[0], month=TODAY[1]):
        data = request.POST
        use_date = data['use_date']
        cost = data['cost']
        detail = data['detail']
        category = data['category']

        use_date = timezone.datetime.strptime(use_date, "%Y/%m/%d")
        tokyo_timezone = pytz.timezone('Asia/Tokyo')
        use_date = tokyo_timezone.localize(use_date)
        use_date += datetime.timedelta(hours=9)

        Money.objects.create(
                use_date = use_date,
                detail = detail,
                cost = int(cost),
                category = category,
                )
        return redirect(to='/money/{}/{}'.format(year, month)) 

#...

ここでは最も一般的なdjango.views.Viewをテンプレートビューとして使っています。この変更に伴って、urls.pyも少し変更する必要があります。

money/urls.py

from django.urls import path
from . import views

app_name = 'money'
urlpatterns = [
        path('', views.MainView.as_view(), name='index'),
        path('<int:year>/<int:month>', views.MainView.as_view(), name='index'),
        ]

このようにas_viewを付けることによってクラスをビューとして呼び出すことができます。

Kaggleのデータセットで遊んでみた 1 データの可視化

Kaggleとは、機械学習とデータサイエンスのプラットフォームのことです。このサイトでは、様々なデータを使って自分で分析を行うことができたり、データ解析のコンペティションに参加して精度を競い合ったりすることができます。

今回はKaggleの中で最初に出会うであろうTaitanicのデータセットを使ってデータ解析作業の大枠を紹介しているカーネルを紹介したいと思います。下のサイトを参考にしています。すごく良いカーネルなのでぜひ読んでみてください!
A Comprehensive ML Workflow with Python | Kaggle

作業環境

Kaggleのノートブック環境を利用しても、自分のJupyter Notebook環境を使っても構いません。構成の詳細は下のようになっています。

from sklearn.cross_validation import train_test_split
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
import matplotlib.pylab as pylab
import matplotlib.pyplot as plt
from pandas import get_dummies
import matplotlib as mpl
import seaborn as sns
import pandas as pd
import numpy as np
import matplotlib
import warnings
import sklearn
import scipy
import numpy
import json
import sys
import csv
import os

print('matplotlib: {}'.format(matplotlib.__version__))
print('sklearn: {}'.format(sklearn.__version__))
print('scipy: {}'.format(scipy.__version__))
print('seaborn: {}'.format(sns.__version__))
print('pandas: {}'.format(pd.__version__))
print('numpy: {}'.format(np.__version__))
print('Python: {}'.format(sys.version))

# 以下出力
"""
/opt/conda/lib/python3.6/site-packages/sklearn/cross_validation.py:41: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.
  "This module will be removed in 0.20.", DeprecationWarning)
matplotlib: 2.2.3
sklearn: 0.19.1
scipy: 1.1.0
seaborn: 0.9.0
pandas: 0.23.4
numpy: 1.15.2
Python: 3.6.6 |Anaconda, Inc.| (default, Oct  9 2018, 12:34:16) 
[GCC 7.3.0]
"""

基本設定

この後、様々な処理をしていくにあたって、プロットについて基本的な設定をしておきます。

sns.set(style='white', context='notebook', palette='deep')
pylab.rcParams['figure.figsize'] = 12,8
warnings.filterwarnings('ignore')
mpl.style.use('ggplot')
sns.set_style('white')
%matplotlib inline

ここら辺のスタイルは各自の好みで変えてよいと思います。

データを見てみる

散布図

train = pd.read_csv('../input/train.csv')
test = pd.read_csv('../input/test.csv')

train_columns
#以下出力
"""
Index(['PassengerId', 'Survived', 'Pclass', 'Name', 'Sex', 'Age', 'SibSp',
       'Parch', 'Ticket', 'Fare', 'Cabin', 'Embarked'],
      dtype='object')
"""

このデータには「PassengerId」など12個のカラムがあります。今回予測するのは'Survived'つまりタイタニック事故で特定の人が生き残るか死ぬかを予測することになります。そのため、データの分析は基本的には'Survived'とほかの変数の間の関係を見ていくことになります。

試しに、Pclass(客の経済的な豊かさ、1が最も裕福)、年齢、乗船料とSurviveの関係を見てみます。

g = sns.FacetGrid(train, hue="Survived", col="Pclass", margin_titles=True,
                  palette={1:"seagreen", 0:"gray"})
g=g.map(plt.scatter, "Fare", "Age",edgecolor="w").add_legend();

下のようなグラフが表示されます。
f:id:pyhaya:20181115231859p:plain
生存者と死者の間に明確な違いは見受けられません。

コードの説明

1行目のsns.FaceGridではtrainからデータをとってきてプロットをすることを示し、colに指定されたPclassごとにグラフを用意することを宣言しています。なので今回はPclassが1,2,3の3つのグラフが描画されます。hueはSurviveの値で区別してプロットすることを意味しています。Surviveには生存(1)と死亡(0)の2種類がありますが、それぞれを何色でプロットするかはpaletteで指定されています。

実際のグラフ描画はg.mapで行われています。

コード 動作
plt.scatter 散布図スタイルでプロットする
"Fare" 横軸はFare(乗船料)
"Age" 縦軸はAge(年齢)
edgecolor ドットの淵は白(white)
add.legend() 凡例を入れる

Boxプロット(箱ひげ図)

各変数がどのような値の範囲をとるのか図示します。

train.plot(kind='box', subplots=True, layout=(2,4), sharex=False, sharey=False)
plt.subplots_adjust(wspace=0.5, hspace=0.6)

f:id:pyhaya:20181115233910p:plain

箱ひげ図の見方についてはWidipediaを見てください。

コードの説明
コード 動作
train.plot Pandasライブラリの機能でプロットする
kind='box' プロットの種類は箱ひげ図
subplots=True グラフをサブプロットに分ける
layout=(2, 4) 2行4列にサブプロットを配置する
sharex=False, sharey=False x, y軸の値の範囲を共有しない
plt.subplots_adjust(wspace=0.5, hspace=0.6) グラフ同士の幅、高さの間隔調整

ヒストグラム

ヒストグラムは、各変数値にどのくらいのデータがあるのか知るのに便利です。

train.hist(figsize=(15, 20))
plt.show()

f:id:pyhaya:20181115235614p:plainf:id:pyhaya:20181115235618p:plainf:id:pyhaya:20181115235611p:plain

「テスト駆動開発」をPythonで書き直してみた 7

書籍「テスト駆動開発」をPythonで書き直したシリーズです。前回の記事はこちらです。
pyhaya.hatenablog.com

今回は、いよいよ多国通貨を扱うための準備に取り掛かります。

テスト駆動開発

テスト駆動開発

Bank(銀行)

多国通貨を扱うためには、為替レートを使って通貨を換算する必要があります。その役割を銀行クラスを作って任せます。どのような動作を期待するのか、テストを書いて明示します。

tests/test_money.py

import sys
sys.path.append('../src')

import unittest
from money import Money

class MoneyTest(unittest.TestCase):
    def testMultiplication(self):
        five = Money.dollar(5)
        self.assertEqual(Money.dollar(10), five * 2)
        self.assertEqual(Money.dollar(15), five * 3)

    def testEquality(self):
        self.assertNotEqual(Money.franc(5), Money.dollar(5))

    def testSimpleAddition(self):
        bank = Bank()    #銀行を用意
        sum_ = Money.dollar(5) + Money.dollar(5)
        reduced = bank.reduce(sum_, "USD")    #USDに換算する
        self.assertEqual(reduced, Money.dollar(10))

if __name__ == '__main__':
    unittest.main()

銀行を表すオブジェクトを用意しておいて、reduceメソッドで通貨の両替を行います。

明白な実装

このテストを通す明白な実装はこのようになります。

src/bank.py

from money import Money

class Bank:
    def reduce(self, source, to):
        return Money.dollar(10)

Moneyの実装を再考する

ここまで書いたら、一度Moneyの実装を考えてみます。今回、両替を銀行に委譲しました。なので、通貨同士の足し算は足した時点では通貨は決定していない状態です。ただ、足される2つの通貨を持っているだけの中間状態のようなものが必要です。いわば抽象的な通貨を表すためのクラスが必要です。

Sumクラスを実装する

このようなクラスとして、Sumクラスを作ってみます。

src/sum.py

class Sum:
    def __init__(self, augent, addend):
        self.augent = augent
        self.addend = addend

src/money.py

from sum import Sum

class Money:
    def __init__(self, amount, currency):
        self.amount = amount
        self.currency = currency

    def __eq__(self, other):
        return self.__dict__ == other.__dict__

    def __add__(self, other):
        return Sum(self, other)

    def __mul__(self, multiplier):
        return Money(self.amount * multiplier, self.currency)
    
    @staticmethod
    def dollar(amount):
        return Money(amount, 'USD')

    @staticmethod
    def franc(amount):
        return Money(amount, 'CHF')

__add__メソッドはSumクラスのインスタンスを返します。Sumクラスは初期化の際に被加算数(addend)と加算数(augent)を持ちます。

Bankクラスの本実装

ここまでくれば、ロジックが固まってきたのでBankクラスのreduceメソッドを書き下せます。source引数はSumクラスのインスタンスを引数に持つので、

src/bank.py

from money import Money

class Bank:
    def reduce(self, source, to):
        amount = source.augent.amount + source.addend.amount
        return Money(amount, to)

まとめ

今回は多国通貨を扱うための下準備を行いました。ここまでくるとだいぶロジックが複雑になってくるのでテストを書いてどのようなロジックにしたいのか明確にしておくことが重要です。