pyhaya’s diary

機械学習系の記事をメインで書きます

Djangoで家計簿のWebアプリケーションを作る 5 テストを書く

DjangoでWebアプリケーションを作る解説記事です。今回のトピックはテストです。今更感がすごいですが、だいぶコードが増えてきたのでテストを書いてくさびを打っておきます。

前回の記事
pyhaya.hatenablog.com

モデルのテスト

from django.test import TestCase
from money.models import Money

# Create your tests here.
class TestMoneyModel(TestCase):
    def test_db_is_empty(self):
        money = Money.objects.all()
        self.assertEqual(money.count(), 0)

これをコマンドラインから実行します。

python manage.py test

すると次のように表示されます。

Creating test database for alias 'default'...
System check identified no issues (0 silenced).
.
----------------------------------------------------------------------
Ran 1 test in 0.001s

OK
Destroying test database for alias 'default'...

テスト用にデータベースが作られ、テストが終わったら消されていることがわかります。

一つデータを作ってみたときのテストケースも追加します。

import datetime
from django.utils import timezone
from django.test import TestCase

from money.models import Money

# Create your tests here.
class TestMoneyModel(TestCase):
    def test_db_is_empty(self):
        money = Money.objects.all()
        self.assertEqual(money.count(), 0)

    def test_save_data(self):
        use_date = timezone.now()
        detail = "テスト"
        cost = 100
        category = "食費"

        Money.objects.create(
                use_date = use_date,
                detail = detail,
                cost = cost,
                category = category,
                )
        obj = Money.objects.all()
        self.assertEqual(obj.count(), 1)
        self.assertEqual(obj[0].use_date, datetime.date.today())
        self.assertEqual(obj[0].detail, detail)
        self.assertEqual(obj[0].cost, cost)
        self.assertEqual(obj[0].category, category)

これでテストが通ることを確認します。

URLのテスト

URLを指定したときに正しいビューが返ってくることを確認します。

import datetime  
from django.urls import resolve
from django.utils import timezone
from django.test import TestCase

from money.models import Money
from money.views import index

#...

class TestURL(TestCase):
    def test_URL_resolve(self):
        url = resolve('/money/')
        self.assertEqual(url.func, index)    #上のURLでindexが呼ばれるか

        url = resolve('/money/2018/11')
        self.assertEqual(url.func, index)    #上のURLでindexが呼ばれるか

このようにテストを書いておくことで、後でコードを変更したときにアプリケーションが壊れていないかをすぐに検証することができます。

次回記事はこちら
pyhaya.hatenablog.com

「テスト駆動開発」をPythonで書き直してみた 6

書籍「テスト駆動開発」をPythonで書き直してみたシリーズの第6弾です。すでに書籍のコードとは大きく乖離し始めていますが一応参考書籍は明示しておきます。過去の記事はこちらです。
pyhaya.hatenablog.com


テスト駆動開発

テスト駆動開発

通貨同士の足し算を実装する

最初に通貨の掛け算はtimesメソッドで実装していました。今回は、足し算を実装します。足し算の場合には掛け算とは異なり書けるほうもかけられるほうもMoneyオブジェクトであることに注意しなくてはいけません。

まずはテストから書きます。

テストを書く

tests/test_money.py

import sys
sys.path.append('../src')

import unittest
from money import Money

class MoneyTest(unittest.TestCase):
    def testMultiplication(self):
        five = Money.dollar(5)
        self.assertEqual(Money.dollar(10), five.times(2))
        self.assertEqual(Money.dollar(15), five.times(3))

    def testFrancMultiplication(self):
        five = Money.franc(5)
        self.assertEqual(Money.franc(10), five.times(2))
        self.assertEqual(Money.franc(15), five.times(3))

    def testEquality(self):
        self.assertNotEqual(Money.franc(5), Money.dollar(5))

    def testSimpleAddition(self):    # <- 追加
        sum_ = Money.dollar(5).plus(Money.dollar(5))
        self.assertEqual(sum_, Money.dollar(10))

if __name__ == '__main__':
    unittest.main()

テストが通るようにコードを書きなおす

テストを通します。
src/money.py

class Money:
    def __init__(self, amount, currency):
        self.amount = amount
        self.currency = currency

    def __eq__(self, other):
        return self.__dict__ == other.__dict__
    
    @staticmethod
    def dollar(amount):
        return Money(amount, 'USD')

    @staticmethod
    def franc(amount):
        return Money(amount, 'CHF')

    def times(self, multiplier):
        return Money(self.amount * multiplier, self.currency)

    def plus(self, addend):
        amount = self.amount + addend.amount
        return Money(amount, self.currency)

もう少し頑張る

これで通貨の掛け算、そして通貨同士の足し算を実装できたわけですが、何か違和感がありました。Pythonでは特殊メソッドを使って基本的な演算が簡単に実装できるので、これを使ったほうが自然なコードになるはずです。

実現したいことをテストで表現します。

tests/test_money.py

import sys
sys.path.append('../src')

import unittest
from money import Money

class MoneyTest(unittest.TestCase):
    def testMultiplication(self):
        five = Money.dollar(5)
        self.assertEqual(Money.dollar(10), five * 2)    # <- 変更
        self.assertEqual(Money.dollar(15), five * 3)    # <- 変更

    def testFrancMultiplication(self):
        five = Money.franc(5)
        self.assertEqual(Money.franc(10), five * 2)    # <- 変更
        self.assertEqual(Money.franc(15), five * 3)    # <- 変更

    def testEquality(self):
        self.assertNotEqual(Money.franc(5), Money.dollar(5))

    def testSimpleAddition(self):
        sum_ = Money.dollar(5) + Money.dollar(5)    # <- 変更
        self.assertEqual(sum_, Money.dollar(10))

if __name__ == '__main__':
    unittest.main()

やはりこのほうが自然な気がします。では、これに合わせてソースコードを変更します。

src/money.py

class Money:
    def __init__(self, amount, currency):
        self.amount = amount
        self.currency = currency

    def __eq__(self, other):
        return self.__dict__ == other.__dict__

    def __add__(self, other):
        return Money(self.amount + other.amount, self.currency)

    def __mul__(self, multiplier):
        return Money(self.amount * multiplier, self.currency)
    
    @staticmethod
    def dollar(amount):
        return Money(amount, 'USD')

    @staticmethod
    def franc(amount):
        return Money(amount, 'CHF')

これで通ります。

テストを見直す

この段階でもう一度テストコードを見直してみます。すると、testFrancMultiplicationはもう必要ない気がしてきます。これは前回までで、FrancクラスとDollarクラスが分かれていたからこそ意味があったものでMoneyクラスに統合された状態ではtestMultiplicationでもう十分信頼性を確かめられています。

import sys
sys.path.append('../src')

import unittest
from money import Money

class MoneyTest(unittest.TestCase):
    def testMultiplication(self):
        five = Money.dollar(5)
        self.assertEqual(Money.dollar(10), five * 2)
        self.assertEqual(Money.dollar(15), five * 3)

    def testEquality(self):
        self.assertNotEqual(Money.franc(5), Money.dollar(5))

    def testSimpleAddition(self):
        sum_ = Money.dollar(5) + Money.dollar(5)
        self.assertEqual(sum_, Money.dollar(10))

if __name__ == '__main__':
    unittest.main()

コードはGitHub上にありますので、ご自由にお使いください。
github.com

次回記事はこちら
pyhaya.hatenablog.com

PythonでC拡張を書く

Pythonは速度で見ると早いとは言えない言語です。しかし、C言語による拡張を書くことができて、それにより速度を大幅に上昇させることができます。よく言語の速さを比較するのに使われるフィボナッチ数列を使ってピュアPythonとC拡張の速度の比較を行います。

PythonのC拡張ではよくCythonが話題に上ります。これはPythonのような文法で簡単にC拡張を書くことができるため人気があります。しかし、この記事では純粋にC言語から出発してPythonへコードを移植します。

この記事は「エキスパートPythonプログラミング」を参考にしています。

エキスパートPythonプログラミング改訂2版

エキスパートPythonプログラミング改訂2版

開発環境

  • Windows10
  • Python 3.6.5
  • Anaconda

必要となるもの

C言語コンパイルが必要になるので、gccVisual Studio等が必要になります。私は、コンパイルはWSLでやっておりますのでgccを使っています。また、Pythonで使える形に書くためにPython.hというヘッダファイルが必要になります。私のようなAnacondaでPythonをインストールしている場合にはC:/Users/(ユーザー名)/Anaconda3/include中にあります。

Pythonでの実装

高速化処理は行わずに純粋に再帰だけで書いてみます。

python

def fib(n):
    if n < 2:
        return 1
    return fib(n-1) + fib(n-2)

Cでの実装

C言語でも実装は似たような感じになります。Python実装と名前がかぶらないように名前はfibonacciに変えています。

C言語

int fibonacci(int n){
    if (n < 2){
        return 1;
    }else{
        return fibonacci(n-1) + fibonacci(n-2);
    }
}

Pythonで使える形にする

上のC言語実装だけではPythonで使えません。Pythonで使えるように下のようにコードを付け加えます。

fibonacci.c

#include <Python.h>

int fibonacci(int n){
    if (n < 2){
        return 1;
    }else{
        return fibonacci(n-1) + fibonacci(n-2);
    }
}

static PyObject* fibonacci_py(PyObject* self, PyObject* args){
    PyObject *result = NULL;
    long n;
    if (PyArg_ParseTuple(args, "l", &n)){
        if( n < 0 ){
            PyErr_SetString(PyExc_ValueError, "n must not be less than 0");
        }else{
            result = Py_BuildValue("L", fibonacci((unsigned int)n));
        }
    }

    return result;
}
static char fibonacci_docs[] = "fibonacci(n): Return nth Fibonacci sequence number computed recuesive\n"; 
                                                         
static PyMethodDef fibonacci_module_methods[] = {
    {"fibonacci", (PyCFunction)fibonacci_py,
        METH_VARARGS, fibonacci_docs},
    {NULL, NULL, 0, NULL} 
};

static struct PyModuleDef fibonacci_module_definition = {
    PyModuleDef_HEAD_INIT,
    "fibonacci",
    "Extension module that provides fibonacci sequence function",
    -1,
    fibonacci_module_methods
};

PyMODINIT_FUNC PyInit_fibonacci(void){
    Py_Initialize();
    return PyModule_Create(&fibonacci_module_definition);
}

ずいぶん長いコードになりました。このコードのほとんどはボイラープレートコードです。一つずつ見ていきます。

fibonacci_py

static PyObject* fibonacci_py(PyObject* self, PyObject* args){
    PyObject *result = NULL;
    long n;
    if (PyArg_ParseTuple(args, "l", &n)){
        if( n < 0 ){
            PyErr_SetString(PyExc_ValueError, "n must not be less than 0");
        }else{
            result = Py_BuildValue("L", fibonacci((unsigned int)n));
        }
    }

    return result;
}

このコードは、C言語の関数を、Pythonで扱えるオブジェクトを返すようにするためのコードです。Python/C APIPyObjectという型をこのために用意していて、すべての関数はこの型のポインタを返す必要があります。

PyObject* argsが関数の受け取る引数を含むタプルへのポインタになっています。nという変数を用意しておいてPyArg_ParseTuplenに入れています。"l"というのは引数がlong型であることを期待していることを示しています。

最後に、fibonacci((unsigned int)n)で数列を計算して、それをPythonで使えるオブジェクトに変換します。変換はPy_BuildValueが行います。

fibonacci_docs[]

ドキュメントです。

fibonacci_module_methods[]

static PyMethodDef fibonacci_module_methods[] = {
    {"fibonacci", (PyCFunction)fibonacci_py,
        METH_VARARGS, fibonacci_docs},
    {NULL, NULL, 0, NULL} 
};

この配列は、モジュールが提供する関数やメソッドを定義します。配列は以下の要素を含みます。

  • 関数名
  • 関数のC実装へのポインタ
  • 呼び出し規約・束縛条件を含むフラグ
  • docstring文字列へのポインタ

配列の最後に入っているのは番兵です。C実装へのポインタは、fibonacci_pyPyCFunctionへのキャストです。この関数の呼び出し規約がMETH_VARAGSで決められています。呼び出し規約にはいくつか選択肢があります。

規約 説明
METH_VARARGS パラメータとして引数リストのみを受け取る
METH_KWARDS キーワード引数を利用できる
METH_NOARGS 引数無し

fibonacci_module_definition

static struct PyModuleDef fibonacci_module_definition = {
    PyModuleDef_HEAD_INIT,
    "fibonacci",
    "Extension module that provides fibonacci sequence function",
    -1,
    fibonacci_module_methods
};

モジュール全体を定義する構造体です。最初の要素は必ずPyModuleDef_HEAD_INITを使います。第二要素はモジュール名です。第三要素はモジュールのdocstringへのポインタ、第三要素はモジュールの状態を保持するために確保されるメモリの大きさを表しています。これはほとんどの場合には-1で大丈夫で、複数のサブインタープリタや複数段階での初期化が必要な時に使います。第四要素は関数をPyModuleDefで定義した配列へのポインタです。

PyInit_fibonacci

PyMODINIT_FUNC PyInit_fibonacci(void){
    Py_Initialize();
    return PyModule_Create(&fibonacci_module_definition);
}

モジュールの初期化関数です。関数名は「PyInit_~」の形式である必要があります。

拡張モジュールをコンパイルする

できたコードをコンパイルするためにsetup.pyを使います。

setup.py

from setuptools import setup, Extension
setup(
        name = 'fibonacci',
        ext_modules = [
            Extension('fibonacci', ['fibonacci.c']),
            ]
        )

では、ビルドを行います。この作業はvenv等の仮想環境で行うことをお勧めします

pip install -e .

使ってみる

速度比較のために次のコードを使います。

import time
import fibonacci    #C拡張

def fib(n):
    if n < 2:
        return 1;
    else:
        return fib(n-1) + fib(n-2)

if __name__ == '__main__':
    start = time.time()
    fib_py = [fib(i) for i in range(35)]
    elapsed_time = time.time() - start
    print(elapsed_time)

    start = time.time()
    fib = [fibonacci.fibonacci(i) for i in range(35)]
    elapsed_time = time.time() - start
    print(elapsed_time)

これを実行すると、Pythonの関数は6秒くらいなのに対してC拡張では0.08秒程度で処理が終わります。80倍くらいの性能向上が実現されています。

さらに最適化してみる

さらに実行速度を上げるために、拡張のコードをメモ化を使って書き換えます。変更箇所はfibonacci関数のみです。

long long dp[1000];
long long fibonacci(unsigned int n){
    if (n < 2){
        return 1;
    }else{
        if (dp[n-1] != 0 && dp[n-2] != 0){
            dp[n] = dp[n-1] + dp[n-2];
            return dp[n];
        }else{
            dp[n] = fibonacci(n-1) + fibonacci(n-2);
            return dp[n];
        }
    }
}

これでビルドしなおして先ほどのコードを実行してみると、4e-5秒くらいで終わります。

AtCoder Beginners Contest (ABC) 002 C: 罠 を解いた

AtCoder Beginners Contest,通称ABCを解いてそれを解説します。自分も競技プログラミングは初心者なので、簡単な問題をわかりやすく解説していこうと思います。

もしよかったらTwitterフォローお願いします。

問題文

神の恵みで財産を築いた高橋くんですが、なんとそこには罠がありました。神は、高橋くんの発した言葉から母音 a、i、u、e、o を全て盗んでいったのです。高橋くんが発した言葉を表す文字列 W が与えられるので、周囲の人が聞く言葉を表す文字列を出力するプログラムを書いてください。

入力値は

  • 1\leq |W|\leq 30
  • Wは半角英小文字のみからなる
  • Wは少なくとも1つの母音以外の文字を含む

解法(Python3)

難易度は低く、文字列操作の良い練習問題になると思います。
文字を1つずつ確かめていって、母音なら除去、それ以外なら残すというようにしていけばOKです。

W = input()    #入力値を受け取る
result = ""
for w in W:
    if w not in ["a", "i", "u", "e", "o"]:
        result += w

print(result)

解法(C++14)

#include <iostream>
#include <algorithm>
#include <string>
#include <regex>
using namespace std;

int main() {
	string s;
	cin >> s;

	cout << regex_replace(s, regex("[aiueo]"), "") << endl;
}

C++のほうは、Pythonとは異なり、正規表現を利用しています。aiueoどれかに一致すれば空文字に置き換える処理をしています。

Djangoで家計簿のWebアプリケーションを作る 4 日ごとの支出額を可視化する

Djangoで家計簿のアプリケーション作ってみた、という記事の4つ目です。今回は日ごとの支出をmatplotlibでグラフ化します。Webアプリでグラフを作る場合にはJavascriptに便利なツールがそろっているのですが、Javascriptは現在勉強中なので今回はmatplotlibでグラフを作ってみます。

過去記事は下のリンクをどうぞ

pyhaya.hatenablog.com

pyhaya.hatenablog.com

pyhaya.hatenablog.com

開発環境

構成

views.py中にグラフを書いて保存するコードを書きます。画像の保存形式は画質がいいのでSVGを使います。この関数をビューを表示するときに動くindexが呼ばれるたびに実行するようにします。

グラフを書く

money/views.py

import calendar
import datetime
from django.shortcuts import render, redirect
from django.utils import timezone
import matplotlib.pyplot as plt
import pytz

from .models import Money
from .forms import SpendingForm

plt.rcParams['font.family'] = 'IPAPGothic'    #日本語の文字化け防止

# Create your views here.
TODAY = str(timezone.now()).split('-')

def index(request, year=TODAY[0], month=TODAY[1]):
    money = Money.objects.filter(use_date__year=year,
            use_date__month=month).order_by('use_date')
    total = 0
    for m in money:
        date = str(m.use_date).split(' ')[0]
        m.use_date = '/'.join(date.split('-')[1:3])

        total += m.cost

    form = SpendingForm()
    context = {'year' : year,
            'month' : month,
            'money' : money,
            'total' : total,
            'form' : form 
    }

    draw_graph(year, month)    #追加

    if request.method == 'POST':
        data = request.POST
        use_date = data['use_date']
        cost = data['cost']
        detail = data['detail']
        category = data['category']

        use_date = timezone.datetime.strptime(use_date, "%Y/%m/%d")
        tokyo_timezone = pytz.timezone('Asia/Tokyo')
        use_date = tokyo_timezone.localize(use_date)
        use_date += datetime.timedelta(hours=9)

        Money.objects.create(
                use_date = use_date,
                detail = detail,
                cost = int(cost),
                category = category,
                )
        return redirect(to='/money/{}/{}'.format(year, month))

    return render(request, 'money/index.html', context)

def draw_graph(year, month):    #追加
    money = Money.objects.filter(use_date__year=year,
            use_date__month=month).order_by('use_date')

    last_day = calendar.monthrange(int(year), int(month))[1] + 1
    day = [i for i in range(1, last_day)]
    cost = [0 for i in range(len(day))]
    for m in money:
        cost[int(str(m.use_date).split('-')[2])-1] += int(m.cost)
    plt.figure()
    plt.bar(day, cost, color='#00bfff', edgecolor='#0000ff')
    plt.grid(True)
    plt.xlim([0, 31])
    plt.xlabel('日付', fontsize=16)
    plt.ylabel('支出額(円)', fontsize=16)
    #staticフォルダの中にimagesというフォルダを用意しておきその中に入るようにしておく
    plt.savefig('money/static/images/bar_{}_{}.svg'.format(year, month),
            transparent=True)

    return None

今回追加したのは2か所だけです。

テンプレートを更新する

ビューに合わせてテンプレートを書き直します。

money/templates/money/index.html

<!DOCTYPE html>    
{% load static %}
<html>
    <head>
        <meta charset="utf-8">
        <title>HousekeepingBook</title>
        <link rel="stylesheet" type="text/css"
        href="{% static 'money/style.css' %}">
    </head>
    <body>
        <h1>{{ year }}年{{ month }}月</h1>
        <form action="/money/" method="post">
            {% csrf_token %}
            {{ form.as_table }}
            <input type="submit" value="送信">
        </form>
        <div class="wapper">
            <div class="main">
                <table>
                    <tr>
                        <th>日付</th>
                        <th>用途</th>
                        <th>カテゴリー</th>
                        <th>金額</th>
                    </tr>
                    {% for m in money %}
                    <tr>
                        <td>{{ m.use_date }}</td>
                        <td>{{ m.detail }}</td>          
                        <td>{{ m.category }}</td>
                        <td>{{ m.cost }}円</td>
                    </tr>
                    {% endfor %}
                </table>
                <div class="tot">
                    合計:{{ total }}円
                </div>
            </div>
            <div class="main">
                <img src="/static/images/bar_{{ year }}_{{ month }}.svg"
                width=80%
            </div>
        </div>
    </body>
</html>

imgタグの画像ファイルへのパスは、{% static %}を使うのが良いのでしょうが、今回はファイル名に{{ year }}{{ month }}といった変数が入ってくるので、ふつうに書きました。

ここまでのページの感じ

f:id:pyhaya:20181110094204p:plain

次回記事
pyhaya.hatenablog.com

LeetCodeを使ってアルゴリズムの勉強&面接対策

今回は、LeetCodeというプログラミングの学習サイト兼就活・転職サイトを紹介したいと思います。
https://camo.githubusercontent.com/4d21a2a0f2bb751bba6bae08f56fbbb87e0b0460/68747470733a2f2f63646e2d696d616765732d312e6d656469756d2e636f6d2f6d61782f313336302f312a357164504c733478395475616276514a7775376975412e706e67
LeetCode - The World's Leading Online Programming Learning Platform

どんなサイトか

プログラミングの勉強ができるサイトですが、競技プログラミングサイト的な側面と、就活サイトとしての側面も持っているサイトです。利用料は基本的に無料で、本気で職探しをしているのでなければ全然無料で満足できます。

日本でも、プログラミング学習サイトとしてProgateやpaiza(こちらは職探しもできますが)などがあります。しかし、LeetCodeはこれらの学習サイトと比べるとやや難易度が高く、全くの素人のための学習サイトではないという感じです。また、とても多機能で、このサイトだけでいろいろなことができそうな感じです。学習コンテンツに関してはアルゴリズム系が主で、言語そのもののコンテンツとしてはRubyの入門があるくらいです。

機能

アルゴリズム

基本的なアルゴリズムの勉強ができます。一例をあげると次のようなコンテンツが用意してあります。

  • Queue & Stack
  • Binary Tree
  • Hash Table

f:id:pyhaya:20181110185504p:plain

面接対策

企業の面接対策のためのコンテンツも充実しています。しかし、いわゆる「よく出題される問題」的なものは無料で使えるのですが、企業ごとの問題を見たり解いたりするには有料会員になる必要があります。問題が掲載されているのは下のような名だたる企業です。

f:id:pyhaya:20181110185644p:plain

また、疑似面接(Mock Interview)を受けることもできます。こちらだと無料会員は企業名や業種を選択できないという制約があります。

オンラインジャッジ

競技プログラミングサイトのように問題を解くことができます。言語はC, C++ Java, Ruby, Python等色々使えます。出題されている問題は、よくある競プロサイトとは少し違います。というのも、例えばAtCoderではコードは一から作りますが、LeetCodeではクラスが用意されていて、そこに回答を入力していくという感じになっています。普通、競技プログラミングではクラスを使うことはない(?)と思うのでこれは新鮮でした。

ユーザー同士の議論・交流

LeetCodeではDiscussionという欄があって、ユーザ同士が議論したりできます。議論のテーマはざっくり分かれていて下のようなジャンルがあります。
f:id:pyhaya:20181110184535p:plain
また、Articleというタブもあって、こちらはまだあまり使っていないので詳しいことは言えませんが、ユーザー(?)の書いた記事を読むことができます。多くは問題に関する記事であるように感じました。

コード実行環境(Playground)

自分の書いたコードをサイト上で実行することができます。
f:id:pyhaya:20181110190324p:plain
NewでPlaygroundを新たに作ると下のような画面になります。
f:id:pyhaya:20181110190557p:plain
言語もだいぶいろいろ選べます。書いたコードは保存して後で開くこともできます。

まとめ

上で紹介したようにLeetCodeは多機能で、プログラミングの勉強にはうってつけの機能がそろっています。特に、アルゴリズムの勉強がしたいという人には非常に便利なサイトだと思います。

「テスト駆動開発」をPythonで書き直してみた 5

書籍「テスト駆動開発」をPythonで書き直してみたシリーズの5です。前回は、DollarクラスとFrancクラスの親クラスとしてMoneyクラスを作り、重複したコードを親クラスへ引き上げました。
pyhaya.hatenablog.com

テスト駆動開発

テスト駆動開発

DollarクラスとFrancクラスを消す

テストでDollar、Francを使わないようにする

DollarクラスとFrancクラスを消すために、まずはテストコードからこれらのクラスを使っている部分をなくしていきます。

tests/test_money

import sys
sys.path.append('../src')

import unittest
from money import Money

class MoneyTest(unittest.TestCase):
    def testMultiplication(self):
        five = Money.dollar(5)
        self.assertEqual(Money.dollar(10), five.times(2))
        self.assertEqual(Money.dollar(15), five.times(3))

    def testFrancMultiplication(self):
        five = Money.franc(5)
        self.assertEqual(Money.franc(10), five.times(2))
        self.assertEqual(Money.franc(15), five.times(3))

    def testEquality(self):
        self.assertNotEqual(Money.franc(5), Money.dollar(5))

if __name__ == '__main__':
    unittest.main()

Moneyクラスを経由してDollarクラスとFrancクラスを読み込みます。

Moneyクラスを書き直す

ではテストが通るようにMoneyクラスを書き換えます。だんだんインポートが面倒になってきたのでDollarクラスとFrancクラスをmoney.pyに移してきます。

src/money.py

class Money:
    def __init__(self, amount, currency):
        self.amount = amount
        self.currency = currency

    def __eq__(self, other):
        return self.__dict__ == other.__dict__

    @staticmethod
    def dollar(amount):
        return Dollar(amount)

    @staticmethod
    def franc(amount):
        return Franc(amount)


class Franc(Money):
    def __init__(self, amount, currency='CHF'):
        super().__init__(amount, currency)

    def times(self, multiplier):
        return Franc(self.amount * multiplier)


class Dollar(Money):
    def __init__(self, amount, currency='USD'):
        super().__init__(amount, currency)

    def times(self, multiplier):
        return Dollar(self.amount * multiplier)         

Moneyクラスに吸収する

ここまでくれば、FrancクラスとDollarクラスの削除まではもう一歩です。これらのクラスの違いはcurrencyフィールドの違いだけなので、次のようにしてやれば統合できそうです。

class Money:
    def __init__(self, amount, currency):
        self.amount = amount
        self.currency = currency    #追加

    def __eq__(self, other):
        return self.__dict__ == other.__dict__

    @staticmethod
    def dollar(amount):
        return Money(amount, 'USD')

    @staticmethod
    def franc(amount):
        return Money(amount, 'CHF')

    def times(self, multiplier):
        return Money(self.amount * multiplier, self.currency)

これで、DollarとFrancクラスの統合は完了です。最後にテストを走らせて成功するのを確認します。