Source code for secmlt.adv.evasion.advlib_attacks.advlib_base
"""Generic wrapper for Adversarial Library evasion attacks."""
from __future__ import annotations # noqa: I001
from typing import TYPE_CHECKING, Literal
import torch
from secmlt.adv.evasion.base_evasion_attack import BaseEvasionAttack
from secmlt.models.pytorch.base_pytorch_nn import BasePyTorchClassifier
from secmlt.trackers.model_tracker import ModelTracker
if TYPE_CHECKING:
from collections.abc import Callable
from secmlt.models.base_model import BaseModel
from secmlt.trackers.trackers import Tracker
[docs]
class BaseAdvLibEvasionAttack(BaseEvasionAttack):
"""Generic wrapper for Adversarial Library Evasion attacks."""
def __init__(
self,
advlib_attack: Callable[..., torch.Tensor],
epsilon: float = torch.inf,
y_target: int | None = None,
lb: float = 0.0,
ub: float = 1.0,
trackers: Tracker | list[Tracker] | None = None,
**kwargs,
) -> None:
"""
Wrap Adversarial Library attacks.
Parameters
----------
advlib_attack : Callable[..., torch.Tensor]
The Adversarial Library attack function to wrap.
The function returns the adversarial examples.
epsilon : float, optional
The perturbation constraint. The default value is
torch.inf, which means no constraint.
y_target : int | None, optional
The target label for the attack. If None, the attack is
untargeted. The default value is None.
lb : float, optional
The lower bound for the perturbation. The default value is 0.0.
ub : float, optional
The upper bound for the perturbation. The default value is 1.0.
trackers : Tracker | list[Tracker] | None, optional
Trackers for the attack (unallowed in Adversarial Library), by default None.
"""
self.advlib_attack = advlib_attack
self.lb = lb
self.ub = ub
self.epsilon = epsilon
self.y_target = y_target
self.trackers = trackers
self.kwargs = kwargs
super().__init__()
@classmethod
def _trackers_allowed(cls) -> Literal[True]:
return True
def _run(
self,
model: BaseModel,
samples: torch.Tensor,
labels: torch.Tensor,
) -> torch.Tensor:
if not isinstance(model, BasePyTorchClassifier):
msg = "Model type not supported."
raise NotImplementedError(msg)
targets = (
torch.ones_like(labels) * self.y_target
if self.y_target is not None
else labels
)
# Wrap model with ModelTracker if trackers are set
model_tracker = None
if self._trackers:
model_tracker = ModelTracker(model, trackers=self._trackers)
model_tracker.init_tracking(x_orig=samples, y=targets)
model = model_tracker
device = model._get_device()
samples = samples.to(device)
targets = targets.to(device)
if self.epsilon < float(torch.inf):
self.kwargs.update({"ε": self.epsilon})
try:
advx = self.advlib_attack(
model=model,
inputs=samples,
labels=targets,
targeted=(self.y_target is not None),
**self.kwargs,
)
finally:
if model_tracker is not None:
model_tracker.end_tracking()
model_tracker.detach()
delta = advx - samples
return advx, delta