File size: 4,060 Bytes
5e373a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
# This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and
# Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example
# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors.
# CreativeML Open RAIL-M
#
# ==========================================================================================
#
# Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved.
# Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit
# LICENSE.md.
#
# ==========================================================================================
"""
Transformer implementation adapted from CLIP ViT:
https://github.com/openai/CLIP/blob/4c0275784d6d9da97ca1f47eaaee31de1867da91/clip/model.py
"""
import math
import torch as th
import torch.nn as nn
def convert_module_to_f16(l):
"""
Convert primitive modules to float16.
"""
if isinstance(l, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
l.weight.data = l.weight.data.half()
if l.bias is not None:
l.bias.data = l.bias.data.half()
class LayerNorm(nn.LayerNorm):
"""
Implementation that supports fp16 inputs but fp32 gains/biases.
"""
def forward(self, x: th.Tensor):
return super().forward(x.float()).to(x.dtype)
class MultiheadAttention(nn.Module):
def __init__(self, n_ctx, width, heads):
super().__init__()
self.n_ctx = n_ctx
self.width = width
self.heads = heads
self.c_qkv = nn.Linear(width, width * 3)
self.c_proj = nn.Linear(width, width)
self.attention = QKVMultiheadAttention(heads, n_ctx)
def forward(self, x):
x = self.c_qkv(x)
x = self.attention(x)
x = self.c_proj(x)
return x
class MLP(nn.Module):
def __init__(self, width):
super().__init__()
self.width = width
self.c_fc = nn.Linear(width, width * 4)
self.c_proj = nn.Linear(width * 4, width)
self.gelu = nn.GELU()
def forward(self, x):
return self.c_proj(self.gelu(self.c_fc(x)))
class QKVMultiheadAttention(nn.Module):
def __init__(self, n_heads: int, n_ctx: int):
super().__init__()
self.n_heads = n_heads
self.n_ctx = n_ctx
def forward(self, qkv):
bs, n_ctx, width = qkv.shape
attn_ch = width // self.n_heads // 3
scale = 1 / math.sqrt(math.sqrt(attn_ch))
qkv = qkv.view(bs, n_ctx, self.n_heads, -1)
q, k, v = th.split(qkv, attn_ch, dim=-1)
weight = th.einsum(
"bthc,bshc->bhts", q * scale, k * scale
) # More stable with f16 than dividing afterwards
wdtype = weight.dtype
weight = th.softmax(weight.float(), dim=-1).type(wdtype)
return th.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
class ResidualAttentionBlock(nn.Module):
def __init__(
self,
n_ctx: int,
width: int,
heads: int,
):
super().__init__()
self.attn = MultiheadAttention(
n_ctx,
width,
heads,
)
self.ln_1 = LayerNorm(width)
self.mlp = MLP(width)
self.ln_2 = LayerNorm(width)
def forward(self, x: th.Tensor):
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class Transformer(nn.Module):
def __init__(
self,
n_ctx: int,
width: int,
layers: int,
heads: int,
):
super().__init__()
self.n_ctx = n_ctx
self.width = width
self.layers = layers
self.resblocks = nn.ModuleList(
[
ResidualAttentionBlock(
n_ctx,
width,
heads,
)
for _ in range(layers)
]
)
def forward(self, x: th.Tensor):
for block in self.resblocks:
x = block(x)
return x
|