Python: Mocking a context manager
Asked Answered
T

6

122

I don't understand why I can't mock NamedTemporaryFile.name in this example:

from mock import Mock, patch
import unittest
import tempfile

def myfunc():
    with tempfile.NamedTemporaryFile() as mytmp:
        return mytmp.name

class TestMock(unittest.TestCase):
    @patch('tempfile.NamedTemporaryFile')
    def test_cm(self, mock_tmp):
        mytmpname = 'abcde'
        mock_tmp.__enter__.return_value.name = mytmpname
        self.assertEqual(myfunc(), mytmpname)

Test results in:

AssertionError: <MagicMock name='NamedTemporaryFile().__enter__().name' id='140275675011280'> != 'abcde'
Telescopy answered 4/3, 2015 at 8:50 Comment(0)
P
204

You are setting the wrong mock: mock_tmp is not the context manager, but instead returns a context manager. Replace your setup line with:

mock_tmp.return_value.__enter__.return_value.name = mytmpname

and your test will work.

Paolo answered 4/3, 2015 at 10:27 Comment(0)
M
41

To expand on Nathaniel's answer, this code block

with tempfile.NamedTemporaryFile() as mytmp:
    return mytmp.name

effectively does three things

# Firstly, it calls NamedTemporaryFile, to create a new instance of the class.
context_manager = tempfile.NamedTemporaryFile()  

# Secondly, it calls __enter__ on the context manager instance.
mytmp = context_manager.__enter__()  

# Thirdly, we are now "inside" the context and can do some work. 
return mytmp.name

When you replace tempfile.NamedTemporaryFile with an instance of Mock or MagicMock

context_manager = mock_tmp()
# This first line, above, will call mock_tmp().
# Therefore we need to set the return_value with
# mock_tmp.return_value

mytmp = context_manager.__enter__()
# This will call mock_tmp.return_value.__enter__() so we need to set 
# mock_tmp.return_value.__enter__.return_value

return mytmp.name
# This will access mock_tmp.return_value.__enter__.return_value.name
Munoz answered 30/11, 2020 at 12:56 Comment(2)
This answer was perfectly succinct and helped me out. Thank you. I was using a contextlib.asynccontextmanager and the only thing I needed to change to get this to work was changing __enter__ to __aenter__Harbot
You're my savior! ♥Jook
D
18

Extending Peter K's answer using pytest and the mocker fixture.

def myfunc():
    with tempfile.NamedTemporaryFile(prefix='fileprefix') as fh:
        return fh.name


def test_myfunc(mocker):
    mocker.patch('tempfile.NamedTemporaryFile').return_value.__enter__.return_value.name = 'tempfilename'
    assert myfunc() == 'tempfilename'
Diplostemonous answered 9/10, 2019 at 18:56 Comment(0)
N
6

Here is an alternative with pytest and mocker fixture, which is a common practice as well:

def test_myfunc(mocker):
    mock_tempfile = mocker.MagicMock(name='tempfile')
    mocker.patch(__name__ + '.tempfile', new=mock_tempfile)
    mytmpname = 'abcde'
    mock_tempfile.NamedTemporaryFile.return_value.__enter__.return_value.name = mytmpname
    assert myfunc() == mytmpname
Nonanonage answered 7/8, 2019 at 13:18 Comment(0)
M
3

I extended hmobrienv's answer to a small working program

import tempfile
import pytest


def myfunc():
    with tempfile.NamedTemporaryFile(prefix="fileprefix") as fh:
        return fh.name


def test_myfunc(mocker):
    mocker.patch("tempfile.NamedTemporaryFile").return_value.__enter__.return_value.name = "tempfilename"
    assert myfunc() == "tempfilename"


if __name__ == "__main__":
    pytest.main(args=[__file__])
Microbe answered 13/1, 2021 at 0:7 Comment(1)
It's crazy how most of the answers here explain the same thing but yours was the only one that could make me understand it. hahaFillet
S
2

Another possibility is to use a factory to create an object that implements the context manager interface:

import unittest
import unittest.mock
import tempfile


def myfunc():
    with tempfile.NamedTemporaryFile() as mytmp:
        return mytmp.name


def mock_named_temporary_file(tmpname):
    class MockNamedTemporaryFile(object):
        def __init__(self, *args, **kwargs):
            self.name = tmpname

        def __enter__(self):
            return self

        def __exit__(self, type, value, traceback):
            pass

    return MockNamedTemporaryFile()


class TestMock(unittest.TestCase):
    @unittest.mock.patch("tempfile.NamedTemporaryFile")
    def test_cm(self, mock_tmp):
        mytmpname = "abcde"
        mock_tmp.return_value = mock_named_temporary_file(mytmpname)
        self.assertEqual(myfunc(), mytmpname)

Sansbury answered 4/2, 2022 at 20:13 Comment(1)
This is the only working solution for my case. Thanks !Tailband

© 2022 - 2024 — McMap. All rights reserved.