lucalp commited on
Commit
661d10b
·
1 Parent(s): 86969f4

xformers when cuda available

Browse files
Files changed (2) hide show
  1. app.py +3 -2
  2. bytelatent/entropy_model.py +1 -1
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import spaces
 
2
  import os
3
  import gradio as gr
4
  import torch
@@ -30,7 +31,7 @@ class Config:
30
 
31
  # Bytelatent Specific
32
  BLT_WEIGHTS_DIR: str = "hf-weights"
33
- BLT_MAX_BYTES_FOR_DEMO: int = 512 # Limit for this specific demo's entropy model
34
 
35
  # Gradio
36
  DEFAULT_PROMPT: str = "Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin."
@@ -158,7 +159,7 @@ class BytelatentProcessor:
158
 
159
  return highlighted_data, patch_count
160
 
161
- def process(self, prompt: str, max_bytes: int) -> Tuple[Optional[matplotlib.figure.Figure], List[Tuple[str, str]], int, str]:
162
  """Processes the prompt using the loaded Bytelatent model."""
163
  status = ""
164
  if not self.is_available or self.tokenizer is None or self.patcher is None:
 
1
  import spaces
2
+ import math
3
  import os
4
  import gradio as gr
5
  import torch
 
31
 
32
  # Bytelatent Specific
33
  BLT_WEIGHTS_DIR: str = "hf-weights"
34
+ BLT_MAX_BYTES_FOR_DEMO: float = math.inf # Limit for this specific demo's entropy model
35
 
36
  # Gradio
37
  DEFAULT_PROMPT: str = "Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin."
 
159
 
160
  return highlighted_data, patch_count
161
 
162
+ def process(self, prompt: str, max_bytes: float) -> Tuple[Optional[matplotlib.figure.Figure], List[Tuple[str, str]], int, str]:
163
  """Processes the prompt using the loaded Bytelatent model."""
164
  status = ""
165
  if not self.is_available or self.tokenizer is None or self.patcher is None:
bytelatent/entropy_model.py CHANGED
@@ -28,7 +28,7 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp
28
  ffn_dim_multiplier=model_params["ffn_dim_multiplier"],
29
  vocab_size=model_params["vocab_size"],
30
  attn_bias_type="causal",
31
- attn_impl="sdpa",
32
  sliding_window=512,
33
  )
34
  )
 
28
  ffn_dim_multiplier=model_params["ffn_dim_multiplier"],
29
  vocab_size=model_params["vocab_size"],
30
  attn_bias_type="causal",
31
+ attn_impl="xformers" if torch.cuda.is_available() else "sdpa",
32
  sliding_window=512,
33
  )
34
  )