File size: 4,100 Bytes
1c72248
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import time

import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import sys
import os
import cv2
import random
from transformers import CLIPImageProcessor

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torchvision.transforms.functional
from toolkit.image_utils import save_tensors, show_img, show_tensors

from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from toolkit.data_loader import AiToolkitDataset, get_dataloader_from_datasets, \
    trigger_dataloader_setup_epoch
from toolkit.config_modules import DatasetConfig
import argparse
from tqdm import tqdm

parser = argparse.ArgumentParser()
parser.add_argument('dataset_folder', type=str, default='input')
parser.add_argument('--epochs', type=int, default=1)
parser.add_argument('--num_frames', type=int, default=1)
parser.add_argument('--output_path', type=str, default=None)


args = parser.parse_args()

if args.output_path is not None:
    args.output_path = os.path.abspath(args.output_path)
    os.makedirs(args.output_path, exist_ok=True)

dataset_folder = args.dataset_folder
resolution = 512
bucket_tolerance = 64
batch_size = 1

clip_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch16")

class FakeAdapter:
    def __init__(self):
        self.clip_image_processor = clip_processor


## make fake sd
class FakeSD:
    def __init__(self):
        self.adapter = FakeAdapter()




dataset_config = DatasetConfig(
    dataset_path=dataset_folder,
    # clip_image_path=dataset_folder,
    # square_crop=True,
    resolution=resolution,
    # caption_ext='json',
    default_caption='default',
    # clip_image_path='/mnt/Datasets2/regs/yetibear_xl_v14/random_aspect/',
    buckets=True,
    bucket_tolerance=bucket_tolerance,
    shrink_video_to_frames=True,
    num_frames=args.num_frames,
    # poi='person',
    # shuffle_augmentations=True,
    # augmentations=[
    #     {
    #         'method': 'Posterize',
    #         'num_bits': [(0, 4), (0, 4), (0, 4)],
    #         'p': 1.0
    #     },
    #
    # ]
)

dataloader: DataLoader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size, sd=FakeSD())


# run through an epoch ang check sizes
dataloader_iterator = iter(dataloader)
idx = 0
for epoch in range(args.epochs):
    for batch in tqdm(dataloader):
        batch: 'DataLoaderBatchDTO'
        img_batch = batch.tensor
        frames = 1
        if len(img_batch.shape) == 5:
            frames = img_batch.shape[1]
            batch_size, frames, channels, height, width = img_batch.shape
        else:
            batch_size, channels, height, width = img_batch.shape

        # img_batch = color_block_imgs(img_batch, neg1_1=True)

        # chunks = torch.chunk(img_batch, batch_size, dim=0)
        # # put them so they are size by side
        # big_img = torch.cat(chunks, dim=3)
        # big_img = big_img.squeeze(0)
        #
        # control_chunks = torch.chunk(batch.clip_image_tensor, batch_size, dim=0)
        # big_control_img = torch.cat(control_chunks, dim=3)
        # big_control_img = big_control_img.squeeze(0) * 2 - 1
        #
        #
        # # resize control image
        # big_control_img = torchvision.transforms.Resize((width, height))(big_control_img)
        #
        # big_img = torch.cat([big_img, big_control_img], dim=2)
        #
        # min_val = big_img.min()
        # max_val = big_img.max()
        #
        # big_img = (big_img / 2 + 0.5).clamp(0, 1)

        big_img = img_batch
        # big_img = big_img.clamp(-1, 1)
        if args.output_path is not None:
            save_tensors(big_img, os.path.join(args.output_path, f'{idx}.png'))
        else:
            show_tensors(big_img)

            # convert to image
            # img = transforms.ToPILImage()(big_img)
            #
            # show_img(img)

            time.sleep(0.2)
        idx += 1
    # if not last epoch
    if epoch < args.epochs - 1:
        trigger_dataloader_setup_epoch(dataloader)

cv2.destroyAllWindows()

print('done')