pyhaya’s diary

機械学習系の記事をメインで書きます

「テスト駆動開発」をPythonで書き直してみた

[更新]2022-03-24 Pythonのバージョンを上げ、テストライブラリをunittestからpytestへ変更

Kent Beck著、和田卓人訳の「テスト駆動開発」、言わずと知れたテスト駆動開発(TDD)の名著です。この本でテスト駆動開発を勉強したという人もきっと多いはず。この本では主にJavaを使ってTDDを説明しています。

今回は、テスト駆動開発Pythonで行ってみたいと思います。
目次

Pythonでのテスト環境

Pythonでよく使われるテストライブラリにpytestというライブラリがある。これを使ってテストを行いながら開発を進めていく。

今回の開発環境は以下の通り

何を開発するか

この記事では、多国通貨を実装する。このプログラムは次のようなことができるようになることを想定している。

  • 通貨の計算(足し算や掛け算)
  • 他の通貨への変換(USD -> CHFなど)
  • 異なる通貨どうしの計算(5 USD + 10 CHFなど)

ディレクトリ構成

最初のディレクトリ構成は下のようになっている。srcにはソースコードを入れていく。testsには名前の通りテストコードを入れていく

.
├── src
└── tests

まずはテストを書いてTDDのサイクルを回し始める

まずは米ドル(USD)を実装することから始める。USDを使って何ができるべきだろうか?

TODOリスト

  • $5 * 2 = $10

まずはこれから実現してみる。プログラムがどのような動作をしたらコードができたことになるかそれをテストに書き下す。Pythonunittestでは次のようにしてテストを書く。

tests/test_money.py

class TestMoney:
    def test_multiplication(self):
        five = Dollar(5)
        five.times(2)
        assert 10 == five.amount

これをpython -m pytestによって実行すると当たり前ですが、エラーが出る。Dollarクラスなんてどこにも定義していないから当たり前。

============================= test session starts =============================
platform linux -- Python 3.10.2, pytest-7.1.1, pluggy-1.0.0
rootdir: /home/yudai/Documents/python/tdd
collected 1 item                                                              

tests/test_money.py F                                                   [100%]

================================== FAILURES ===================================
________________________ TestMoney.test_multiplication ________________________

self = <test_money.TestMoney object at 0x7f344bae7d60>

    def test_multiplication(self):
>       five = Dollar(5)
E       NameError: name 'Dollar' is not defined

tests/test_money.py:6: NameError
=========================== short test summary info ===========================
FAILED tests/test_money.py::TestMoney::test_multiplication - NameError: name...
============================== 1 failed in 0.03s ==============================

テストを通す

一番簡単な方法で通す

まずはコードの質とか正しい実装とかは意識せずにとにかく通してみる。

src/dollar.py

class Dollar:     # テストコードを見るとDollarクラスが必要
    def __init__(self, something):     # テストコードをみると初期化で変数を1つとるみたい  
        self.amount = 10     # テストコードでamountというフィールドにアクセスしてるから必要
                                                                                    
    def times(self, multiplier):                                                    
        pass        # 空実装にしておく

テストコードのほうも少し直す。PythonではJavaと違ってパッケージ指定がないからソースコードがどこにあるかsys.path.appendで教えてあげなければいけない。

tests/test_money.py

from src.dollar import Dollar

class TestMoney:
    def test_multiplication(self):
        five = Dollar(5)
        five.times(2)
        assert 10 == five.amount

この状態でテストコードを実行すると下のようになって通る。

============================= test session starts =============================
platform linux -- Python 3.10.2, pytest-7.1.1, pluggy-1.0.0
rootdir: /home/yudai/Documents/python/tdd
collected 1 item                                                              

tests/test_money.py .                                                   [100%]

============================== 1 passed in 0.01s ==============================

このソースコードであれば通るのは当たり前です。これで、正しくソースコードを書けばテストが通ることが確認できました

コードを正しく直す

上の実装は、テストは通りますが、正しくはありません。five.times(3)とした瞬間に崩壊します。Dollarクラスが初期化の際に受け取る変数は金額であるので、somethingamountであるべきです。また、timesメソッドでamountフィールドが10になるので、timesメソッドの中でamountの値が10に変更されます。

ここまでの考察をソースコードに反映させると、

src/dollar.py

class Dollar:     
    def __init__(self, amount):     # something -> amountに変更
        self.amount = amount
             
    def times(self, multiplier):                                                    
        self.amount = 10    # timesメソッドを使うとamountが10になるように変更

これでもテストは通ります。でも値「10」がべた書きのままです。この10はいったいどこから出てきたのでしょうか?テストコードの意図を考えると、これはもともとの金額(5 USD)にtimesの引数(2)をかけて計算されたものであることがわかります。なので、ソースコードは以下のようになります。

src/dollar.py

class Dollar:     
    def __init__(self, amount):
        self.amount = amount

    def times(self, multiplier): 
        self.amount *= multiplier

べた書きは解消され、テストも通ります。きちんとコードが一般化されたか確かめるためにテストを追加してみます。

tests/test_money.py

from src.dollar import Dollar

class TestMoney:                                                                       
    def test_multiplication(self):                                                     
        five = Dollar(5)                                                               
        five.times(2)                                                                  
                                                                                       
        assert 10 == five.amount                                                       
                                                                                       
        five.times(3)                                                                  
                                                                                       
        assert 15 == five.amount  

実行してみると

================================= test session starts =================================
platform linux -- Python 3.10.2, pytest-7.1.1, pluggy-1.0.0
rootdir: /home/yudai/Documents/python/tdd
collected 1 item                                                                      

tests/test_money.py F                                                           [100%]

====================================== FAILURES =======================================
____________________________ TestMoney.test_multiplication ____________________________

self = <test_money.TestMoney object at 0x7f59f6f27e80>

    def test_multiplication(self):
        five = Dollar(5)
        five.times(2)
    
        assert 10 == five.amount
    
        five.times(3)
    
>       assert 15 == five.amount
E       assert 15 == 30
E        +  where 30 = <src.dollar.Dollar object at 0x7f59f6f68a90>.amount

tests/test_money.py:13: AssertionError
=============================== short test summary info ===============================
FAILED tests/test_money.py::TestMoney::test_multiplication - assert 15 == 30
================================== 1 failed in 0.03s ==================================

エラーが出ます。15になってほしいところが30になってしまっています。これはtimesメソッドの副作用です。上のfive.times(2)amountを変更してしまっており、それがfive.times(3)にも影響を与えてしまっています。

副作用をどうにかする

副作用を直すにはどうすればよいでしょうか?amountフィールドを変更しなければよいわけですから、

def times(self, multiplier):
    return self.amount * multiplier

という解決方法もあります。しかしこれだと計算結果をさらに計算に使うというのがやりづらいです。最善の方法は、新しいamountフィールドを持ったDollarオブジェクトを返すことです。つまり、コードは下のようになります。

tests/test_money.py

from src.dollar import Dollar

class TestMoney:                                                                       
    def test_multiplication(self):                                                     
        five = Dollar(5)                                                               
                                                                                       
        product = five.times(2)                                                        
        assert 10 == product.amount                                                    
                                                                                       
        product = five.times(3)                                                        
        assert 15 == product.amount   


src/dollar.py

class Dollar:     
    def __init__(self, amount):
        self.amount = amount
    
    def times(self, multiplier): 
        return Dollar(self.amount * multiplier)

まとめ

テスト駆動開発」で取り上げられている多国通貨をPythonで実装し始めました。今回はPythonでのテスト駆動開発の雰囲気をつかむところまで書きました。ここまでのところだとJavaの場合とあまり差はありません。次回はここからさらにDollarクラスを作りこんでいきます。