积极的屁孩 commited on
Commit
0a31a84
·
1 Parent(s): 2b148a9
Files changed (1) hide show
  1. app.py +43 -0
app.py CHANGED
@@ -192,6 +192,49 @@ except ImportError as e:
192
  torchaudio.save(output_path, waveform, sr)
193
  return output_path
194
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  # 模型配置常量
196
  REPO_ID = "amphion/Vevo"
197
  CACHE_DIR = "./ckpts/Vevo"
 
192
  torchaudio.save(output_path, waveform, sr)
193
  return output_path
194
 
195
+ # 修复可能存在的递归调用问题
196
+ # 检查是否在运行时发生了transformers库中的注意力机制递归
197
+ try:
198
+ import transformers
199
+ from transformers.models.llama.modeling_llama import LlamaAttention, LlamaModel
200
+
201
+ # 保存原始的注意力前向函数
202
+ if hasattr(LlamaAttention, "forward"):
203
+ original_attention_forward = LlamaAttention.forward
204
+
205
+ # 创建防止递归的补丁函数
206
+ def safe_attention_forward(self, *args, **kwargs):
207
+ # 使用原始函数,但避免递归调用
208
+ return original_attention_forward(self, *args, **kwargs)
209
+
210
+ # 替换原始函数
211
+ LlamaAttention.forward = safe_attention_forward
212
+ print("已修复LlamaAttention.forward,防止递归")
213
+
214
+ # 可能存在其他递归路径
215
+ if hasattr(transformers.models.llama.modeling_llama, "LlamaAttention"):
216
+ for attr_name in dir(transformers.models.llama.modeling_llama.LlamaAttention):
217
+ if attr_name.startswith("_") and "forward" in attr_name:
218
+ attr = getattr(transformers.models.llama.modeling_llama.LlamaAttention, attr_name)
219
+ if callable(attr):
220
+ # 保存原始函数
221
+ setattr(transformers.models.llama.modeling_llama.LlamaAttention,
222
+ f"original_{attr_name}", attr)
223
+
224
+ # 创建安全函数
225
+ def create_safe_function(original_func, attr_name):
226
+ def safe_function(self, *args, **kwargs):
227
+ return original_func(self, *args, **kwargs)
228
+ return safe_function
229
+
230
+ # 替换函数
231
+ setattr(transformers.models.llama.modeling_llama.LlamaAttention,
232
+ attr_name,
233
+ create_safe_function(attr, attr_name))
234
+ print(f"已修复潜在的递归函数: {attr_name}")
235
+ except Exception as e:
236
+ print(f"应用注意力机制补丁时出错: {str(e)}")
237
+
238
  # 模型配置常量
239
  REPO_ID = "amphion/Vevo"
240
  CACHE_DIR = "./ckpts/Vevo"