Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from torch import nn | |
def replace_linear_with_lora( | |
module: nn.Module, | |
max_rank: int, | |
scale: float = 1.0, | |
) -> None: | |
for name, child in module.named_children(): | |
if isinstance(child, nn.Linear): | |
new_lora = LinearLora( | |
in_features=child.in_features, | |
out_features=child.out_features, | |
bias=child.bias, | |
rank=max_rank, | |
scale=scale, | |
dtype=child.weight.dtype, | |
device=child.weight.device, | |
) | |
new_lora.weight = child.weight | |
new_lora.bias = child.bias if child.bias is not None else None | |
setattr(module, name, new_lora) | |
else: | |
replace_linear_with_lora( | |
module=child, | |
max_rank=max_rank, | |
scale=scale, | |
) | |
class LinearLora(nn.Linear): | |
def __init__( | |
self, | |
in_features: int, | |
out_features: int, | |
bias: bool, | |
rank: int, | |
dtype: torch.dtype, | |
device: torch.device, | |
lora_bias: bool = True, | |
scale: float = 1.0, | |
*args, | |
**kwargs, | |
) -> None: | |
super().__init__( | |
in_features=in_features, | |
out_features=out_features, | |
bias=bias is not None, | |
device=device, | |
dtype=dtype, | |
*args, | |
**kwargs, | |
) | |
assert isinstance(scale, float), "scale must be a float" | |
self.scale = scale | |
self.rank = rank | |
self.lora_bias = lora_bias | |
self.dtype = dtype | |
self.device = device | |
if rank > (new_rank := min(self.out_features, self.in_features)): | |
self.rank = new_rank | |
self.lora_A = nn.Linear( | |
in_features=in_features, | |
out_features=self.rank, | |
bias=False, | |
dtype=dtype, | |
device=device, | |
) | |
self.lora_B = nn.Linear( | |
in_features=self.rank, | |
out_features=out_features, | |
bias=self.lora_bias, | |
dtype=dtype, | |
device=device, | |
) | |
nn.init.zeros_(self.lora_B.weight) | |
if self.lora_B.bias is not None: | |
nn.init.zeros_(self.lora_B.bias) | |
def set_scale(self, scale: float) -> None: | |
assert isinstance(scale, float), "scalar value must be a float" | |
self.scale = scale | |
def forward(self, input: torch.Tensor) -> torch.Tensor: | |
base_out = super().forward(input) | |
_lora_out_B = self.lora_B(self.lora_A(input)) | |
lora_update = _lora_out_B * self.scale | |
return base_out + lora_update | |
class MixtureOfLoRAExperts(nn.Linear): | |
def __init__( | |
self, | |
in_features: int, | |
out_features: int, | |
num_experts: int, | |
rank: int, | |
bias: bool = True, | |
dtype: torch.dtype = None, | |
device: torch.device = None, | |
scale: float = 1.0, | |
top_k: int = 2, # 选择前k个专家 | |
) -> None: | |
super().__init__( | |
in_features=in_features, | |
out_features=out_features, | |
bias=bias, | |
device=device, | |
dtype=dtype, | |
) | |
self.num_experts = num_experts | |
self.rank = min(rank, min(in_features, out_features)) | |
self.scale = scale | |
self.top_k = min(top_k, num_experts) | |
# 共享LoRA模块 | |
self.shared_lora_A = nn.Linear(in_features, self.rank, bias=False, dtype=dtype, device=device) | |
self.shared_lora_B = nn.Linear(self.rank, out_features, bias=False, dtype=dtype, device=device) | |
# 专家LoRA模块 | |
self.expert_lora_A = nn.ModuleList([ | |
nn.Linear(in_features, self.rank, bias=False, dtype=dtype, device=device) | |
for _ in range(num_experts) | |
]) | |
self.expert_lora_B = nn.ModuleList([ | |
nn.Linear(self.rank, out_features, bias=False, dtype=dtype, device=device) | |
for _ in range(num_experts) | |
]) | |
# 门控网络 | |
self.gate = nn.Linear(in_features, num_experts, dtype=dtype, device=device) | |
# 初始化 | |
self._init_weights() | |
def _init_weights(self) -> None: | |
# 初始化共享LoRA | |
nn.init.zeros_(self.shared_lora_B.weight) | |
# 初始化专家LoRA | |
for expert_B in self.expert_lora_B: | |
nn.init.zeros_(expert_B.weight) | |
# 初始化门控网络 | |
nn.init.zeros_(self.gate.bias) | |
nn.init.normal_(self.gate.weight, std=0.01) | |
def forward(self, input: torch.Tensor) -> torch.Tensor: | |
batch_size = input.shape[0] | |
# 基础输出 (与原始Linear层相同) | |
base_out = super().forward(input) | |
# 共享LoRA输出 | |
shared_lora = self.shared_lora_B(self.shared_lora_A(input)) | |
# 计算门控权重 | |
gate_logits = self.gate(input) | |
gate_weights = F.softmax(gate_logits, dim=-1) | |
# 选择top-k专家 | |
top_k_weights, top_k_indices = torch.topk(gate_weights, self.top_k, dim=-1) | |
top_k_weights = top_k_weights / top_k_weights.sum(dim=-1, keepdim=True) | |
# 计算专家输出 | |
expert_outputs = torch.zeros_like(base_out) | |
for k in range(self.top_k): | |
# 获取当前批次中每个样本选中的专家索引 | |
expert_idx = top_k_indices[:, k] | |
expert_weight = top_k_weights[:, k].unsqueeze(-1) | |
# 为每个样本单独计算选中专家的输出 | |
for i in range(batch_size): | |
idx = expert_idx[i] | |
expert_out = self.expert_lora_B[idx](self.expert_lora_A[idx](input[i:i+1])) | |
expert_outputs[i:i+1] += expert_out * expert_weight[i] | |
# 组合所有输出 | |
final_output = ( | |
base_out + | |
self.scale * (shared_lora + expert_outputs) | |
) | |
return final_output | |
def set_scale(self, scale: float) -> None: | |
assert isinstance(scale, float), "scale must be a float" | |
self.scale = scale |