Testing functions that use input()
is not a straightforward task. This tutorial shows how to do that.
The problem
Contest programming problem solutions are usually tested automatically. So the input is just a set of strings with predefined format that fed into standard input.
Suppose, for example, a problem requires us to read in N stacks of different sizes. Then input might look like this:
3
4 1 2 3 2
1 2
0
Here, 3 is the number of stacks to read in. The first number on each subsequent line is the size of the stack (which is not required in Python, but is helpful in other languages), followed by its elements.
To extract data from the input, a few manipulations have to be made:
def get_input_stacks():
n = int(input())
stacks = []
for _ in range(n):
str_stack = input().split(' ')
stack = [int(s) for s in str_stack]
stacks.append(stack)
return stacks
And it doesn't take a lot of effort to make a mistake here. In fact, I already made one in the code above. So, it would be nice to have automatic tests for this kind of functions too.
The solution
One of the solutions I found on StackOverflow when applied to my problem looks like this:
from unittest.mock import patch
import unittest
import containers
class ContainersTestCase(unittest.TestCase):
def test_get_input_stacks_processed_input_correctly(self):
user_input = [
'3',
'4 1 2 3 2',
'1 2',
'0',
]
expected_stacks = [
[1, 2, 3, 2],
[2],
[],
]
with patch('builtins.input', side_effect=user_input):
stacks = containers.get_input_stacks()
self.assertEqual(stacks, expected_stacks)
if __name__ == '__main__':
unittest.main()
When I run the test, I see that it fails:
The problem is that I read in the size of the stack as an element of the stack. Instead, I should just ignore it: str_stack = input().split(' ')[1:]
.
How the solution works
This is the most interesting line in the test:
with patch('builtins.input', side_effect=user_input):
I assume it is known how with statement works. unittest.mock.patch
has several keyword arguments that modify its behavior, but in this case, it simply replaces built-in input()
function with a unittest.mock.MagicMock
object. One of the properties of the object is to return its return_value
unless its side_effect
returns something else:
>>> import unittest.mock
>>> mock = unittest.mock.MagicMock()
>>> mock.return_value = '43'
>>> mock()
'43'
>>> def foo():
... return 'something_else'
...
>>> mock.side_effect = foo
>>> mock()
'something_else'
Notice that side_effect
, unlike return_value
, is a function. side_effect
can also be an exception object (then the exception will be raised when the mock is called) or an iterable. If it's an iterable, the mock yields a new value every time it's called:
>>> mock.side_effect = [1, 2]
>>> mock()
1
>>> mock()
2
>>> mock()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/local/Cellar/python3/3.6.1/Frameworks/Python.framework/Versions/3.6/lib/python3.6/unittest/mock.py", line 939, in __call__
return _mock_self._mock_call(*args, **kwargs)
File "/usr/local/Cellar/python3/3.6.1/Frameworks/Python.framework/Versions/3.6/lib/python3.6/unittest/mock.py", line 998, in _mock_call
result = next(effect)
StopIteration
>>>
and this is what we are trying to achieve by mocking input()
: a different predefined value each time it is called. This is why we are passing side_effect=user_input
to the patch()
function (side_effect
is both a property of MagicMock
and a keyword argument of patch()
):
user_input = [
'3',
'4 1 2 3 2',
'1 2',
'0',
]
with patch('builtins.input', side_effect=user_input):
The argument 'builtins.input'
contains a name that points to the input()
function. What patch()
does is it basically makes this name point to the MagicMock
object until the end of the with
block. An object in Python can have multiple names, so patch the right one (here's how). builtins is a module where built-in function names are located:
>>> import builtins
>>> builtins.print('hi')
hi
Each time the function get_input_stacks()
calls the patched input()
it gets the next string in the user_input
list. This is the behavior we wanted.
Where to go from here
To get a better understanding of the capabilities offered by unittest.mock
, it is best to read the manual. As an exercise, I would suggest testing a function that does the reverse of get_input_stacks()
by patching print()
.
Thank you for reading!
Top comments (3)
rather than multiple tests to a func that has one call to input(), what if you have one test with multiple calls to input(), e.g.:
def myfunc():
age = input('what is your age')
color = input('favorite color')
print(f'your age is {age} and you like {color}')
the mock approach you cover doesn't seem to work for more than one user input per test. What if I wanted to test myfunc() and make sure that if the inputs were '27' and 'Red', the output would be 'your age is 27 and you like Red'?
sorry...my code was at fault. The approach you describe above works perfectly for testing myfunc()!
You explain very well, thanks for sharing, keep it up.