Source code for secmlt.adv.evasion.base_evasion_attack

"""Base classes for implementing attacks and wrapping backends."""

from __future__ import annotations

import importlib.util
import warnings
from abc import abstractmethod
from typing import TYPE_CHECKING, Literal

import torch
from secmlt.adv.backends import Backends
from torch.utils.data import DataLoader, TensorDataset

if TYPE_CHECKING:
    from collections.abc import Iterator

    from secmlt.models.base_model import BaseModel
    from secmlt.models.pytorch.base_pytorch_nn import BasePyTorchClassifier
    from secmlt.trackers.trackers import Tracker



[docs] class BaseEvasionAttackCreator: """Generic creator for evasion attacks."""
[docs] @classmethod def get_implementation(cls, backend: str) -> BaseEvasionAttack: """ Get the implementation of the attack with the given backend. Parameters ---------- backend : str The backend for the attack. See secmlt.adv.backends for available backends. Returns ------- BaseEvasionAttack Attack implementation. """ implementations = { Backends.FOOLBOX: cls.get_foolbox_implementation, Backends.ADVLIB: cls.get_advlib_implementation, Backends.NATIVE: cls._get_native_implementation, } cls.check_backend_available(backend) return implementations[backend]()
[docs] @classmethod def check_backend_available(cls, backend: str) -> bool: """ Check if a given backend is available for the attack. Parameters ---------- backend : str Backend string. Returns ------- bool True if the given backend is implemented. Raises ------ NotImplementedError Raises NotImplementedError if the requested backend is not in the list of the possible backends (check secmlt.adv.backends). """ if backend in cls.get_backends(): return True msg = "Unsupported or not-implemented backend." raise NotImplementedError(msg)
[docs] @classmethod def get_foolbox_implementation(cls) -> BaseEvasionAttack: """ Get the Foolbox implementation of the attack. Returns ------- BaseEvasionAttack Foolbox implementation of the attack. Raises ------ ImportError Raises ImportError if Foolbox extra is not installed. """ if importlib.util.find_spec("foolbox", None) is not None: return cls._get_foolbox_implementation() msg = "Foolbox extra not installed." raise ImportError(msg)
@staticmethod def _get_foolbox_implementation() -> BaseEvasionAttack: msg = "Foolbox implementation not available." raise NotImplementedError(msg)
[docs] @classmethod def get_advlib_implementation(cls) -> BaseEvasionAttack: """ Get the Adversarial Library implementation of the attack. Returns ------- BaseEvasionAttack Adversarial Library implementation of the attack. Raises ------ ImportError Raises ImportError if Adversarial Library extra is not installed. """ if importlib.util.find_spec("adv_lib", None) is not None: return cls._get_advlib_implementation() msg = "Adversarial Library extra not installed." raise ImportError(msg)
@staticmethod def _get_advlib_implementation() -> BaseEvasionAttack: msg = "Adversarial Library implementation not available." raise NotImplementedError(msg) @staticmethod def _get_native_implementation() -> BaseEvasionAttack: msg = "Native implementation not available." raise NotImplementedError(msg)
[docs] @staticmethod @abstractmethod def get_backends() -> set[str]: """ Get the available backends for the given attack. Returns ------- set[str] Set of implemented backends available for the attack. Raises ------ NotImplementedError Raises NotImplementedError if not implemented in the inherited class. """ msg = "Backends should be specified in inherited class." raise NotImplementedError(msg)
[docs] class BaseEvasionAttack: """Base class for evasion attacks.""" def _run_batches( self, model: BaseModel, data_loader: DataLoader, ) -> Iterator[tuple[torch.Tensor, torch.Tensor]]: """Run the attack on each batch and yield adversarials with labels.""" for samples, labels in data_loader: trackers = getattr(self, "_trackers", None) # Initialize tracking for new batch if trackers is not None: if isinstance(trackers, list): for tracker in trackers: tracker.init_tracking() else: trackers.init_tracking() try: x_adv, _ = self._run(model, samples, labels) finally: # End tracking for current batch if trackers is not None: if isinstance(trackers, list): for tracker in trackers: tracker.end_tracking() else: trackers.end_tracking() yield x_adv, labels def __call__( self, model: BaseModel, data_loader: DataLoader, stream: bool = False, ) -> DataLoader | Iterator[tuple[torch.Tensor, torch.Tensor]]: """ Compute the attack against the model, using the input data. Parameters ---------- model : BasePyTorchClassifier | torch.nn.Module Model to test. If a raw ``torch.nn.Module`` is passed, it is automatically wrapped in ``BasePyTorchClassifier``. data_loader : DataLoader Test dataloader. stream : bool, default=False If False, materialize all adversarial batches and return a DataLoader. If True, return an iterator yielding attacked batches lazily. Returns ------- DataLoader | Iterator[tuple[torch.Tensor, torch.Tensor]] Materialized dataloader with adversarial examples and original labels, or a lazy iterator over attacked batches. """ model = self._ensure_wrapped(model) attacked_batches = self._run_batches(model, data_loader) if stream: trackers = getattr(self, "_trackers", None) if trackers is not None: warnings.warn( "Trackers are enabled while streaming attack batches. " "Only consumed batches will be tracked.", UserWarning, stacklevel=2, ) return attacked_batches adversarials = [] original_labels = [] for x_adv, labels in attacked_batches: adversarials.append(x_adv) original_labels.append(labels) adversarials = torch.vstack(adversarials) original_labels = torch.hstack(original_labels) adversarial_dataset = TensorDataset(adversarials, original_labels) return DataLoader( adversarial_dataset, batch_size=data_loader.batch_size, ) @property def trackers(self) -> list[Tracker] | None: """ Get the trackers set for this attack. Returns ------- list[Tracker] | None Trackers set for the attack, if any. """ return getattr(self, "_trackers", None) @trackers.setter def trackers(self, trackers: list[Tracker] | Tracker | None = None) -> None: if self._trackers_allowed(): if trackers is not None and not isinstance(trackers, list): trackers = [trackers] self._trackers = trackers elif trackers is not None: msg = "Trackers not implemented for this attack." raise NotImplementedError(msg) @classmethod @abstractmethod def _trackers_allowed(cls) -> Literal[False]: return False @staticmethod def _ensure_wrapped(model: BaseModel | torch.nn.Module) -> BasePyTorchClassifier: """Wrap a raw nn.Module into BasePyTorchClassifier if needed.""" from secmlt.models.pytorch.base_pytorch_nn import BasePyTorchClassifier if isinstance(model, BasePyTorchClassifier): return model if isinstance(model, torch.nn.Module): return BasePyTorchClassifier(model=model) msg = f"Unsupported model type: {type(model)}" raise TypeError(msg)
[docs] @classmethod def check_perturbation_model_available(cls, perturbation_model: str) -> bool: """ Check whether the given perturbation model is available for the attack. Parameters ---------- perturbation_model : str A perturbation model. Returns ------- bool True if the attack implements the given perturbation model. Raises ------ NotImplementedError Raises NotImplementedError if not implemented in the inherited class. """ if perturbation_model in cls.get_perturbation_models(): return True msg = "Unsupported or not-implemented perturbation model." raise NotImplementedError(msg)
[docs] @staticmethod @abstractmethod def get_perturbation_models() -> set[str]: """ Check the perturbation models implemented for the given attack. Returns ------- set[str] The set of perturbation models for which the attack is implemented. Raises ------ NotImplementedError Raises NotImplementedError if not implemented in the inherited class. """ msg = "Perturbation models should be specified in inherited class." raise NotImplementedError(msg)
@abstractmethod def _run( self, model: BasePyTorchClassifier | torch.nn.Module, samples: torch.Tensor, labels: torch.Tensor, ) -> torch.Tensor: ...