Unit Testing for Data Science

python
software engineering
Python testing for Machine Learning
Author

Sachin Abeywardana

Published

July 3, 2022

Introduction

Writing tests has always been poison to me. And I’m still not at the point of writing tests for everything, but I have been coming around. Enough to say that I am actually having fun writing them. The way I see it the point of unit tests is to catch bugs, and catch them early.

For die hard Jupyter worshippers like myself, the question is, what do you mean catch them early? You just copy and paste your tested code into a .py file and call it a day right? Unfortunately, most of the time the code in a single Jupyter notebook is too messy for a enterprise level monorepo. While projects like nbdev exist, introducing such a library to an existing repo is not trivial. None the less it may even be an organisational requirement to have high code coverage by testing as much as possible.

This tutorial is some of the tricks I have picked up along the way, including best practices. These include how to test large Deep Learning models. I do not claim to be a testing guru or anywhere near it.

bird chomping on biscuit and screaming meme

Basic Unit Test structure (conventions)

Usually you would have a tests folder which will contain test files that starts with test_*.py. These file usually correspond 1 to 1 with whatever is in your src directory that you are testing (eg. src/a.py would have a tests/test_a.py). Each function/ class that you are testing would similarly have a def test_*() function. All testable functions must start with test_. And finally, usually you would have an assert statement inside these tests, but testing goes beyond these statements, and are not a necessity.

In order to run them you can simply run pytest /path/to/folders/tests/.

Dependency Injection

Since these are usually run in CICD framework, it is important that these tests are run quickly. Therefore, we should not instantiate large NLP/ CV models inside a test. One way to get around this is to inject the dependency to a function.

Consider the following two functions:

def create_classification_model(num_classes: int) -> nn.Module:
    model = models.resnet34(pretrained=True)
    return torch.nn.Sequential(
        *(
            list(model.children())[:-1] + [nn.Linear(512, num_classes)]
        )
    )

# don't name it with_injection, this is just for illustration
def create_classification_model_with_injection(base_model: nn.Module, num_classes: int) -> nn.Module:
    return torch.nn.Sequential(
        *(
            list(base_model.children())[:-1] + [nn.Linear(512, num_classes)]
        )
    )

Out of the two, the second is more testable as we do not 1. need to instatiate a large model, 2. Dowload anything from the internet. When testing we could pass in something as simple as test_base_model = nn.Conv2D(3, 512). While it’s true we are not testing out a full resnet model, we are still able to check for bugs that may be caused by running above.

Pytest Fixtures and conftest.py

Suppose that you needed a model definition for multiple test functions. While we can instantiate a dummy model inside a test_* function, one way to write this instantion once, is to write a function called def dummy_model() -> nn.Module and decorate it with @pytest.fixture. Once this is done, we can pass it into the test functions as an argument, and pytest will take care of passing in a instantiated version. If this model definition is required in other files for testing, we can move it into a conftest.py which will make it accessible for all files in that tests directory. Here is an example of a dummy transformer model and tokenizer in a conftest.py file.

@pytest.fixture
def model() -> transformers.PreTrainedModel:
    config = transformers.DistilBertConfig(
        vocab_size=4,  # must be the same as the vocab size in the tokenizer
        n_layers=1,
        n_heads=1,
        dim=4,
        hidden_dim=4,
    )
    model = transformers.DistilBertModel(config)
    return model


@pytest.fixture
def tokenizer(tmp_path: pathlib.Path) -> transformers.PreTrainedTokenizer:
    with open(tmp_path / "vocab.txt", "w") as f:
        f.write("[CLS]\n[SEP]\n[MASK]\n[UNK]\n")

    tokenizer = transformers.DistilBertTokenizer(tmp_path / "vocab.txt")
    return tokenizer

@pytest.fixture
def test_sentences() -> list[str]:
    return [
        "Never gonna give you up",
        "Never gonna let you down",
        "Never gonna run around and desert you",
    ]

And the usage in a test file (not conftest) is shown below:

def test_model_output(model, tokenizer, test_sentences):
    values = model(**tokenizer(test_sentences))
    assert len(values) == len(test_sentences)

Mocking

Depending on complexity, and use case you may not want to construct a dummy object. Instead, we may create unittest.mock.Mock objects. The magic about these objects are that 1. You can call them with infinitely many methods (apart from some assert_* methods), meaning you do not need to implement methods associated with those instances.

Let’s consider the function create_classification_model_with_injection. In this case, instead of creating a fake test model, let’s do the following:

def test_create_classification_model_with_injection():
    mock_model = mock.Mock()
    create_classification_model_with_injection(mock_model, 10)

    mock_model.children.assert_called_once()

In the above what we are testing is that children attribute of the model was called. This means that any future implementation would require children to be called in its implementation, unless the tests are changed. I will refer you to this excellent blog for further magic you can do with mock classes.

Before moving on, I want to stress the point that unit testing does not need to be about matching inputs to expected outputs.

Patching

Some functions require you to perform actions that you cannot test. Downloading is one such example. Suppose I have this function:

# in models.py
def get_model_and tokenizer(model_name: str):
    model = AutoModel.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    return model, tokenizer

One way to test this is to “patch” the AutoModel.from_pretrained and AutoTokenizer.from_pretrained functions.

def test_get_model(model, tokenizer):
    with mock.patch.object(
        models.AutoModel, "from_pretrained", return_value=model
    ) as mock_model, mock.patch.object(
        models.AutoTokenizer, "from_pretrained", return_value=tokenizer
    ) as mock_tokenizer:
        model_returned, tokenizer_returned = models.get_model_and_tokenizer("bert")

    assert model == model_returned
    assert tokenizer == tokenizer_returned

In the above we case we are effectively testing that from_pretrained gets called during the function.

In order to use mock.patch.object the first argument goes models.AutoModel, despite the fact that AutoModel comes from the transformers library. This is because the “instance” that we are patching is in the models.py file. The second argument is a string of the function that we are calling, and finally the the return_value argument forces that function to return this despite whatever argument.

Parametrizing

You may want to test for varying values of a certain input. While it is possible to do so using a for loop, pytest offers the pytest.mark.parametrize decorator. Suppose we have a fake base model for the image classification model we defined above. In the following example we can test multiple num_classes without resorting to an ugly for loop.

@pytest.mark.parametrize("num_classes", [10, 15])
def test_create_classification_model(
    base_model: nn.Module, # this comes from a fixture
    num_classes: int,
):
    model = create_classification_model_with_injection(base_model, num_classes)
    fake_input = torch.randn(16, 3, 28, 28) 
    assert model(fake_input).shape[-1] == num_classes

Conclusion

In my concluding remarks, I would like to stress that some tests are better than none. I personally don’t believe that tests have to be exhaustive, but I can understand if this is a point of contention.

Also occasionally there are tests which do not include any assert statements. It simply checks if a group of functions simply run end to end.

Best of luck with your testing journey!

Kudos

Kudos to Ryan Lin for all the help with writing tests.

Shameless Self Promotion

If you enjoyed the tutorial buy my course (usually 90% off).