Source code for secmlt.tests.test_trainer

import torch
from secmlt.models.pytorch.base_pytorch_trainer import BasePyTorchTrainer
from secmlt.tests.mocks import MockLoss
from torch.optim import SGD


[docs] def test_pytorch_trainer(model, data_loader) -> None: pytorch_model = model._model optimizer = SGD(pytorch_model.parameters(), lr=0.01) criterion = MockLoss() # Create the trainer instance trainer = BasePyTorchTrainer(optimizer=optimizer, loss=criterion) # Train the model trained_model = trainer.train(pytorch_model, data_loader) assert isinstance(trained_model, torch.nn.Module)