Pytest: Mock multiple calls of same method with different side_effect
Asked Answered
P

2

8

I have a unit test like so below:

# utilities.py  
def get_side_effects():
    def side_effect_func3(self):
        # Need the "self" to do some stuff at run time.
        return {"final":"some3"} 

    def side_effect_func2(self):
        # Need the "self" to do some stuff at run time.
        return {"status":"some2"}
      
    def side_effect_func1(self):
        # Need the "self" to do some stuff at run time.
        return {"name":"some1"} 

    return side_effect_func1, side_effect_func2, side_effect_func2

#################

# test_a.py
def test_endtoend():
   
    s1, s2, s3 = utilities.get_side_effects()
    
    m1 = mock.MagicMock()
    m1.side_effect = s1

    m2 = mock.MagicMock()
    m2.side_effect = s2

    m3 = mock.MagicMock()
    m3.side_effect = s3
   
    with mock.patch("a.get_request", m3):
        with mock.patch("a.get_request", m2):
            with mock.patch("a.get_request", m1):
                foo = a() # Class to test
                result = foo.run() 
    
    

As part of the foo.run() code run, get_request is called multiple times. I want to have a different side_effect function for each call of get_request method, in this case it is side_effect_func1, side_effect_func2, side_effect_func3. But what I'm noticing is that only m1 mock object is active, i.e only side_effect_func1 is invoked but not the other 2. How do I achieve this?

I have also tried the below, but the actual side_effect functions don't get invoked, they always return the function object, but don't actually execute the side_effect functions.

# utilities.py
def get_side_effects():
    def side_effect_func3(self):
        # Need the "self" to do some stuff at run time.
        return {"final":"some3"} 

    def side_effect_func2(self):
        # Need the "self" to do some stuff at run time.
        return {"status":"some2"}
      
    def side_effect_func1(self):
        # Need the "self" to do some stuff at run time.
        return {"name":"some1"} 

    all_get_side_effects = []
    all_get_side_effects.append(side_effect_func1)
    all_get_side_effects.append(side_effect_func2)
    all_get_side_effects.append(side_effect_func3)
     
    return all_get_side_effects

#########################
# test_a.py
def test_endtoend():

    all_side_effects = utilities.get_side_effects()

    m = mock.MagicMock()
    m.side_effect = all_side_effects

    with mock.patch("a.get_request", m):
       foo = a() # Class to test
       result = foo.run()
Pretermit answered 24/11, 2020 at 17:47 Comment(0)
G
12

Your first attempt doesn't work because each mock just replaced the previous one (the outer two mocks don't do anything).

Your second attempt doesn't work because side-effect is overloaded to serve a different purpose for iterables (docs):

If side_effect is an iterable then each call to the mock will return the next value from the iterable.

Instead you could use a callable class for the side-effect, which is maintaining some state about which underlying function to actually call, consecutively.

Basic example with two functions:

>>> class SideEffect:
...     def __init__(self, *fns):
...         self.fs = iter(fns)
...     def __call__(self, *args, **kwargs):
...         f = next(self.fs)
...         return f(*args, **kwargs)
... 
>>> def sf1():
...     print("called sf1")
...     return 1
... 
>>> def sf2():
...     print("called sf2")
...     return 2
... 
>>> def foo():
...     print("called actual func")
...     return "f"
... 
>>> with mock.patch("__main__.foo", side_effect=SideEffect(sf1, sf2)) as m:
...     first = foo()
...     second = foo()
... 
called sf1
called sf2
>>> assert first == 1
>>> assert second == 2
>>> assert m.call_count == 2
Gd answered 24/11, 2020 at 18:5 Comment(1)
Thank you for this explanation! I used your answer to make improvements on my actual code and it worked. :-)Pretermit
T
0

The following example tests the handling of the situation when some of the remote API requests cause an error (exception) and most of the requests are OK.

The example uses almost the same SideEffect class as the answer by wim. The improvement is that functions can be called an unlimited number of times within an infinite loop, rather than just twice.

import asyncio
import itertools
from typing import Dict, List
from unittest import mock

from gql import Client, gql
from gql.transport.aiohttp import AIOHTTPTransport


class RemoteAPICaller:
    """
    A class that contains functions for sending a single request to a remote API,
    a series of such requests, and final processing of the result
    of a series of requests. During final processing, if there were unsuccessful requests,
    a warning is added to the final result.
    """

    DATA_RETRIEVE_PROBLEMS_WARNING = "We had problems getting data from a 3rd party, so the result is clipped."
    REQUESTS_COUNT = 4

    async def make_single_request(self) -> dict:
        """
        remote API request demo, uses GraphQL, source - https://gql.readthedocs.io/en/latest/transports/aiohttp.html
        """
        transport = AIOHTTPTransport(url="https://countries.trevorblades.com/graphql")
        graphql_query = gql(
            """
                query getContinents {
                continents {
                    code
                    name
                }
                }
            """
        )
        async with Client(
            transport=transport,
            fetch_schema_from_transport=True,
        ) as session:
            result = await session.execute(graphql_query)
            print(f"{result=}")
        return result

    def prepare_final_result(self, all_tasks_results: List[Dict]) -> Dict:
        response_template: Dict = {
            "errors": None,
            "warnings": [],
            "continents_data": {},
        }

        for item in all_tasks_results:

            # If there were errors when calling the remote API,
            # then a warning should be added to the response.
            if isinstance(item, Exception):
                if self.DATA_RETRIEVE_PROBLEMS_WARNING not in response_template["warnings"]:
                    response_template["warnings"].append(self.DATA_RETRIEVE_PROBLEMS_WARNING)
                continue

            if not item:
                continue

            for continent in item["continents"]:
                response_template["continents_data"][continent["code"]] = continent["name"]
        print(f"{response_template=}")
        return response_template

    def make_remote_requests(self) -> List[Dict]:
        """
        Make several remote API calls and combine their results into a list
        """
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        tasks = [self.make_single_request() for _ in range(self.REQUESTS_COUNT)]
        all_tasks_results = loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
        loop.close()
        print(f"{all_tasks_results=}")
        return self.prepare_final_result(all_tasks_results)


class SideEffect:
    """
    A callable class that receives a list of functions during initialization.
    Each time the class is called, it calls the next function in the list.
    If the entire list has been traversed, that is, all functions have been called sequentially,
    then the list traversal starts over.
    """

    def __init__(self, *input_functions):
        self.functions_to_call = itertools.cycle(input_functions)

    def __call__(self, *args, **kwargs):
        function_to_call = next(self.functions_to_call)
        return function_to_call(*args, **kwargs)


def mocked_single_request_ok():
    return {
        "continents": [
            {"code": "AF", "name": "Africa"},
            {"code": "AN", "name": "Antarctica"},
            {"code": "AS", "name": "Asia"},
            {"code": "EU", "name": "Europe"},
            {"code": "NA", "name": "North America"},
            {"code": "OC", "name": "Oceania"},
            {"code": "SA", "name": "South America"},
        ]
    }


def mocked_single_request_failure():
    raise Exception("Some exception")


def test_demo():
    """
    Scenario: some of the remote API calls were successful (50%),
    but some were not. Check that a warning DATA_RETRIEVE_PROBLEMS_WARNING
    has been added to the response.
    """
    with mock.patch(
        "tests.unit.clients.test_demo.RemoteAPICaller.make_single_request",
        side_effect=SideEffect(mocked_single_request_ok, mocked_single_request_failure),
    ) as m:
        all_tasks_results = RemoteAPICaller().make_remote_requests()
        assert m.call_count == RemoteAPICaller.REQUESTS_COUNT
    print(f"{all_tasks_results=}")
    assert RemoteAPICaller.DATA_RETRIEVE_PROBLEMS_WARNING in all_tasks_results["warnings"]
Tosh answered 31/5, 2023 at 15:13 Comment(0)

© 2022 - 2024 — McMap. All rights reserved.