lzyhha
clean
af44a4b
import torch
from torch import nn
def replace_linear_with_lora(
module: nn.Module,
max_rank: int,
scale: float = 1.0,
) -> None:
for name, child in module.named_children():
if isinstance(child, nn.Linear):
new_lora = LinearLora(
in_features=child.in_features,
out_features=child.out_features,
bias=child.bias,
rank=max_rank,
scale=scale,
dtype=child.weight.dtype,
device=child.weight.device,
)
new_lora.weight = child.weight
new_lora.bias = child.bias if child.bias is not None else None
setattr(module, name, new_lora)
else:
replace_linear_with_lora(
module=child,
max_rank=max_rank,
scale=scale,
)
class LinearLora(nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool,
rank: int,
dtype: torch.dtype,
device: torch.device,
lora_bias: bool = True,
scale: float = 1.0,
*args,
**kwargs,
) -> None:
super().__init__(
in_features=in_features,
out_features=out_features,
bias=bias is not None,
device=device,
dtype=dtype,
*args,
**kwargs,
)
assert isinstance(scale, float), "scale must be a float"
self.scale = scale
self.rank = rank
self.lora_bias = lora_bias
self.dtype = dtype
self.device = device
if rank > (new_rank := min(self.out_features, self.in_features)):
self.rank = new_rank
self.lora_A = nn.Linear(
in_features=in_features,
out_features=self.rank,
bias=False,
dtype=dtype,
device=device,
)
self.lora_B = nn.Linear(
in_features=self.rank,
out_features=out_features,
bias=self.lora_bias,
dtype=dtype,
device=device,
)
nn.init.zeros_(self.lora_B.weight)
if self.lora_B.bias is not None:
nn.init.zeros_(self.lora_B.bias)
def set_scale(self, scale: float) -> None:
assert isinstance(scale, float), "scalar value must be a float"
self.scale = scale
def forward(self, input: torch.Tensor) -> torch.Tensor:
base_out = super().forward(input)
_lora_out_B = self.lora_B(self.lora_A(input))
lora_update = _lora_out_B * self.scale
return base_out + lora_update
class MixtureOfLoRAExperts(nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
num_experts: int,
rank: int,
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None,
scale: float = 1.0,
top_k: int = 2, # 选择前k个专家
) -> None:
super().__init__(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
dtype=dtype,
)
self.num_experts = num_experts
self.rank = min(rank, min(in_features, out_features))
self.scale = scale
self.top_k = min(top_k, num_experts)
# 共享LoRA模块
self.shared_lora_A = nn.Linear(in_features, self.rank, bias=False, dtype=dtype, device=device)
self.shared_lora_B = nn.Linear(self.rank, out_features, bias=False, dtype=dtype, device=device)
# 专家LoRA模块
self.expert_lora_A = nn.ModuleList([
nn.Linear(in_features, self.rank, bias=False, dtype=dtype, device=device)
for _ in range(num_experts)
])
self.expert_lora_B = nn.ModuleList([
nn.Linear(self.rank, out_features, bias=False, dtype=dtype, device=device)
for _ in range(num_experts)
])
# 门控网络
self.gate = nn.Linear(in_features, num_experts, dtype=dtype, device=device)
# 初始化
self._init_weights()
def _init_weights(self) -> None:
# 初始化共享LoRA
nn.init.zeros_(self.shared_lora_B.weight)
# 初始化专家LoRA
for expert_B in self.expert_lora_B:
nn.init.zeros_(expert_B.weight)
# 初始化门控网络
nn.init.zeros_(self.gate.bias)
nn.init.normal_(self.gate.weight, std=0.01)
def forward(self, input: torch.Tensor) -> torch.Tensor:
batch_size = input.shape[0]
# 基础输出 (与原始Linear层相同)
base_out = super().forward(input)
# 共享LoRA输出
shared_lora = self.shared_lora_B(self.shared_lora_A(input))
# 计算门控权重
gate_logits = self.gate(input)
gate_weights = F.softmax(gate_logits, dim=-1)
# 选择top-k专家
top_k_weights, top_k_indices = torch.topk(gate_weights, self.top_k, dim=-1)
top_k_weights = top_k_weights / top_k_weights.sum(dim=-1, keepdim=True)
# 计算专家输出
expert_outputs = torch.zeros_like(base_out)
for k in range(self.top_k):
# 获取当前批次中每个样本选中的专家索引
expert_idx = top_k_indices[:, k]
expert_weight = top_k_weights[:, k].unsqueeze(-1)
# 为每个样本单独计算选中专家的输出
for i in range(batch_size):
idx = expert_idx[i]
expert_out = self.expert_lora_B[idx](self.expert_lora_A[idx](input[i:i+1]))
expert_outputs[i:i+1] += expert_out * expert_weight[i]
# 组合所有输出
final_output = (
base_out +
self.scale * (shared_lora + expert_outputs)
)
return final_output
def set_scale(self, scale: float) -> None:
assert isinstance(scale, float), "scale must be a float"
self.scale = scale