MatterGPT / usage_example.py
xiaohang07's picture
Upload 7 files
4475574 verified
raw
history blame contribute delete
2.79 kB
from mattergpt_wrapper import MatterGPTWrapper, SimpleTokenizer
import torch
from tqdm import tqdm
import os
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load the model
model_path = "./" # Directory containing config.json and pytorch_model.bin
if not os.path.exists(os.path.join(model_path, "config.json")):
raise FileNotFoundError(f"Config file not found in {model_path}")
if not os.path.exists(os.path.join(model_path, "pytorch_model.pt")):
raise FileNotFoundError(f"Model weights not found in {model_path}")
model = MatterGPTWrapper.from_pretrained(model_path)
model.to('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Model loaded from {model_path}")
# Load the tokenizer
tokenizer_path = "Voc_prior"
if not os.path.exists(tokenizer_path):
raise FileNotFoundError(f"Tokenizer vocabulary file not found at {tokenizer_path}")
tokenizer = SimpleTokenizer(tokenizer_path)
logger.info(f"Tokenizer loaded from {tokenizer_path}")
# Function to generate a single sequence
def generate_single(condition):
context = '>'
x = torch.tensor([tokenizer.stoi[context]], dtype=torch.long)[None,...].to(model.device)
p = torch.tensor([condition]).unsqueeze(1).to(model.device)
generated = model.generate(x, prop=p, max_length=model.config.block_size, temperature=1.2, do_sample=True, top_k=0, top_p=0.9)
return tokenizer.decode(generated[0].tolist())
# Function to generate multiple sequences
def generate_multiple(condition, num_sequences, batch_size=32):
all_sequences = []
for _ in tqdm(range(0, num_sequences, batch_size)):
current_batch_size = min(batch_size, num_sequences - len(all_sequences))
context = '>'
x = torch.tensor([tokenizer.stoi[context]], dtype=torch.long)[None,...].repeat(current_batch_size, 1).to(model.device)
p = torch.tensor([condition]).repeat(current_batch_size, 1).unsqueeze(1).to(model.device)
generated = model.generate(x, prop=p, max_length=model.config.block_size, temperature=1.2, do_sample=True, top_k=0, top_p=0.9)
all_sequences.extend([tokenizer.decode(seq.tolist()) for seq in generated])
if len(all_sequences) >= num_sequences:
break
return all_sequences[:num_sequences]
# Example usage
condition = [-1.0, 2.0] # eform and bandgap
# Generate a single sequence
logger.info("Generating a single sequence:")
single_sequence = generate_single(condition)
print(single_sequence)
print()
# Generate multiple sequences
num_sequences = 10
logger.info(f"Generating {num_sequences} sequences:")
multiple_sequences = generate_multiple(condition, num_sequences)
for i, seq in enumerate(multiple_sequences, 1):
print(seq)
logger.info("Generation complete")