Unit Test in Python for Beginner

Speaker: 毛毛

Outline

  • 軟體測試等級簡介
  • Unit Test 觀念與細節介紹
  • Unit Test 實戰 - pytest
  • Unit Test 實戰 - mock
  • 附錄

軟體測試等級簡介

Scope 由小到大依序排列:

  1. 單元測試 (Unit Testing)
  2. 整合測試 (Integration Testing)
  3. 系統測試 (System Testing)
  4. 驗收測試 (Acceptance Testing)

Unit Test 觀念與細節介紹

針對最小的執行單元做測試,像是函式或方法

專心在自己的測試目標,不要去測試無關的 code

若測試目標與外部資源有依賴關係,請透過 mock 等方法去解決

保證每個 UT 的獨立性,前面執行的 UT 不該影響後面 UT 的結果

保證 UT 每次執行的結果一致,不受任何環境因素影響

利用 UT 來展示 spec 和用法

取有意義的 UT 名稱

UT 基本架構分三塊,依序為

  • 事前準備 (optional)
  • 執行目標 (required)
  • 驗證結果 (required)
In [ ]:
# test_sample.py

from random import randint

from unittest.mock import patch
# from mock import patch


def get_score():
    return randint(1, 10)


def pass_the_test():
    score = get_score()
    if score >= 6:
        print("Pass")
        return True
    else:
        print("Not Pass")
        return False

# ...
In [ ]:
# ...

@patch("test_sample.get_score")
def test_pass(mock_get_score):
    mock_get_score.return_value = 8
    
    result = pass_the_test()

    assert result # assert result == True


@patch("test_sample.get_score")
def test_not_pass(mock_get_score):
    mock_get_score.return_value = 2
    
    result = pass_the_test()

    assert not result # assert result == False

養成好習慣:

  • 每發現一個 bug,就補一個 UT
  • 新寫的 UT 通過了不要太開心,要懷疑有沒有寫錯

Unit Test 實戰 - pytest

安裝

  • pip install pytest

簡單的範例

In [ ]:
# test_sample.py


def double(number):
    return number * 2


def test_double():
    assert double(10) == 20

執行

各種執行方式

[注意]

  • 一旦執行多個檔案,檔案不能重名(哪怕是在不同目錄底下)
  • [解決方法]
    • 在重名發生的子目錄內加一個 __init__.py

常用指令介紹

  • -v
    • 印出詳細的 file path 和 test function name
  • -k
    • 只執行特定的(prefix-matched) test class 和 test function
  • -s
    • 讓執行過程中的 stdout 輸出到螢幕上
  • -x
    • 一旦有一個 test 失敗就停止

-v 示範: 印出詳細的 file path 和 test function name

In [ ]:
# test_sample_2.py


def add(num_1, num_2):
    return num_1 + num_2


def test_add():
    assert add(2, 4) == 6

當前目錄結構

執行結果

-k 示範: 只執行特定的(prefix-matched) test class 和 test function

In [ ]:
# test_sample_3.py


def sub(num_1, num_2):
    return num_1 - num_2


class MyMath(object):
    def triple(self, number):
        return number * 3
    
    def add_five(self, number):
        return number + 5
    

def test_sub():
    assert sub(15, 3) == 12
    

class TestMyMath(object):
    def test_triple(self):
        assert MyMath().triple(10) == 30
        
    def test_add_five(self):
        assert MyMath().add_five(10) == 15

當前目錄結構

指定執行 TestMyMath 開頭的

指定執行 test_a 開頭的

指定執行 TestMyMath 和 test_a 開頭的

-s 示範: 讓執行過程中的 stdout 輸出到螢幕上

In [ ]:
# test_sample_4.py


def say_hi(name):
    print("Say Hi to {}".format(name))
    return True
    
    
def test_say_hi():
    assert say_hi("Maomao")

當前目錄結構

沒有使用 -s 的結果

有使用 -s 的結果

-x 示範: 一旦有一個 test 失敗就停止

In [ ]:
# test_sample_5.py


def always_false():
    return False


def test_always_false():
    assert always_false()

當前目錄結構

沒有使用 -x 的結果

有使用 -x 的結果

範例與練習題下載: https://github.com/win911/UT_class

  • 請切換到 for_pytest 資料夾

範例一: 驗證回傳值

In [ ]:
# my_math.py


def fibonacci(num):
    if num == 0 or num == 1:
        return 1
    else:
        return fibonacci(num-2) + fibonacci(num-1)
In [ ]:
# test_my_math.py

from my_math import fibonacci


def test_fibonacci():
    assert fibonacci(0) == 1
    assert fibonacci(1) == 1
    assert fibonacci(2) == 2
    assert fibonacci(3) == 3
    assert fibonacci(4) == 5
    assert fibonacci(5) == 8

執行結果

練習一: 驗證回傳值

In [ ]:
# my_math.py


def is_multiples_of_three(num):
    if num % 3 == 0:
        return True
    else:
        return False
In [ ]:
# test_my_math.py

from my_math import is_multiples_of_three


def test_is_multiples_of_three():
    # TODO: verify by 15, 23, 42, 51 and 67
    pass

範例二: 驗證內容或狀態有沒有改變

In [ ]:
# my_math.py

default_cache = {
    0: 1,
    1: 1
}


def fibonacci(num, cache=default_cache):
    if num not in cache:
        if num == 0 or num == 1:
            value = 1
        else:
            value = fibonacci(num-2, cache) + fibonacci(num-1, cache)
        cache[num] = value
    return cache[num]
In [ ]:
# test_my_math.py

from my_math import fibonacci


class Testfibonacci(object):
    def test_value(self):
        assert fibonacci(0) == 1
        assert fibonacci(1) == 1
        assert fibonacci(2) == 2
        assert fibonacci(3) == 3
        assert fibonacci(4) == 5
        assert fibonacci(5) == 8
        
    def test_cache(self):
        def _verify_result(num, cache):
            for i in range(num+1):
                assert i in cache
        
        # === Case 1 ===
        cache = {
            0: 1,
            1: 1
        }
        fibonacci(3, cache=cache)
        _verify_result(3, cache)
        
        # === Case 2 ===
        cache = {}
        fibonacci(5, cache=cache)
        _verify_result(5, cache)
        

執行結果

練習二: 驗證內容或狀態有沒有改變

In [ ]:
# my_math.py


def is_multiples_of_three(num):
    if num % 3 == 0:
        return True
    else:
        return False


def insert_number(my_list, number_list):
    for num in number_list:
        if is_multiples_of_three(num):
            my_list.append(num)

In [ ]:
# test_my_math.py

from my_math import insert_number


def test_insert_number():
    # TODO: Case 1~6
    pass

範例三: 驗證有沒有噴我預期的 Exception

In [ ]:
# my_math.py

default_cache = {
    0: 1,
    1: 1
}


def fibonacci(num, cache=default_cache):
    if not isinstance(num, int) or not isinstance(cache, dict):
        raise ValueError("Usage: fibonacci(num=<int>, cahce=<dict>)")

    if num not in cache:
        if num == 0 or num == 1:
            value = 1
        else:
            value = fibonacci(num-2, cache) + fibonacci(num-1, cache)
        cache[num] = value
    return cache[num]
In [ ]:
# test_my_math.py

from pytest import raises

from my_math import fibonacci


class Testfibonacci(object):
    # ...
        
    def test_invalid_parameter(self):
        with raises(ValueError) as e:
            fibonacci("123")
        assert str(e.value) == "Usage: fibonacci(num=<int>, cahce=<dict>)"
        
        with raises(ValueError) as e:
            fibonacci(123, cache=[])
        assert str(e.value) == "Usage: fibonacci(num=<int>, cahce=<dict>)"

執行結果

Note:

  • with raises(...) 底下的任一行 code 一旦噴出 error,該行後的 code 將不會被執行
  • 承上,assertion 千萬不要放在 with raises(...) 底下,因為根本不會被執行到 (UT 等於沒寫)

練習三: 驗證有沒有噴我預期的 Exception

In [ ]:
# my_math.py


def is_multiples_of_three(num):
    if not isinstance(num, int):
        raise ValueError("The parameter MUST be an integer.")
    
    if num % 3 == 0:
        return True
    else:
        return False
In [ ]:
# test_my_math.py

from pytest import raises

from my_math import is_multiples_of_three


def test_invalid_parameter_for_is_multiples_of_three():
    # TODO
    pass

Unit Test 實戰 - mock

mock is a library for testing in Python. It allows you to replace parts of your system under test with mock objects and make assertions about how they have been used.

安裝

  • For python2.7 ~ python3.2
    • pip install mock
    • >>> import mock
  • For python3.3 ~
    • built-in module
    • >>> from unittest import mock

範例與練習題下載: https://github.com/win911/UT_class

  • 請切換到 for_mock 資料夾

範例一: 模擬 get_today_info 回傳值,以便驗證 is_today_my_birthday 回傳值

In [ ]:
# about_time.py

from datetime import datetime


def get_today_info():
    today = datetime.utcnow()
    return today.month, today.day


def is_today_my_birthday(birthday):
    month, day = get_today_info()

    try:
        birthday = datetime.strptime(birthday, "%m-%d")
    except Exception as e:
        raise ValueError(str(e))
    
    if month == birthday.month and day == birthday.day:
        return True
    return False

Tips:

  • 使用 patch 指定 get_today_info 要被模擬
  • 使用 return_value 模擬 get_today_info 的回傳值
In [ ]:
# test_is_today_my_birthday.py

from unittest.mock import patch
# from mock import patch

from about_time import is_today_my_birthday


@patch("about_time.get_today_info")
def test_is_birthday(mock_get_today_info):
    mock_get_today_info.return_value = (1, 20)
    assert is_today_my_birthday("01-20")


@patch("about_time.get_today_info")
def test_is_not_birthday(mock_get_today_info):
    mock_get_today_info.return_value = (10, 20)
    assert not is_today_my_birthday("01-20")

執行結果

動動腦: 如果程式改成下面這樣,patch 裡的路徑哪種寫法才是對的?

In [ ]:
# utils.py

from datetime import datetime


def get_today_info():
    today = datetime.utcnow()
    return today.month, today.day
In [ ]:
# about_time.py

from datetime import datetime

from utils import get_today_info


def is_today_my_birthday(birthday):
    month, day = get_today_info()

    try:
        birthday = datetime.strptime(birthday, "%m-%d")
    except Exception as e:
        raise ValueError(str(e))
    
    if month == birthday.month and day == birthday.day:
        return True
    return False
In [ ]:
# test_is_today_my_birthday.py

from unittest.mock import patch
# from mock import patch

from about_time import is_today_my_birthday


@patch("utils.get_today_info")
def test_is_birthday_candidate_1(mock_get_today_info):
    mock_get_today_info.return_value = (1, 20)
    assert is_today_my_birthday("01-20")


@patch("about_time.get_today_info")
def test_is_birthday_candidate_2(mock_get_today_info):
    mock_get_today_info.return_value = (1, 20)
    assert is_today_my_birthday("01-20")

執行結果

What happened?

patch("utils.get_today_info") 沒有成功模擬到目標

Why?

雖然 get_today_info() 的源頭在 utils.py,但因為在 about_time.py 裡 from utils import get_today_info 了,就可以想像成 about_time.py copy 了一份 get_today_info() 到自己裡面,所以 is_today_my_birthday() 呼叫到的其實是 about_time.py 裡的那份,它才是我們要 patch 的目標!

再一次動動腦: 如果程式又改成下面這樣,結果會是如何呢?

In [ ]:
# utils.py

from datetime import datetime


def get_today_info():
    today = datetime.utcnow()
    return today.month, today.day
In [ ]:
# about_time.py

from datetime import datetime

import utils


def is_today_my_birthday(birthday):
    month, day = utils.get_today_info()

    try:
        birthday = datetime.strptime(birthday, "%m-%d")
    except Exception as e:
        raise ValueError(str(e))
    
    if month == birthday.month and day == birthday.day:
        return True
    return False
In [ ]:
# test_is_today_my_birthday.py

from unittest.mock import patch
# from mock import patch

from about_time import is_today_my_birthday


@patch("utils.get_today_info")
def test_is_birthday_candidate_1(mock_get_today_info):
    mock_get_today_info.return_value = (1, 20)
    assert is_today_my_birthday("01-20")


@patch("about_time.utils.get_today_info")
def test_is_birthday_candidate_2(mock_get_today_info):
    mock_get_today_info.return_value = (1, 20)
    assert is_today_my_birthday("01-20")

執行結果

Note:

  • 使用 patch 卻發現沒效果時,不要慌張,請再次檢查目標路徑是否正確喔 : )

練習一之一: 模擬 get_scores_from_db 回傳值,以便驗證 grade 回傳值

In [ ]:
# student.py

from random import randint


def get_scores_from_db(name):
    return [randint(1, 100) for i in range(3)]


def grade(name):
    scores = get_scores_from_db(name)
    total_score = sum(scores)
    quotient = total_score // len(scores)

    if quotient >= 80:
        return "A"
    elif quotient >= 60:
        return "B"
    else:
        return "C"
In [ ]:
# test_grade.py

from unittest.mock import patch
# from mock import patch

from student import grade


# TODO: test_grade_A, test_grade_B, test_grade_C

練習一之二: 模擬 get_scores_from_db 回傳值,以便驗證 grade 回傳值

In [ ]:
# utils.py

from random import randint


def get_scores_from_db(name):
    return [randint(1, 100) for i in range(3)]
In [ ]:
# student.py

from utils import get_scores_from_db


def grade(name):
    scores = get_scores_from_db(name)
    total_score = sum(scores)
    quotient = total_score // len(scores)

    if quotient >= 80:
        return "A"
    elif quotient >= 60:
        return "B"
    else:
        return "C"
In [ ]:
# test_grade.py

from unittest.mock import patch
# from mock import patch

from student import grade


# TODO: test_grade_A, test_grade_B, test_grade_C

練習一之三: 模擬 get_scores_from_db 回傳值,以便驗證 grade 回傳值

In [ ]:
# utils.py

from random import randint


def get_scores_from_db(name):
    return [randint(1, 100) for i in range(3)]
In [ ]:
# student.py

import utils


def grade(name):
    scores = utils.get_scores_from_db(name)
    total_score = sum(scores)
    quotient = total_score // len(scores)

    if quotient >= 80:
        return "A"
    elif quotient >= 60:
        return "B"
    else:
        return "C"
In [ ]:
# test_grade.py

from unittest.mock import patch
# from mock import patch

from student import grade


# TODO: test_grade_A, test_grade_B, test_grade_C

範例二之一: 驗證有執行 get_scores_from_db

In [ ]:
# student.py

from random import randint


def get_scores_from_db(name):
    return [randint(1, 100) for i in range(3)]


def grade(name):
    scores = get_scores_from_db(name)
    total_score = sum(scores)
    quotient = total_score // len(scores)

    if quotient >= 80:
        return "A"
    elif quotient >= 60:
        return "B"
    else:
        return "C"

Tips:

  • 使用 patch 指定 get_scores_from_db 要被模擬
  • For unittest.mock
    • For python3.3~3.5: 使用 call_count 驗證 get_scores_from_db 只被 call 一次
    • For python3.6~: 使用 assert_called_once 驗證 get_scores_from_db 只被 call 一次
  • For mock
    • 使用 assert_called_once 驗證 get_scores_from_db 只被 call 一次
In [ ]:
# test_grade.py

from unittest.mock import patch
# from mock import patch

from student import grade


@patch("student.get_scores_from_db")
def test_success_case(mock_get_scores):
    mock_get_scores.return_value = [80, 90, 100]

    grade("Maomao")

    mock_get_scores.assert_called_once()
    # assert mock_get_scores.call_count == 1


@patch("student.get_scores_from_db")
def test_failed_case(mock_get_scores):
    mock_get_scores.return_value = [80, 90, 100]

    grade("Maomao")
    grade("Abby")

    mock_get_scores.assert_called_once()
    # assert mock_get_scores.call_count == 1

執行結果

範例二之二: 驗證有執行 get_scores_from_db 且帶入的參數是正確的

In [ ]:
# student.py

from random import randint


def get_scores_from_db(name):
    return [randint(1, 100) for i in range(3)]


def grade(name):
    scores = get_scores_from_db(name)
    total_score = sum(scores)
    quotient = total_score // len(scores)

    if quotient >= 80:
        return "A"
    elif quotient >= 60:
        return "B"
    else:
        return "C"

Tips:

  • 使用 patch 指定 get_scores_from_db 要被模擬
  • 使用 assert_called_once_with 驗證 get_scores_from_db 只被 call 一次且傳入的參數是預期的
In [ ]:
# test_grade.py

from unittest.mock import patch
# from mock import patch

from student import grade


@patch("student.get_scores_from_db")
def test_success_case(mock_get_scores):
    mock_get_scores.return_value = [80, 90, 100]

    grade("Maomao")
    
    mock_get_scores.assert_called_once_with("Maomao")
    

@patch("student.get_scores_from_db")
def test_failed_case(mock_get_scores):
    mock_get_scores.return_value = [80, 90, 100]

    grade("Maomao")
    
    mock_get_scores.assert_called_once_with("Abby")

執行結果

範例二之三: 驗證沒有執行非預期的 function

In [ ]:
# party.py

def give_apple():
    pass


def give_guava():
    pass


def give_banana():
    pass


def give_grape():
    pass


def lucky_draw(color):
    if color == "Red":
        give_apple()
    elif color == "Green":
        give_guava()
    elif color == "Yellow":
        give_banana()
    else:
        give_grape()

Tips:

  • 使用 patch 指定相關的 function 要被模擬
  • 使用 assert_not_called 驗證部分 function 沒有被 call
In [ ]:
# test_lock_draw.py

from unittest.mock import patch
# from mock import patch

from party import lucky_draw


@patch("party.give_grape")
@patch("party.give_banana")
@patch("party.give_guava")
@patch("party.give_apple")
def test_get_red_ball(mock_give_apple, mock_give_guava, mock_give_banana, mock_give_grape):
    lucky_draw("Red")
    
    mock_give_apple.assert_called_once() # assert mock_give_apple.call_count == 1
    mock_give_guava.assert_not_called()
    mock_give_banana.assert_not_called()
    mock_give_grape.assert_not_called()

執行結果

Note:

  • 同時使用多個 patch(...) 做 decorator 時,越下層的 decorator 產生的參數放在越前面

練習二:

  • 驗證有執行我預期的 function
    • get_user_order
  • 驗證有執行我預期的 function 且帶入的參數是正確的
    • give_fried_chicken, give_french_fries, give_custard_tarts
  • 驗證沒有執行非預期的 function
    • give_fried_chicken, give_french_fries, give_custard_tarts
In [ ]:
# KFC.py

from random import randint, choice


def give_fried_chicken(number):
    pass


def give_french_fries(number):
    pass


def give_custard_tarts(number):
    pass


def get_user_order():
    return randint(1, 10), choice(["Fried chicken", "French fries", "Custard Tarts"])


def process_user_order():
    number, item = get_user_order()
    if item == "Fried chicken": # buy five, get one free
        number += number // 5
        give_fried_chicken(number)
    elif item == "French fries":
        give_french_fries(number)
    elif item == "Custard Tarts": # buy three, get one free
        number += number // 3
        give_custard_tarts(number)
In [ ]:
# test_process_user_order.py

from unittest.mock import patch
# from mock import patch

from KFC import process_user_order


# TODO:
# - test_order_two_fried_chicken, test_order_seven_fried_chicken, test_order_ten_fried_chicken
# - test_order_two_french_fries, test_order_seven_french_fries, test_order_ten_french_fries
# - test_order_two_custard_tarts, test_order_seven_custard_tarts, test_order_ten_custard_tarts

範例三: 模擬 is_today_birthday 行為,使其會 raise exception,以便驗證 celebrate_for_customer 行為

In [ ]:
# restaurant.py

def is_today_birthday():
    pass


def prepare_a_cake():
    pass


def print_error(error):
    pass


def celebrate_for_customer(customer_birthday):
    try:
        if is_today_birthday(customer_birthday):
            prepare_a_cake()
    except ValueError as e:
        print_error(e)

Tips:

  • 使用 patch 指定 is_today_birthday 要被模擬
  • 使用 side_effect 模擬 is_today_birthday 的行為
In [ ]:
# test_celebrate_for_customer.py

from unittest.mock import patch
# from mock import patch

from restaurant import celebrate_for_customer


@patch("restaurant.print_error")
@patch("restaurant.is_today_birthday")
def test_invalid_input(mock_is_today_birthday, mock_print_error):
    mock_is_today_birthday.side_effect = ValueError("Invalid input")
    # Otherwise, you can write like this (more powerful)
    # def fake_is_today_birthday(birthday):
    #    raise ValueError("Invalid input")
    # mock_is_today_birthday.side_effect = fake_is_today_birthday
    
    celebrate_for_customer("01-01")
    
    mock_print_error.assert_called_once() # assert mock_print_error.call_count == 1

執行結果

練習三: 模擬 cook 行為,使其會 raise exception,以便驗證 front_desk 行為

In [ ]:
# KFC.py

def get_order():
    pass


def cook(order):
    pass


def close_door():
    pass


def notify_manager():
    pass


def front_desk():
    order = get_order()
    try:
        food = cook(order)
        return food
    except RuntimeError as e:
        if "Sold Out" in str(e):
            close_door()
        else:
            notify_manager()
In [ ]:
# test_front_desk.py

from unittest.mock import patch
# from mock import patch

from KFC import front_desk


# TODO: test_sold_out, test_other_runtime_error

範例四之一: 模擬 quality_checker,以驗證 check_quality 行為是否正確

In [ ]:
# factory.py

class BadProductQualityError(Exception):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs) # For python3
        # super(BadProductQuality, Exception).__init__(*args, **kwargs) # For python2

class Product():
    def __init__(self, quality_checker):
        self.quality_checker = quality_checker
        self.color = "Red"
        self.smell = "Rose"
        self.length = 10
        self.width = 10
        self.height = 10

    def create(self):        
        pass

    def check_quality(self):
        if self.quality_checker.checkpoint_1(self.length, self.width, self.height) < 95:
            raise BadProductQualityError("Shape")
        if self.quality_checker.checkpoint_2(self.color) < 90:
            raise BadProductQualityError("Color")
        if self.quality_checker.checkpoint_3(self.smell) < 98:
            raise BadProductQualityError("Smell")

Tips:

  • 利用 MagicMock 模擬類別
    • .<attribute_name> = <value>
    • .<method_name>.return_value = <value>
    • .<method_name>.side_effect = <exception>
    • .<method_name>.side_effect = <function>
    • .<method_name>.assert_not_called()
    • ...
In [ ]:
# test_product.py

from unittest.mock import MagicMock
# from mock import MagicMock

from pytest import raises

from factory import Product, BadProductQualityError


def test_bed_product_quality_because_of_shape():
    mock_quality_checker = MagicMock()
    mock_quality_checker.checkpoint_1.return_value = 94

    p = Product(mock_quality_checker)
    p.create()
    with raises(BadProductQualityError) as e:
        p.check_quality()
    assert "Shape" in str(e)


def test_bed_product_quality_because_of_color():
    mock_quality_checker = MagicMock()
    mock_quality_checker.checkpoint_1.return_value = 95
    mock_quality_checker.checkpoint_2.return_value = 89

    p = Product(mock_quality_checker)
    p.create()
    with raises(BadProductQualityError) as e:
        p.check_quality()
    assert "Color" in str(e)

# ...
In [ ]:
# ...

def test_bed_product_quality_because_of_smell():
    mock_quality_checker = MagicMock()
    mock_quality_checker.checkpoint_1.return_value = 95
    mock_quality_checker.checkpoint_2.return_value = 90
    mock_quality_checker.checkpoint_3.return_value = 97

    p = Product(mock_quality_checker)
    p.create()
    with raises(BadProductQualityError) as e:
        p.check_quality()
    assert "Smell" in str(e)


def test_good_product_quality():
    mock_quality_checker = MagicMock()
    mock_quality_checker.checkpoint_1.return_value = 95
    mock_quality_checker.checkpoint_2.return_value = 90
    mock_quality_checker.checkpoint_3.return_value = 98

    p = Product(mock_quality_checker)
    p.create()
    p.check_quality()

執行結果

範例四之二: 模擬 ProductQualityChecker,以驗證 check_quality 行為是否正確

In [ ]:
# factory.py

from random import randint


class BadProductQualityError(Exception):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs) # For python3
        # super(BadProductQuality, Exception).__init__(*args, **kwargs) # For python2


class ProductQualityChecker():
    @staticmethod
    def checkpoint_1(length, width, height):
        return randint(1, 100)
    
    @staticmethod
    def checkpoint_2(color):
        return randint(1, 100)

    @staticmethod
    def checkpoint_3(smell):
        return randint(1, 100)

# ...
In [ ]:
# ...

class Product():
    def __init__(self):
        self.color = "Red"
        self.smell = "Rose"
        self.length = 10
        self.width = 10
        self.height = 10

    def create(self):        
        pass

    def check_quality(self):
        if ProductQualityChecker.checkpoint_1(self.length, self.width, self.height) < 95:
            raise BadProductQualityError("Shape")
        if ProductQualityChecker.checkpoint_2(self.color) < 90:
            raise BadProductQualityError("Color")
        if ProductQualityChecker.checkpoint_3(self.smell) < 98:
            raise BadProductQualityError("Smell")

Tips:

  • patch 後的東西,本身就是 MagicMock 的物件了
In [ ]:
# test_product.py

from unittest.mock import patch, MagicMock
# from mock import patch, MagicMock

from pytest import raises

from factory import Product, BadProductQualityError


@patch("factory.ProductQualityChecker")
def test_bed_product_quality_because_of_shape(mock_quality_checker):
    mock_quality_checker.checkpoint_1.return_value = 94

    p = Product()
    p.create()
    with raises(BadProductQualityError) as e:
        p.check_quality()
    assert "Shape" in str(e)

@patch("factory.ProductQualityChecker")
def test_bed_product_quality_because_of_color(mock_quality_checker):
    mock_quality_checker.checkpoint_1.return_value = 95
    mock_quality_checker.checkpoint_2.return_value = 89

    p = Product()
    p.create()
    with raises(BadProductQualityError) as e:
        p.check_quality()
    assert "Color" in str(e)

# ...
In [ ]:
# ...

@patch("factory.ProductQualityChecker")
def test_bed_product_quality_because_of_smell(mock_quality_checker):
    mock_quality_checker.checkpoint_1.return_value = 95
    mock_quality_checker.checkpoint_2.return_value = 90
    mock_quality_checker.checkpoint_3.return_value = 97

    p = Product()
    p.create()
    with raises(BadProductQualityError) as e:
        p.check_quality()
    assert "Smell" in str(e)


@patch("factory.ProductQualityChecker")
def test_good_product_quality(mock_quality_checker):
    mock_quality_checker.checkpoint_1.return_value = 95
    mock_quality_checker.checkpoint_2.return_value = 90
    mock_quality_checker.checkpoint_3.return_value = 98

    p = Product()
    p.create()
    p.check_quality()

執行結果

範例四之三: 模擬 ProductQualityChecker 的相關方法,以驗證 check_quality 行為是否正確

In [ ]:
# factory.py

from random import randint


class BadProductQualityError(Exception):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs) # For python3
        # super(BadProductQuality, Exception).__init__(*args, **kwargs) # For python2


class ProductQualityChecker():
    @staticmethod
    def checkpoint_1(length, width, height):
        return randint(1, 100)
    
    @staticmethod
    def checkpoint_2(color):
        return randint(1, 100)

    @staticmethod
    def checkpoint_3(smell):
        return randint(1, 100)

# ...
In [ ]:
# ...

class Product():
    def __init__(self):
        self.color = "Red"
        self.smell = "Rose"
        self.length = 10
        self.width = 10
        self.height = 10

    def create(self):        
        pass

    def check_quality(self):
        if ProductQualityChecker.checkpoint_1(self.length, self.width, self.height) < 95:
            raise BadProductQualityError("Shape")
        if ProductQualityChecker.checkpoint_2(self.color) < 90:
            raise BadProductQualityError("Color")
        if ProductQualityChecker.checkpoint_3(self.smell) < 98:
            raise BadProductQualityError("Smell")
In [ ]:
# test_product.py

from unittest.mock import patch
# from mock import patch

from pytest import raises

from factory import Product, BadProductQualityError


@patch("factory.ProductQualityChecker.checkpoint_1")
def test_bed_product_quality_because_of_shape(mock_checkpoint_1):
    mock_checkpoint_1.return_value = 94

    p = Product()
    p.create()
    with raises(BadProductQualityError) as e:
        p.check_quality()
    assert "Shape" in str(e)

@patch("factory.ProductQualityChecker.checkpoint_2")
@patch("factory.ProductQualityChecker.checkpoint_1")
def test_bed_product_quality_because_of_color(mock_checkpoint_1, mock_checkpoint_2):
    mock_checkpoint_1.return_value = 95
    mock_checkpoint_2.return_value = 89

    p = Product()
    p.create()
    with raises(BadProductQualityError) as e:
        p.check_quality()
    assert "Color" in str(e)

# ...
In [ ]:
# ...

@patch("factory.ProductQualityChecker.checkpoint_3")
@patch("factory.ProductQualityChecker.checkpoint_2")
@patch("factory.ProductQualityChecker.checkpoint_1")
def test_bed_product_quality_because_of_smell(mock_checkpoint_1, mock_checkpoint_2, mock_checkpoint_3):
    mock_checkpoint_1.return_value = 95
    mock_checkpoint_2.return_value = 90
    mock_checkpoint_3.return_value = 97

    p = Product()
    p.create()
    with raises(BadProductQualityError) as e:
        p.check_quality()
    assert "Smell" in str(e)


@patch("factory.ProductQualityChecker.checkpoint_3")
@patch("factory.ProductQualityChecker.checkpoint_2")
@patch("factory.ProductQualityChecker.checkpoint_1")
def test_good_product_quality(mock_checkpoint_1, mock_checkpoint_2, mock_checkpoint_3):
    mock_checkpoint_1.return_value = 95
    mock_checkpoint_2.return_value = 90
    mock_checkpoint_3.return_value = 98

    p = Product()
    p.create()
    p.check_quality()

執行結果

練習四: 模擬 get_user 回傳值,以驗證 display_user_info 行為是否正確

In [ ]:
# account.py

def get_user(user_id):
    pass


def display_user_info(user_id):
    user = get_user(user_id)
    html = ""
    html += "<h1>{} ({})</h1>".format(user.username, "Adult" if user.is_adult() else "Child")
    return html
In [ ]:
# test_display_user_info.py

from unittest.mock import patch, MagicMock
# from mock import patch, MagicMock

from account import display_user_info


# TODO: test_adult, test_child

附錄

投影片沒介紹到的重要東西,還請大家自行去 google

  • setup, teardown
  • fixtures

已養成寫 UT 的好習慣後,不妨去暸解一下 TDD (Test-Driven Development)