File size: 3,288 Bytes
82f9e44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
from torch.utils.data import IterableDataset

class BilingualDataset(IterableDataset):
    def __init__(self, ds_stream, tokenizer, seq_len):
        self.ds_stream = ds_stream
        self.tokenizer = tokenizer
        self.seq_len = seq_len
        self.stride = seq_len // 2
        self.sos_token = tokenizer.token_to_id('<s>')
        self.eos_token = tokenizer.token_to_id('</s>')
        self.pad_token = tokenizer.token_to_id('<pad>')

    def process_text(self, text):
        token_ids = self.tokenizer.encode(text).ids + [self.eos_token]

        for i in range(0, max(1, len(token_ids) - self.seq_len + 1), self.stride):
            chunk = token_ids[i:i + self.seq_len - 2]  # leave space for <s> and pad
            chunk = [self.sos_token] + chunk
            if len(chunk) < self.seq_len:
                chunk += [self.pad_token] * (self.seq_len - len(chunk))
            
            input_tensor = torch.tensor(chunk[:-1], dtype=torch.long)
            label_tensor = torch.tensor(chunk[1:], dtype=torch.long)
            yield {
                "input": input_tensor,
                "label": label_tensor
            }

    def __iter__(self):
        for item in self.ds_stream:
            text = item["text"]
            yield from self.process_text(text)


"""import torch 
import torch.nn as nn
from torch.utils.data import Dataset

import json

class BilingualDataset(Dataset):
    def __init__(self, ds, tokenizer, seq_len):
        super().__init__()
        
        self.tokenizer = tokenizer
        self.seq_len = seq_len
        self.ds = ds
        self.stride = seq_len//2
        self.sos_token = torch.tensor([tokenizer.token_to_id('<s>')],dtype=torch.int64) 
        self.eos_token = torch.tensor([tokenizer.token_to_id('</s>')],dtype=torch.int64) 
        self.pad_token = torch.tensor([tokenizer.token_to_id('<pad>')],dtype=torch.int64) 
        
        self.data_tokens = []
        
        for text in self.ds:
            # text = text['instruction'] +" ### " + text['text'] + " \n" + text['output']
            # text = text['user'] +" ### " + text['ai']
            text = text['text']
            tokens = tokenizer.encode(text).ids
            self.data_tokens.extend(tokens + [self.eos_token])
        
    def __len__(self):
        return (len(self.data_tokens) - self.seq_len) // self.stride
    
    def __getitem__(self, index):
        
        input_tokens = torch.tensor(self.data_tokens[index*self.stride:(index*self.stride)+self.seq_len- 1]).tolist()
        
        input_tokens = [self.sos_token] + input_tokens + [self.pad_token]
        if len(input_tokens) < self.seq_len - 1:
            input_tokens+=[self.pad_token] * ((self.seq_len - 1 ) - len(input_tokens))
            
        input_tokens = torch.tensor(input_tokens)
        
        
        return {
            "input": input_tokens[:-1],
            # "input_mask": (input_tokens[:-1] != self.pad_token).unsqueeze(0).int() & causal_mask(input_tokens[:-1].size(0)), # (1, seq_len) & (1, seq_len, seq_len)
            "label":input_tokens[1:]                                        # ^ CONFUSION SYNTAX :)
        }
        
def causal_mask(size):
    mask = torch.triu(torch.ones(1,size,size), diagonal=1).type(torch.int)
    return mask == 0"""