Srinivasan Iyer sviyer commited on
Commit
fc946a1
·
unverified ·
1 Parent(s): 083656c

Some fixes for entropy model predictions (#83)

Browse files

Co-authored-by: Srini Iyer <[email protected]>

bytelatent/data/patcher.py CHANGED
@@ -91,7 +91,7 @@ def calculate_entropies(
91
  split = split.reshape(-1, max_length)
92
  if device is not None:
93
  split = split.to(device)
94
- assert torch.all(split >= 0) and torch.all(split < 260)
95
  pred = entropy_model(split)
96
  pred = pred.reshape(-1, pred.shape[-1])[
97
  : split.numel() - pad_size, :
@@ -103,7 +103,7 @@ def calculate_entropies(
103
  concat_entropies = torch.cat(entropies, dim=0)
104
  concat_entropies = concat_entropies.reshape(tokens.shape)
105
  concat_preds = torch.cat(preds, dim=0)
106
- concat_preds = concat_preds.reshape(tokens.shape[0], tokens.shape[1], -1)
107
  return concat_entropies, concat_preds
108
 
109
 
 
91
  split = split.reshape(-1, max_length)
92
  if device is not None:
93
  split = split.to(device)
94
+ # assert torch.all(split >= 0) and torch.all(split < 260)
95
  pred = entropy_model(split)
96
  pred = pred.reshape(-1, pred.shape[-1])[
97
  : split.numel() - pad_size, :
 
103
  concat_entropies = torch.cat(entropies, dim=0)
104
  concat_entropies = concat_entropies.reshape(tokens.shape)
105
  concat_preds = torch.cat(preds, dim=0)
106
+ concat_preds = concat_preds.reshape(tokens.shape[0], -1)
107
  return concat_entropies, concat_preds
108
 
109
 
bytelatent/entropy_model.py CHANGED
@@ -15,7 +15,7 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp
15
  reloaded = json.loads(fr.read())
16
 
17
  torch.set_default_dtype(torch.bfloat16)
18
- model_params = reloaded["model"]
19
  logger.warning(
20
  "Update checkpoint to load attn and sliding window args from checkpoint"
21
  )
@@ -24,7 +24,7 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp
24
  dim=model_params["dim"],
25
  n_layers=model_params["n_layers"],
26
  n_heads=model_params["n_heads"],
27
- max_seqlen=model_params["max_length"],
28
  ffn_dim_multiplier=model_params["ffn_dim_multiplier"],
29
  vocab_size=model_params["vocab_size"],
30
  attn_bias_type="local_block_causal",
@@ -34,7 +34,7 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp
34
  )
35
 
36
  entropy_model.load_state_dict(
37
- torch.load(state_dict_path, map_location=device), strict=False
38
  )
39
  entropy_model.to(device)
40
  entropy_model = entropy_model.eval()
 
15
  reloaded = json.loads(fr.read())
16
 
17
  torch.set_default_dtype(torch.bfloat16)
18
+ model_params = reloaded["entropy_model"]
19
  logger.warning(
20
  "Update checkpoint to load attn and sliding window args from checkpoint"
21
  )
 
24
  dim=model_params["dim"],
25
  n_layers=model_params["n_layers"],
26
  n_heads=model_params["n_heads"],
27
+ max_seqlen=model_params["max_seqlen"],
28
  ffn_dim_multiplier=model_params["ffn_dim_multiplier"],
29
  vocab_size=model_params["vocab_size"],
30
  attn_bias_type="local_block_causal",
 
34
  )
35
 
36
  entropy_model.load_state_dict(
37
+ torch.load(state_dict_path, map_location=device)["model"], strict=False
38
  )
39
  entropy_model.to(device)
40
  entropy_model = entropy_model.eval()