Kaggleのタイタニックデータの解析
Kaggleの定番データセットといえば「タイタニックの生存者予測」です。今回は生存者の予測を目指して解析を行っていきたいと思います。データの可視化について詳しい説明は前回記事で書いているのでそちらを参照してください。
Titanic: Machine Learning from Disaster | Kaggle
解析環境
- Kaggleのカーネル
データの全体像の把握
データ解析をする際に最も重要なことはもちろんデータを理解することです。まずは大雑把にデータの全体像を見てみましょう。
import numpy as np import pandas as pd import matplotlib.pyplot as plt import seaborn as sns import os import warnings warnings.filterwarnings('ignore') print(os.listdir("../input")) # -> ['train.csv', 'gender_submission.csv', 'test.csv']
train = pd.read_csv('../input/train.csv') test = pd.read_csv('../input/test.csv') train.head()
いくつかの変数があることがわかります。
- PassengerId : 乗客ID
- Survived : 生存ラベル(1: 生存 0:死亡)
- Pclass : 身分(1が最高)
- Name : 乗客氏名
- Sex : 性別
- Age : 年齢
- SibSp : 乗客中の親戚の人数
- Parch : 乗客中の親もしくは子供の人数
- Ticket : チケット番号
- Fare : 乗船料
- Cabin : 客室番号
- Embarked : どこの港で乗船したか
変数の吟味
これらの変数の中で、中心となるのはもちろん'Survived'です。これにほかの変数が影響を与えうるか考えてみます。
'PassengerId', 'Ticket'
'PassengerId'はただの通し番号なので、生存には関係ない気がします。ただ、これが乗客の身分や部屋の位置に関係する可能性は一度調べる必要がありそうです。これは'Ticket'にも言えます。
'Pclass', 'Cabin', 'Fare'
'Pclass'に関しては、漠然と、身分の高い人は安全なところに客室があるのかなと思うので、関係ありそうです。これも後で解析しましょう。同様に'Cabin'と'Fare'も調べてみます。
'Sex', 'Age'
'Sex'や'Age'に関しては、欧米ではレディーファーストの精神が強いので女性や子供が優先的に救助されていた可能性があります。
'Name'
'Name'は生存とは全く関係なさそうに見えます。ただ、上の表を見てみると、乗客の名前には最初に'Mr'や'Mrs'などがついていることがわかります。これはその人の社会的な地位や状態を表しているといえるので、ここの情報だけは使える気がします。
'SibSp', 'Parch'
これらは親戚関係のパラメータですが、正直、これらのパラメータが'Survived'に効いてくるストーリーは思いつきませんでした。相関係数だけ調べて、早々に落としてもいいかもしれません。
'Embarked'
乗船地です。これはその土地の発展状況によって乗客層の経済力を示している可能性があります。
カテゴリ変数を数値へ変換していく
'Name'を変換する
'Name'の変換は少し手間です。まずは正規表現を使って敬称を取り出し、'Title'という名前の新しいカラムを作ります。
train['Title'] = train['Name'].str.extract(r'([A-Za-z]+)\.', expand=False) test['Title'] = test['Name'].str.extract(r'([A-Za-z]+)\.', expand=False)
数値に置き換える
では、'Sex', 'Embarked', 'Title'を数値に変換していきます。これにはsklearn.preprocessing.LabelEncoder
を使います。
from sklearn.preprocessing import LabelEncoder for col in ['Sex', 'Title']: le = LabelEncoder() le.fit(train[col]) train[col] = le.transform(train[col]) le.fit(test[col]) test[col] = le.transform(test[col]) train['Embarked'] = train['Embarked'].map({'S' : 1, 'C' : 2, 'Q' : 3}) test['Embarked'] = train['Embarked'].map({'S' : 1, 'C' : 2, 'Q' : 3})
多分'Embarked'はnull valueがあってうまくいかなかったので普通にmap
を使って置き換えました。
'Name'を落とす
もう名前情報はいらなくなったので落としてしまいましょう。
train.drop('Name', axis=1, inplace=True) test.drop('Name', axis=1, inplace=True)
ここまでやってきて、データの状況は下のようになっています。
Ticketを処理する
現時点ではまだ、'Ticket'が数値量になっていません。しかし、どのように数値に変換してよいのかはっきりしないので、もう少し詳しくこの変数について知ることにします。
いくつかTicketを眺めてみると、数字だけからなるチケット番号と、アルファベットが混ざるものの2つがあることがわかります。まずはこれから分離してみます。
number_ticket = train[train['Ticket'].str.match('\d+')] num_alpha_ticket = train[train['Ticket'].str.match('[A-Z]+.+')]
数字だけのチケット
数字だけのチケットnumber_ticket
についてまず見てみましょう。数字がどんなふうに分布しているか見るために、簡単なプロットをしてみます。
number_ticket['Ticket'] = number_ticket['Ticket'].apply(lambda x: int(x)) number_ticket.sort_values('Ticket', inplace=True) plt.figure() plt.plot(number_ticket['Ticket'], '-o') plt.show()
大きく分けて2つのグループに分かれている様子が見て取れます。ylim
をいじって拡大してみると、チケット番号は以下の5通りに分けられることがわかります。
- 10000以下
i*1e+4
より大きく(i+1)*1e+4
以下 (i=1,2,3)- 300000以上
このグループ分けをしたときに、生存率に差が出るか見てみましょう。
x = [1, 2, 3, 4, 5] lowest_num_ticket = number_ticket[number_ticket['Ticket'] <= 100000] sec_lowest_num_ticket = number_ticket[(number_ticket['Ticket'] > 100000) & (number_ticket['Ticket'] < 200000)] thir_lowest_num_ticket = number_ticket[(number_ticket['Ticket'] > 200000) & (number_ticket['Ticket'] < 300000)] four_lowest_num_ticket = number_ticket[(number_ticket['Ticket'] > 300000) & (number_ticket['Ticket'] < 400000)] high_num_ticket = number_ticket[number_ticket['Ticket'] > 3000000] y = [lowest_num_ticket['Survived'].mean(), sec_lowest_num_ticket['Survived'].mean(), thir_lowest_num_ticket['Survived'].mean(), four_lowest_num_ticket['Survived'].mean(), high_num_ticket['Survived'].mean() ] plt.figure() plt.bar(x, y) plt.xlabel('ticket number') plt.ylabel('Survived') plt.show()
数字が大きいほど生存率が小さくなるという傾向がありそうです。
アルファベットが入ったチケット
こちらは上に比べて分類が面倒です。正規表現を使ってうまくやっていきます。
A_ticket = num_alpha_ticket[num_alpha_ticket['Ticket'].str.match('A.+')] CA_ticket = num_alpha_ticket[num_alpha_ticket['Ticket'].str.match('C\.*A\.*.+')] PC_ticket = num_alpha_ticket[num_alpha_ticket['Ticket'].str.match('PC.+')] PP_ticket = num_alpha_ticket[num_alpha_ticket['Ticket'].str.match('PP.+')] SOTON_ticket = num_alpha_ticket[num_alpha_ticket['Ticket'].str.match('SOTON.+')] STON_ticket = num_alpha_ticket[num_alpha_ticket['Ticket'].str.match('STON.+')] LINE_ticket = num_alpha_ticket[num_alpha_ticket['Ticket'].str.match('LINE.*')] FC_ticket = num_alpha_ticket[num_alpha_ticket['Ticket'].str.match('F\.C\.(C\.)*.+')] W_ticket = num_alpha_ticket[num_alpha_ticket['Ticket'].str.match('W.+')] C_ticket = num_alpha_ticket[num_alpha_ticket['Ticket'].str.match('C.+')] SC_ticket = num_alpha_ticket[num_alpha_ticket['Ticket'].str.match('S(\.)*C.+')] SO_ticket = num_alpha_ticket[num_alpha_ticket['Ticket'].str.match('S(\.)*O.+')] other_ticket = num_alpha_ticket[ num_alpha_ticket['Ticket'].str.match( '(Fa)*(P/PP)*(S\.P)*(S\.*W)*.+' ) ] x = [i for i in range(1, 14)] y = [A_ticket['Survived'].mean(), CA_ticket['Survived'].mean(), PC_ticket['Survived'].mean() ,PP_ticket['Survived'].mean(), SOTON_ticket['Survived'].mean(), STON_ticket['Survived'].mean() ,LINE_ticket['Survived'].mean(), FC_ticket['Survived'].mean(), W_ticket['Survived'].mean() ,C_ticket['Survived'].mean(), SC_ticket['Survived'].mean(), SO_ticket['Survived'].mean() ,other_ticket['Survived'].mean() ] plt.figure() plt.bar(x, y) plt.ylabel('survived') plt.show()
このように場合分けが多岐にわたるときには、漏れがないか確認する手段を持っておくことが重要です。今回の場合は、セット型を使ってチェックが可能です。
new_set = set(A_ticket['Ticket']) | set(CA_ticket['Ticket']) |\ set(PC_ticket['Ticket']) | set(PP_ticket['Ticket']) |\ set(SOTON_ticket['Ticket']) | set(STON_ticket['Ticket']) |\ set(LINE_ticket['Ticket']) | set(FC_ticket['Ticket']) |\ set(W_ticket['Ticket']) | set(C_ticket['Ticket']) |\ set(SC_ticket['Ticket']) | set(SO_ticket['Ticket']) |\ set(other_ticket['Ticket']) set(num_alpha_ticket['Ticket']) - new_set
これで計算結果が空のセットになることを確認します。
グラフを見ると、チケットによって生存率には大きな差が出ていることがわかります。
チケットのラベリング
この結果をもとにして、チケットのラベリングを行います。
number_ticket.loc[number_ticket['Ticket'] <= 100000, 'Ticket'] = 14 number_ticket.loc[(number_ticket['Ticket'] > 100000) & (number_ticket['Ticket'] <= 200000), 'Ticket'] = 15 number_ticket.loc[(number_ticket['Ticket'] > 200000) & (number_ticket['Ticket'] <= 300000), 'Ticket'] = 13 number_ticket.loc[(number_ticket['Ticket'] > 300000) & (number_ticket['Ticket'] <= 400000), 'Ticket'] = 5 number_ticket.loc[number_ticket['Ticket'] > 3000000, 'Ticket'] = 6 num_alpha_ticket.loc[num_alpha_ticket['Ticket'].str.match('A.+'), 'Ticket'] = "1" num_alpha_ticket.loc[num_alpha_ticket['Ticket'].str.match('C\.*A\.*.+'), 'Ticket'] = "8" num_alpha_ticket.loc[num_alpha_ticket['Ticket'].str.match('PC.+'), 'Ticket'] = "16" num_alpha_ticket.loc[num_alpha_ticket['Ticket'].str.match('PP.+'), 'Ticket'] = "18" num_alpha_ticket.loc[num_alpha_ticket['Ticket'].str.match('SOTON.+'), 'Ticket'] = "3" num_alpha_ticket.loc[num_alpha_ticket['Ticket'].str.match('STON.+'), 'Ticket'] = "11" num_alpha_ticket.loc[num_alpha_ticket['Ticket'].str.match('LINE.*'), 'Ticket'] = "7" num_alpha_ticket.loc[num_alpha_ticket['Ticket'].str.match('F\.C\.(C\.)*.+'), 'Ticket'] = "17" num_alpha_ticket.loc[num_alpha_ticket['Ticket'].str.match('W.+'), 'Ticket'] = "4" num_alpha_ticket.loc[num_alpha_ticket['Ticket'].str.match('C.+'), 'Ticket'] = "9" num_alpha_ticket.loc[num_alpha_ticket['Ticket'].str.match('S(\.)*C.+'), 'Ticket'] = "12" num_alpha_ticket.loc[num_alpha_ticket['Ticket'].str.match('S(\.)*O.+'), 'Ticket'] = "2" num_alpha_ticket.loc[num_alpha_ticket['Ticket'].str.match('[^\d](Fa)*(P/PP)*(S\.P)*(S\.*W)*.+'), 'Ticket'] = "10" num_alpha_ticket['Ticket'] = num_alpha_ticket['Ticket'].apply(lambda x: int(x)) train = pd.concat([number_ticket, num_alpha_ticket])
ここまでの処理の結果
ここまでの処理で、各変数の相関係数がどのように変化したか見てみます。
plt.figure(figsize=(10, 8)) sns.heatmap(train.corr(), annot=True, cmap='Reds') plt.show()
TicketとSurvivedに比較的強い相関が出ました。
次回は、これらのデータを使って実際に機械学習をやってみたいと思います。
- 作者: 門脇大輔,阪田隆司,保坂桂佑,平松雄司
- 出版社/メーカー: 技術評論社
- 発売日: 2019/10/09
- メディア: 単行本(ソフトカバー)
- この商品を含むブログを見る
機械学習を原理から理解する 線形分類
最近、機械学習がブームでPythonを使えばだれでも簡単に学習器を作れるようになってきましたね。Pythonはライブラリが充実しているのでモデルについて何も知らなくでも機械学習をできます。私も別に仕事で機械学習を使っているわけではないのでそのような状態に甘んじていたのですが、最近になってさすがに気持ち悪さを感じていたので勉強を始めました。
今回は簡単なモデルとして、分類問題、特に線形の二値分類問題を考えます。
対象とする問題
いくつかの特徴量が与えられていて、それに対して正解ラベルが+1, -1で与えられているような状況を考えます。考える仮説空間はhalf-spacesクラス(日本語訳がわからない)です。
仮説空間という言葉が出てくると難しく感じますが、予測に使うのは下のような関数です。
は重みとバイアスです。基本的には線形関数で、最終的にsignで符号だけ取り出すことで予測値を算出しています。
学習アルゴリズム
与えられている特徴量を、正解ラベルをとします。この時、予測が当たっている状況を数式であらわすと、
符号関数は、少し抽象的です。もう少し扱いやすい形に書き換えましょう。要は両辺で符号さえ合っていればよいので、
としてもよいでしょう。訓練データが全部で個あるような状況を考えると、重みパラメータとバイアスが満たすべき条件は、
となります。さらに条件を詰めます。まず、バイアスを重みパラメータの中に取り込んでしまいます。
次に重みパラメータをの最小値で規格化(?)します。
これをつかうことで、条件はもう少し厳しくすることができて、
これが求めるべきパラメータの満たすべき条件です。
どのように解くか
問題はこれをどう解くかです。これはよく見ると、次のような問題と同じであることがわかります。
maximize 0
s. t.
線形計画問題です。今回は目的関数はどうでもよくて、制約条件のみが問題になります。この問題の解き方は探せば見つかって、シンプレックス法(単体法)というものが有名です。
Pythonで実装する
では、Pythonで実装します。線形計画法のところはめんどくさくなってライブラリ使いました。ごめんなさい。
import numpy as np from sklearn.datasets import make_classification from sklearn.model_selection import train_test_split if __name__ == '__main__': #サンプル生成 X, y = make_classification( n_samples=500, n_classes=2, n_features=2, n_redundant=0, class_sep=3, ) y = list(map(lambda x: -1 if x==0 else 1, y)) X_train, X_test, y_train, y_test = train_test_split(X, y)
分類サンプルは、Scikit-learnのmake_classification
を使って生成しました。今回は簡単のためにサンプル数は500個、クラス数は2(二値分類)、特徴量は2種類で重要性が少ないものは0個としました。
学習アルゴリズムの本体は下のように定義します。
import pulp def solve(X, y): lp = pulp.LpProblem('lp', pulp.LpMaximize) # 解く問題を作る w1 = pulp.LpVariable('w1', 0) #重みパラメータ1 w2 = pulp.LpVariable('w2', 0) #重みパラメータ2 w3 = pulp.LpVariable('b', 0) #バイアス lp += 0 #目的関数 for i in range(len(X)): lp += y[i] * (X[i, 0] * w1 + X[i, 1] * w2 + w3) >= 1 #制約条件 lp.solve() #解く return w1.value(), w2.value(), w3.value() #各パラメータを返す
pulp
という線形計画問題を解くためのライブラリがありますのでありがたく使わせてもらいます。
学習がうまくいっているか確かめるために、テストデータを使って精度の確認を行います。
w1, w2, b = solve(X, y) acc = 0 for i in range(len(X_test)): pred = X_test[i].dot(np.array([[w1], [w2]])) + b if np.sign(pred) == np.sign(y_test[i]): acc += 1 print(acc/len(X_test))
乱数のシードを指定していないので指向のたびに値は変わりますが、おおむね9割程度の精度は出ています。
今回は、非常に単純なモデルについて書きました。その結果、構造は簡単でも9割程度の精度が出るということが確認できました(class_sep
大きくしているのでもう少し精度出てもよいと思うが...)。
ABC012 B 入浴時間 を解いた
今回は、AtCoder Beginners Contest 012 のB問題を解きましたのでまとめておきます。この問題は特別アルゴリズム等が必要なわけではなく、難しくもないのですが、ゼロ埋めが必要で、私がこれをよく忘れるので、備忘録的な感じで書きました。
beta.atcoder.jp
問題文
問題は下のようになっています。
高橋君は、お風呂で湯船に浸かった秒数を数える習慣があります。
今日は、高橋君は湯船でN秒まで数えました。
しかし、秒だと解りにくいので、何時間何分何秒、という形に直したいです。
秒数Nが与えられるので、hh:mm:ss の形式に変換しなさい。入力は以下の形式で標準入力から与えられる。
N1行目には、高橋君が湯船に浸かった秒数を表す整数N(0≦N≦86399)が与えられる。
出力
高橋君が湯船に浸かっていた時間を、hh:mm:ssの形式で、1行で出力せよ。出力の末尾には改行をいれること。
解答
算数的な感じです。まず時間から求めて、次に分、秒と求めていけばよいです。解答はゼロ埋め出ないとだめなので、ios
とiomanip
をインクルードしておきます。
#include <iostream> #include <ios> #include <iomanip> using namespace std; int main() { int n; cin >> n; int hour, min, sec; hour = n / 3600; n -= hour * 3600; min = n / 60; n -= min * 60; sec = n; cout << setfill('0') << right << setw(2) << hour << ":"; cout << setfill('0') << right << setw(2) << min << ":"; cout << setfill('0') << right << setw(2) << sec << endl; }
一応解説をしておくと、setfill
で何で埋めるか(今回は0)を指定していて、表示するものを左右どちらに寄せるかを指定しているのが次のright
の部分、そして埋めた後に最終的に何文字にするのかを指定しているのがsetw
になります。なのでsetw(4)
としたときには、hour=3
の時、表示は「0003」と4桁になります。
Kaggleのデータセットで遊んでみた 2
前回の続きで、Titanicのデータセットで分析の基礎を学びます。
年齢と生存の関係を見てみる
前回の記事の最後に見たヒストグラムでは、Ageのデータがきれいな正規分布のような形をしているのが印象的でした。果たしてこの変数がSurviveに影響してくるのか、Survivedで分けてプロットしなおしてみます。
f,ax=plt.subplots(1,2,figsize=(20,10)) train[train['Survived']==0].Age.plot.hist(ax=ax[0],bins=20,edgecolor='black',color='red') ax[0].set_title('Survived= 0') x1=list(range(0,85,5)) ax[0].set_xticks(x1) train[train['Survived']==1].Age.plot.hist(ax=ax[1],color='green',bins=20,edgecolor='black') ax[1].set_title('Survived= 1') x2=list(range(0,85,5)) ax[1].set_xticks(x2) plt.show()
分布の形状に大きな差はありません。しかし、低年齢に注目してみると死者よりも明らかに生存者が多いことがわかります。
コードの説明
コード | 動作 |
---|---|
plt.subplots(1,2,figsize=(20,10)) |
プロット領域の用意。1行2列で、大きさは(20, 10) |
train[train['Survived']==0].Age |
trainからSurvivedが0のものを抽出し、そのAgeカラムを取り出す |
plot.hist(ax=ax[0], bins=20,...) |
最初のプロット領域にヒストグラムをプロット。ビンは20個用意 |
性別と生存率の関係
性別と生存率の関係はどのようになっているでしょうか?
f,ax=plt.subplots(1,2,figsize=(18,8)) train[['Sex','Survived']].groupby(['Sex']).mean().plot.bar(ax=ax[0]) ax[0].set_title('Survived vs Sex') sns.countplot('Sex',hue='Survived',data=train,ax=ax[1]) ax[1].set_title('Sex:Survived vs Dead') plt.show()
今度は明確な差が出てきました。男性に比べると女性のほうが生存率が明らかに高いことがわかります。
コードの説明
コード | 動作 |
---|---|
.groupby(['Sex']).mean() |
'Sex'でデータをまとめる。'Survived'の行の値は平均で置き換える |
sns.countplot('Sex',hue='Survived',data=train,ax=ax[1]) |
trainデータで、横軸'Sex'縦軸'Survive'でプロット |
複数の指標を同時に比較してみる
バイオリンプロット
複数の指標を一つのグラフ上にプロットするのはバイオリンプロットが便利です。sns.violinplot
で利用できます。
f,ax=plt.subplots(1,2,figsize=(18,8)) sns.violinplot("Pclass","Age", hue="Survived", data=train,split=True,ax=ax[0], palette='pastel') ax[0].set_title('Pclass and Age vs Survived') ax[0].set_yticks(range(0,110,10)) sns.violinplot("Sex","Age", hue="Survived", data=train,split=True,ax=ax[1], palette='pastel') ax[1].set_title('Sex and Age vs Survived') ax[1].set_yticks(range(0,110,10)) plt.show()
ここから読み取れるのは、生存者数のピークは死者数のものとほぼ一致しているが、生存者数に関しては低年齢領域にもう一つ小さなピークがあるものが多いということです。また、Pclass=1では生存者数は1ピークですが、そのピーク位置は死者数と比較して低年齢側にシフトしています。
ペアプロット
sns.jointplot(x="Age", y="Survived", data=train, size=5,ratio=5, kind='kde', color='green') plt.show()
ヒートマップ
ヒートマップで各変数の相関の強さをざっくりと見ます。
plt.figure(figsize=(7,4)) sns.heatmap(train.corr(),annot=True,cmap='Reds') plt.show()
annot=True
とすることで、相関係数がグラフ中に表示されるようになります。
Djangoで家計簿のWebアプリケーションを作る 7 ビューをクラスを使って整理する
Djangoで家計簿のWebアプリケーションを作っています。
ビューが汚い
ここまで様々な機能を実装してきました。その結果、views.pyの中身がだいぶ見づらくなってしまっています。
money/views.py
import calendar import datetime from django.shortcuts import render, redirect from django.utils import timezone import matplotlib.pyplot as plt import pytz from .models import Money from .forms import SpendingForm plt.rcParams['font.family'] = 'IPAPGothic' #日本語の文字化け防止 # Create your views here. TODAY = str(timezone.now()).split('-') def index(request, year=TODAY[0], month=TODAY[1]): money = Money.objects.filter(use_date__year=year, use_date__month=month).order_by('use_date') total = index_utils.calc_month_pay(money) index_utils.format_date(money) form = SpendingForm() next_year, next_month = get_next(year, month) prev_year, prev_month = get_prev(year, month) context = {'year' : year, 'month' : month, 'prev_year' : prev_year, 'prev_month' : prev_month, 'next_year' : next_year, 'next_month' : next_month, 'money' : money, 'total' : total, 'form' : form } draw_graph(year, month) if request.method == 'POST': data = request.POST use_date = data['use_date'] cost = data['cost'] detail = data['detail'] category = data['category'] use_date = timezone.datetime.strptime(use_date, "%Y/%m/%d") tokyo_timezone = pytz.timezone('Asia/Tokyo') use_date = tokyo_timezone.localize(use_date) use_date += datetime.timedelta(hours=9) Money.objects.create( use_date = use_date, detail = detail, cost = int(cost), category = category, ) return redirect(to='/money/{}/{}'.format(year, month)) return render(request, 'money/index.html', context) #...
どうにかしましょう。
リファクタリング
ビューの機能とは無関係の部分を抽出する
まずはビュー本来の機能である、コンテクストをHTMLに送るということ以外のことをしている部分を探し出して抽出していきましょう。まずはその月の支出合計を計算する部分は抽出できそうです。同様にして、データベースから日付をとってきて表示用に整形する部分も抽出できそうです。
money/utils/index_utils.py
def calc_month_pay(money): total = 0 for m in money: total += m.cost return total def format_date(money): for m in money: date = str(m.use_date).split(' ')[0] m.use_date = '/'.join(date.split('-')[1:3]) return None
ついでに前月と次月を計算する関数もこちらに移してしまいましょう。
money/utils/index_utils.py
def calc_month_pay(money): total = 0 for m in money: total += m.cost return total def format_date(money): for m in money: date = str(m.use_date).split(' ')[0] m.use_date = '/'.join(date.split('-')[1:3]) return None def get_next(year, month): year = int(year) month = int(month) if month == 12: return str(year + 1), '1' else: return str(year), str(month + 1) def get_prev(year, month): year = int(year) month = int(month) if month == 1: return str(year - 1), '12' else: return str(year), str(month - 1)
money/views.py
def index(request, year=TODAY[0], month=TODAY[1]): money = Money.objects.filter(use_date__year=year, use_date__month=month).order_by('use_date') total = index_utils.calc_month_pay(money) index_utils.format_date(money) form = SpendingForm() next_year, next_month = index_utils.get_next(year, month) prev_year, prev_month = index_utils.get_prev(year, month) context = {'year' : year, 'month' : month, 'prev_year' : prev_year, 'prev_month' : prev_month, 'next_year' : next_year, 'next_month' : next_month, 'money' : money, 'total' : total, 'form' : form } draw_graph(year, month) if request.method == 'POST': data = request.POST use_date = data['use_date'] cost = data['cost'] detail = data['detail'] category = data['category'] use_date = timezone.datetime.strptime(use_date, "%Y/%m/%d") tokyo_timezone = pytz.timezone('Asia/Tokyo') use_date = tokyo_timezone.localize(use_date) use_date += datetime.timedelta(hours=9) Money.objects.create( use_date = use_date, detail = detail, cost = int(cost), category = category, ) return redirect(to='/money/{}/{}'.format(year, month)) return render(request, 'money/index.html', context)
ビュークラスを使う
これで少しはすっきりしましたが、通常表示されるときに実行される部分と、postを受け取ったときに実行される部分が混ざってしまっています。これを解決するには、ビューをクラスとして書きます。
money/views.py
import calendar import datetime from django.shortcuts import render, redirect from django.utils import timezone from django.views import View import matplotlib.pyplot as plt import pytz from .models import Money from .forms import SpendingForm from .utils import index_utils plt.rcParams['font.family'] = 'IPAPGothic' TODAY = str(timezone.now()).split('-') class MainView(View): def get(self, request, year=TODAY[0], month=TODAY[1]): money = Money.objects.filter(use_date__year=year, use_date__month=month).order_by('use_date') total = index_utils.calc_month_pay(money) index_utils.format_date(money) form = SpendingForm() next_year, next_month = index_utils.get_next(year, month) prev_year, prev_month = index_utils.get_prev(year, month) context = {'year' : year, 'month' : month, 'prev_year' : prev_year, 'prev_month' : prev_month, 'next_year' : next_year, 'next_month' : next_month, 'money' : money, 'total' : total, 'form' : form } draw_graph(year, month) return render(request, 'money/index.html', context) def post(self, request, year=TODAY[0], month=TODAY[1]): data = request.POST use_date = data['use_date'] cost = data['cost'] detail = data['detail'] category = data['category'] use_date = timezone.datetime.strptime(use_date, "%Y/%m/%d") tokyo_timezone = pytz.timezone('Asia/Tokyo') use_date = tokyo_timezone.localize(use_date) use_date += datetime.timedelta(hours=9) Money.objects.create( use_date = use_date, detail = detail, cost = int(cost), category = category, ) return redirect(to='/money/{}/{}'.format(year, month)) #...
ここでは最も一般的なdjango.views.View
をテンプレートビューとして使っています。この変更に伴って、urls.pyも少し変更する必要があります。
money/urls.py
from django.urls import path from . import views app_name = 'money' urlpatterns = [ path('', views.MainView.as_view(), name='index'), path('<int:year>/<int:month>', views.MainView.as_view(), name='index'), ]
このようにas_view
を付けることによってクラスをビューとして呼び出すことができます。
Kaggleのデータセットで遊んでみた 1 データの可視化
Kaggleとは、機械学習とデータサイエンスのプラットフォームのことです。このサイトでは、様々なデータを使って自分で分析を行うことができたり、データ解析のコンペティションに参加して精度を競い合ったりすることができます。
今回はKaggleの中で最初に出会うであろうTaitanicのデータセットを使ってデータ解析作業の大枠を紹介しているカーネルを紹介したいと思います。下のサイトを参考にしています。すごく良いカーネルなのでぜひ読んでみてください!
A Comprehensive ML Workflow with Python | Kaggle
作業環境
Kaggleのノートブック環境を利用しても、自分のJupyter Notebook環境を使っても構いません。構成の詳細は下のようになっています。
from sklearn.cross_validation import train_test_split from sklearn.metrics import classification_report from sklearn.metrics import confusion_matrix from sklearn.metrics import accuracy_score import matplotlib.pylab as pylab import matplotlib.pyplot as plt from pandas import get_dummies import matplotlib as mpl import seaborn as sns import pandas as pd import numpy as np import matplotlib import warnings import sklearn import scipy import numpy import json import sys import csv import os print('matplotlib: {}'.format(matplotlib.__version__)) print('sklearn: {}'.format(sklearn.__version__)) print('scipy: {}'.format(scipy.__version__)) print('seaborn: {}'.format(sns.__version__)) print('pandas: {}'.format(pd.__version__)) print('numpy: {}'.format(np.__version__)) print('Python: {}'.format(sys.version)) # 以下出力 """ /opt/conda/lib/python3.6/site-packages/sklearn/cross_validation.py:41: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20. "This module will be removed in 0.20.", DeprecationWarning) matplotlib: 2.2.3 sklearn: 0.19.1 scipy: 1.1.0 seaborn: 0.9.0 pandas: 0.23.4 numpy: 1.15.2 Python: 3.6.6 |Anaconda, Inc.| (default, Oct 9 2018, 12:34:16) [GCC 7.3.0] """
基本設定
この後、様々な処理をしていくにあたって、プロットについて基本的な設定をしておきます。
sns.set(style='white', context='notebook', palette='deep') pylab.rcParams['figure.figsize'] = 12,8 warnings.filterwarnings('ignore') mpl.style.use('ggplot') sns.set_style('white') %matplotlib inline
ここら辺のスタイルは各自の好みで変えてよいと思います。
データを見てみる
散布図
train = pd.read_csv('../input/train.csv') test = pd.read_csv('../input/test.csv') train_columns #以下出力 """ Index(['PassengerId', 'Survived', 'Pclass', 'Name', 'Sex', 'Age', 'SibSp', 'Parch', 'Ticket', 'Fare', 'Cabin', 'Embarked'], dtype='object') """
このデータには「PassengerId」など12個のカラムがあります。今回予測するのは'Survived'つまりタイタニック事故で特定の人が生き残るか死ぬかを予測することになります。そのため、データの分析は基本的には'Survived'とほかの変数の間の関係を見ていくことになります。
試しに、Pclass(客の経済的な豊かさ、1が最も裕福)、年齢、乗船料とSurviveの関係を見てみます。
g = sns.FacetGrid(train, hue="Survived", col="Pclass", margin_titles=True, palette={1:"seagreen", 0:"gray"}) g=g.map(plt.scatter, "Fare", "Age",edgecolor="w").add_legend();
下のようなグラフが表示されます。
生存者と死者の間に明確な違いは見受けられません。
コードの説明
1行目のsns.FaceGrid
ではtrainからデータをとってきてプロットをすることを示し、col
に指定されたPclassごとにグラフを用意することを宣言しています。なので今回はPclassが1,2,3の3つのグラフが描画されます。hue
はSurviveの値で区別してプロットすることを意味しています。Surviveには生存(1)と死亡(0)の2種類がありますが、それぞれを何色でプロットするかはpalette
で指定されています。
実際のグラフ描画はg.map
で行われています。
コード | 動作 |
---|---|
plt.scatter |
散布図スタイルでプロットする |
"Fare" |
横軸はFare(乗船料) |
"Age" |
縦軸はAge(年齢) |
edgecolor |
ドットの淵は白(white) |
add.legend() |
凡例を入れる |
Boxプロット(箱ひげ図)
各変数がどのような値の範囲をとるのか図示します。
train.plot(kind='box', subplots=True, layout=(2,4), sharex=False, sharey=False) plt.subplots_adjust(wspace=0.5, hspace=0.6)
箱ひげ図の見方についてはWidipediaを見てください。
コードの説明
コード | 動作 |
---|---|
train.plot |
Pandasライブラリの機能でプロットする |
kind='box' |
プロットの種類は箱ひげ図 |
subplots=True |
グラフをサブプロットに分ける |
layout=(2, 4) |
2行4列にサブプロットを配置する |
sharex=False, sharey=False |
x, y軸の値の範囲を共有しない |
plt.subplots_adjust(wspace=0.5, hspace=0.6) |
グラフ同士の幅、高さの間隔調整 |
「テスト駆動開発」をPythonで書き直してみた 7
書籍「テスト駆動開発」をPythonで書き直したシリーズです。前回の記事はこちらです。
pyhaya.hatenablog.com
今回は、いよいよ多国通貨を扱うための準備に取り掛かります。
- 作者: Kent Beck,和田卓人
- 出版社/メーカー: オーム社
- 発売日: 2017/10/14
- メディア: 単行本(ソフトカバー)
- この商品を含むブログ (1件) を見る
Bank(銀行)
多国通貨を扱うためには、為替レートを使って通貨を換算する必要があります。その役割を銀行クラスを作って任せます。どのような動作を期待するのか、テストを書いて明示します。
tests/test_money.py
import sys sys.path.append('../src') import unittest from money import Money class MoneyTest(unittest.TestCase): def testMultiplication(self): five = Money.dollar(5) self.assertEqual(Money.dollar(10), five * 2) self.assertEqual(Money.dollar(15), five * 3) def testEquality(self): self.assertNotEqual(Money.franc(5), Money.dollar(5)) def testSimpleAddition(self): bank = Bank() #銀行を用意 sum_ = Money.dollar(5) + Money.dollar(5) reduced = bank.reduce(sum_, "USD") #USDに換算する self.assertEqual(reduced, Money.dollar(10)) if __name__ == '__main__': unittest.main()
銀行を表すオブジェクトを用意しておいて、reduce
メソッドで通貨の両替を行います。
明白な実装
このテストを通す明白な実装はこのようになります。
src/bank.py
from money import Money class Bank: def reduce(self, source, to): return Money.dollar(10)
Moneyの実装を再考する
ここまで書いたら、一度Moneyの実装を考えてみます。今回、両替を銀行に委譲しました。なので、通貨同士の足し算は足した時点では通貨は決定していない状態です。ただ、足される2つの通貨を持っているだけの中間状態のようなものが必要です。いわば抽象的な通貨を表すためのクラスが必要です。
Sumクラスを実装する
このようなクラスとして、Sum
クラスを作ってみます。
src/sum.py
class Sum: def __init__(self, augent, addend): self.augent = augent self.addend = addend
src/money.py
from sum import Sum class Money: def __init__(self, amount, currency): self.amount = amount self.currency = currency def __eq__(self, other): return self.__dict__ == other.__dict__ def __add__(self, other): return Sum(self, other) def __mul__(self, multiplier): return Money(self.amount * multiplier, self.currency) @staticmethod def dollar(amount): return Money(amount, 'USD') @staticmethod def franc(amount): return Money(amount, 'CHF')
__add__
メソッドはSumクラスのインスタンスを返します。Sumクラスは初期化の際に被加算数(addend)と加算数(augent)を持ちます。
Bankクラスの本実装
ここまでくれば、ロジックが固まってきたのでBankクラスのreduceメソッドを書き下せます。source
引数はSumクラスのインスタンスを引数に持つので、
src/bank.py
from money import Money class Bank: def reduce(self, source, to): amount = source.augent.amount + source.addend.amount return Money(amount, to)
まとめ
今回は多国通貨を扱うための下準備を行いました。ここまでくるとだいぶロジックが複雑になってくるのでテストを書いてどのようなロジックにしたいのか明確にしておくことが重要です。