|
from mattergpt_wrapper import MatterGPTWrapper, SimpleTokenizer |
|
import torch |
|
from tqdm import tqdm |
|
import os |
|
import logging |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
model_path = "./" |
|
if not os.path.exists(os.path.join(model_path, "config.json")): |
|
raise FileNotFoundError(f"Config file not found in {model_path}") |
|
if not os.path.exists(os.path.join(model_path, "pytorch_model.pt")): |
|
raise FileNotFoundError(f"Model weights not found in {model_path}") |
|
|
|
model = MatterGPTWrapper.from_pretrained(model_path) |
|
model.to('cuda' if torch.cuda.is_available() else 'cpu') |
|
logger.info(f"Model loaded from {model_path}") |
|
|
|
|
|
tokenizer_path = "Voc_prior" |
|
if not os.path.exists(tokenizer_path): |
|
raise FileNotFoundError(f"Tokenizer vocabulary file not found at {tokenizer_path}") |
|
tokenizer = SimpleTokenizer(tokenizer_path) |
|
logger.info(f"Tokenizer loaded from {tokenizer_path}") |
|
|
|
|
|
def generate_single(condition): |
|
context = '>' |
|
x = torch.tensor([tokenizer.stoi[context]], dtype=torch.long)[None,...].to(model.device) |
|
p = torch.tensor([condition]).unsqueeze(1).to(model.device) |
|
|
|
generated = model.generate(x, prop=p, max_length=model.config.block_size, temperature=1.2, do_sample=True, top_k=0, top_p=0.9) |
|
return tokenizer.decode(generated[0].tolist()) |
|
|
|
|
|
def generate_multiple(condition, num_sequences, batch_size=32): |
|
all_sequences = [] |
|
for _ in tqdm(range(0, num_sequences, batch_size)): |
|
current_batch_size = min(batch_size, num_sequences - len(all_sequences)) |
|
context = '>' |
|
x = torch.tensor([tokenizer.stoi[context]], dtype=torch.long)[None,...].repeat(current_batch_size, 1).to(model.device) |
|
p = torch.tensor([condition]).repeat(current_batch_size, 1).unsqueeze(1).to(model.device) |
|
|
|
generated = model.generate(x, prop=p, max_length=model.config.block_size, temperature=1.2, do_sample=True, top_k=0, top_p=0.9) |
|
all_sequences.extend([tokenizer.decode(seq.tolist()) for seq in generated]) |
|
|
|
if len(all_sequences) >= num_sequences: |
|
break |
|
|
|
return all_sequences[:num_sequences] |
|
|
|
|
|
condition = [-1.0, 2.0] |
|
|
|
|
|
logger.info("Generating a single sequence:") |
|
single_sequence = generate_single(condition) |
|
print(single_sequence) |
|
print() |
|
|
|
|
|
num_sequences = 10 |
|
logger.info(f"Generating {num_sequences} sequences:") |
|
multiple_sequences = generate_multiple(condition, num_sequences) |
|
for i, seq in enumerate(multiple_sequences, 1): |
|
print(seq) |
|
|
|
logger.info("Generation complete") |