Source code for secmlt.models.hugging_face.base_hf_lm

"""Wrapper for Hugging Face causal language models."""

import torch
from secmlt.models.base_language_model import BaseLanguageModel
from transformers import AutoModelForCausalLM, AutoTokenizer


[docs] class HFCausalLM(BaseLanguageModel): """Wrapper for Hugging Face causal language models.""" def __init__( self, model_path: str, device: torch.device | None = None, dtype: torch.dtype | None = None, tokenizer_kwargs: dict | None = None, model_kwargs: dict | None = None, ) -> None: """ Create a wrapped Hugging Face causal language model. Parameters ---------- model_path : str Model name or local path. device : torch.device, optional Device where the model is loaded. Defaults to GPU if available. dtype : torch.dtype, optional Model precision. Defaults to model default dtype. tokenizer_kwargs : dict, optional Extra arguments for AutoTokenizer.from_pretrained(). model_kwargs : dict, optional Extra arguments for AutoModelForCausalLM.from_pretrained(). """ self._model_path = model_path self._device = device or ( torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") ) self._dtype = dtype self._tokenizer = None self._model = None self._tokenizer_kwargs = tokenizer_kwargs or {} self._model_kwargs = model_kwargs or {} self.load()
[docs] def load(self) -> None: """Load model and tokenizer if not already loaded.""" if self._model is not None and self._tokenizer is not None: return self._tokenizer = AutoTokenizer.from_pretrained( self._model_path, **self._tokenizer_kwargs ) self._model = AutoModelForCausalLM.from_pretrained( self._model_path, **self._model_kwargs ).eval() if ( self._tokenizer.pad_token_id is None and self._tokenizer.eos_token_id is not None ): self._tokenizer.pad_token = self._tokenizer.eos_token self._dtype = self._dtype or getattr(self._model, "dtype", torch.float32) self._model.to(self._device, dtype=self._dtype)
@property def model(self) -> AutoModelForCausalLM: """ Get the wrapped Hugging Face model. Returns ------- AutoModelForCausalLM Wrapped Hugging Face model. """ return self._model @property def tokenizer(self) -> AutoTokenizer: """ Get the wrapped Hugging Face tokenizer. Returns ------- AutoTokenizer Wrapped Hugging Face tokenizer. """ return self._tokenizer
[docs] @torch.no_grad() def encode(self, texts: list[str], **kwargs) -> torch.LongTensor: """ Tokenize a batch of text prompts. Parameters ---------- texts : list of str Batch of input prompts. Returns ------- torch.LongTensor Tensor of token IDs. """ enc = self._tokenizer( texts, return_tensors="pt", padding=True, truncation=kwargs.pop("truncation", True), **kwargs, ) return enc["input_ids"].to(self._device)
[docs] @torch.no_grad() def decode(self, ids: torch.LongTensor, **kwargs) -> list[str]: """ Decode a batch of token IDs into text. Parameters ---------- ids : torch.LongTensor Tensor of token IDs. Returns ------- list of str Decoded text. """ return self._tokenizer.batch_decode(ids, skip_special_tokens=True, **kwargs)
[docs] @torch.no_grad() def predict(self, input_ids: torch.LongTensor, **kwargs) -> torch.Tensor: """ Compute next-token logits for each sequence in the batch. Parameters ---------- input_ids : torch.LongTensor Tensor of token IDs. Returns ------- torch.Tensor Logits for the next token. """ if input_ids.device != self._device: input_ids = input_ids.to(self._device) pad_id = self._tokenizer.pad_token_id if pad_id is not None: attention_mask = (input_ids != pad_id).long() else: attention_mask = torch.ones_like(input_ids, dtype=torch.long) out = self._model(input_ids=input_ids, attention_mask=attention_mask, **kwargs) return out.logits[:, -1, :]
[docs] @torch.no_grad() def generate(self, prompts: list[list[dict]], **kwargs) -> list[str]: """ Generate text completions from chat-style prompts. Parameters ---------- prompts : list of list of dict Batch of chat messages, each formatted as a list of {"role": str, "content": str}. **kwargs Additional parameters for model.generate(). Returns ------- list of str Generated text completions. Raises ------ ValueError If tokenizer does not define `apply_chat_template`. """ if not hasattr(self._tokenizer, "apply_chat_template"): msg = ( "Tokenizer does not define `apply_chat_template`. " "Provide a chat template when loading the model." ) raise ValueError(msg) rendered_prompts = [ self._tokenizer.apply_chat_template( p, add_generation_prompt=True, tokenize=False ) for p in prompts ] enc = self._tokenizer( rendered_prompts, return_tensors="pt", padding=True, truncation=True ) input_ids = enc["input_ids"].to(self._device) attention_mask = enc.get("attention_mask") if attention_mask is None: pad_id = self._tokenizer.pad_token_id if pad_id is not None: attention_mask = (input_ids != pad_id).long() else: attention_mask = torch.ones_like(input_ids, dtype=torch.long) else: attention_mask = attention_mask.to(self._device) gen_ids = self._model.generate( input_ids=input_ids, attention_mask=attention_mask, **kwargs ) return self._tokenizer.batch_decode(gen_ids, skip_special_tokens=True)
[docs] @torch.no_grad() def logprobs( self, prompts: list[str], targets: list[str], **kwargs ) -> list[torch.Tensor]: """ Compute log-probabilities for each token in the target continuation. Parameters ---------- prompts : list of str Conditioning prompts. targets : list of str Target continuations. Returns ------- list of torch.Tensor List of log-probabilities for each target token. Each tensor has shape [target_len_i]. """ assert len(prompts) == len(targets), ( "Prompts and targets must have the same length." ) enc_p = self._tokenizer( prompts, return_tensors="pt", padding=True, truncation=True ) enc_t = self._tokenizer( targets, return_tensors="pt", padding=True, truncation=True, add_special_tokens=False, ) input_ids = torch.cat([enc_p["input_ids"], enc_t["input_ids"]], dim=1).to( self._device ) attention_mask = torch.cat( [enc_p["attention_mask"], torch.ones_like(enc_t["input_ids"])], dim=1 ).to(self._device) out = self._model(input_ids=input_ids, attention_mask=attention_mask, **kwargs) logp = out.logits.log_softmax(dim=-1) # [B, T, V] tgt_ids = enc_t["input_ids"].to(self._device) pad_id = self._tokenizer.pad_token_id tgt_lens = ( (tgt_ids != pad_id).sum(dim=1) if pad_id is not None else tgt_ids.ne(-1).sum(dim=1) ) prm_lens = enc_p["attention_mask"].sum(dim=1).to(self._device) batch_logps = [] for b in range(input_ids.size(0)): p_len = int(prm_lens[b]) t_len = int(tgt_lens[b]) if t_len == 0: batch_logps.append(torch.empty(0, device=self._device)) continue idx_pos = torch.arange(t_len, device=self._device) + (p_len - 1) idx_tok = tgt_ids[b, :t_len] lp = logp[b, idx_pos, :].gather(1, idx_tok.unsqueeze(1)).squeeze(1) batch_logps.append(lp) return batch_logps
[docs] @torch.no_grad() def hidden_states( self, input_ids: torch.LongTensor, **kwargs ) -> list[torch.Tensor]: """ Return hidden states of the model. Parameters ---------- input_ids : torch.LongTensor Tensor of token IDs. Returns ------- list of torch.Tensor Hidden representations per layer. """ if input_ids.device != self._device: input_ids = input_ids.to(self._device) pad_id = self._tokenizer.pad_token_id if pad_id is not None: attention_mask = (input_ids != pad_id).long() else: attention_mask = torch.ones_like(input_ids, dtype=torch.long) out = self._model( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, **kwargs, ) return list(out.hidden_states)