Source code for secmlt.models.base_language_model
"""Basic wrapper for generic language model."""
from abc import ABC, abstractmethod
from typing import Union
import torch
[docs]
class BaseLanguageModel(ABC):
"""Abstract base class defining the common interface for language models."""
[docs]
@abstractmethod
def encode(self, text: Union[str, list[str]], **kwargs) -> torch.LongTensor:
"""
Convert input text into token IDs.
Parameters
----------
text : str or list of str
Input text(s) to tokenize.
Returns
-------
torch.LongTensor
Token IDs tensor.
"""
...
[docs]
@abstractmethod
def decode(self, ids: torch.LongTensor, **kwargs) -> Union[str, list[str]]:
"""
Convert token IDs back into text.
Parameters
----------
ids : torch.LongTensor
Token IDs tensor.
Returns
-------
str or list of str
Decoded text(s).
"""
...
[docs]
@abstractmethod
def predict(self, input_ids: torch.LongTensor, **kwargs) -> torch.Tensor:
"""
Predict logits for the next token given an input sequence.
Parameters
----------
input_ids : torch.LongTensor
Input token IDs.
Returns
-------
torch.Tensor
Logits for the next token (shape [batch, vocab_size]).
"""
...
[docs]
@abstractmethod
def generate(self, prompts: list[str], **kwargs) -> list[str]:
"""
Generate text continuations from given prompts.
Parameters
----------
prompts : list of str
List of prompt strings.
Returns
-------
list of str
Generated text outputs.
"""
...
[docs]
@abstractmethod
def logprobs(
self, prompts: list[str], targets: list[str], **kwargs
) -> list[torch.Tensor]:
"""
Compute log-probabilities for each token in the target continuations.
Parameters
----------
prompts : list of str
Conditioning prompts.
targets : list of str
Target continuations.
Returns
-------
list of torch.Tensor
List of log-probability tensors, one per sample in the batch.
Each tensor has shape [target_len_i].
"""
...
[docs]
@abstractmethod
def hidden_states(
self, input_ids: torch.LongTensor, **kwargs
) -> list[torch.Tensor]:
"""
Return hidden states of the model for given input.
Parameters
----------
input_ids : torch.LongTensor
Input token IDs.
Returns
-------
list of torch.Tensor
Hidden representations per layer.
"""
...
def __call__(self, input_ids: torch.LongTensor, **kwargs) -> torch.Tensor:
"""
Shortcut for self.predict().
Parameters
----------
input_ids : torch.LongTensor
Input token IDs.
Returns
-------
torch.Tensor
Logits for the next token.
"""
return self.predict(input_ids, **kwargs)