secmlt.models.pytorch package#

Submodules#

secmlt.models.pytorch.base_pytorch_nn module#

Wrappers for PyTorch models.

class secmlt.models.pytorch.base_pytorch_nn.BasePyTorchClassifier(model: Module, preprocessing: DataProcessing | None = None, postprocessing: DataProcessing | None = None, trainer: BasePyTorchTrainer | None = None)[source]#

Bases: BaseModel

Wrapper for PyTorch classifier.

property model: Module#

Get the wrapped instance of PyTorch model.

Returns:

Wrapped PyTorch model.

Return type:

torch.nn.Module

predict(x: Tensor) Tensor[source]#

Return the predicted class for the given samples.

Parameters:

x (torch.Tensor) – Input samples.

Returns:

Predicted class for the samples.

Return type:

torch.Tensor

train(dataloader: DataLoader) Module[source]#

Train the model with given dataloader, if trainer is set.

Parameters:

dataloader (DataLoader) – Training PyTorch dataloader to use for training.

Returns:

Trained PyTorch model.

Return type:

torch.nn.Module

Raises:

ValueError – Raises ValueError if the trainer is not set.

secmlt.models.pytorch.base_pytorch_trainer module#

PyTorch model trainers.

class secmlt.models.pytorch.base_pytorch_trainer.BasePyTorchTrainer(optimizer: Optimizer, epochs: int = 5, loss: Module | None = None, scheduler: _LRScheduler | None = None)[source]#

Bases: BaseTrainer

Trainer for PyTorch models.

train(model: Module, dataloader: DataLoader) Module[source]#

Train model with given loader.

Parameters:
  • model (torch.nn.Module) – Pytorch model to be trained.

  • dataloader (DataLoader) – Train data loader.

Returns:

Trained model.

Return type:

torch.nn.Module

Module contents#

PyTorch model wrappers.