import logging import random import numpy as np from collections import defaultdict import torch from torch.nn.utils.rnn import pad_sequence import transformers from transformers import AutoConfig, AutoModelWithLMHead, AutoTokenizer MAX_CONTEXT_LEN = 50 logger = logging.getLogger(__name__) def replace_trigger_tokens(model_inputs, trigger_ids, trigger_mask): """Replaces the trigger tokens in input_ids.""" out = model_inputs.copy() input_ids = model_inputs['input_ids'] device = input_ids.device trigger_ids = trigger_ids.repeat(trigger_mask.size(0), 1).to(device) try: filled = input_ids.masked_scatter(trigger_mask, trigger_ids).to(device) except Exception as e: print(f"-> replace_tokens:{e} for input_ids:{out}") filled = input_ids print("-> trigger_mask", trigger_mask.dtype) print("-> trigger_ids", trigger_ids.dtype) print("-> input_ids", input_ids.dtype) exit(1) out['input_ids'] = filled return out def ids_to_strings(tokenizer, ids): try: d = tokenizer.convert_ids_to_tokens(ids) except: pass try: d = tokenizer.convert_ids_to_tokens(ids.squeeze(0)) except: pass return [x.replace("Ġ", "") for x in d] def set_seed(seed: int): """Sets the relevant random seeds.""" random.seed(seed) np.random.seed(seed) torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) def hotflip_attack(averaged_grad, embedding_matrix, increase_loss=False, num_candidates=1, filter=None): """Returns the top candidate replacements.""" with torch.no_grad(): gradient_dot_embedding_matrix = torch.matmul( embedding_matrix, averaged_grad ) if filter is not None: gradient_dot_embedding_matrix -= filter if not increase_loss: gradient_dot_embedding_matrix *= -1 _, top_k_ids = gradient_dot_embedding_matrix.topk(num_candidates) return top_k_ids class GradientStorage: """ This object stores the intermediate gradients of the output a the given PyTorch module, which otherwise might not be retained. """ def __init__(self, module): self._stored_gradient = None module.register_backward_hook(self.hook) def hook(self, module, grad_in, grad_out): self._stored_gradient = grad_out[0] def reset(self): self._stored_gradient = None def get(self): return self._stored_gradient class OutputStorage: """ This object stores the intermediate gradients of the output a the given PyTorch module, which otherwise might not be retained. """ def __init__(self, model, config): self._stored_output = None self.config = config self.model = model self.embeddings = self.get_embeddings() self.embeddings.register_forward_hook(self.hook) def hook(self, module, input, output): self._stored_output = output def get(self): return self._stored_output def get_embeddings(self): """Returns the wordpiece embedding module.""" model_type = self.config.model_type if model_type == "llama": base_model = getattr(self.model, "model") embeddings = base_model.embed_tokens elif model_type == "gpt2": base_model = getattr(self.model, "transformer") embeddings = base_model.wte elif model_type == "opt": base_model = getattr(self.model, "model") decoder = getattr(base_model, "decoder") embeddings = decoder.embed_tokens elif model_type == "xlnet": embeddings = self.model.transformer.word_embedding else: base_model = getattr(self.model, model_type) embeddings = base_model.embeddings.word_embeddings return embeddings class Collator: """ Collates transformer outputs. """ def __init__(self, tokenizer=None, pad_token_id=0): self._tokenizer = tokenizer self._pad_token_id = pad_token_id self._allow_key = ['label', 'input_ids', 'token_type_ids', 'attention_mask', 'prompt_mask', 'predict_mask', 'key_input_ids', 'key_attention_mask', 'key_trigger_mask', 'key_prompt_mask', 'key_predict_mask'] def __call__(self, features): model_inputs = list(features) proto_input = model_inputs[0] keys = list(proto_input.keys()) padded_inputs = {} for key in keys: if not key in self._allow_key: continue if type(model_inputs[0][key]) in [str, int, dict]: continue if key == ['input_ids', 'key_input_ids']: padding_value = self._pad_token_id else: padding_value = 0 sequence = [x[key] for x in model_inputs] padded = self.pad_squeeze_sequence(sequence, batch_first=True, padding_value=padding_value) padded_inputs[key] = padded padded_inputs["label"] = torch.tensor([x["label"] for x in model_inputs]).long() if "idx" in keys: padded_inputs["idx"] = torch.tensor([x["idx"] for x in model_inputs], dtype=torch.long) if self._tokenizer is not None: padded_inputs["labels"] = torch.stack([self._tokenizer.label_ids[x["label"]] for x in model_inputs]) padded_inputs["key_labels"] = torch.stack([self._tokenizer.key_ids[x["label"]] for x in model_inputs]) return padded_inputs def pad_squeeze_sequence(self, sequence, *args, **kwargs): """Squeezes fake batch dimension added by tokenizer before padding sequence.""" return pad_sequence([torch.tensor(x).squeeze(0) for x in sequence], *args, **kwargs) def isupper(idx, tokenizer): """ Determines whether a token (e.g., word piece) begins with a capital letter. """ _isupper = False # We only want to check tokens that begin words. Since byte-pair encoding # captures a prefix space, we need to check that the decoded token begins # with a space, and has a capitalized second character. if isinstance(tokenizer, transformers.GPT2Tokenizer): decoded = tokenizer.decode([idx]) if decoded[0] == ' ' and decoded[1].isupper(): _isupper = True # For all other tokenization schemes, we can just check the first character # is capitalized. elif tokenizer.decode([idx])[0].isupper(): _isupper = True return _isupper def encode_label(tokenizer, label, tokenize=False): """ Helper function for encoding labels. Deals with the subtleties of handling multiple tokens. """ if isinstance(label, str): if tokenize: # Ensure label is properly tokenized, and only retain first token # if it gets split into multiple tokens. TODO: Make sure this is # desired behavior. tokens = tokenizer.tokenize(label) if len(tokens) > 1: raise ValueError(f'Label "{label}" gets mapped to multiple tokens.') if tokens[0] == tokenizer.unk_token: raise ValueError(f'Label "{label}" gets mapped to unk.') label = tokens[0] encoded = torch.tensor(tokenizer.convert_tokens_to_ids([label])).unsqueeze(0) elif isinstance(label, list): encoded = torch.tensor(tokenizer.convert_tokens_to_ids(label)).unsqueeze(0) elif isinstance(label, int): encoded = torch.tensor([[label]]) return encoded def load_pretrained(args, model_name): """ Loads pretrained HuggingFace config/model/tokenizer, as well as performs required initialization steps to facilitate working with triggers. """ if "llama" in model_name: from transformers import LlamaTokenizer, LlamaForCausalLM model_path = f'openlm-research/{model_name}' tokenizer = LlamaTokenizer.from_pretrained(model_path) model = LlamaForCausalLM.from_pretrained(model_path, torch_dtype=torch.float32) tokenizer = add_task_specific_tokens(tokenizer) config = model.config elif "glm" in model_name: from transformers import AutoModelForSeq2SeqLM model_path = f'THUDM/{model_name}' config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) model = AutoModelForSeq2SeqLM.from_pretrained(model_path, trust_remote_code=True) model = model.half() model.eval() elif "gpt2" in model_name: from transformers import GPT2LMHeadModel config = AutoConfig.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name, add_prefix_space=True) model = GPT2LMHeadModel.from_pretrained(model_name) model.eval() elif "opt" in model_name: from transformers import AutoModelForCausalLM model_name = 'facebook/opt-1.3b' config = AutoConfig.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name, add_prefix_space=True) model = AutoModelForCausalLM.from_pretrained(model_name)#, load_in_8bit=True) model.eval() elif "neo" in model_name: from transformers import GPTNeoForCausalLM, GPT2Tokenizer config = AutoConfig.from_pretrained(model_name) tokenizer = GPT2Tokenizer.from_pretrained(model_name) model = GPTNeoForCausalLM.from_pretrained(model_name) model.eval() else: config = AutoConfig.from_pretrained(model_name) model = AutoModelWithLMHead.from_pretrained(model_name) model.eval() tokenizer = AutoTokenizer.from_pretrained(model_name, add_prefix_space=True) tokenizer = add_task_specific_tokens(tokenizer) # only for GPT2 if ('gpt' in tokenizer.name_or_path) or ('opt' in tokenizer.name_or_path): tokenizer.mask_token = tokenizer.unk_token config.mask_token = tokenizer.unk_token config.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token) config.mask_token_id = tokenizer.convert_tokens_to_ids(tokenizer.mask_token) elif "llama" in tokenizer.name_or_path: tokenizer.mask_token = tokenizer.unk_token tokenizer.mask_token_id = tokenizer.unk_token_id config.mask_token = tokenizer.unk_token config.mask_token_id = tokenizer.unk_token_id tokenizer.key_template = args.template tokenizer.prompt_template = args.template.replace("[K] ", "") tokenizer.label_ids = args.label2ids tokenizer.key_ids = args.key2ids if args.key2ids is not None else args.label2ids tokenizer.num_key_tokens = sum(token == '[K]' for token in tokenizer.key_template.split()) tokenizer.num_prompt_tokens = sum(token == '[T]' for token in tokenizer.prompt_template.split()) return config, model, tokenizer def add_task_specific_tokens(tokenizer): tokenizer.add_special_tokens({ 'additional_special_tokens': ['[K]', '[T]', '[P]', '[Y]'] }) tokenizer.key_token = '[K]' tokenizer.key_token_id = tokenizer.convert_tokens_to_ids('[K]') tokenizer.prompt_token = '[T]' tokenizer.prompt_token_id = tokenizer.convert_tokens_to_ids('[T]') tokenizer.predict_token = '[P]' tokenizer.predict_token_id = tokenizer.convert_tokens_to_ids('[P]') # NOTE: BERT and RoBERTa tokenizers work properly if [X] is not a special token... # tokenizer.lama_x = '[X]' # tokenizer.lama_x_id = tokenizer.convert_tokens_to_ids('[X]') # tokenizer.lama_y = '[Y]' # tokenizer.lama_x_id = tokenizer.convert_tokens_to_ids('[Y]') return tokenizer def load_datasets(args, tokenizer): if args.task == "super_glue": from .tasks.superglue.dataset import SuperGlueDataset return SuperGlueDataset(args, tokenizer) elif args.task == "glue": from .tasks.glue.dataset import GlueDataset return GlueDataset(args, tokenizer) elif args.task == "financial": from .tasks.financial.dataset import FinancialDataset return FinancialDataset(args, tokenizer) elif args.task == "twitter": from .tasks.twitter.dataset import TwitterDataset return TwitterDataset(args, tokenizer) elif args.task == "imdb": from .tasks.imdb.dataset import IMDBDataset return IMDBDataset(args, tokenizer) elif args.task == "ag_news": from .tasks.ag_news.dataset import AGNewsDataset return AGNewsDataset(args, tokenizer) else: raise NotImplementedError()