Pytest mocking samples

Using patch

src/models/person.py
import time


class Person:
    def __init__(self, name):
        self.name = name
        self.age = 0

    def say_hello(self):
        print("Doing complex calculations")
        time.sleep(5)
        return f"Hello, my name is {self.name}."

    def get_age(self):
        return self.age

    def set_age(self, age):
        self.age = age

Sample test case

from unittest.mock import MagicMock

from src.models.person import Person


def test_person_class_mock():
    # Create a MagicMock object for the Person class
    mock_person = MagicMock(spec=Person)

    # Set the return value for the say_hello() method
    mock_person.say_hello.return_value = "Mocked greeting"

    # Set the return value for the get_age() method
    mock_person.get_age.return_value = 25

    # Test the mocked behavior
    assert mock_person.say_hello() == "Mocked greeting"
    assert mock_person.get_age() == 25

    # Call the set_age() method and assert that it was called with the correct argument
    mock_person.set_age(30)
    mock_person.set_age.assert_called_with(30)

How to provide a mocked implementation for a method?

#1 Simply specify the return value

from unittest.mock import patch, Mock

def test_2():
    with patch.object(
        src.models.person.Person, "say_hello", return_value="Hello, mock!"
    ):
        person = src.models.person.Person("John Doe")
        assert person.say_hello() == "Hello, mock!"

#2 Using a function to provide mocked Implementation

from unittest.mock import patch, Mock

def test_3():
    def mock_say_hello(self):
        print("Inside mock_say_hello() in test_3")
        return "Return value from a Mocked method"

    with patch.object(src.models.person.Person, "say_hello", mock_say_hello):
        person = src.models.person.Person("John Doe")
        assert person.say_hello() == "Return value from a Mocked method"

        person.set_age(20)
        assert person.get_age() == 20

#3 Auto Specification

For ensuring that the mock objects in your tests have the same api as the objects they are replacing, you can use auto-speccing. Auto-speccing can be done through the autospec argument to patch, or the create_autospec() function. Auto-speccing creates mock objects that have the same attributes and methods as the objects they are replacing, and any functions and methods (including constructors) have the same call signature as the real object.

This ensures that your mocks will fail in the same way as your production code if they are used incorrectly.

from unittest.mock import patch, Mock

def test_4():
    m = Mock()
    person = m.create_autospec(src.models.person.Person, spec_set=True, instance=True)

    person.say_hello.return_value = "Hello, mock!"
    assert person.say_hello() == "Hello, mock!"

More details are here - https://docs.python.org/3.9/library/unittest.mock.html#autospeccing


Another patch example

src/db/postgres.py
import warnings

import pandas as pd
import sqlalchemy
from sqlalchemy import exc as sa_exc
from sqlalchemy.exc import SQLAlchemyError


class Database:
    def __init__(self, db_name, config):
        self.engine = None
        self.results_df = None
        self.db_name = db_name
        self.config = config

        print("Actual Database.__init__")

    def connect(self):
        host = self.config["host"]
        port = self.config["port"]
        service = self.config["service"]
        user = self.config["user"]
        password = self.config["password"]

        if password is None:
            engine = sqlalchemy.create_engine(
                f"postgresql+psycopg2://{user}@{host}:{port}/{service}",
                echo=True,
            )

            self.engine = engine
        else:
            engine = sqlalchemy.create_engine(
                f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{service}",
                echo=True,
            )

            self.engine = engine

    def execute_query(self, query, params):
        try:
            with warnings.catch_warnings():
                warnings.simplefilter("ignore", category=sa_exc.SAWarning)

            self.connect()

            if params is None:
                self.results_df = pd.read_sql(query, self.engine)
            else:
                self.results_df = pd.read_sql(query, self.engine, params=params)

        except SQLAlchemyError as e:
            error = str(e.__dict__["orig"])
            raise e

    def fetch_results(self):
        return self.results_df
src/my_db_module.py
from src.db.postgres import Database


def perform_query(db_name, query):
    config = {
        "host": "localhost",
        "port": "5433",
        "service": "postgres",
        "user": "postgres",
        "password": None,
    }

    db = Database(db_name, config)
    db.connect()
    db.execute_query(query, None)
    results = db.fetch_results()
    # Process the results and return the desired output
    return results


if __name__ == "__main__":
    query = "SELECT table_schema, table_name, column_name FROM INFORMATION_SCHEMA.COLUMNS ORDER BY 1, 2, 3 LIMIT 5"
    results = perform_query("postgres", query)
    print(results)

Sample result

python3 src/my_db_module.py
Actual Database.__init__
2023-07-02 17:54:28,984 INFO sqlalchemy.engine.Engine select pg_catalog.version()
2023-07-02 17:54:28,984 INFO sqlalchemy.engine.Engine [raw sql] {}
2023-07-02 17:54:28,986 INFO sqlalchemy.engine.Engine select current_schema()
2023-07-02 17:54:28,986 INFO sqlalchemy.engine.Engine [raw sql] {}
2023-07-02 17:54:28,987 INFO sqlalchemy.engine.Engine show standard_conforming_strings
2023-07-02 17:54:28,987 INFO sqlalchemy.engine.Engine [raw sql] {}
2023-07-02 17:54:28,990 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2023-07-02 17:54:28,990 INFO sqlalchemy.engine.Engine SELECT pg_catalog.pg_class.relname
FROM pg_catalog.pg_class JOIN pg_catalog.pg_namespace ON pg_catalog.pg_namespace.oid = pg_catalog.pg_class.relnamespace
WHERE pg_catalog.pg_class.relname = %(table_name)s AND pg_catalog.pg_class.relkind = ANY (ARRAY[%(param_1)s, %(param_2)s, %(param_3)s, %(param_4)s, %(param_5)s]) AND pg_catalog.pg_table_is_visible(pg_catalog.pg_class.oid) AND pg_catalog.pg_namespace.nspname != %(nspname_1)s
2023-07-02 17:54:28,990 INFO sqlalchemy.engine.Engine [generated in 0.00014s] {'table_name': 'SELECT table_schema, table_name, column_name FROM INFORMATION_SCHEMA.COLUMNS ORDER BY 1, 2, 3 LIMIT 5', 'param_1': 'r', 'param_2': 'p', 'param_3': 'f', 'param_4': 'v', 'param_5': 'm', 'nspname_1': 'pg_catalog'}
2023-07-02 17:54:28,992 INFO sqlalchemy.engine.Engine SELECT table_schema, table_name, column_name FROM INFORMATION_SCHEMA.COLUMNS ORDER BY 1, 2, 3 LIMIT 5
2023-07-02 17:54:28,992 INFO sqlalchemy.engine.Engine [raw sql] {}
2023-07-02 17:54:29,010 INFO sqlalchemy.engine.Engine ROLLBACK
  table_schema    table_name     column_name
0       test      brand_master        brand_id
1       test      brand_master      brand_name
2       test      brand_master       excipient
3       test      brand_master      generic_id
4       test      brand_master  license_number

How to write a test for this?

import unittest
from unittest.mock import MagicMock, patch

import src.db.postgres
from src.my_db_module import perform_query


class MyModuleTestCase(unittest.TestCase):
    def test_perform_query(self):
        # Create a MagicMock object for the Database class
        db_mock = MagicMock()
        db_mock.connect.return_value = None
        db_mock.execute_query.return_value = None
        db_mock.fetch_results.return_value = [("Result 1",), ("Result 2",)]

        # Patch the Database class in my_module with the MagicMock object
        with patch("src.my_db_module.Database", return_value=db_mock):
            # Call the function under test
            result = perform_query("my_db", "SELECT * FROM my_table")

            # Perform assertions
            self.assertEqual(result, [("Result 1",), ("Result 2",)])
            db_mock.connect.assert_called_once()
            db_mock.execute_query.assert_called_once_with(
                "SELECT * FROM my_table", None
            )
            db_mock.fetch_results.assert_called_once()
import unittest
from unittest.mock import MagicMock, patch

import src.db.postgres
from src.my_db_module import perform_query

class MyModuleTestCase(unittest.TestCase):
	def test_perform_query_another_way(self):
        # Patch the Database class in my_module with the MagicMock object
        with patch("src.my_db_module.Database") as db_mock:
            db_mock.return_value.connect.return_value = None
            db_mock.return_value.execute_query.return_value = None
            db_mock.return_value.fetch_results.return_value = [
                ("Result 1",),
                ("Result 2",),
            ]

            # Call the function under test
            result = perform_query("my_db", "SELECT * FROM my_table")

            # Perform assertions
            self.assertEqual(result, [("Result 1",), ("Result 2",)])
            db_mock.return_value.connect.assert_called_once()
            db_mock.return_value.execute_query.assert_called_once_with(
                "SELECT * FROM my_table", None
            )
            db_mock.return_value.fetch_results.assert_called_once()

Patching a specific method

import unittest
from unittest.mock import MagicMock, patch

import src.db.postgres
from src.my_db_module import perform_query

class MyModuleTestCase(unittest.TestCase):
	def test3(self):
        db1 = src.db.postgres.Database("db1", {})

        with patch.object(db1, "connect", lambda: "Mocked connect") as mock:
            print(db1.connect())
pytest -s -vv -o log_cli=True tests/test_my_db_module.py
tests/test_my_db_module.py::MyModuleTestCase::test3 Actual Database.__init__
Mocked connect
PASSED

Patching multiple methods

import unittest
from unittest.mock import MagicMock, patch

import src.db.postgres
from src.my_db_module import perform_query

class MyModuleTestCase(unittest.TestCase):
    def test4(self):
        db1 = src.db.postgres.Database("db1", {})

        with patch.object(db1, "connect", lambda: "Mocked connect") as mock1:
            with patch.object(
                db1, "execute_query", lambda: "Mocked execute_query()"
            ) as mock2:
                print(db1.connect())
                print(db1.execute_query())
pytest -s -vv -o log_cli=True tests/test_my_db_module.py
tests/test_my_db_module.py::MyModuleTestCase::test4 Actual Database.__init__
Mocked connect
Mocked execute_query()
PASSED

Patching __init__

import unittest
from unittest.mock import MagicMock, patch

import src.db.postgres
from src.my_db_module import perform_query

class MyModuleTestCase(unittest.TestCase):
	@staticmethod
    def mocked_init(db_name, config):
        print("Mocked init")

    @patch.object(src.db.postgres.Database, "connect", lambda x: "Mocked connect")
    @patch.object(src.db.postgres.Database, "execute_query", lambda x: "Mocked execute_query()")
    @patch.object(src.db.postgres.Database, "__init__", mocked_init)
    def test5(self):
        db1 = src.db.postgres.Database("db1", {})

        print(db1.connect())
        print(db1.execute_query())
pytest -s -vv -o log_cli=True tests/test_my_db_module.py
tests/test_my_db_module.py::MyModuleTestCase::test5 Mocked init
Mocked connect
Mocked execute_query()
PASSED

Where to patch?

src/services/my_api.py
import requests
from requests import Session


def get_name(url: str) -> str:
    print(f"Calling API: {url}")
    session = requests.Session()
    response = session.get(url)
    json_response = response.json()
    print("Response from API:", json_response)
    return json_response["name"]


def get_name_another_version(url: str) -> str:
    print(f"Calling API: {url}")
    session = Session()
    response = session.get(url)
    json_response = response.json()
    print("Response from API:", json_response)
    return json_response["name"]


if __name__ == "__main__":
    print(get_name("https://swapi.dev/api/people/1"))

tests/test_my_api_2.py
from unittest.mock import patch
from src.services.my_api import get_name, get_name_another_version
from pytest import fixture


@fixture(scope="function")
def fixture_url():
    return "https://swapi.dev/api/people/1"


def test_1(fixture_url):
    payload = {"name": "random user"}

    with patch("requests.Session") as mock:
        mock.return_value.get.return_value.json.return_value = payload

        assert get_name(fixture_url) == "random user"


def test_2(fixture_url):
    payload = {"name": "random user"}

    with patch("src.services.my_api.Session") as mock:
        mock.return_value.get.return_value.json.return_value = payload

        assert get_name_another_version(fixture_url) == "random user"
The basic principle is that you patch where an object isĀ looked up, which is not necessarily the same place as where it is defined.
pytest -s -vv -o log_cli=True tests/test_my_api_2.py
tests/test_my_api_2.py::test_1 Calling API: https://swapi.dev/api/people/1
Response from API: {'name': 'random user'}
PASSED
tests/test_my_api_2.py::test_2 Calling API: https://swapi.dev/api/people/1
Response from API: {'name': 'random user'}
PASSED