lzyhha commited on
Commit
7ecea30
·
1 Parent(s): 5887ab3
Files changed (2) hide show
  1. app.py +2 -0
  2. flux/math.py +2 -2
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import argparse
2
  import spaces
3
  from visualcloze import VisualClozeModel
 
1
+ import subprocess
2
+ subprocess.run('pip install flash-attn --no-build-isolation', shell=True)
3
  import argparse
4
  import spaces
5
  from visualcloze import VisualClozeModel
flux/math.py CHANGED
@@ -2,8 +2,8 @@ from einops import rearrange
2
  import torch
3
  from torch import Tensor
4
  import torch.nn.functional as F
5
- # from flash_attn import flash_attn_varlen_func
6
- # from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
7
 
8
 
9
  def _upad_input(query_layer, key_layer, value_layer, query_mask, key_mask, query_length):
 
2
  import torch
3
  from torch import Tensor
4
  import torch.nn.functional as F
5
+ from flash_attn import flash_attn_varlen_func
6
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
7
 
8
 
9
  def _upad_input(query_layer, key_layer, value_layer, query_mask, key_mask, query_length):