mebubo commited on
Commit
19904de
·
1 Parent(s): 15b7594

style: Regroup import statements at the top of app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -5
app.py CHANGED
@@ -1,11 +1,8 @@
1
- #%%
2
  from text_processing import split_into_words, Word
3
  import torch
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
5
  from pprint import pprint
6
 
7
- #%%
8
-
9
  def load_model_and_tokenizer(model_name):
10
  tokenizer = AutoTokenizer.from_pretrained(model_name)
11
  model = AutoModelForCausalLM.from_pretrained(model_name)
@@ -27,7 +24,6 @@ def calculate_log_probabilities(model, tokenizer, inputs, input_ids):
27
  tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
28
  return list(zip(tokens[1:], token_log_probs.tolist()))
29
 
30
- from transformers import PreTrainedModel, PreTrainedTokenizer
31
 
32
  def generate_replacements(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prefix: str, device: torch.device, num_samples: int = 5) -> list[str]:
33
  input_context = tokenizer(prefix, return_tensors="pt").to(device)
 
 
1
  from text_processing import split_into_words, Word
2
  import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer
4
  from pprint import pprint
5
 
 
 
6
  def load_model_and_tokenizer(model_name):
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForCausalLM.from_pretrained(model_name)
 
24
  tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
25
  return list(zip(tokens[1:], token_log_probs.tolist()))
26
 
 
27
 
28
  def generate_replacements(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prefix: str, device: torch.device, num_samples: int = 5) -> list[str]:
29
  input_context = tokenizer(prefix, return_tensors="pt").to(device)