"""Base classes for implementing attacks and wrapping backends."""
from __future__ import annotations
import importlib.util
import warnings
from abc import abstractmethod
from typing import TYPE_CHECKING, Literal
import torch
from secmlt.adv.backends import Backends
from torch.utils.data import DataLoader, TensorDataset
if TYPE_CHECKING:
from collections.abc import Iterator
from secmlt.models.base_model import BaseModel
from secmlt.models.pytorch.base_pytorch_nn import BasePyTorchClassifier
from secmlt.trackers.trackers import Tracker
[docs]
class BaseEvasionAttackCreator:
"""Generic creator for evasion attacks."""
[docs]
@classmethod
def get_implementation(cls, backend: str) -> BaseEvasionAttack:
"""
Get the implementation of the attack with the given backend.
Parameters
----------
backend : str
The backend for the attack. See secmlt.adv.backends for
available backends.
Returns
-------
BaseEvasionAttack
Attack implementation.
"""
implementations = {
Backends.FOOLBOX: cls.get_foolbox_implementation,
Backends.ADVLIB: cls.get_advlib_implementation,
Backends.NATIVE: cls._get_native_implementation,
}
cls.check_backend_available(backend)
return implementations[backend]()
[docs]
@classmethod
def check_backend_available(cls, backend: str) -> bool:
"""
Check if a given backend is available for the attack.
Parameters
----------
backend : str
Backend string.
Returns
-------
bool
True if the given backend is implemented.
Raises
------
NotImplementedError
Raises NotImplementedError if the requested backend is not in
the list of the possible backends (check secmlt.adv.backends).
"""
if backend in cls.get_backends():
return True
msg = "Unsupported or not-implemented backend."
raise NotImplementedError(msg)
[docs]
@classmethod
def get_foolbox_implementation(cls) -> BaseEvasionAttack:
"""
Get the Foolbox implementation of the attack.
Returns
-------
BaseEvasionAttack
Foolbox implementation of the attack.
Raises
------
ImportError
Raises ImportError if Foolbox extra is not installed.
"""
if importlib.util.find_spec("foolbox", None) is not None:
return cls._get_foolbox_implementation()
msg = "Foolbox extra not installed."
raise ImportError(msg)
@staticmethod
def _get_foolbox_implementation() -> BaseEvasionAttack:
msg = "Foolbox implementation not available."
raise NotImplementedError(msg)
[docs]
@classmethod
def get_advlib_implementation(cls) -> BaseEvasionAttack:
"""
Get the Adversarial Library implementation of the attack.
Returns
-------
BaseEvasionAttack
Adversarial Library implementation of the attack.
Raises
------
ImportError
Raises ImportError if Adversarial Library extra is not installed.
"""
if importlib.util.find_spec("adv_lib", None) is not None:
return cls._get_advlib_implementation()
msg = "Adversarial Library extra not installed."
raise ImportError(msg)
@staticmethod
def _get_advlib_implementation() -> BaseEvasionAttack:
msg = "Adversarial Library implementation not available."
raise NotImplementedError(msg)
@staticmethod
def _get_native_implementation() -> BaseEvasionAttack:
msg = "Native implementation not available."
raise NotImplementedError(msg)
[docs]
@staticmethod
@abstractmethod
def get_backends() -> set[str]:
"""
Get the available backends for the given attack.
Returns
-------
set[str]
Set of implemented backends available for the attack.
Raises
------
NotImplementedError
Raises NotImplementedError if not implemented in the inherited class.
"""
msg = "Backends should be specified in inherited class."
raise NotImplementedError(msg)
[docs]
class BaseEvasionAttack:
"""Base class for evasion attacks."""
def _run_batches(
self,
model: BaseModel,
data_loader: DataLoader,
) -> Iterator[tuple[torch.Tensor, torch.Tensor]]:
"""Run the attack on each batch and yield adversarials with labels."""
for samples, labels in data_loader:
trackers = getattr(self, "_trackers", None)
# Initialize tracking for new batch
if trackers is not None:
if isinstance(trackers, list):
for tracker in trackers:
tracker.init_tracking()
else:
trackers.init_tracking()
try:
x_adv, _ = self._run(model, samples, labels)
finally:
# End tracking for current batch
if trackers is not None:
if isinstance(trackers, list):
for tracker in trackers:
tracker.end_tracking()
else:
trackers.end_tracking()
yield x_adv, labels
def __call__(
self,
model: BaseModel,
data_loader: DataLoader,
stream: bool = False,
) -> DataLoader | Iterator[tuple[torch.Tensor, torch.Tensor]]:
"""
Compute the attack against the model, using the input data.
Parameters
----------
model : BasePyTorchClassifier | torch.nn.Module
Model to test. If a raw ``torch.nn.Module`` is passed, it is
automatically wrapped in ``BasePyTorchClassifier``.
data_loader : DataLoader
Test dataloader.
stream : bool, default=False
If False, materialize all adversarial batches and return a
DataLoader.
If True, return an iterator yielding attacked batches lazily.
Returns
-------
DataLoader | Iterator[tuple[torch.Tensor, torch.Tensor]]
Materialized dataloader with adversarial examples and original
labels, or a lazy iterator over attacked batches.
"""
model = self._ensure_wrapped(model)
attacked_batches = self._run_batches(model, data_loader)
if stream:
trackers = getattr(self, "_trackers", None)
if trackers is not None:
warnings.warn(
"Trackers are enabled while streaming attack batches. "
"Only consumed batches will be tracked.",
UserWarning,
stacklevel=2,
)
return attacked_batches
adversarials = []
original_labels = []
for x_adv, labels in attacked_batches:
adversarials.append(x_adv)
original_labels.append(labels)
adversarials = torch.vstack(adversarials)
original_labels = torch.hstack(original_labels)
adversarial_dataset = TensorDataset(adversarials, original_labels)
return DataLoader(
adversarial_dataset,
batch_size=data_loader.batch_size,
)
@property
def trackers(self) -> list[Tracker] | None:
"""
Get the trackers set for this attack.
Returns
-------
list[Tracker] | None
Trackers set for the attack, if any.
"""
return getattr(self, "_trackers", None)
@trackers.setter
def trackers(self, trackers: list[Tracker] | Tracker | None = None) -> None:
if self._trackers_allowed():
if trackers is not None and not isinstance(trackers, list):
trackers = [trackers]
self._trackers = trackers
elif trackers is not None:
msg = "Trackers not implemented for this attack."
raise NotImplementedError(msg)
@classmethod
@abstractmethod
def _trackers_allowed(cls) -> Literal[False]:
return False
@staticmethod
def _ensure_wrapped(model: BaseModel | torch.nn.Module) -> BasePyTorchClassifier:
"""Wrap a raw nn.Module into BasePyTorchClassifier if needed."""
from secmlt.models.pytorch.base_pytorch_nn import BasePyTorchClassifier
if isinstance(model, BasePyTorchClassifier):
return model
if isinstance(model, torch.nn.Module):
return BasePyTorchClassifier(model=model)
msg = f"Unsupported model type: {type(model)}"
raise TypeError(msg)
[docs]
@classmethod
def check_perturbation_model_available(cls, perturbation_model: str) -> bool:
"""
Check whether the given perturbation model is available for the attack.
Parameters
----------
perturbation_model : str
A perturbation model.
Returns
-------
bool
True if the attack implements the given perturbation model.
Raises
------
NotImplementedError
Raises NotImplementedError if not implemented in the inherited class.
"""
if perturbation_model in cls.get_perturbation_models():
return True
msg = "Unsupported or not-implemented perturbation model."
raise NotImplementedError(msg)
[docs]
@staticmethod
@abstractmethod
def get_perturbation_models() -> set[str]:
"""
Check the perturbation models implemented for the given attack.
Returns
-------
set[str]
The set of perturbation models for which the attack is implemented.
Raises
------
NotImplementedError
Raises NotImplementedError if not implemented in the inherited class.
"""
msg = "Perturbation models should be specified in inherited class."
raise NotImplementedError(msg)
@abstractmethod
def _run(
self,
model: BasePyTorchClassifier | torch.nn.Module,
samples: torch.Tensor,
labels: torch.Tensor,
) -> torch.Tensor: ...