Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,207 Bytes
319886d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
import torch
from torch import nn
def replace_linear_with_lora(
module: nn.Module,
max_rank: int,
scale: float = 1.0,
) -> None:
for name, child in module.named_children():
if isinstance(child, nn.Linear):
new_lora = LinearLora(
in_features=child.in_features,
out_features=child.out_features,
bias=child.bias,
rank=max_rank,
scale=scale,
dtype=child.weight.dtype,
device=child.weight.device,
)
new_lora.weight = child.weight
new_lora.bias = child.bias if child.bias is not None else None
setattr(module, name, new_lora)
else:
replace_linear_with_lora(
module=child,
max_rank=max_rank,
scale=scale,
)
class LinearLora(nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool,
rank: int,
dtype: torch.dtype,
device: torch.device,
lora_bias: bool = True,
scale: float = 1.0,
*args,
**kwargs,
) -> None:
super().__init__(
in_features=in_features,
out_features=out_features,
bias=bias is not None,
device=device,
dtype=dtype,
*args,
**kwargs,
)
assert isinstance(scale, float), "scale must be a float"
self.scale = scale
self.rank = rank
self.lora_bias = lora_bias
self.dtype = dtype
self.device = device
if rank > (new_rank := min(self.out_features, self.in_features)):
self.rank = new_rank
self.lora_A = nn.Linear(
in_features=in_features,
out_features=self.rank,
bias=False,
dtype=dtype,
device=device,
)
self.lora_B = nn.Linear(
in_features=self.rank,
out_features=out_features,
bias=self.lora_bias,
dtype=dtype,
device=device,
)
nn.init.zeros_(self.lora_B.weight)
if self.lora_B.bias is not None:
nn.init.zeros_(self.lora_B.bias)
def set_scale(self, scale: float) -> None:
assert isinstance(scale, float), "scalar value must be a float"
self.scale = scale
def forward(self, input: torch.Tensor) -> torch.Tensor:
base_out = super().forward(input)
_lora_out_B = self.lora_B(self.lora_A(input))
lora_update = _lora_out_B * self.scale
return base_out + lora_update
class MixtureOfLoRAExperts(nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
num_experts: int,
rank: int,
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None,
scale: float = 1.0,
top_k: int = 2, # 选择前k个专家
) -> None:
super().__init__(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
dtype=dtype,
)
self.num_experts = num_experts
self.rank = min(rank, min(in_features, out_features))
self.scale = scale
self.top_k = min(top_k, num_experts)
# 共享LoRA模块
self.shared_lora_A = nn.Linear(in_features, self.rank, bias=False, dtype=dtype, device=device)
self.shared_lora_B = nn.Linear(self.rank, out_features, bias=False, dtype=dtype, device=device)
# 专家LoRA模块
self.expert_lora_A = nn.ModuleList([
nn.Linear(in_features, self.rank, bias=False, dtype=dtype, device=device)
for _ in range(num_experts)
])
self.expert_lora_B = nn.ModuleList([
nn.Linear(self.rank, out_features, bias=False, dtype=dtype, device=device)
for _ in range(num_experts)
])
# 门控网络
self.gate = nn.Linear(in_features, num_experts, dtype=dtype, device=device)
# 初始化
self._init_weights()
def _init_weights(self) -> None:
# 初始化共享LoRA
nn.init.zeros_(self.shared_lora_B.weight)
# 初始化专家LoRA
for expert_B in self.expert_lora_B:
nn.init.zeros_(expert_B.weight)
# 初始化门控网络
nn.init.zeros_(self.gate.bias)
nn.init.normal_(self.gate.weight, std=0.01)
def forward(self, input: torch.Tensor) -> torch.Tensor:
batch_size = input.shape[0]
# 基础输出 (与原始Linear层相同)
base_out = super().forward(input)
# 共享LoRA输出
shared_lora = self.shared_lora_B(self.shared_lora_A(input))
# 计算门控权重
gate_logits = self.gate(input)
gate_weights = F.softmax(gate_logits, dim=-1)
# 选择top-k专家
top_k_weights, top_k_indices = torch.topk(gate_weights, self.top_k, dim=-1)
top_k_weights = top_k_weights / top_k_weights.sum(dim=-1, keepdim=True)
# 计算专家输出
expert_outputs = torch.zeros_like(base_out)
for k in range(self.top_k):
# 获取当前批次中每个样本选中的专家索引
expert_idx = top_k_indices[:, k]
expert_weight = top_k_weights[:, k].unsqueeze(-1)
# 为每个样本单独计算选中专家的输出
for i in range(batch_size):
idx = expert_idx[i]
expert_out = self.expert_lora_B[idx](self.expert_lora_A[idx](input[i:i+1]))
expert_outputs[i:i+1] += expert_out * expert_weight[i]
# 组合所有输出
final_output = (
base_out +
self.scale * (shared_lora + expert_outputs)
)
return final_output
def set_scale(self, scale: float) -> None:
assert isinstance(scale, float), "scale must be a float"
self.scale = scale |