File size: 1,555 Bytes
4475574
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from transformers import Pipeline
import torch
from typing import Dict, List, Union

class MatterGPTPipeline(Pipeline):
    def __init__(self, model, tokenizer, device=-1):
        super().__init__(model=model, tokenizer=tokenizer, device=device)

    def _sanitize_parameters(self, **kwargs):
        return {}, {}, {}

    def preprocess(self, inputs: Union[Dict[str, float], List[Dict[str, float]]]) -> Dict[str, torch.Tensor]:
        if isinstance(inputs, dict):
            inputs = [inputs]
        
        conditions = [[input['formation_energy'], input['band_gap']] for input in inputs]
        context = '>'
        x = torch.tensor([self.tokenizer.stoi[context]], dtype=torch.long)[None,...].repeat(len(conditions), 1).to(self.device)
        p = torch.tensor(conditions, dtype=torch.float).unsqueeze(1).to(self.device)
        
        return {"input_ids": x, "prop": p}

    def _forward(self, model_inputs):
        return self.model.generate(
            model_inputs["input_ids"], 
            prop=model_inputs["prop"],
            max_length=self.model.config.block_size, 
            temperature=1.2, 
            do_sample=True, 
            top_k=0, 
            top_p=0.9
        )

    def postprocess(self, model_outputs):
        return [self.tokenizer.decode(seq.tolist()) for seq in model_outputs]

    def __call__(self, inputs: Union[Dict[str, float], List[Dict[str, float]]]):
        pre_processed = self.preprocess(inputs)
        model_outputs = self._forward(pre_processed)
        return self.postprocess(model_outputs)