Source code for secmlt.tests.mocks

"""Mock classes for testing."""

from collections.abc import Iterator

import torch


[docs] class MockLayer(torch.autograd.Function): """Fake layer that returns the input."""
[docs] @staticmethod def forward(ctx, inputs: torch.Tensor) -> torch.Tensor: # noqa: ANN001 """Fake forward, returns 10 scores.""" ctx.save_for_backward(inputs) return torch.randn(inputs.size(0), 10)
[docs] @staticmethod def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: # noqa: ANN001 """Fake backward, returns inputs.""" (inputs,) = ctx.saved_tensors return inputs
[docs] class MockModel(torch.nn.Module): """Mock class for torch model."""
[docs] @staticmethod def parameters() -> Iterator[torch.Tensor]: """Return fake parameters.""" params = torch.rand(10, 10) return iter( [ params, ], )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Return random outputs for classification and add fake gradients to x.""" # Mock output shape (batch_size, 10) fake_layer = MockLayer.apply return fake_layer(x)
[docs] def decision_function(self, *args, **kwargs) -> torch.Tensor: """Return random outputs for classification and add fake gradients to x.""" return self.forward(*args, **kwargs)
[docs] class MockLoss(torch.nn.Module): """Fake loss function."""
[docs] def forward(*args) -> torch.Tensor: """Override forward.""" x = torch.rand(10) x.backward = lambda: x return x