Source code for secmlt.adv.evasion.foolbox_attacks.foolbox_base
"""Generic wrapper for Foolbox evasion attacks."""
from __future__ import annotations # noqa: I001
from typing import Literal, TYPE_CHECKING
import torch
from foolbox.criteria import Misclassification, TargetedMisclassification
from foolbox.models.pytorch import PyTorchModel
from secmlt.adv.evasion.base_evasion_attack import BaseEvasionAttack
from secmlt.models.pytorch.base_pytorch_nn import BasePyTorchClassifier
from secmlt.trackers.model_tracker import ModelTracker
if TYPE_CHECKING:
from foolbox.attacks.base import Attack
from secmlt.models.base_model import BaseModel
from secmlt.trackers.trackers import Tracker
[docs]
class BaseFoolboxEvasionAttack(BaseEvasionAttack):
"""Generic wrapper for Foolbox Evasion attacks."""
def __init__(
self,
foolbox_attack: type[Attack],
epsilon: float = torch.inf,
y_target: int | None = None,
lb: float = 0.0,
ub: float = 1.0,
trackers: Tracker | list[Tracker] | None = None,
) -> None:
"""
Wrap Foolbox attacks.
Parameters
----------
foolbox_attack : Type[Attack]
Foolbox attack class to wrap.
epsilon : float, optional
Perturbation constraint, by default torch.inf.
y_target : int | None, optional
Target label for the attack, None if untargeted, by default None.
lb : float, optional
Lower bound of the input space, by default 0.0.
ub : float, optional
Upper bound of the input space, by default 1.0.
trackers : Tracker | list[Tracker] | None, optional
Trackers for the attack (unallowed in Foolbox), by default None.
"""
self.foolbox_attack = foolbox_attack
self.lb = lb
self.ub = ub
self.epsilon = epsilon
self.y_target = y_target
self.trackers = trackers
super().__init__()
@classmethod
def _trackers_allowed(cls) -> Literal[True]:
return True
def _run(
self,
model: BaseModel,
samples: torch.Tensor,
labels: torch.Tensor,
) -> torch.Tensor:
if not isinstance(model, BasePyTorchClassifier):
msg = "Model type not supported."
raise NotImplementedError(msg)
target = None
if self.y_target is not None:
target = (torch.zeros_like(labels) + self.y_target).type(labels.dtype)
tracking_labels = labels if target is None else target
# Wrap model with ModelTracker if trackers are set
model_tracker = None
if self._trackers:
model_tracker = ModelTracker(model, trackers=self._trackers)
model_tracker.init_tracking(x_orig=samples, y=tracking_labels)
model = model_tracker
device = model._get_device()
samples = samples.to(device)
labels = labels.to(device)
foolbox_model = PyTorchModel(model.model, (self.lb, self.ub), device=device)
if self.y_target is None:
criterion = Misclassification(labels)
else:
target = target.to(device)
criterion = TargetedMisclassification(target)
try:
_, advx, _ = self.foolbox_attack(
model=foolbox_model,
inputs=samples,
criterion=criterion,
epsilons=self.epsilon,
)
finally:
if model_tracker is not None:
model_tracker.end_tracking()
model_tracker.detach()
# foolbox deals only with additive perturbations
delta = advx - samples
return advx, delta