MohamedRashad's picture
Upload code
6dd488f
raw
history blame contribute delete
3.89 kB
import os
import json
import random
import glob
import torch
import einops
import torchvision
import safetensors.torch as sf
def write_to_json(data, file_path):
temp_file_path = file_path + ".tmp"
with open(temp_file_path, 'wt', encoding='utf-8') as temp_file:
json.dump(data, temp_file, indent=4)
os.replace(temp_file_path, file_path)
return
def read_from_json(file_path):
with open(file_path, 'rt', encoding='utf-8') as file:
data = json.load(file)
return data
def get_active_parameters(m):
return {k:v for k, v in m.named_parameters() if v.requires_grad}
def cast_training_params(m, dtype=torch.float32):
for param in m.parameters():
if param.requires_grad:
param.data = param.to(dtype)
return
def set_attr_recursive(obj, attr, value):
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
setattr(obj, attrs[-1], value)
return
@torch.no_grad()
def batch_mixture(a, b, probability_a=0.5, mask_a=None):
assert a.shape == b.shape, "Tensors must have the same shape"
batch_size = a.size(0)
if mask_a is None:
mask_a = torch.rand(batch_size) < probability_a
mask_a = mask_a.to(a.device)
mask_a = mask_a.reshape((batch_size,) + (1,) * (a.dim() - 1))
result = torch.where(mask_a, a, b)
return result
@torch.no_grad()
def zero_module(module):
for p in module.parameters():
p.detach().zero_()
return module
def load_last_state(model, folder='accelerator_output'):
file_pattern = os.path.join(folder, '**', 'model.safetensors')
files = glob.glob(file_pattern, recursive=True)
if not files:
print("No model.safetensors files found in the specified folder.")
return
newest_file = max(files, key=os.path.getmtime)
state_dict = sf.load_file(newest_file)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
if missing_keys:
print("Missing keys:", missing_keys)
if unexpected_keys:
print("Unexpected keys:", unexpected_keys)
print("Loaded model state from:", newest_file)
return
def generate_random_prompt_from_tags(tags_str, min_length=3, max_length=32):
tags = tags_str.split(', ')
tags = random.sample(tags, k=min(random.randint(min_length, max_length), len(tags)))
prompt = ', '.join(tags)
return prompt
def save_bcthw_as_mp4(x, output_filename, fps=10):
b, c, t, h, w = x.shape
per_row = b
for p in [6, 5, 4, 3, 2]:
if b % p == 0:
per_row = p
break
os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
x = x.detach().cpu().to(torch.uint8)
x = einops.rearrange(x, '(m n) c t h w -> t (m h) (n w) c', n=per_row)
torchvision.io.write_video(output_filename, x, fps=fps, video_codec='h264', options={'crf': '0'})
return x
def save_bcthw_as_png(x, output_filename):
os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
x = x.detach().cpu().to(torch.uint8)
x = einops.rearrange(x, 'b c t h w -> c (b h) (t w)')
torchvision.io.write_png(x, output_filename)
return output_filename
def add_tensors_with_padding(tensor1, tensor2):
if tensor1.shape == tensor2.shape:
return tensor1 + tensor2
shape1 = tensor1.shape
shape2 = tensor2.shape
new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2))
padded_tensor1 = torch.zeros(new_shape)
padded_tensor2 = torch.zeros(new_shape)
padded_tensor1[tuple(slice(0, s) for s in shape1)] = tensor1
padded_tensor2[tuple(slice(0, s) for s in shape2)] = tensor2
result = padded_tensor1 + padded_tensor2
return result