Source code for secmlt.adv.evasion.advlib_attacks.advlib_fmn
"""Wrapper of the FMN attack implemented in Adversarial Library."""
from __future__ import annotations # noqa: I001
from functools import partial
from adv_lib.attacks import fmn
from secmlt.adv.evasion.advlib_attacks.advlib_base import BaseAdvLibEvasionAttack
from secmlt.adv.evasion.perturbation_models import LpPerturbationModels
[docs]
class FMNAdvLib(BaseAdvLibEvasionAttack):
"""Wrapper of the Adversarial Library implementation of the FMN attack."""
def __init__(
self,
perturbation_model: str,
num_steps: int,
max_step_size: float,
min_step_size: float | None = None,
gamma: float | None = 0.05,
y_target: int | None = None,
lb: float = 0.0,
ub: float = 1.0,
**kwargs,
) -> None:
"""
Initialize a FMN attack with the Adversarial Library backend.
Parameters
----------
perturbation_model : str
The perturbation model to be used for the attack.
num_steps : int
The number of iterations for the attack.
max_step_size : float
The attack maximum step size.
min_step_size : float, optional
The attack minimum step size. If None, it is set to max_step_size/100.
The default value is None.
gamma: float, optional
Step size for modifying the eps-ball. Will decay with cosine annealing.
y_target : int | None, optional
The target label for the attack. If None, the attack is
untargeted. The default value is None.
lb : float, optional
The lower bound for the perturbation. The default value is 0.0.
ub : float, optional
The upper bound for the perturbation. The default value is 1.0.
"""
perturbation_models = {
LpPerturbationModels.L0: partial(fmn, norm=0),
LpPerturbationModels.L1: partial(fmn, norm=1),
LpPerturbationModels.L2: partial(fmn, norm=2),
LpPerturbationModels.LINF: partial(fmn, norm=float("inf")),
}
advlib_attack_func = perturbation_models.get(perturbation_model)
advlib_attack = partial(
advlib_attack_func,
steps=num_steps,
α_init=max_step_size,
α_final=min_step_size,
γ_init=gamma,
)
super().__init__(
advlib_attack=advlib_attack, y_target=y_target, lb=lb, ub=ub, **kwargs
)
[docs]
@staticmethod
def get_perturbation_models() -> set[str]:
"""
Check the perturbation models implemented for this attack.
Returns
-------
set[str]
The list of perturbation models implemented for this attack.
"""
return {
LpPerturbationModels.L0,
LpPerturbationModels.L1,
LpPerturbationModels.L2,
LpPerturbationModels.LINF,
}