import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoTokenizer, Gemma2ForTokenClassification, BitsAndBytesConfig import os os.environ["TOKENIZERS_PARALLELISM"] = "false" torch.set_float32_matmul_precision("high") def repeat_function(xs, max_length = 128): new_xs = [] for x in xs: if x.shape[1] >= max_length-1: new_xs.append(x[:,:max_length-1,:]) else: new_xs.append(x) xs = new_xs mean_xs = [x.mean(1,keepdim=True).expand(-1,max_length - x.shape[1],-1) for x in xs] xs = [torch.cat([x,mean_x],1) for mean_x, x in zip(mean_xs, xs)] return xs class Gemma2Model(nn.Module): def __init__(self): super().__init__() self.tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b", ) self.tokenizer_max_length = 128 # quantization_config = BitsAndBytesConfig(load_in_8bit=True) self.model = Gemma2ForTokenClassification.from_pretrained( "google/gemma-2-2b", # device_map="auto", # quantization_config=quantization_config, ).float() self.model.score = nn.Identity() @torch.no_grad() def forward(self, input_prompt): input_prompt = list(input_prompt) outputs = [] for _input_prompt in input_prompt: input_ids = self.tokenizer(_input_prompt, add_special_tokens=False, max_length=77, return_tensors="pt").to("cuda") _outputs = self.model(**input_ids)["logits"] outputs.append(_outputs) outputs = repeat_function(outputs) outputs = torch.cat(outputs,0) return outputs if __name__ == "__main__": model = Gemma2Model().cuda() input_text = ["Write me a poem about Machine Learning.", "Write me a poem about Deep Learning."] print(model(input_text)) print(model(input_text)[0].shape) print(model(input_text).shape)