pyhaya’s diary

機械学習系の記事をメインで書きます

「テスト駆動開発」をPythonで書き直してみた 4

書籍「テスト駆動開発」をPythonで書き直してみたシリーズです。過去の記事はこちらです。

「テスト駆動開発」をPythonで書き直してみた - pyhaya’s diary

「テスト駆動開発」をPythonで書き直してみた 2 - pyhaya’s diary

pyhaya.hatenablog.com

テスト駆動開発

テスト駆動開発

スイスフランの実装を行う

前回の最後にはテストコードは次のようになっていました。

tests/test_money.py

import sys
sys.path.append('../src')

import unittest

from dollar import Dollar

class MoneyTest(unittest.TestCase):
    def testMultiplication(self):
        five = Dollar(5)
        self.assertEqual(Dollar(10), five.times(2))
        self.assertEqual(Dollar(15), five.times(3))

if __name__ == '__main__':
    unittest.main()

今回は、この米ドルの実装と同様に、スイスフランの実装を付け加えてみたいと思います。

まずはテストから

テストコードは米ドルのものをほとんど同じでよいはずです。

tests/test_money.py

import sys
sys.path.append('../src')

import unittest

from dollar import Dollar
from franc import Franc

class MoneyTest(unittest.TestCase):
    def testMultiplication(self):
        five = Dollar(5)
        self.assertEqual(Dollar(10), five.times(2))
        self.assertEqual(Dollar(15), five.times(3))

    def testFrancMultiplication(self):
        five = Franc(5)
        self.assertEqual(Franc(10), five.times(2))
        self.assertEqual(Franc(15), five.times(3))

if __name__ == '__main__':
    unittest.main()

このような重複は後で消去すべきですが、今はこのままにしておいて、実装を完成させます。

スイスフランクラスを書く

テストが米ドルと同じなので、当然実装も同じになってきます。Java風にファイルを分けて実装します。

src/franc.py

class Franc:
    def __init__(self, amount):
        self.amount = amount

    def __eq__(self, other):  
        return self.__dict__ == other.__dict__

    def times(self, multiplier):
        return Franc(self.amount * multiplier)

ここまででテストを走らせると、きちんと通ります。

重複をどうにかする

テストをきちんと通したうえで、今回生じた大量の重複をどうにかする方法を考えます。この問題の最善の解決方法はスーパークラスを用意することです。

スーパークラスの空実装を継承させる

まずは、スーパークラスMoneyを空で作って、それをDollarFrancクラスに継承させてテストが引き続き通ることを確認します。

src/money.py

class Money:
    pass

src/franc.py

import money

class Franc(money.Money):
    def __init__(self, amount):
        self.amount = amount

    def __eq__(self, other): 
        return self.__dict__ == other.__dict__

    def times(self, multiplier):
        return Franc(self.amount * multiplier)

src/dollar.py

import money

class Dollar(money.Money):
    def __init__(self, amount):
        self.amount = amount

    def __eq__(self, other): 
        return self.__dict__ == other.__dict__

    def times(self, multiplier):
        return Dollar(self.amount * multiplier)

共通項をスーパークラスに引き上げる

DollarFrancクラスを見てみると、全く同じコードがあるのに気づきます。これらはスーパークラスに引き上げてしまいましょう。

src/money.py

class Money:
    def __init__(self, amount):
        self.amount = amount

    def __eq__(self, other): 
        return self.__dict__ == other.__dict__

src/franc.py

import money

class Franc(money.Money):
    def __init__(self, amount):
        super().__init__(amount)

    def times(self, multiplier):
        return Franc(self.amount * multiplier)

src/dollar.py

import money

class Dollar(money.Money):
    def __init__(self, amount):
        super().__init__(amount)

    def times(self, multiplier):
        return Dollar(self.amount * multiplier)

timesメソッドは引き上げられそうで引き上げられない。これは今は放っておく。

テストを追加する

ドルのほかにフランが追加されたことによって確かめなければいけないことが出てきた。それは5ドルと5フランがちゃんと等しくないと判断されるかという点である。この疑念をテストに書くと次のようになる。

tests/test_money.py

import sys
sys.path.append('../src')

import unittest

from dollar import Dollar
from dollar import Franc

class MoneyTest(unittest.TestCase):
    #...
    def testEquality(self):
        self.assertNotEqual(Franc(5), Dollar(5))

なんとこのテストは失敗する。

テストを通す

これを解決する方法は何だろうか?一つの方法は、各通貨クラスに通貨情報を持たせることです。これをやってみましょう。

src/money.py

class Money:
    def __init__(self, amount, currency):
        self.amount = amount
        self.currency = currency

    def __eq__(self, other): 
        return self.__dict__ == other.__dict__

src/franc.py

import money

class Franc(money.Money):
    def __init__(self, amount, currency='CHF'):
        super().__init__(amount, currency)

    def times(self, multiplier):
        return Franc(self.amount * multiplier)

src/dollar.py

import money

class Dollar(money.Money):
    def __init__(self, amount, currency='USD'):
        super().__init__(amount, currency)

    def times(self, multiplier):
        return Dollar(self.amount * multiplier)

これで通る。このようにクラスフィールドを増やしたことがのちにtimesメソッドを引き上げることにつながる。これは次回に回すことにします。