Spaces:
Paused
Paused
File size: 4,100 Bytes
1c72248 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import time
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import sys
import os
import cv2
import random
from transformers import CLIPImageProcessor
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torchvision.transforms.functional
from toolkit.image_utils import save_tensors, show_img, show_tensors
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from toolkit.data_loader import AiToolkitDataset, get_dataloader_from_datasets, \
trigger_dataloader_setup_epoch
from toolkit.config_modules import DatasetConfig
import argparse
from tqdm import tqdm
parser = argparse.ArgumentParser()
parser.add_argument('dataset_folder', type=str, default='input')
parser.add_argument('--epochs', type=int, default=1)
parser.add_argument('--num_frames', type=int, default=1)
parser.add_argument('--output_path', type=str, default=None)
args = parser.parse_args()
if args.output_path is not None:
args.output_path = os.path.abspath(args.output_path)
os.makedirs(args.output_path, exist_ok=True)
dataset_folder = args.dataset_folder
resolution = 512
bucket_tolerance = 64
batch_size = 1
clip_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch16")
class FakeAdapter:
def __init__(self):
self.clip_image_processor = clip_processor
## make fake sd
class FakeSD:
def __init__(self):
self.adapter = FakeAdapter()
dataset_config = DatasetConfig(
dataset_path=dataset_folder,
# clip_image_path=dataset_folder,
# square_crop=True,
resolution=resolution,
# caption_ext='json',
default_caption='default',
# clip_image_path='/mnt/Datasets2/regs/yetibear_xl_v14/random_aspect/',
buckets=True,
bucket_tolerance=bucket_tolerance,
shrink_video_to_frames=True,
num_frames=args.num_frames,
# poi='person',
# shuffle_augmentations=True,
# augmentations=[
# {
# 'method': 'Posterize',
# 'num_bits': [(0, 4), (0, 4), (0, 4)],
# 'p': 1.0
# },
#
# ]
)
dataloader: DataLoader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size, sd=FakeSD())
# run through an epoch ang check sizes
dataloader_iterator = iter(dataloader)
idx = 0
for epoch in range(args.epochs):
for batch in tqdm(dataloader):
batch: 'DataLoaderBatchDTO'
img_batch = batch.tensor
frames = 1
if len(img_batch.shape) == 5:
frames = img_batch.shape[1]
batch_size, frames, channels, height, width = img_batch.shape
else:
batch_size, channels, height, width = img_batch.shape
# img_batch = color_block_imgs(img_batch, neg1_1=True)
# chunks = torch.chunk(img_batch, batch_size, dim=0)
# # put them so they are size by side
# big_img = torch.cat(chunks, dim=3)
# big_img = big_img.squeeze(0)
#
# control_chunks = torch.chunk(batch.clip_image_tensor, batch_size, dim=0)
# big_control_img = torch.cat(control_chunks, dim=3)
# big_control_img = big_control_img.squeeze(0) * 2 - 1
#
#
# # resize control image
# big_control_img = torchvision.transforms.Resize((width, height))(big_control_img)
#
# big_img = torch.cat([big_img, big_control_img], dim=2)
#
# min_val = big_img.min()
# max_val = big_img.max()
#
# big_img = (big_img / 2 + 0.5).clamp(0, 1)
big_img = img_batch
# big_img = big_img.clamp(-1, 1)
if args.output_path is not None:
save_tensors(big_img, os.path.join(args.output_path, f'{idx}.png'))
else:
show_tensors(big_img)
# convert to image
# img = transforms.ToPILImage()(big_img)
#
# show_img(img)
time.sleep(0.2)
idx += 1
# if not last epoch
if epoch < args.epochs - 1:
trigger_dataloader_setup_epoch(dataloader)
cv2.destroyAllWindows()
print('done')
|