Source code for secmlt.tests.mocks_lm
"""Mock classes for testing language models."""
from types import SimpleNamespace
from typing import Any
import torch
[docs]
class MockHFTokenizer:
"""Fake tokenizer replicating minimal HF interface."""
def __init__(self, with_chat_template: bool = True) -> None:
"""Initialize the mock tokenizer."""
self.pad_token_id = 0
self.eos_token_id = 1
self.pad_token = "<pad>" # noqa: S105
self.eos_token = "</s>" # noqa: S105
if with_chat_template:
self.apply_chat_template = self._apply_chat_template
def __call__(
self,
texts: list[str],
return_tensors: str = "pt",
**kwargs: Any, # noqa: ANN401
) -> dict[str, torch.Tensor]:
"""Simulate tokenization returning fake input IDs and attention mask."""
batch = len(texts)
seq_len = max(len(t) for t in texts)
input_ids = torch.randint(2, 50, (batch, seq_len))
attn = torch.ones_like(input_ids)
return {"input_ids": input_ids, "attention_mask": attn}
[docs]
def batch_decode(
self,
ids: torch.Tensor,
skip_special_tokens: bool = True,
**kwargs: Any, # noqa: ANN401
) -> list[str]:
"""Simulate decoding of token IDs into dummy text."""
return ["decoded text" for _ in range(ids.size(0))]
def _apply_chat_template(
self,
messages: list[dict[str, str]],
add_generation_prompt: bool = True,
tokenize: bool = False,
) -> str:
return " ".join(m["content"] for m in messages)
[docs]
class MockHFModel(torch.nn.Module):
"""Fake causal LM returning random logits and hidden states."""
def __init__(self) -> None:
"""Initialize the mock model."""
super().__init__()
self.dtype = torch.float32
[docs]
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
**kwargs: Any, # noqa: ANN401
) -> SimpleNamespace:
"""Simulate a forward pass returning random logits and hidden states."""
b, t = input_ids.shape
device = input_ids.device
logits = torch.randn(b, t, 100, device=device)
hidden_states = [torch.randn(b, t, 16, device=device) for _ in range(3)]
return SimpleNamespace(logits=logits, hidden_states=hidden_states)
[docs]
def generate(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
**kwargs: Any, # noqa: ANN401
) -> torch.Tensor:
"""Simulate text generation returning random token IDs."""
b, t = input_ids.shape
return torch.randint(2, 50, (b, t + 5))