Cobra / cobra_utils /utils.py
JunhaoZhuang's picture
init
23e7e6a verified
import os
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
import matplotlib.pyplot as plt
import cv2
import torch.nn.functional as F
class _bn_relu_conv(nn.Module):
def __init__(self, in_filters, nb_filters, fw, fh, subsample=1):
super(_bn_relu_conv, self).__init__()
self.model = nn.Sequential(
nn.BatchNorm2d(in_filters, eps=1e-3),
nn.LeakyReLU(0.2),
nn.Conv2d(in_filters, nb_filters, (fw, fh), stride=subsample, padding=(fw//2, fh//2), padding_mode='zeros')
)
def forward(self, x):
return self.model(x)
# the following are for debugs
print("****", np.max(x.cpu().numpy()), np.min(x.cpu().numpy()), np.mean(x.cpu().numpy()), np.std(x.cpu().numpy()), x.shape)
for i,layer in enumerate(self.model):
if i != 2:
x = layer(x)
else:
x = layer(x)
#x = nn.functional.pad(x, (1, 1, 1, 1), mode='constant', value=0)
print("____", np.max(x.cpu().numpy()), np.min(x.cpu().numpy()), np.mean(x.cpu().numpy()), np.std(x.cpu().numpy()), x.shape)
print(x[0])
return x
class _u_bn_relu_conv(nn.Module):
def __init__(self, in_filters, nb_filters, fw, fh, subsample=1):
super(_u_bn_relu_conv, self).__init__()
self.model = nn.Sequential(
nn.BatchNorm2d(in_filters, eps=1e-3),
nn.LeakyReLU(0.2),
nn.Conv2d(in_filters, nb_filters, (fw, fh), stride=subsample, padding=(fw//2, fh//2)),
nn.Upsample(scale_factor=2, mode='nearest')
)
def forward(self, x):
return self.model(x)
class _shortcut(nn.Module):
def __init__(self, in_filters, nb_filters, subsample=1):
super(_shortcut, self).__init__()
self.process = False
self.model = None
if in_filters != nb_filters or subsample != 1:
self.process = True
self.model = nn.Sequential(
nn.Conv2d(in_filters, nb_filters, (1, 1), stride=subsample)
)
def forward(self, x, y):
#print(x.size(), y.size(), self.process)
if self.process:
y0 = self.model(x)
#print("merge+", torch.max(y0+y), torch.min(y0+y),torch.mean(y0+y), torch.std(y0+y), y0.shape)
return y0 + y
else:
#print("merge", torch.max(x+y), torch.min(x+y),torch.mean(x+y), torch.std(x+y), y.shape)
return x + y
class _u_shortcut(nn.Module):
def __init__(self, in_filters, nb_filters, subsample):
super(_u_shortcut, self).__init__()
self.process = False
self.model = None
if in_filters != nb_filters:
self.process = True
self.model = nn.Sequential(
nn.Conv2d(in_filters, nb_filters, (1, 1), stride=subsample, padding_mode='zeros'),
nn.Upsample(scale_factor=2, mode='nearest')
)
def forward(self, x, y):
if self.process:
return self.model(x) + y
else:
return x + y
class basic_block(nn.Module):
def __init__(self, in_filters, nb_filters, init_subsample=1):
super(basic_block, self).__init__()
self.conv1 = _bn_relu_conv(in_filters, nb_filters, 3, 3, subsample=init_subsample)
self.residual = _bn_relu_conv(nb_filters, nb_filters, 3, 3)
self.shortcut = _shortcut(in_filters, nb_filters, subsample=init_subsample)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.residual(x1)
return self.shortcut(x, x2)
class _u_basic_block(nn.Module):
def __init__(self, in_filters, nb_filters, init_subsample=1):
super(_u_basic_block, self).__init__()
self.conv1 = _u_bn_relu_conv(in_filters, nb_filters, 3, 3, subsample=init_subsample)
self.residual = _bn_relu_conv(nb_filters, nb_filters, 3, 3)
self.shortcut = _u_shortcut(in_filters, nb_filters, subsample=init_subsample)
def forward(self, x):
y = self.residual(self.conv1(x))
return self.shortcut(x, y)
class _residual_block(nn.Module):
def __init__(self, in_filters, nb_filters, repetitions, is_first_layer=False):
super(_residual_block, self).__init__()
layers = []
for i in range(repetitions):
init_subsample = 1
if i == repetitions - 1 and not is_first_layer:
init_subsample = 2
if i == 0:
l = basic_block(in_filters=in_filters, nb_filters=nb_filters, init_subsample=init_subsample)
else:
l = basic_block(in_filters=nb_filters, nb_filters=nb_filters, init_subsample=init_subsample)
layers.append(l)
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
class _upsampling_residual_block(nn.Module):
def __init__(self, in_filters, nb_filters, repetitions):
super(_upsampling_residual_block, self).__init__()
layers = []
for i in range(repetitions):
l = None
if i == 0:
l = _u_basic_block(in_filters=in_filters, nb_filters=nb_filters)#(input)
else:
l = basic_block(in_filters=nb_filters, nb_filters=nb_filters)#(input)
layers.append(l)
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
class res_skip(nn.Module):
def __init__(self):
super(res_skip, self).__init__()
self.block0 = _residual_block(in_filters=1, nb_filters=24, repetitions=2, is_first_layer=True)#(input)
self.block1 = _residual_block(in_filters=24, nb_filters=48, repetitions=3)#(block0)
self.block2 = _residual_block(in_filters=48, nb_filters=96, repetitions=5)#(block1)
self.block3 = _residual_block(in_filters=96, nb_filters=192, repetitions=7)#(block2)
self.block4 = _residual_block(in_filters=192, nb_filters=384, repetitions=12)#(block3)
self.block5 = _upsampling_residual_block(in_filters=384, nb_filters=192, repetitions=7)#(block4)
self.res1 = _shortcut(in_filters=192, nb_filters=192)#(block3, block5, subsample=(1,1))
self.block6 = _upsampling_residual_block(in_filters=192, nb_filters=96, repetitions=5)#(res1)
self.res2 = _shortcut(in_filters=96, nb_filters=96)#(block2, block6, subsample=(1,1))
self.block7 = _upsampling_residual_block(in_filters=96, nb_filters=48, repetitions=3)#(res2)
self.res3 = _shortcut(in_filters=48, nb_filters=48)#(block1, block7, subsample=(1,1))
self.block8 = _upsampling_residual_block(in_filters=48, nb_filters=24, repetitions=2)#(res3)
self.res4 = _shortcut(in_filters=24, nb_filters=24)#(block0,block8, subsample=(1,1))
self.block9 = _residual_block(in_filters=24, nb_filters=16, repetitions=2, is_first_layer=True)#(res4)
self.conv15 = _bn_relu_conv(in_filters=16, nb_filters=1, fh=1, fw=1, subsample=1)#(block7)
def forward(self, x):
x0 = self.block0(x)
x1 = self.block1(x0)
x2 = self.block2(x1)
x3 = self.block3(x2)
x4 = self.block4(x3)
x5 = self.block5(x4)
res1 = self.res1(x3, x5)
x6 = self.block6(res1)
res2 = self.res2(x2, x6)
x7 = self.block7(res2)
res3 = self.res3(x1, x7)
x8 = self.block8(res3)
res4 = self.res4(x0, x8)
x9 = self.block9(res4)
y = self.conv15(x9)
return y
class MyDataset(Dataset):
def __init__(self, image_paths, transform=None):
self.image_paths = image_paths
self.transform = transform
def get_class_label(self, image_name):
# your method here
head, tail = os.path.split(image_name)
#print(tail)
return tail
def __getitem__(self, index):
image_path = self.image_paths[index]
x = Image.open(image_path)
y = self.get_class_label(image_path.split('/')[-1])
if self.transform is not None:
x = self.transform(x)
return x, y
def __len__(self):
return len(self.image_paths)
def loadImages(folder):
imgs = []
matches = []
for filename in os.listdir(folder):
file_path = os.path.join(folder, filename)
if os.path.isfile(file_path):
matches.append(file_path)
return matches
def crop_center_square(image):
width, height = image.size
side_length = min(width, height)
left = (width - side_length) // 2
top = (height - side_length) // 2
right = left + side_length
bottom = top + side_length
cropped_image = image.crop((left, top, right, bottom))
return cropped_image
def crop_image(image, crop_size, stride):
width, height = image.size
crop_width, crop_height = crop_size
cropped_images = []
for j in range(0, height - crop_height + 1, stride):
for i in range(0, width - crop_width + 1, stride):
crop_box = (i, j, i + crop_width, j + crop_height)
cropped_image = image.crop(crop_box)
cropped_images.append(cropped_image)
return cropped_images
def process_image_ref(image):
resized_image_512 = image.resize((512, 512))
image_list = [resized_image_512]
crop_size_384 = (384, 384)
stride_384 = 128
image_list.extend(crop_image(resized_image_512, crop_size_384, stride_384))
return image_list
def process_image_Q(image):
resized_image_512 = image.resize((512, 512)).convert("RGB").convert("RGB")
image_list = []
crop_size_384 = (384, 384)
stride_384 = 128
image_list.extend(crop_image(resized_image_512, crop_size_384, stride_384))
return image_list
def process_image(image, target_width=512, target_height = 512):
img_width, img_height = image.size
img_ratio = img_width / img_height
target_ratio = target_width / target_height
ratio_error = abs(img_ratio - target_ratio) / target_ratio
if ratio_error < 0.15:
resized_image = image.resize((target_width, target_height), Image.BICUBIC)
else:
if img_ratio > target_ratio:
new_width = int(img_height * target_ratio)
left = int((0 + img_width - new_width)/2)
top = 0
right = left + new_width
bottom = img_height
else:
new_height = int(img_width / target_ratio)
left = 0
top = int((0 + img_height - new_height)/2)
right = img_width
bottom = top + new_height
cropped_image = image.crop((left, top, right, bottom))
resized_image = cropped_image.resize((target_width, target_height), Image.BICUBIC)
return resized_image.convert('RGB')
def crop_image_varres(image, crop_size, h_stride, w_stride):
width, height = image.size
crop_width, crop_height = crop_size
cropped_images = []
for j in range(0, height - crop_height + 1, h_stride):
for i in range(0, width - crop_width + 1, w_stride):
crop_box = (i, j, i + crop_width, j + crop_height)
cropped_image = image.crop(crop_box)
cropped_images.append(cropped_image)
return cropped_images
def process_image_ref_varres(image, target_width=512, target_height = 512):
resized_image_512 = image.resize((target_width, target_height))
image_list = [resized_image_512]
crop_size_384 = (target_width//4*3, target_height//4*3)
w_stride_384 = target_width//4
h_stride_384 = target_height//4
image_list.extend(crop_image_varres(resized_image_512, crop_size_384, h_stride = h_stride_384, w_stride = w_stride_384))
return image_list
def process_image_Q_varres(image, target_width=512, target_height = 512):
resized_image_512 = image.resize((target_width, target_height)).convert("RGB").convert("RGB")
image_list = []
crop_size_384 = (target_width//4*3, target_height//4*3)
w_stride_384 = target_width//4
h_stride_384 = target_height//4
image_list.extend(crop_image_varres(resized_image_512, crop_size_384, h_stride = h_stride_384, w_stride = w_stride_384))
return image_list
import torch
import torch.nn as nn
import torch.nn.functional as F
class ResNetBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super(ResNetBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x) # 直接相加
out = F.relu(out)
return out
class TwoLayerResNet(nn.Module):
def __init__(self, in_channels, out_channels):
super(TwoLayerResNet, self).__init__()
self.block1 = ResNetBlock(in_channels, out_channels)
self.block2 = ResNetBlock(out_channels, out_channels)
self.block3 = ResNetBlock(out_channels, out_channels)
self.block4 = ResNetBlock(out_channels, out_channels)
def forward(self, x):
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
x = self.block4(x)
return x
class MultiHiddenResNetModel(nn.Module):
def __init__(self, channels_list, num_tensors):
super(MultiHiddenResNetModel, self).__init__()
self.two_layer_resnets = nn.ModuleList([TwoLayerResNet(channels_list[idx]*2, channels_list[min(len(channels_list)-1,idx+2)]) for idx in range(num_tensors)])
def forward(self, tensor_list):
processed_list = []
for i, tensor in enumerate(tensor_list):
tensor = self.two_layer_resnets[i](tensor)
processed_list.append(tensor)
return processed_list
def calculate_target_size(h, w):
if random.random()>0.5:
target_h = (h // 8) * 8
target_w = (w // 8) * 8
elif random.random()>0.5:
target_h = (h // 8) * 8
target_w = (w // 8) * 8
else:
target_h = (h // 8) * 8
target_w = (w // 8) * 8
if target_h == 0:
target_h = 8
if target_w == 0:
target_w = 8
return target_h, target_w
def downsample_tensor(tensor):
b, c, h, w = tensor.shape
target_h, target_w = calculate_target_size(h, w)
downsampled_tensor = F.interpolate(tensor, size=(target_h, target_w), mode='bilinear', align_corners=False)
return downsampled_tensor
def get_pixart_config():
pixart_config = {
"_class_name": "Transformer2DModel",
"_diffusers_version": "0.22.0.dev0",
"activation_fn": "gelu-approximate",
"attention_bias": True,
"attention_head_dim": 72,
"attention_type": "default",
"caption_channels": 4096,
"cross_attention_dim": 1152,
"double_self_attention": False,
"dropout": 0.0,
"in_channels": 4,
# "interpolation_scale": 2,
"norm_elementwise_affine": False,
"norm_eps": 1e-06,
"norm_num_groups": 32,
"norm_type": "ada_norm_single",
"num_attention_heads": 16,
"num_embeds_ada_norm": 1000,
"num_layers": 28,
"num_vector_embeds": None,
"only_cross_attention": False,
"out_channels": 8,
"patch_size": 2,
"sample_size": 128,
"upcast_attention": False,
# "use_additional_conditions": False,
"use_linear_projection": False
}
return pixart_config
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, 3, 1, 1),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
def forward(self, x):
return self.double_conv(x)
class UNet(nn.Module):
def __init__(self):
super().__init__()
# left
self.left_conv_1 = DoubleConv(6, 64)
self.down_1 = nn.MaxPool2d(2, 2)
self.left_conv_2 = DoubleConv(64, 128)
self.down_2 = nn.MaxPool2d(2, 2)
self.left_conv_3 = DoubleConv(128, 256)
self.down_3 = nn.MaxPool2d(2, 2)
self.left_conv_4 = DoubleConv(256, 512)
self.down_4 = nn.MaxPool2d(2, 2)
# center
self.center_conv = DoubleConv(512, 1024)
# right
self.up_1 = nn.ConvTranspose2d(1024, 512, 2, 2)
self.right_conv_1 = DoubleConv(1024, 512)
self.up_2 = nn.ConvTranspose2d(512, 256, 2, 2)
self.right_conv_2 = DoubleConv(512, 256)
self.up_3 = nn.ConvTranspose2d(256, 128, 2, 2)
self.right_conv_3 = DoubleConv(256, 128)
self.up_4 = nn.ConvTranspose2d(128, 64, 2, 2)
self.right_conv_4 = DoubleConv(128, 64)
# output
self.output = nn.Conv2d(64, 3, 1, 1, 0)
def forward(self, x):
# left
x1 = self.left_conv_1(x)
x1_down = self.down_1(x1)
x2 = self.left_conv_2(x1_down)
x2_down = self.down_2(x2)
x3 = self.left_conv_3(x2_down)
x3_down = self.down_3(x3)
x4 = self.left_conv_4(x3_down)
x4_down = self.down_4(x4)
# center
x5 = self.center_conv(x4_down)
# right
x6_up = self.up_1(x5)
temp = torch.cat((x6_up, x4), dim=1)
x6 = self.right_conv_1(temp)
x7_up = self.up_2(x6)
temp = torch.cat((x7_up, x3), dim=1)
x7 = self.right_conv_2(temp)
x8_up = self.up_3(x7)
temp = torch.cat((x8_up, x2), dim=1)
x8 = self.right_conv_3(temp)
x9_up = self.up_4(x8)
temp = torch.cat((x9_up, x1), dim=1)
x9 = self.right_conv_4(temp)
# output
output = self.output(x9)
return output
from copy import deepcopy
def init_causal_dit(model, base_model):
temp_ckpt = deepcopy(base_model)
checkpoint = temp_ckpt.state_dict()
# checkpoint['pos_embed_1d.weight'] = torch.zeros(3, model.config.num_attention_heads * model.config.attention_head_dim, device=model.pos_embed_1d.weight.device, dtype = model.pos_embed_1d.weight.dtype)
model.load_state_dict(checkpoint, strict=True)
del temp_ckpt
return model
def init_controlnet(model, base_model):
temp_ckpt = deepcopy(base_model)
checkpoint = temp_ckpt.state_dict()
checkpoint_weight = checkpoint['pos_embed.proj.weight']
new_weight = torch.zeros(model.pos_embed.proj.weight.shape, device=model.pos_embed.proj.weight.device, dtype = model.pos_embed.proj.weight.dtype)
print('model.pos_embed.proj.weight.shape',model.pos_embed.proj.weight.shape)
new_weight[:, :4] = checkpoint_weight
checkpoint['pos_embed.proj.weight'] = new_weight
print('new_weight', new_weight.dtype)
model.load_state_dict(checkpoint, strict=False)
del temp_ckpt
return model