Source code for secmlt.tests.test_data
import torch
from secmlt.adv.evasion.perturbation_models import LpPerturbationModels
from secmlt.data.distributions import GeneralizedNormal, Rademacher
from secmlt.data.lp_uniform_sampling import LpUniformSampling
[docs]
def test_rademacher_dist_shape():
dist = Rademacher()
shape = torch.Size([3, 4])
sample = dist.sample(shape)
assert sample.shape == shape
[docs]
def test_rademacher_dist_values():
dist = Rademacher()
shape = torch.Size([3, 4])
sample = dist.sample(shape)
assert (sample == -1).any()
assert (sample == 1).any()
assert not (((sample != -1) & (sample != 1)).all())
[docs]
def test_gnormal_dist_shape():
dist = GeneralizedNormal()
shape = torch.Size([3, 4])
sample = dist.sample(shape)
assert sample.shape == shape
[docs]
def test_gnormal_dist_dtype():
dist = GeneralizedNormal()
shape = torch.Size([3, 4])
sample = dist.sample(shape)
assert sample.dtype == torch.float32
[docs]
def test_gnormal_dist_p():
dist = GeneralizedNormal()
shape = torch.Size([3, 4])
sample1 = dist.sample(shape, p=1)
sample2 = dist.sample(shape, p=2)
assert not torch.equal(sample1, sample2)