mostafa-sh commited on
Commit
2b27faa
·
1 Parent(s): 17aac59

model output answer extraction

Browse files
Files changed (1) hide show
  1. utils/llama_utils.py +15 -2
utils/llama_utils.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import torch
3
  from transformers import BitsAndBytesConfig, AutoModelForCausalLM, PreTrainedTokenizerFast
4
  from peft import PeftModel
 
5
  import streamlit as st
6
 
7
  # Set the cache directory to persistent storage
@@ -128,6 +129,9 @@ def generate_response(
128
  return_tensors="pt"
129
  ).to(model.device)
130
 
 
 
 
131
  generation_params = {
132
  "do_sample": do_sample,
133
  "temperature": temperature if do_sample else None,
@@ -139,10 +143,19 @@ def generate_response(
139
 
140
  output = model.generate(**inputs, **generation_params)
141
 
142
- # Decode and clean up response
143
- response = tokenizer.decode(output[0], skip_special_tokens=True)
 
 
 
144
 
145
  if 'assistant' in response:
146
  response = response.split('assistant')[1].strip()
147
 
 
 
 
 
 
 
148
  return response
 
2
  import torch
3
  from transformers import BitsAndBytesConfig, AutoModelForCausalLM, PreTrainedTokenizerFast
4
  from peft import PeftModel
5
+ import re
6
  import streamlit as st
7
 
8
  # Set the cache directory to persistent storage
 
129
  return_tensors="pt"
130
  ).to(model.device)
131
 
132
+ # Store the number of input tokens for reference
133
+ input_token_length = inputs.input_ids.shape[1]
134
+
135
  generation_params = {
136
  "do_sample": do_sample,
137
  "temperature": temperature if do_sample else None,
 
143
 
144
  output = model.generate(**inputs, **generation_params)
145
 
146
+ # Extract only the newly generated tokens
147
+ new_tokens = output[0][input_token_length:]
148
+
149
+ # Decode only the new tokens
150
+ response = tokenizer.decode(new_tokens, skip_special_tokens=True)
151
 
152
  if 'assistant' in response:
153
  response = response.split('assistant')[1].strip()
154
 
155
+ # In case there's still any assistant prefix, clean it up
156
+ if response.startswith("assistant") or response.startswith("<assistant>"):
157
+ response = re.sub(r"^assistant[:\s]*|^<assistant>[\s]*", "", response, flags=re.IGNORECASE)
158
+
159
+ response = re.sub(r'^\s*(?:answer\s*)+:?\s*', '', response, flags=re.IGNORECASE)
160
+
161
  return response