Source code for secmlt.tests.mocks
"""Mock classes for testing."""
from collections.abc import Iterator
import torch
[docs]
class MockLayer(torch.autograd.Function):
"""Fake layer that returns the input."""
[docs]
@staticmethod
def forward(ctx, inputs: torch.Tensor) -> torch.Tensor: # noqa: ANN001
"""Fake forward, returns 10 scores."""
ctx.save_for_backward(inputs)
return torch.randn(inputs.size(0), 10)
[docs]
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: # noqa: ANN001
"""Fake backward, returns inputs."""
(inputs,) = ctx.saved_tensors
return inputs
[docs]
class MockModel(torch.nn.Module):
"""Mock class for torch model."""
[docs]
@staticmethod
def parameters() -> Iterator[torch.Tensor]:
"""Return fake parameters."""
params = torch.rand(10, 10)
return iter(
[
params,
],
)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Return random outputs for classification and add fake gradients to x."""
# Mock output shape (batch_size, 10)
fake_layer = MockLayer.apply
return fake_layer(x)
[docs]
def decision_function(self, *args, **kwargs) -> torch.Tensor:
"""Return random outputs for classification and add fake gradients to x."""
return self.forward(*args, **kwargs)
[docs]
class MockLoss(torch.nn.Module):
"""Fake loss function."""
[docs]
def forward(*args) -> torch.Tensor:
"""Override forward."""
x = torch.rand(10)
x.backward = lambda: x
return x