secmlt.models.pytorch package#
Submodules#
secmlt.models.pytorch.base_pytorch_nn module#
Wrappers for PyTorch models.
- class secmlt.models.pytorch.base_pytorch_nn.BasePyTorchClassifier(model: Module, preprocessing: DataProcessing | None = None, postprocessing: DataProcessing | None = None, trainer: BasePyTorchTrainer | None = None)[source]#
Bases:
BaseModelWrapper for PyTorch classifier.
- property model: Module#
Get the wrapped instance of PyTorch model.
- Returns:
Wrapped PyTorch model.
- Return type:
torch.nn.Module
- predict(x: Tensor) Tensor[source]#
Return the predicted class for the given samples.
- Parameters:
x (torch.Tensor) – Input samples.
- Returns:
Predicted class for the samples.
- Return type:
torch.Tensor
- train(dataloader: DataLoader) Module[source]#
Train the model with given dataloader, if trainer is set.
- Parameters:
dataloader (DataLoader) – Training PyTorch dataloader to use for training.
- Returns:
Trained PyTorch model.
- Return type:
torch.nn.Module
- Raises:
ValueError – Raises ValueError if the trainer is not set.
secmlt.models.pytorch.base_pytorch_trainer module#
PyTorch model trainers.
- class secmlt.models.pytorch.base_pytorch_trainer.BasePyTorchTrainer(optimizer: Optimizer, epochs: int = 5, loss: Module | None = None, scheduler: _LRScheduler | None = None)[source]#
Bases:
BaseTrainerTrainer for PyTorch models.
Module contents#
PyTorch model wrappers.