Pytest mocking samples
- Created - 2023/07/05
- Last updated - 2023/07/05
Using patch
src/models/person.py
import time
class Person:
def __init__(self, name):
self.name = name
self.age = 0
def say_hello(self):
print("Doing complex calculations")
time.sleep(5)
return f"Hello, my name is {self.name}."
def get_age(self):
return self.age
def set_age(self, age):
self.age = age
Sample test case
from unittest.mock import MagicMock
from src.models.person import Person
def test_person_class_mock():
# Create a MagicMock object for the Person class
mock_person = MagicMock(spec=Person)
# Set the return value for the say_hello() method
mock_person.say_hello.return_value = "Mocked greeting"
# Set the return value for the get_age() method
mock_person.get_age.return_value = 25
# Test the mocked behavior
assert mock_person.say_hello() == "Mocked greeting"
assert mock_person.get_age() == 25
# Call the set_age() method and assert that it was called with the correct argument
mock_person.set_age(30)
mock_person.set_age.assert_called_with(30)
How to provide a mocked implementation for a method?
- A decorator
patch()can be used as a context manager in a with statement.
#1 Simply specify the return value
from unittest.mock import patch, Mock
def test_2():
with patch.object(
src.models.person.Person, "say_hello", return_value="Hello, mock!"
):
person = src.models.person.Person("John Doe")
assert person.say_hello() == "Hello, mock!"
- Here, the return value is specified in
patch.object.
#2 Using a function to provide mocked Implementation
from unittest.mock import patch, Mock
def test_3():
def mock_say_hello(self):
print("Inside mock_say_hello() in test_3")
return "Return value from a Mocked method"
with patch.object(src.models.person.Person, "say_hello", mock_say_hello):
person = src.models.person.Person("John Doe")
assert person.say_hello() == "Return value from a Mocked method"
person.set_age(20)
assert person.get_age() == 20
- Here, for return value, a function is used.
#3 Auto Specification
For ensuring that the mock objects in your tests have the same api as the objects they are replacing, you can use auto-speccing. Auto-speccing can be done through the autospec argument to patch, or the create_autospec() function. Auto-speccing creates mock objects that have the same attributes and methods as the objects they are replacing, and any functions and methods (including constructors) have the same call signature as the real object.
This ensures that your mocks will fail in the same way as your production code if they are used incorrectly.
from unittest.mock import patch, Mock
def test_4():
m = Mock()
person = m.create_autospec(src.models.person.Person, spec_set=True, instance=True)
person.say_hello.return_value = "Hello, mock!"
assert person.say_hello() == "Hello, mock!"
More details are here - https://docs.python.org/3.9/library/unittest.mock.html#autospeccing
Another patch example
src/db/postgres.py
- This class can be used to connect to a Postgres database, execute a query, and fetch the results.
import warnings
import pandas as pd
import sqlalchemy
from sqlalchemy import exc as sa_exc
from sqlalchemy.exc import SQLAlchemyError
class Database:
def __init__(self, db_name, config):
self.engine = None
self.results_df = None
self.db_name = db_name
self.config = config
print("Actual Database.__init__")
def connect(self):
host = self.config["host"]
port = self.config["port"]
service = self.config["service"]
user = self.config["user"]
password = self.config["password"]
if password is None:
engine = sqlalchemy.create_engine(
f"postgresql+psycopg2://{user}@{host}:{port}/{service}",
echo=True,
)
self.engine = engine
else:
engine = sqlalchemy.create_engine(
f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{service}",
echo=True,
)
self.engine = engine
def execute_query(self, query, params):
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=sa_exc.SAWarning)
self.connect()
if params is None:
self.results_df = pd.read_sql(query, self.engine)
else:
self.results_df = pd.read_sql(query, self.engine, params=params)
except SQLAlchemyError as e:
error = str(e.__dict__["orig"])
raise e
def fetch_results(self):
return self.results_df
src/my_db_module.py
from src.db.postgres import Database
def perform_query(db_name, query):
config = {
"host": "localhost",
"port": "5433",
"service": "postgres",
"user": "postgres",
"password": None,
}
db = Database(db_name, config)
db.connect()
db.execute_query(query, None)
results = db.fetch_results()
# Process the results and return the desired output
return results
if __name__ == "__main__":
query = "SELECT table_schema, table_name, column_name FROM INFORMATION_SCHEMA.COLUMNS ORDER BY 1, 2, 3 LIMIT 5"
results = perform_query("postgres", query)
print(results)
Sample result
python3 src/my_db_module.py
Actual Database.__init__
2023-07-02 17:54:28,984 INFO sqlalchemy.engine.Engine select pg_catalog.version()
2023-07-02 17:54:28,984 INFO sqlalchemy.engine.Engine [raw sql] {}
2023-07-02 17:54:28,986 INFO sqlalchemy.engine.Engine select current_schema()
2023-07-02 17:54:28,986 INFO sqlalchemy.engine.Engine [raw sql] {}
2023-07-02 17:54:28,987 INFO sqlalchemy.engine.Engine show standard_conforming_strings
2023-07-02 17:54:28,987 INFO sqlalchemy.engine.Engine [raw sql] {}
2023-07-02 17:54:28,990 INFO sqlalchemy.engine.Engine BEGIN (implicit)
2023-07-02 17:54:28,990 INFO sqlalchemy.engine.Engine SELECT pg_catalog.pg_class.relname
FROM pg_catalog.pg_class JOIN pg_catalog.pg_namespace ON pg_catalog.pg_namespace.oid = pg_catalog.pg_class.relnamespace
WHERE pg_catalog.pg_class.relname = %(table_name)s AND pg_catalog.pg_class.relkind = ANY (ARRAY[%(param_1)s, %(param_2)s, %(param_3)s, %(param_4)s, %(param_5)s]) AND pg_catalog.pg_table_is_visible(pg_catalog.pg_class.oid) AND pg_catalog.pg_namespace.nspname != %(nspname_1)s
2023-07-02 17:54:28,990 INFO sqlalchemy.engine.Engine [generated in 0.00014s] {'table_name': 'SELECT table_schema, table_name, column_name FROM INFORMATION_SCHEMA.COLUMNS ORDER BY 1, 2, 3 LIMIT 5', 'param_1': 'r', 'param_2': 'p', 'param_3': 'f', 'param_4': 'v', 'param_5': 'm', 'nspname_1': 'pg_catalog'}
2023-07-02 17:54:28,992 INFO sqlalchemy.engine.Engine SELECT table_schema, table_name, column_name FROM INFORMATION_SCHEMA.COLUMNS ORDER BY 1, 2, 3 LIMIT 5
2023-07-02 17:54:28,992 INFO sqlalchemy.engine.Engine [raw sql] {}
2023-07-02 17:54:29,010 INFO sqlalchemy.engine.Engine ROLLBACK
table_schema table_name column_name
0 test brand_master brand_id
1 test brand_master brand_name
2 test brand_master excipient
3 test brand_master generic_id
4 test brand_master license_number
How to write a test for this?
import unittest
from unittest.mock import MagicMock, patch
import src.db.postgres
from src.my_db_module import perform_query
class MyModuleTestCase(unittest.TestCase):
def test_perform_query(self):
# Create a MagicMock object for the Database class
db_mock = MagicMock()
db_mock.connect.return_value = None
db_mock.execute_query.return_value = None
db_mock.fetch_results.return_value = [("Result 1",), ("Result 2",)]
# Patch the Database class in my_module with the MagicMock object
with patch("src.my_db_module.Database", return_value=db_mock):
# Call the function under test
result = perform_query("my_db", "SELECT * FROM my_table")
# Perform assertions
self.assertEqual(result, [("Result 1",), ("Result 2",)])
db_mock.connect.assert_called_once()
db_mock.execute_query.assert_called_once_with(
"SELECT * FROM my_table", None
)
db_mock.fetch_results.assert_called_once()
import unittest
from unittest.mock import MagicMock, patch
import src.db.postgres
from src.my_db_module import perform_query
class MyModuleTestCase(unittest.TestCase):
def test_perform_query_another_way(self):
# Patch the Database class in my_module with the MagicMock object
with patch("src.my_db_module.Database") as db_mock:
db_mock.return_value.connect.return_value = None
db_mock.return_value.execute_query.return_value = None
db_mock.return_value.fetch_results.return_value = [
("Result 1",),
("Result 2",),
]
# Call the function under test
result = perform_query("my_db", "SELECT * FROM my_table")
# Perform assertions
self.assertEqual(result, [("Result 1",), ("Result 2",)])
db_mock.return_value.connect.assert_called_once()
db_mock.return_value.execute_query.assert_called_once_with(
"SELECT * FROM my_table", None
)
db_mock.return_value.fetch_results.assert_called_once()
Patching a specific method
import unittest
from unittest.mock import MagicMock, patch
import src.db.postgres
from src.my_db_module import perform_query
class MyModuleTestCase(unittest.TestCase):
def test3(self):
db1 = src.db.postgres.Database("db1", {})
with patch.object(db1, "connect", lambda: "Mocked connect") as mock:
print(db1.connect())
pytest -s -vv -o log_cli=True tests/test_my_db_module.py
tests/test_my_db_module.py::MyModuleTestCase::test3 Actual Database.__init__
Mocked connect
PASSED
Patching multiple methods
import unittest
from unittest.mock import MagicMock, patch
import src.db.postgres
from src.my_db_module import perform_query
class MyModuleTestCase(unittest.TestCase):
def test4(self):
db1 = src.db.postgres.Database("db1", {})
with patch.object(db1, "connect", lambda: "Mocked connect") as mock1:
with patch.object(
db1, "execute_query", lambda: "Mocked execute_query()"
) as mock2:
print(db1.connect())
print(db1.execute_query())
pytest -s -vv -o log_cli=True tests/test_my_db_module.py
tests/test_my_db_module.py::MyModuleTestCase::test4 Actual Database.__init__
Mocked connect
Mocked execute_query()
PASSED
Patching __init__
import unittest
from unittest.mock import MagicMock, patch
import src.db.postgres
from src.my_db_module import perform_query
class MyModuleTestCase(unittest.TestCase):
@staticmethod
def mocked_init(db_name, config):
print("Mocked init")
@patch.object(src.db.postgres.Database, "connect", lambda x: "Mocked connect")
@patch.object(src.db.postgres.Database, "execute_query", lambda x: "Mocked execute_query()")
@patch.object(src.db.postgres.Database, "__init__", mocked_init)
def test5(self):
db1 = src.db.postgres.Database("db1", {})
print(db1.connect())
print(db1.execute_query())
pytest -s -vv -o log_cli=True tests/test_my_db_module.py
tests/test_my_db_module.py::MyModuleTestCase::test5 Mocked init
Mocked connect
Mocked execute_query()
PASSED
Where to patch?
src/services/my_api.py
import requests
from requests import Session
def get_name(url: str) -> str:
print(f"Calling API: {url}")
session = requests.Session()
response = session.get(url)
json_response = response.json()
print("Response from API:", json_response)
return json_response["name"]
def get_name_another_version(url: str) -> str:
print(f"Calling API: {url}")
session = Session()
response = session.get(url)
json_response = response.json()
print("Response from API:", json_response)
return json_response["name"]
if __name__ == "__main__":
print(get_name("https://swapi.dev/api/people/1"))
- Here,
get_name()creates a session by usingrequests.Session(). get_name_another_version()usesSession().- When creating test cases, different ways will be used to supply mocked values.
tests/test_my_api_2.py
from unittest.mock import patch
from src.services.my_api import get_name, get_name_another_version
from pytest import fixture
@fixture(scope="function")
def fixture_url():
return "https://swapi.dev/api/people/1"
def test_1(fixture_url):
payload = {"name": "random user"}
with patch("requests.Session") as mock:
mock.return_value.get.return_value.json.return_value = payload
assert get_name(fixture_url) == "random user"
def test_2(fixture_url):
payload = {"name": "random user"}
with patch("src.services.my_api.Session") as mock:
mock.return_value.get.return_value.json.return_value = payload
assert get_name_another_version(fixture_url) == "random user"
- Since the URL is reused, it is defined using a fixture.
test_1testsget_name().test_2testsget_name_another_version.- Since
get_name_another_versionusesSessionrather thanrequests.Session, when mocking this object, we have to use the path where it is used.
- Since
- Refer this for more details - https://docs.python.org/3.9/library/unittest.mock.html#where-to-patch
| The basic principle is that you patch where an object isĀ looked up, which is not necessarily the same place as where it is defined. |
|---|
pytest -s -vv -o log_cli=True tests/test_my_api_2.py
tests/test_my_api_2.py::test_1 Calling API: https://swapi.dev/api/people/1
Response from API: {'name': 'random user'}
PASSED
tests/test_my_api_2.py::test_2 Calling API: https://swapi.dev/api/people/1
Response from API: {'name': 'random user'}
PASSED