File size: 4,222 Bytes
b20af9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
from diffusers.loaders import AttnProcsLayers

from modules.BEATs.BEATs import BEATs, BEATsConfig
from modules.AudioToken.embedder import FGAEmbedder
from diffusers import AutoencoderKL, UNet2DConditionModel
from diffusers.models.attention_processor import LoRAAttnProcessor


class AudioTokenWrapper(torch.nn.Module):
    """Simple wrapper module for Stable Diffusion that holds all the models together"""

    def __init__(
        self,
        args,
        accelerator,
    ):

        super().__init__()
        # Load scheduler and models
        from modules.clip_text_model.modeling_clip import CLIPTextModel
        self.text_encoder = CLIPTextModel.from_pretrained(
            args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
        )
        self.unet = UNet2DConditionModel.from_pretrained(
            args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
        )
        self.vae = AutoencoderKL.from_pretrained(
            args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
        )

        checkpoint = torch.load(
            'models/BEATs/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt')
        cfg = BEATsConfig(checkpoint['cfg'])
        self.aud_encoder = BEATs(cfg)
        self.aud_encoder.load_state_dict(checkpoint['model'])
        self.aud_encoder.predictor = None
        input_size = 768 * 3

        if args.pretrained_model_name_or_path == "CompVis/stable-diffusion-v1-4":
            self.embedder = FGAEmbedder(input_size=input_size, output_size=768)

        else:
            self.embedder = FGAEmbedder(input_size=input_size, output_size=1024)

        self.vae.eval()
        self.unet.eval()
        self.text_encoder.eval()
        self.aud_encoder.eval()

        if 'lora' in args and args.lora:
            # Set correct lora layers
            lora_attn_procs = {}
            for name in self.unet.attn_processors.keys():
                cross_attention_dim = None if name.endswith(
                    "attn1.processor") else self.unet.config.cross_attention_dim
                if name.startswith("mid_block"):
                    hidden_size = self.unet.config.block_out_channels[-1]
                elif name.startswith("up_blocks"):
                    block_id = int(name[len("up_blocks.")])
                    hidden_size = list(reversed(self.unet.config.block_out_channels))[block_id]
                elif name.startswith("down_blocks"):
                    block_id = int(name[len("down_blocks.")])
                    hidden_size = self.unet.config.block_out_channels[block_id]

                lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size,
                                                          cross_attention_dim=cross_attention_dim)

            self.unet.set_attn_processor(lora_attn_procs)
            self.lora_layers = AttnProcsLayers(self.unet.attn_processors)

        if args.data_set == 'train':

            # Freeze vae, unet, text_enc and aud_encoder
            self.vae.requires_grad_(False)
            self.unet.requires_grad_(False)
            self.text_encoder.requires_grad_(False)
            self.aud_encoder.requires_grad_(False)
            self.embedder.requires_grad_(True)
            self.embedder.train()

            if 'lora' in args and args.lora:
                self.unet.train()

        if args.data_set == 'test':

            from transformers import CLIPTextModel
            self.text_encoder = CLIPTextModel.from_pretrained(
                args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
            )

            self.embedder.eval()
            embedder_learned_embeds = args.learned_embeds
            self.embedder.load_state_dict(torch.load(embedder_learned_embeds, map_location=accelerator.device))

            if 'lora' in args and args.lora:
                self.lora_layers.eval()
                lora_layers_learned_embeds = args.lora_learned_embeds
                self.lora_layers.load_state_dict(torch.load(lora_layers_learned_embeds, map_location=accelerator.device))
                self.unet.load_attn_procs(lora_layers_learned_embeds)