import torch from torch import nn def replace_linear_with_lora( module: nn.Module, max_rank: int, scale: float = 1.0, ) -> None: for name, child in module.named_children(): if isinstance(child, nn.Linear): new_lora = LinearLora( in_features=child.in_features, out_features=child.out_features, bias=child.bias, rank=max_rank, scale=scale, dtype=child.weight.dtype, device=child.weight.device, ) new_lora.weight = child.weight new_lora.bias = child.bias if child.bias is not None else None setattr(module, name, new_lora) else: replace_linear_with_lora( module=child, max_rank=max_rank, scale=scale, ) class LinearLora(nn.Linear): def __init__( self, in_features: int, out_features: int, bias: bool, rank: int, dtype: torch.dtype, device: torch.device, lora_bias: bool = True, scale: float = 1.0, *args, **kwargs, ) -> None: super().__init__( in_features=in_features, out_features=out_features, bias=bias is not None, device=device, dtype=dtype, *args, **kwargs, ) assert isinstance(scale, float), "scale must be a float" self.scale = scale self.rank = rank self.lora_bias = lora_bias self.dtype = dtype self.device = device if rank > (new_rank := min(self.out_features, self.in_features)): self.rank = new_rank self.lora_A = nn.Linear( in_features=in_features, out_features=self.rank, bias=False, dtype=dtype, device=device, ) self.lora_B = nn.Linear( in_features=self.rank, out_features=out_features, bias=self.lora_bias, dtype=dtype, device=device, ) nn.init.zeros_(self.lora_B.weight) if self.lora_B.bias is not None: nn.init.zeros_(self.lora_B.bias) def set_scale(self, scale: float) -> None: assert isinstance(scale, float), "scalar value must be a float" self.scale = scale def forward(self, input: torch.Tensor) -> torch.Tensor: base_out = super().forward(input) _lora_out_B = self.lora_B(self.lora_A(input)) lora_update = _lora_out_B * self.scale return base_out + lora_update class MixtureOfLoRAExperts(nn.Linear): def __init__( self, in_features: int, out_features: int, num_experts: int, rank: int, bias: bool = True, dtype: torch.dtype = None, device: torch.device = None, scale: float = 1.0, top_k: int = 2, # 选择前k个专家 ) -> None: super().__init__( in_features=in_features, out_features=out_features, bias=bias, device=device, dtype=dtype, ) self.num_experts = num_experts self.rank = min(rank, min(in_features, out_features)) self.scale = scale self.top_k = min(top_k, num_experts) # 共享LoRA模块 self.shared_lora_A = nn.Linear(in_features, self.rank, bias=False, dtype=dtype, device=device) self.shared_lora_B = nn.Linear(self.rank, out_features, bias=False, dtype=dtype, device=device) # 专家LoRA模块 self.expert_lora_A = nn.ModuleList([ nn.Linear(in_features, self.rank, bias=False, dtype=dtype, device=device) for _ in range(num_experts) ]) self.expert_lora_B = nn.ModuleList([ nn.Linear(self.rank, out_features, bias=False, dtype=dtype, device=device) for _ in range(num_experts) ]) # 门控网络 self.gate = nn.Linear(in_features, num_experts, dtype=dtype, device=device) # 初始化 self._init_weights() def _init_weights(self) -> None: # 初始化共享LoRA nn.init.zeros_(self.shared_lora_B.weight) # 初始化专家LoRA for expert_B in self.expert_lora_B: nn.init.zeros_(expert_B.weight) # 初始化门控网络 nn.init.zeros_(self.gate.bias) nn.init.normal_(self.gate.weight, std=0.01) def forward(self, input: torch.Tensor) -> torch.Tensor: batch_size = input.shape[0] # 基础输出 (与原始Linear层相同) base_out = super().forward(input) # 共享LoRA输出 shared_lora = self.shared_lora_B(self.shared_lora_A(input)) # 计算门控权重 gate_logits = self.gate(input) gate_weights = F.softmax(gate_logits, dim=-1) # 选择top-k专家 top_k_weights, top_k_indices = torch.topk(gate_weights, self.top_k, dim=-1) top_k_weights = top_k_weights / top_k_weights.sum(dim=-1, keepdim=True) # 计算专家输出 expert_outputs = torch.zeros_like(base_out) for k in range(self.top_k): # 获取当前批次中每个样本选中的专家索引 expert_idx = top_k_indices[:, k] expert_weight = top_k_weights[:, k].unsqueeze(-1) # 为每个样本单独计算选中专家的输出 for i in range(batch_size): idx = expert_idx[i] expert_out = self.expert_lora_B[idx](self.expert_lora_A[idx](input[i:i+1])) expert_outputs[i:i+1] += expert_out * expert_weight[i] # 组合所有输出 final_output = ( base_out + self.scale * (shared_lora + expert_outputs) ) return final_output def set_scale(self, scale: float) -> None: assert isinstance(scale, float), "scale must be a float" self.scale = scale