{ "cells": [ { "cell_type": "markdown", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "source": [ "## StoryDiffusion: Consistent Self-Attention for Long-Range Image and Video Generation \n", "[![Paper page](https://huggingface.co/datasets/huggingface/badges/resolve/main/paper-page-md-dark.svg)]()\n", "[[Paper]()]   [[Project Page]()]  
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Import Packages" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/flex/StoryDiffusion/venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "# %load_ext autoreload\n", "# %autoreload 2\n", "import gradio as gr\n", "import numpy as np\n", "import torch\n", "import requests\n", "import random\n", "import os\n", "import sys\n", "import pickle\n", "from PIL import Image\n", "from tqdm.auto import tqdm\n", "from datetime import datetime\n", "from utils.gradio_utils import is_torch2_available\n", "if is_torch2_available():\n", " from utils.gradio_utils import \\\n", " AttnProcessor2_0 as AttnProcessor\n", "else:\n", " from utils.gradio_utils import AttnProcessor\n", "\n", "import diffusers\n", "from diffusers import StableDiffusionXLPipeline\n", "from diffusers import DDIMScheduler\n", "import torch.nn.functional as F\n", "from utils.gradio_utils import cal_attn_mask_xl\n", "import copy\n", "import os\n", "from diffusers.utils import load_image\n", "from utils.utils import get_comic\n", "from utils.style_template import styles" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Set Config " ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "## Global\n", "STYLE_NAMES = list(styles.keys())\n", "DEFAULT_STYLE_NAME = \"(No style)\"\n", "MAX_SEED = np.iinfo(np.int32).max\n", "global models_dict\n", "use_va = False\n", "models_dict = {\n", " \"Juggernaut\":\"RunDiffusion/Juggernaut-XL-v8\",\n", " \"RealVision\":\"SG161222/RealVisXL_V4.0\" ,\n", " \"SDXL\":\"stabilityai/stable-diffusion-xl-base-1.0\" ,\n", " \"Unstable\": \"stablediffusionapi/sdxl-unstable-diffusers-y\"\n", "}" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "False" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.cuda.is_available()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def setup_seed(seed):\n", " torch.manual_seed(seed)\n", " torch.cuda.manual_seed_all(seed)\n", " np.random.seed(seed)\n", " random.seed(seed)\n", " torch.backends.cudnn.deterministic = True\n", "\n", " \n", "#################################################\n", "########Consistent Self-Attention################\n", "#################################################\n", "class SpatialAttnProcessor2_0(torch.nn.Module):\n", " r\"\"\"\n", " Attention processor for IP-Adapater for PyTorch 2.0.\n", " Args:\n", " hidden_size (`int`):\n", " The hidden size of the attention layer.\n", " cross_attention_dim (`int`):\n", " The number of channels in the `encoder_hidden_states`.\n", " text_context_len (`int`, defaults to 77):\n", " The context length of the text features.\n", " scale (`float`, defaults to 1.0):\n", " the weight scale of image prompt.\n", " \"\"\"\n", "\n", " def __init__(self, hidden_size = None, cross_attention_dim=None,id_length = 4,device = \"cuda\",dtype = torch.float16):\n", " super().__init__()\n", " if not hasattr(F, \"scaled_dot_product_attention\"):\n", " raise ImportError(\"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.\")\n", " self.device = device\n", " self.dtype = dtype\n", " self.hidden_size = hidden_size\n", " self.cross_attention_dim = cross_attention_dim\n", " self.total_length = id_length + 1\n", " self.id_length = id_length\n", " self.id_bank = {}\n", "\n", " def __call__(\n", " self,\n", " attn,\n", " hidden_states,\n", " encoder_hidden_states=None,\n", " attention_mask=None,\n", " temb=None):\n", " global total_count,attn_count,cur_step,mask1024,mask4096\n", " global sa32, sa64\n", " global write\n", " global height,width\n", " if write:\n", " # print(f\"white:{cur_step}\")\n", " self.id_bank[cur_step] = [hidden_states[:self.id_length], hidden_states[self.id_length:]]\n", " else:\n", " encoder_hidden_states = torch.cat((self.id_bank[cur_step][0].to(self.device),hidden_states[:1],self.id_bank[cur_step][1].to(self.device),hidden_states[1:]))\n", " # skip in early step\n", " if cur_step <5:\n", " hidden_states = self.__call2__(attn, hidden_states,encoder_hidden_states,attention_mask,temb)\n", " else: # 256 1024 4096\n", " random_number = random.random()\n", " if cur_step <20:\n", " rand_num = 0.3\n", " else:\n", " rand_num = 0.1\n", " if random_number > rand_num:\n", " if not write:\n", " if hidden_states.shape[1] == (height//32) * (width//32):\n", " attention_mask = mask1024[mask1024.shape[0] // self.total_length * self.id_length:]\n", " else:\n", " attention_mask = mask4096[mask4096.shape[0] // self.total_length * self.id_length:]\n", " else:\n", " if hidden_states.shape[1] == (height//32) * (width//32):\n", " attention_mask = mask1024[:mask1024.shape[0] // self.total_length * self.id_length,:mask1024.shape[0] // self.total_length * self.id_length]\n", " else:\n", " attention_mask = mask4096[:mask4096.shape[0] // self.total_length * self.id_length,:mask4096.shape[0] // self.total_length * self.id_length]\n", " hidden_states = self.__call1__(attn, hidden_states,encoder_hidden_states,attention_mask,temb)\n", " else:\n", " hidden_states = self.__call2__(attn, hidden_states,None,attention_mask,temb)\n", " attn_count +=1\n", " if attn_count == total_count:\n", " attn_count = 0\n", " cur_step += 1\n", " mask1024,mask4096 = cal_attn_mask_xl(self.total_length,self.id_length,sa32,sa64,height,width, device=self.device, dtype= self.dtype)\n", "\n", " return hidden_states\n", " def __call1__(\n", " self,\n", " attn,\n", " hidden_states,\n", " encoder_hidden_states=None,\n", " attention_mask=None,\n", " temb=None,\n", " ):\n", " residual = hidden_states\n", " if attn.spatial_norm is not None:\n", " hidden_states = attn.spatial_norm(hidden_states, temb)\n", " input_ndim = hidden_states.ndim\n", "\n", " if input_ndim == 4:\n", " total_batch_size, channel, height, width = hidden_states.shape\n", " hidden_states = hidden_states.view(total_batch_size, channel, height * width).transpose(1, 2)\n", " total_batch_size,nums_token,channel = hidden_states.shape\n", " img_nums = total_batch_size//2\n", " hidden_states = hidden_states.view(-1,img_nums,nums_token,channel).reshape(-1,img_nums * nums_token,channel)\n", "\n", " batch_size, sequence_length, _ = hidden_states.shape\n", "\n", " if attn.group_norm is not None:\n", " hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n", "\n", " query = attn.to_q(hidden_states)\n", "\n", " if encoder_hidden_states is None:\n", " encoder_hidden_states = hidden_states # B, N, C\n", " else:\n", " encoder_hidden_states = encoder_hidden_states.view(-1,self.id_length+1,nums_token,channel).reshape(-1,(self.id_length+1) * nums_token,channel)\n", "\n", " key = attn.to_k(encoder_hidden_states)\n", " value = attn.to_v(encoder_hidden_states)\n", "\n", "\n", " inner_dim = key.shape[-1]\n", " head_dim = inner_dim // attn.heads\n", "\n", " query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n", "\n", " key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n", " value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n", " hidden_states = F.scaled_dot_product_attention(\n", " query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False\n", " )\n", "\n", " hidden_states = hidden_states.transpose(1, 2).reshape(total_batch_size, -1, attn.heads * head_dim)\n", " hidden_states = hidden_states.to(query.dtype)\n", "\n", "\n", "\n", " # linear proj\n", " hidden_states = attn.to_out[0](hidden_states)\n", " # dropout\n", " hidden_states = attn.to_out[1](hidden_states)\n", "\n", "\n", " if input_ndim == 4:\n", " hidden_states = hidden_states.transpose(-1, -2).reshape(total_batch_size, channel, height, width)\n", " if attn.residual_connection:\n", " hidden_states = hidden_states + residual\n", " hidden_states = hidden_states / attn.rescale_output_factor\n", " # print(hidden_states.shape)\n", " return hidden_states\n", " def __call2__(\n", " self,\n", " attn,\n", " hidden_states,\n", " encoder_hidden_states=None,\n", " attention_mask=None,\n", " temb=None):\n", " residual = hidden_states\n", "\n", " if attn.spatial_norm is not None:\n", " hidden_states = attn.spatial_norm(hidden_states, temb)\n", "\n", " input_ndim = hidden_states.ndim\n", "\n", " if input_ndim == 4:\n", " batch_size, channel, height, width = hidden_states.shape\n", " hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n", "\n", " batch_size, sequence_length, channel = (\n", " hidden_states.shape\n", " )\n", " # print(hidden_states.shape)\n", " if attention_mask is not None:\n", " attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n", " # scaled_dot_product_attention expects attention_mask shape to be\n", " # (batch, heads, source_length, target_length)\n", " attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])\n", "\n", " if attn.group_norm is not None:\n", " hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n", "\n", " query = attn.to_q(hidden_states)\n", "\n", " if encoder_hidden_states is None:\n", " encoder_hidden_states = hidden_states # B, N, C\n", " else:\n", " encoder_hidden_states = encoder_hidden_states.view(-1,self.id_length+1,sequence_length,channel).reshape(-1,(self.id_length+1) * sequence_length,channel)\n", "\n", " key = attn.to_k(encoder_hidden_states)\n", " value = attn.to_v(encoder_hidden_states)\n", "\n", " inner_dim = key.shape[-1]\n", " head_dim = inner_dim // attn.heads\n", "\n", " query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n", "\n", " key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n", " value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n", "\n", " hidden_states = F.scaled_dot_product_attention(\n", " query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False\n", " )\n", "\n", " hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)\n", " hidden_states = hidden_states.to(query.dtype)\n", "\n", " # linear proj\n", " hidden_states = attn.to_out[0](hidden_states)\n", " # dropout\n", " hidden_states = attn.to_out[1](hidden_states)\n", "\n", " if input_ndim == 4:\n", " hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n", "\n", " if attn.residual_connection:\n", " hidden_states = hidden_states + residual\n", "\n", " hidden_states = hidden_states / attn.rescale_output_factor\n", "\n", " return hidden_states\n", "\n", "def set_attention_processor(unet,id_length):\n", " attn_procs = {}\n", " for name in unet.attn_processors.keys():\n", " cross_attention_dim = None if name.endswith(\"attn1.processor\") else unet.config.cross_attention_dim\n", " if name.startswith(\"mid_block\"):\n", " hidden_size = unet.config.block_out_channels[-1]\n", " elif name.startswith(\"up_blocks\"):\n", " block_id = int(name[len(\"up_blocks.\")])\n", " hidden_size = list(reversed(unet.config.block_out_channels))[block_id]\n", " elif name.startswith(\"down_blocks\"):\n", " block_id = int(name[len(\"down_blocks.\")])\n", " hidden_size = unet.config.block_out_channels[block_id]\n", " if cross_attention_dim is None:\n", " if name.startswith(\"up_blocks\") :\n", " attn_procs[name] = SpatialAttnProcessor2_0(id_length = id_length)\n", " else: \n", " attn_procs[name] = AttnProcessor()\n", " else:\n", " attn_procs[name] = AttnProcessor()\n", "\n", " unet.set_attn_processor(attn_procs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load Pipeline" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading pipeline components...: 14%|█▍ | 1/7 [00:00<00:00, 6.29it/s]\n" ] }, { "ename": "OSError", "evalue": "Error no file named pytorch_model.bin, model.safetensors, tf_model.h5, model.ckpt.index or flax_model.msgpack found in directory /Users/flex/.cache/huggingface/hub/models--SG161222--RealVisXL_V4.0/snapshots/49740684ab2d8f4f5dcf6c644df2b33388a8ba85/text_encoder.", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mOSError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[5], line 27\u001b[0m\n\u001b[1;32m 25\u001b[0m sd_model_path \u001b[39m=\u001b[39m models_dict[\u001b[39m\"\u001b[39m\u001b[39mRealVision\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m#\"SG161222/RealVisXL_V4.0\"\u001b[39;00m\n\u001b[1;32m 26\u001b[0m \u001b[39m### LOAD Stable Diffusion Pipeline\u001b[39;00m\n\u001b[0;32m---> 27\u001b[0m pipe \u001b[39m=\u001b[39m StableDiffusionXLPipeline\u001b[39m.\u001b[39;49mfrom_pretrained(sd_model_path, torch_dtype\u001b[39m=\u001b[39;49mtorch\u001b[39m.\u001b[39;49mfloat16, use_safetensors\u001b[39m=\u001b[39;49m\u001b[39mFalse\u001b[39;49;00m)\n\u001b[1;32m 28\u001b[0m pipe \u001b[39m=\u001b[39m pipe\u001b[39m.\u001b[39mto(device)\n\u001b[1;32m 29\u001b[0m pipe\u001b[39m.\u001b[39menable_freeu(s1\u001b[39m=\u001b[39m\u001b[39m0.6\u001b[39m, s2\u001b[39m=\u001b[39m\u001b[39m0.4\u001b[39m, b1\u001b[39m=\u001b[39m\u001b[39m1.1\u001b[39m, b2\u001b[39m=\u001b[39m\u001b[39m1.2\u001b[39m)\n", "File \u001b[0;32m~/StoryDiffusion/venv/lib/python3.12/site-packages/huggingface_hub/utils/_validators.py:114\u001b[0m, in \u001b[0;36mvalidate_hf_hub_args.._inner_fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 111\u001b[0m \u001b[39mif\u001b[39;00m check_use_auth_token:\n\u001b[1;32m 112\u001b[0m kwargs \u001b[39m=\u001b[39m smoothly_deprecate_use_auth_token(fn_name\u001b[39m=\u001b[39mfn\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m, has_token\u001b[39m=\u001b[39mhas_token, kwargs\u001b[39m=\u001b[39mkwargs)\n\u001b[0;32m--> 114\u001b[0m \u001b[39mreturn\u001b[39;00m fn(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", "File \u001b[0;32m~/StoryDiffusion/venv/lib/python3.12/site-packages/diffusers/pipelines/pipeline_utils.py:881\u001b[0m, in \u001b[0;36mDiffusionPipeline.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, **kwargs)\u001b[0m\n\u001b[1;32m 878\u001b[0m loaded_sub_model \u001b[39m=\u001b[39m passed_class_obj[name]\n\u001b[1;32m 879\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 880\u001b[0m \u001b[39m# load sub model\u001b[39;00m\n\u001b[0;32m--> 881\u001b[0m loaded_sub_model \u001b[39m=\u001b[39m load_sub_model(\n\u001b[1;32m 882\u001b[0m library_name\u001b[39m=\u001b[39;49mlibrary_name,\n\u001b[1;32m 883\u001b[0m class_name\u001b[39m=\u001b[39;49mclass_name,\n\u001b[1;32m 884\u001b[0m importable_classes\u001b[39m=\u001b[39;49mimportable_classes,\n\u001b[1;32m 885\u001b[0m pipelines\u001b[39m=\u001b[39;49mpipelines,\n\u001b[1;32m 886\u001b[0m is_pipeline_module\u001b[39m=\u001b[39;49mis_pipeline_module,\n\u001b[1;32m 887\u001b[0m pipeline_class\u001b[39m=\u001b[39;49mpipeline_class,\n\u001b[1;32m 888\u001b[0m torch_dtype\u001b[39m=\u001b[39;49mtorch_dtype,\n\u001b[1;32m 889\u001b[0m provider\u001b[39m=\u001b[39;49mprovider,\n\u001b[1;32m 890\u001b[0m sess_options\u001b[39m=\u001b[39;49msess_options,\n\u001b[1;32m 891\u001b[0m device_map\u001b[39m=\u001b[39;49mcurrent_device_map,\n\u001b[1;32m 892\u001b[0m max_memory\u001b[39m=\u001b[39;49mmax_memory,\n\u001b[1;32m 893\u001b[0m offload_folder\u001b[39m=\u001b[39;49moffload_folder,\n\u001b[1;32m 894\u001b[0m offload_state_dict\u001b[39m=\u001b[39;49moffload_state_dict,\n\u001b[1;32m 895\u001b[0m model_variants\u001b[39m=\u001b[39;49mmodel_variants,\n\u001b[1;32m 896\u001b[0m name\u001b[39m=\u001b[39;49mname,\n\u001b[1;32m 897\u001b[0m from_flax\u001b[39m=\u001b[39;49mfrom_flax,\n\u001b[1;32m 898\u001b[0m variant\u001b[39m=\u001b[39;49mvariant,\n\u001b[1;32m 899\u001b[0m low_cpu_mem_usage\u001b[39m=\u001b[39;49mlow_cpu_mem_usage,\n\u001b[1;32m 900\u001b[0m cached_folder\u001b[39m=\u001b[39;49mcached_folder,\n\u001b[1;32m 901\u001b[0m )\n\u001b[1;32m 902\u001b[0m logger\u001b[39m.\u001b[39minfo(\n\u001b[1;32m 903\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mLoaded \u001b[39m\u001b[39m{\u001b[39;00mname\u001b[39m}\u001b[39;00m\u001b[39m as \u001b[39m\u001b[39m{\u001b[39;00mclass_name\u001b[39m}\u001b[39;00m\u001b[39m from `\u001b[39m\u001b[39m{\u001b[39;00mname\u001b[39m}\u001b[39;00m\u001b[39m` subfolder of \u001b[39m\u001b[39m{\u001b[39;00mpretrained_model_name_or_path\u001b[39m}\u001b[39;00m\u001b[39m.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 904\u001b[0m )\n\u001b[1;32m 906\u001b[0m init_kwargs[name] \u001b[39m=\u001b[39m loaded_sub_model \u001b[39m# UNet(...), # DiffusionSchedule(...)\u001b[39;00m\n", "File \u001b[0;32m~/StoryDiffusion/venv/lib/python3.12/site-packages/diffusers/pipelines/pipeline_loading_utils.py:703\u001b[0m, in \u001b[0;36mload_sub_model\u001b[0;34m(library_name, class_name, importable_classes, pipelines, is_pipeline_module, pipeline_class, torch_dtype, provider, sess_options, device_map, max_memory, offload_folder, offload_state_dict, model_variants, name, from_flax, variant, low_cpu_mem_usage, cached_folder)\u001b[0m\n\u001b[1;32m 701\u001b[0m \u001b[39m# check if the module is in a subdirectory\u001b[39;00m\n\u001b[1;32m 702\u001b[0m \u001b[39mif\u001b[39;00m os\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39misdir(os\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mjoin(cached_folder, name)):\n\u001b[0;32m--> 703\u001b[0m loaded_sub_model \u001b[39m=\u001b[39m load_method(os\u001b[39m.\u001b[39;49mpath\u001b[39m.\u001b[39;49mjoin(cached_folder, name), \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mloading_kwargs)\n\u001b[1;32m 704\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 705\u001b[0m \u001b[39m# else load from the root directory\u001b[39;00m\n\u001b[1;32m 706\u001b[0m loaded_sub_model \u001b[39m=\u001b[39m load_method(cached_folder, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mloading_kwargs)\n", "File \u001b[0;32m~/StoryDiffusion/venv/lib/python3.12/site-packages/transformers/modeling_utils.py:3447\u001b[0m, in \u001b[0;36mPreTrainedModel.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, *model_args, **kwargs)\u001b[0m\n\u001b[1;32m 3442\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mEnvironmentError\u001b[39;00m(\n\u001b[1;32m 3443\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mError no file named \u001b[39m\u001b[39m{\u001b[39;00m_add_variant(SAFE_WEIGHTS_NAME,\u001b[39m \u001b[39mvariant)\u001b[39m}\u001b[39;00m\u001b[39m found in directory\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 3444\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m \u001b[39m\u001b[39m{\u001b[39;00mpretrained_model_name_or_path\u001b[39m}\u001b[39;00m\u001b[39m.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 3445\u001b[0m )\n\u001b[1;32m 3446\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m-> 3447\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mEnvironmentError\u001b[39;00m(\n\u001b[1;32m 3448\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mError no file named \u001b[39m\u001b[39m{\u001b[39;00m_add_variant(WEIGHTS_NAME,\u001b[39m \u001b[39mvariant)\u001b[39m}\u001b[39;00m\u001b[39m, \u001b[39m\u001b[39m{\u001b[39;00m_add_variant(SAFE_WEIGHTS_NAME,\u001b[39m \u001b[39mvariant)\u001b[39m}\u001b[39;00m\u001b[39m,\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 3449\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m \u001b[39m\u001b[39m{\u001b[39;00mTF2_WEIGHTS_NAME\u001b[39m}\u001b[39;00m\u001b[39m, \u001b[39m\u001b[39m{\u001b[39;00mTF_WEIGHTS_NAME\u001b[39m \u001b[39m\u001b[39m+\u001b[39m\u001b[39m \u001b[39m\u001b[39m'\u001b[39m\u001b[39m.index\u001b[39m\u001b[39m'\u001b[39m\u001b[39m}\u001b[39;00m\u001b[39m or \u001b[39m\u001b[39m{\u001b[39;00mFLAX_WEIGHTS_NAME\u001b[39m}\u001b[39;00m\u001b[39m found in directory\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 3450\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m \u001b[39m\u001b[39m{\u001b[39;00mpretrained_model_name_or_path\u001b[39m}\u001b[39;00m\u001b[39m.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 3451\u001b[0m )\n\u001b[1;32m 3452\u001b[0m \u001b[39melif\u001b[39;00m os\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39misfile(os\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mjoin(subfolder, pretrained_model_name_or_path)):\n\u001b[1;32m 3453\u001b[0m archive_file \u001b[39m=\u001b[39m pretrained_model_name_or_path\n", "\u001b[0;31mOSError\u001b[0m: Error no file named pytorch_model.bin, model.safetensors, tf_model.h5, model.ckpt.index or flax_model.msgpack found in directory /Users/flex/.cache/huggingface/hub/models--SG161222--RealVisXL_V4.0/snapshots/49740684ab2d8f4f5dcf6c644df2b33388a8ba85/text_encoder." ] } ], "source": [ "global attn_count, total_count, id_length, total_length,cur_step, cur_model_type\n", "global write\n", "global sa32, sa64\n", "global height,width\n", "attn_count = 0\n", "total_count = 0\n", "cur_step = 0\n", "id_length = 4\n", "total_length = 5\n", "cur_model_type = \"\"\n", "device=\"cuda\"\n", "global attn_procs,unet\n", "attn_procs = {}\n", "###\n", "write = False\n", "### strength of consistent self-attention: the larger, the stronger\n", "sa32 = 0.5\n", "sa64 = 0.5\n", "### Res. of the Generated Comics. Please Note: SDXL models may do worse in a low-resolution! \n", "height = 768\n", "width = 768\n", "###\n", "global pipe\n", "global sd_model_path\n", "sd_model_path = models_dict[\"RealVision\"] #\"SG161222/RealVisXL_V4.0\"\n", "### LOAD Stable Diffusion Pipeline\n", "pipe = StableDiffusionXLPipeline.from_pretrained(sd_model_path, torch_dtype=torch.float16, use_safetensors=False)\n", "pipe = pipe.to(device)\n", "pipe.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2)\n", "pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)\n", "pipe.scheduler.set_timesteps(50)\n", "unet = pipe.unet\n", "\n", "### Insert PairedAttention\n", "for name in unet.attn_processors.keys():\n", " cross_attention_dim = None if name.endswith(\"attn1.processor\") else unet.config.cross_attention_dim\n", " if name.startswith(\"mid_block\"):\n", " hidden_size = unet.config.block_out_channels[-1]\n", " elif name.startswith(\"up_blocks\"):\n", " block_id = int(name[len(\"up_blocks.\")])\n", " hidden_size = list(reversed(unet.config.block_out_channels))[block_id]\n", " elif name.startswith(\"down_blocks\"):\n", " block_id = int(name[len(\"down_blocks.\")])\n", " hidden_size = unet.config.block_out_channels[block_id]\n", " if cross_attention_dim is None and (name.startswith(\"up_blocks\") ) :\n", " attn_procs[name] = SpatialAttnProcessor2_0(id_length = id_length)\n", " total_count +=1\n", " else:\n", " attn_procs[name] = AttnProcessor()\n", "print(\"successsfully load consistent self-attention\")\n", "print(f\"number of the processor : {total_count}\")\n", "unet.set_attn_processor(copy.deepcopy(attn_procs))\n", "global mask1024,mask4096\n", "mask1024, mask4096 = cal_attn_mask_xl(total_length,id_length,sa32,sa64,height,width,device=device,dtype= torch.float16)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create the text description for the comics\n", "Tips: Existing text2image diffusion models may not always generate images that accurately match text descriptions. Our training-free approach can improve the consistency of characters, but it does not enhance the control over the text. Therefore, in some cases, you may need to carefully craft your prompts." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "guidance_scale = 5.0\n", "seed = 2047\n", "sa32 = 0.5\n", "sa64 = 0.5\n", "id_length = 4\n", "num_steps = 50\n", "general_prompt = \"a man with a black suit\"\n", "negative_prompt = \"naked, deformed, bad anatomy, disfigured, poorly drawn face, mutation, extra limb, ugly, disgusting, poorly drawn hands, missing limb, floating limbs, disconnected limbs, blurry, watermarks, oversaturated, distorted hands, amputation\"\n", "prompt_array = [\"wake up in the bed\",\n", " \"have breakfast\",\n", " \"is on the road, go to the company\",\n", " \"work in the company\",\n", " \"running in the playground\",\n", " \"reading book in the home\"\n", " ]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def apply_style_positive(style_name: str, positive: str):\n", " p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])\n", " return p.replace(\"{prompt}\", positive) \n", "def apply_style(style_name: str, positives: list, negative: str = \"\"):\n", " p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])\n", " return [p.replace(\"{prompt}\", positive) for positive in positives], n + ' ' + negative\n", "### Set the generated Style\n", "style_name = \"Comic book\"\n", "setup_seed(seed)\n", "generator = torch.Generator(device=\"cuda\").manual_seed(seed)\n", "prompts = [general_prompt+\",\"+prompt for prompt in prompt_array]\n", "id_prompts = prompts[:id_length]\n", "real_prompts = prompts[id_length:]\n", "torch.cuda.empty_cache()\n", "write = True\n", "cur_step = 0\n", "attn_count = 0\n", "id_prompts, negative_prompt = apply_style(style_name, id_prompts, negative_prompt)\n", "id_images = pipe(id_prompts, num_inference_steps = num_steps, guidance_scale=guidance_scale, height = height, width = width,negative_prompt = negative_prompt,generator = generator).images\n", "\n", "write = False\n", "for id_image in id_images:\n", " display(id_image)\n", "real_images = []\n", "for real_prompt in real_prompts:\n", " cur_step = 0\n", " real_prompt = apply_style_positive(style_name, real_prompt)\n", " real_images.append(pipe(real_prompt, num_inference_steps=num_steps, guidance_scale=guidance_scale, height = height, width = width,negative_prompt = negative_prompt,generator = generator).images[0])\n", "for real_image in real_images:\n", " display(real_image) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Continued Creation\n", "From now on, you can create endless stories about this character without worrying about memory constraints." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "new_prompt_array = [\"siting on the sofa\",\n", " \"on the bed, at night \"]\n", "new_prompts = [general_prompt+\",\"+prompt for prompt in new_prompt_array]\n", "new_images = []\n", "for new_prompt in new_prompts :\n", " cur_step = 0\n", " new_prompt = apply_style_positive(style_name, new_prompt)\n", " new_images.append(pipe(new_prompt, num_inference_steps=num_steps, guidance_scale=guidance_scale, height = height, width = width,negative_prompt = negative_prompt,generator = generator).images[0])\n", "for new_image in new_images:\n", " display(new_image) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Make pictures into comics" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "###\n", "total_images = id_images + real_images + new_images\n", "from PIL import Image,ImageOps,ImageDraw, ImageFont\n", "#### LOAD Fonts, can also replace with any Fonts you have!\n", "font = ImageFont.truetype(\"./fonts/Inkfree.ttf\", 30)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# import importlib\n", "# import utils.utils\n", "# importlib.reload(utils)\n", "from utils.utils import get_row_image\n", "from utils.utils import get_row_image\n", "from utils.utils import get_comic_4panel" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "comics = get_comic_4panel(total_images, captions = prompt_array+ new_prompts,font = font )\n", "for comic in comics:\n", " display(comic)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "fileId": "51613593-0d85-430e-8fce-c85e580fc483", "kernelspec": { "display_name": "venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.4" }, "vscode": { "interpreter": { "hash": "c1bd42f2f9f6cfcf8171e9e1e863f0572afe983234a3d808193da2fd055f98b3" } } }, "nbformat": 4, "nbformat_minor": 4 }