Source code for secmlt.adv.evasion.modular_attacks.eot_gradient

"""Modular attack component with Expectation over Transformation (EoT) gradient."""

from typing import TYPE_CHECKING

import torch

if TYPE_CHECKING:
    from secmlt.models.base_model import BaseModel


[docs] class EoTGradientMixin: """Modular attack component with Expectation over Transformation (EoT) gradient. Add as a mixin to any modular attack to enable EoT gradient computation. """ def __init__( self, eot_samples: int = 10, eot_radius: float = 0.03, *args, **kwargs ) -> None: """Add EoT gradient computation to modular attack.""" super().__init__(*args, **kwargs) self.eot_samples = eot_samples self.eot_radius = eot_radius def _loss_and_grad( self, model: "BaseModel", samples: torch.Tensor, delta: torch.Tensor, target: torch.Tensor, multiplier: int, ) -> tuple[torch.Tensor, torch.Tensor]: """ Compute averaged finite-difference style gradient into delta.grad. Parameters ---------- model : BaseModel The model to attack. samples : torch.Tensor Original clean samples. delta : torch.Tensor Current perturbation. target : torch.Tensor Target labels. multiplier : int Multiplier for loss (1 for untargeted, -1 for targeted). Returns ------- scores : torch.Tensor Model scores for the adversarial examples. losses : torch.Tensor Loss values for the adversarial examples. """ # basic params b = samples.size(0) # batch size sigma = self.eot_radius # noise scale k = self.eot_samples # number of neighbors device = samples.device # ensure delta is a leaf and clear old grad delta.requires_grad_() if delta.grad is not None: delta.grad.detach_() delta.grad.zero_() # number of antithetic pairs and whether to include a center sample pairs = k // 2 is_odd = (k % 2) == 1 # prepare accumulator for losses, scores and gradient estimate losses_list = [] scores_list = [] grad_est_sum = torch.zeros_like( delta, device=device ) # gradient accumulator (w.r.t. x and thus delta) # vectorized pairs if pairs > 0 if pairs > 0: # draw random directions for pairs noise_pairs = torch.randn((b, pairs, *delta.shape[1:]), device=device) # build pos and neg deltas of shape [b*pairs, ...] delta_pos = (delta.unsqueeze(1) + sigma * noise_pairs).reshape( b * pairs, *delta.shape[1:] ) delta_neg = (delta.unsqueeze(1) - sigma * noise_pairs).reshape( b * pairs, *delta.shape[1:] ) # expand samples and targets to match pairs samples_rep = ( samples.unsqueeze(1) .expand(-1, pairs, *samples.shape[1:]) .reshape(b * pairs, *samples.shape[1:]) ) target_rep = target.repeat_interleave(pairs, dim=0) # pass through manipulation function so x_pos/x_neg follow same constraints x_pos, _ = self.manipulation_function(samples_rep, delta_pos) x_neg, _ = self.manipulation_function(samples_rep, delta_neg) # forward for pos and neg pos_scores, pos_losses = self.forward_loss( model=model, x=x_pos, target=target_rep ) neg_scores, neg_losses = self.forward_loss( model=model, x=x_neg, target=target_rep ) # reshape back to [b, pairs] and [b, pairs, c] pos_losses = pos_losses.view(b, pairs) neg_losses = neg_losses.view(b, pairs) pos_scores = pos_scores.view(b, pairs, -1) neg_scores = neg_scores.view(b, pairs, -1) # collect pos/neg for averaging reporting losses_list.append(pos_losses) losses_list.append(neg_losses) scores_list.append(pos_scores) scores_list.append(neg_scores) # finite-difference antithetic contribution: (pos_loss - neg_loss) * noise # shape diffs [b, pairs] -> expand to match noise dims for broadcasting diffs = pos_losses - neg_losses # [b, pairs] expand_dims = [1] * ( noise_pairs.dim() - 2 ) # e.g. for images will be [1,1] etc diffs_exp = diffs.view(b, pairs, *expand_dims) # [b, pairs, 1, ...] contribs = diffs_exp * noise_pairs # [b, pairs, ...] grad_pairs = contribs.mean( dim=1 ) # average across pairs -> [b, ...] (or sum if you prefer) grad_est_sum += grad_pairs # accumulate into gradient estimator # center sample if k is odd if is_odd: # build the center adversarial example via manipulation x_center, _ = self.manipulation_function(samples, delta) center_scores, center_losses = self.forward_loss( model=model, x=x_center, target=target ) # add center to lists for averaging; center does not contribute to grad losses_list.append(center_losses.view(b, 1)) scores_list.append(center_scores.view(b, 1, -1)) # combine all losses and scores and compute averages for reporting losses_all = ( torch.cat(losses_list, dim=1) if len(losses_list) > 0 else torch.zeros((b, 0), device=device) ) scores_all = ( torch.cat(scores_list, dim=1) if len(scores_list) > 0 else torch.zeros((b, 0, 0), device=device) ) # multiply by multiplier and compute per-sample average loss avg_losses = ( (losses_all * multiplier).mean(dim=1) if losses_all.numel() > 0 else torch.zeros((b,), device=device) ) avg_scores = ( scores_all.mean(dim=1) if scores_all.numel() > 0 else torch.zeros((b, 0), device=device) ) # place gradient into delta.grad delta.grad = ((grad_est_sum / k) * multiplier).detach() return avg_scores.detach(), avg_losses.detach()