File size: 2,785 Bytes
4475574
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from mattergpt_wrapper import MatterGPTWrapper, SimpleTokenizer
import torch
from tqdm import tqdm
import os
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Load the model
model_path = "./"  # Directory containing config.json and pytorch_model.bin
if not os.path.exists(os.path.join(model_path, "config.json")):
    raise FileNotFoundError(f"Config file not found in {model_path}")
if not os.path.exists(os.path.join(model_path, "pytorch_model.pt")):
    raise FileNotFoundError(f"Model weights not found in {model_path}")

model = MatterGPTWrapper.from_pretrained(model_path)
model.to('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Model loaded from {model_path}")

# Load the tokenizer
tokenizer_path = "Voc_prior"
if not os.path.exists(tokenizer_path):
    raise FileNotFoundError(f"Tokenizer vocabulary file not found at {tokenizer_path}")
tokenizer = SimpleTokenizer(tokenizer_path)
logger.info(f"Tokenizer loaded from {tokenizer_path}")

# Function to generate a single sequence
def generate_single(condition):
    context = '>'
    x = torch.tensor([tokenizer.stoi[context]], dtype=torch.long)[None,...].to(model.device)
    p = torch.tensor([condition]).unsqueeze(1).to(model.device)
    
    generated = model.generate(x, prop=p, max_length=model.config.block_size, temperature=1.2, do_sample=True, top_k=0, top_p=0.9)
    return tokenizer.decode(generated[0].tolist())

# Function to generate multiple sequences
def generate_multiple(condition, num_sequences, batch_size=32):
    all_sequences = []
    for _ in tqdm(range(0, num_sequences, batch_size)):
        current_batch_size = min(batch_size, num_sequences - len(all_sequences))
        context = '>'
        x = torch.tensor([tokenizer.stoi[context]], dtype=torch.long)[None,...].repeat(current_batch_size, 1).to(model.device)
        p = torch.tensor([condition]).repeat(current_batch_size, 1).unsqueeze(1).to(model.device)
        
        generated = model.generate(x, prop=p, max_length=model.config.block_size, temperature=1.2, do_sample=True, top_k=0, top_p=0.9)
        all_sequences.extend([tokenizer.decode(seq.tolist()) for seq in generated])
        
        if len(all_sequences) >= num_sequences:
            break
    
    return all_sequences[:num_sequences]

# Example usage
condition = [-1.0, 2.0]  # eform and bandgap

# Generate a single sequence
logger.info("Generating a single sequence:")
single_sequence = generate_single(condition)
print(single_sequence)
print()

# Generate multiple sequences
num_sequences = 10
logger.info(f"Generating {num_sequences} sequences:")
multiple_sequences = generate_multiple(condition, num_sequences)
for i, seq in enumerate(multiple_sequences, 1):
    print(seq)

logger.info("Generation complete")