|
|
|
|
|
from transformers import ( |
|
AutoTokenizer, |
|
AutoModel, |
|
AutoModelForCausalLM, |
|
) |
|
import torch |
|
|
|
|
|
MODEL_OPTIONS = { |
|
"BERT (bert-base-uncased)": "bert-base-uncased", |
|
"DistilBERT": "distilbert-base-uncased", |
|
"RoBERTa": "roberta-base", |
|
"GPT-2": "gpt2", |
|
"Electra": "google/electra-small-discriminator", |
|
"ALBERT": "albert-base-v2", |
|
"XLNet": "xlnet-base-cased", |
|
} |
|
|
|
|
|
def load_model(model_name): |
|
if "gpt2" in model_name or "causal" in model_name: |
|
model = AutoModelForCausalLM.from_pretrained(model_name, output_attentions=True) |
|
else: |
|
model = AutoModel.from_pretrained(model_name, output_attentions=True) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
return tokenizer, model |
|
|
|
|
|
def get_model_info(model): |
|
config = model.config |
|
return { |
|
"Model Type": config.model_type, |
|
"Number of Layers": getattr(config, "num_hidden_layers", "N/A"), |
|
"Number of Attention Heads": getattr(config, "num_attention_heads", "N/A"), |
|
"Total Parameters": sum(p.numel() for p in model.parameters()), |
|
} |
|
|