lzyhha commited on
Commit
f3f3b57
·
1 Parent(s): 175c8d0

flashattention

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. flux/math.py +2 -2
app.py CHANGED
@@ -473,7 +473,7 @@ def create_demo(model):
473
 
474
  def parse_args():
475
  parser = argparse.ArgumentParser()
476
- parser.add_argument("--model_path", type=str, default=None)
477
  parser.add_argument("--precision", type=str, choices=["fp32", "bf16", "fp16"], default="bf16")
478
  parser.add_argument("--resolution", type=int, default=384)
479
  return parser.parse_args()
 
473
 
474
  def parse_args():
475
  parser = argparse.ArgumentParser()
476
+ parser.add_argument("--model_path", type=str, default="models/visualcloze-384-lora.pth")
477
  parser.add_argument("--precision", type=str, choices=["fp32", "bf16", "fp16"], default="bf16")
478
  parser.add_argument("--resolution", type=int, default=384)
479
  return parser.parse_args()
flux/math.py CHANGED
@@ -2,8 +2,8 @@ from einops import rearrange
2
  import torch
3
  from torch import Tensor
4
  import torch.nn.functional as F
5
- from flash_attn import flash_attn_varlen_func
6
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
7
 
8
 
9
  def _upad_input(query_layer, key_layer, value_layer, query_mask, key_mask, query_length):
 
2
  import torch
3
  from torch import Tensor
4
  import torch.nn.functional as F
5
+ # from flash_attn import flash_attn_varlen_func
6
+ # from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
7
 
8
 
9
  def _upad_input(query_layer, key_layer, value_layer, query_mask, key_mask, query_length):