TwT-6's picture
Upload 2667 files
256a159 verified
raw
history blame contribute delete
4.94 kB
from typing import Dict, List, Optional
import numpy as np
import torch
from opencompass.models.base import BaseModel, LMTemplateParser
class InternLM(BaseModel):
def __init__(self,
path: str,
max_seq_len: int = 2048,
tokenizer_only: bool = False,
tokenizer_path: Optional[str] = None,
model_config: Optional[str] = None,
tokenizer_type: Optional[str] = 'v7',
meta_template: Optional[Dict] = None):
if tokenizer_only:
self._load_tokenizer(tokenizer_path=tokenizer_path,
tokenizer_type=tokenizer_type,
max_seq_len=max_seq_len)
else:
self._load_model(path=path,
max_seq_len=max_seq_len,
tokenizer_path=tokenizer_path,
tokenizer_type=tokenizer_type,
model_config=model_config)
self.template_parser = LMTemplateParser(meta_template)
self.eos_token_id = None
if meta_template and 'eos_token_id' in meta_template:
self.eos_token_id = meta_template['eos_token_id']
def _load_model(self,
path: str,
max_seq_len: int,
tokenizer_path: Optional[str] = None,
tokenizer_type: Optional[str] = None,
model_config: Optional[str] = None):
from internlm.load.load_model import load_llm
from internlm.model import build_model_with_cfg
self.model, self.tokenizer, self.generator, _ = load_llm(
path,
max_seq_len,
tokenizer_path=tokenizer_path,
tokenizer_type=tokenizer_type,
module=build_model_with_cfg,
model_config_path=model_config)
def _load_tokenizer(self, tokenizer_path: str, tokenizer_type: str,
max_seq_len: int):
from internlm.load.tokenizer import LLMTokenizer
from sentencepiece import SentencePieceProcessor
tokenizer = SentencePieceProcessor()
tokenizer.load(tokenizer_path)
tokenizer = LLMTokenizer(tokenizer,
max_seq_len=max_seq_len,
tokenizer_type=tokenizer_type)
self.tokenizer = tokenizer
def get_token_len(self, prompt: str) -> int:
"""Get lengths of the tokenized strings.
Args:
prompt (str): Input string.
Returns:
int: Length of the input tokens
"""
tokens = self.tokenizer([prompt], truncation=False)['tokens']
return len(tokens[0])
def generate(self, inputs: List[str], max_out_len: int) -> List[str]:
"""Generate results given a list of inputs.
Args:
inputs (List[str]): A list of strings.
max_out_len (int): The maximum length of the output.
Returns:
List[str]: A list of generated strings.
"""
return self.generator.generate(inputs,
generation_kwargs={
'max_gen_len': max_out_len,
'eos_token_id': self.eos_token_id
})
def get_ppl(self,
input_texts: List[str],
mask_length: Optional[List[int]] = None) -> List[float]:
"""Get perplexity scores given a list of inputs.
Args:
input_texts (List[str]): A list of strings.
mask_length (Optional[List[int]]): A list of mask lengths. If
provided, the perplexity scores will be calculated with the
first mask_length[i] tokens masked out.
Returns:
List[float]: A list of perplexity scores.
"""
outputs, inputs = self.generator.get_logits(input_texts)
shift_logits = outputs[..., :-1, :].contiguous().float()
shift_labels = inputs['tokens'][..., 1:].contiguous()
loss_fct = torch.nn.CrossEntropyLoss(
reduction='none', ignore_index=self.tokenizer.pad_token_id)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)).view(shift_labels.size())
if mask_length is not None:
mask = torch.zeros_like(shift_labels) # [batch,seqlen]
for i in range(len(mask)):
for j in range(mask_length[i] - 1, len(mask[i])):
mask[i][j] = 1
loss = loss * mask
lens = (inputs['tokens'] !=
self.tokenizer.pad_token_id).sum(-1).cpu().numpy()
if mask_length is not None:
lens -= np.array(mask_length)
ce_loss = loss.sum(-1).cpu().detach().numpy() / lens
return ce_loss