abancp commited on
Commit
17a3eb0
·
verified ·
1 Parent(s): cc68838

Update inference_fine_tune.py

Browse files
Files changed (1) hide show
  1. inference_fine_tune.py +17 -18
inference_fine_tune.py CHANGED
@@ -7,6 +7,21 @@ from pathlib import Path
7
  from config import get_config, get_weights_file_path
8
  from train import get_model
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  def generate_text(
11
  model, text, tokenizer, max_len, device,
12
  temperature=0.7, top_k=50
@@ -104,19 +119,7 @@ def run_model(config):
104
 
105
  def generate_response(prompt:str):
106
  print("Prompt : ",prompt)
107
- config = get_config("./openweb.config.json")
108
- device = "cuda" if torch.cuda.is_available() else "cpu"
109
- tokenizer = get_tokenizer(config)
110
- pad_token_id = tokenizer.token_to_id("<pad>")
111
- eos_token_id = tokenizer.token_to_id("</s>")
112
- user_token_id = tokenizer.token_to_id("<user>")
113
- ai_token_id = tokenizer.token_to_id("<ai>")
114
-
115
- model = get_model(config, tokenizer.get_vocab_size()).to(device)
116
- model_path = get_weights_file_path(config,config['preload'])
117
- model.eval()
118
- state = torch.load(model_path,map_location=torch.device('cpu'))
119
- model.load_state_dict(state['model_state_dict'])
120
  word = ""
121
  input_tokens = tokenizer.encode(prompt).ids
122
  input_tokens.extend([user_token_id] + input_tokens + [ai_token_id] )
@@ -149,8 +152,4 @@ def generate_response(prompt:str):
149
  if next_token.item() == eos_token_id:
150
  break
151
  print("Output : ",word)
152
- return word
153
-
154
- if __name__ == "__main__":
155
- config = get_config("openweb.config.json")
156
- run_model(config)
 
7
  from config import get_config, get_weights_file_path
8
  from train import get_model
9
 
10
+
11
+ config = get_config("./openweb.config.json")
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ tokenizer = get_tokenizer(config)
14
+ pad_token_id = tokenizer.token_to_id("<pad>")
15
+ eos_token_id = tokenizer.token_to_id("</s>")
16
+ user_token_id = tokenizer.token_to_id("<user>")
17
+ ai_token_id = tokenizer.token_to_id("<ai>")
18
+
19
+ model = get_model(config, tokenizer.get_vocab_size()).to(device)
20
+ model_path = get_weights_file_path(config,config['preload'])
21
+ model.eval()
22
+ state = torch.load(model_path,map_location=torch.device('cpu'))
23
+ model.load_state_dict(state['model_state_dict'])
24
+
25
  def generate_text(
26
  model, text, tokenizer, max_len, device,
27
  temperature=0.7, top_k=50
 
119
 
120
  def generate_response(prompt:str):
121
  print("Prompt : ",prompt)
122
+
 
 
 
 
 
 
 
 
 
 
 
 
123
  word = ""
124
  input_tokens = tokenizer.encode(prompt).ids
125
  input_tokens.extend([user_token_id] + input_tokens + [ai_token_id] )
 
152
  if next_token.item() == eos_token_id:
153
  break
154
  print("Output : ",word)
155
+ return word