pyhaya’s diary

プログラミング、特にPythonについての記事を書きます。Djangoや機械学習などホットな話題をわかりやすく説明していきたいと思います。

「テスト駆動開発」をPythonで書き直してみた

Kent Beck著、和田卓人訳の「テスト駆動開発」、言わずと知れたテスト駆動開発(TDD)の名著です。この本でテスト駆動開発を勉強したという人もきっと多いはず。この本では主にJavaを使ってTDDを説明しています。

テスト駆動開発

テスト駆動開発

今回は、テスト駆動開発Pythonで行ってみたいと思います。
目次

Pythonでのテスト環境

PythonにはJavaのテストライブラリJUnitに影響を受けたテストライブラリであるunittestというライブラリがある。これを使ってテストを行いながら開発を進めていく。

今回の開発環境は以下の通り

何を開発するか

この記事では、多国通貨を実装する。このプログラムは次のようなことができるようになることを想定している。

  • 通貨の計算(足し算や掛け算)
  • 他の通貨への変換(USD -> CHFなど)
  • 異なる通貨どうしの計算(5 USD + 10 CHFなど)

ディレクトリ構成

最初のディレクトリ構成は下のようになっている。srcにはソースコードを入れていく。testsには名前の通りテストコードを入れていく

.
├── src
└── tests

まずはテストを書いてTDDのサイクルを回し始める

まずは米ドル(USD)を実装することから始める。USDを使って何ができるべきだろうか?

TODOリスト

  • $5 * 2 = $10

まずはこれから実現してみる。プログラムがどのような動作をしたらコードができたことになるかそれをテストに書き下す。Pythonunittestでは次のようにしてテストを書く。

tests/test_money.py

import unittest

class MoneyTest(unittest.TestCase):
    def testMultiplication(self):
        five = Dollar(5)
        five.times(2)
        self.assertEqual(10, five.amount)

if __name__ == '__main__':
    unittest.main()

必要なのは、テストクラスがunittest.TestCaseを継承することです。全テストを実行するにはunittest.main()を書いておきます。これを実行すると当たり前ですが、エラーが出る。Dollarクラスなんてどこにも定義していないから当たり前。

E
======================================================================
ERROR: testMultiplication (__main__.MoneyTest)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/mnt/c/Users/owner/Documents/Python/Codes/TDD/tests/test_money.py", line 5, in testMultiplication
    five = Dollar(5)
NameError: name 'Dollar' is not defined

----------------------------------------------------------------------
Ran 1 test in 0.003s

FAILED (errors=1)

テストを通す

一番簡単な方法で通す

まずはコードの質とか正しい実装とかは意識せずにとにかく通してみる。

src/dollar.py

class Dollar:     # テストコードを見るとDollarクラスが必要
    def __init__(self, something):     # テストコードをみると初期化で変数を1つとるみたい  
        self.amount = 10     # テストコードでamountというフィールドにアクセスしてるから必要
                                                                                    
    def times(self, multiplier):                                                    
        pass        # 空実装にしておく

テストコードのほうも少し直す。PythonではJavaと違ってパッケージ指定がないからソースコードがどこにあるかsys.path.appendで教えてあげなければいけない。

tests/test_money.py

import sys
sys.path.append('../src')

import unittest
from dollar import Dollar

class MoneyTest(unittest.TestCase):
    def testMultiplication(self):
        five = Dollar(5)
        five.times(2)
        self.assertEqual(10, five.amount)

if __name__ == '__main__':
    unittest.main()

この状態でテストコードを実行すると下のようになって通る。

.
----------------------------------------------------------------------
Ran 1 test in 0.001s

OK

このソースコードであれば通るのは当たり前です。これで、正しくソースコードを書けばテストが通ることが確認できました

コードを正しく直す

上の実装は、テストは通りますが、正しくはありません。five.times(3)とした瞬間に崩壊します。Dollarクラスが初期化の際に受け取る変数は金額であるので、somethingamountであるべきです。また、timesメソッドでamountフィールドが10になるので、timesメソッドの中でamountの値が10に変更されます。

ここまでの考察をソースコードに反映させると、

src/dollar.py

class Dollar:     
    def __init__(self, amount):     # something -> amountに変更
        self.amount = amount
             
    def times(self, multiplier):                                                    
        self.amount = 10    # timesメソッドを使うとamountが10になるように変更

これでもテストは通ります。でも値「10」がべた書きのままです。この10はいったいどこから出てきたのでしょうか?テストコードの意図を考えると、これはもともとの金額(5 USD)にtimesの引数(2)をかけて計算されたものであることがわかります。なので、ソースコードは以下のようになります。

src/dollar.py

class Dollar:     
    def __init__(self, amount):
        self.amount = amount

    def times(self, multiplier): 
        self.amount *= multiplier

べた書きは解消され、テストも通ります。きちんとコードが一般化されたか確かめるためにテストを追加してみます。

tests/test_money.py

import sys
sys.path.append('../src')

import unittest
from dollar import Dollar

class MoneyTest(unittest.TestCase):
    def testMultiplication(self):
        five = Dollar(5)
        five.times(2)
        self.assertEqual(10, five.amount)
        five.times(3)
        self.assertEqual(15, five.amount)

if __name__ == '__main__':
    unittest.main()

実行してみると

F
======================================================================
FAIL: testMultiplication (__main__.MoneyTest)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/mnt/c/Users/owner/Documents/Python/Codes/TDD/tests/test_money.py", line 13, in testMultiplication
    self.assertEqual(15, five.amount)
AssertionError: 15 != 30

----------------------------------------------------------------------
Ran 1 test in 0.002s

FAILED (failures=1)

エラーが出ます。15になってほしいところが30になってしまっています。これはtimesメソッドの副作用です。上のfive.times(2)amountを変更してしまっており、それがfive.times(3)にも影響を与えてしまっています。

副作用をどうにかする

副作用を直すにはどうすればよいでしょうか?amountフィールドを変更しなければよいわけですから、

def times(self, multiplier):
    return self.amount * multiplier

という解決方法もあります。しかしこれだと計算結果をさらに計算に使うというのがやりづらいです。最善の方法は、新しいamountフィールドを持ったDollarオブジェクトを返すことです。つまり、コードは下のようになります。

tests/test_money.py

import unittest
from dollar import Dollar

class MoneyTest(unittest.TestCase):
    def testMultiplication(self):
        five = Dollar(5)
        product = five.times(2)
        self.assertEqual(10, product.amount)
        product = five.times(3)
        self.assertEqual(15, product.amount)

if __name__ == '__main__':
    unittest.main()


src/dollar.py

class Dollar:     
    def __init__(self, amount):
        self.amount = amount
    
    def times(self, multiplier): 
        return Dollar(self.amount * multiplier)

まとめ

テスト駆動開発」で取り上げられている多国通貨をPythonで実装し始めました。今回はPythonでのテスト駆動開発の雰囲気をつかむところまで書きました。ここまでのところだとJavaの場合とあまり差はありません。次回はここからさらにDollarクラスを作りこんでいきます。

ソースコードは↓にあります。
github.com

コミットは基本的に記事の更新ごとに行っていきますので、gitが使える方は目的の場所までチェックアウトして利用してください。

テスト駆動開発

テスト駆動開発