Spaces:
Running
on
L4
Running
on
L4
Commit
·
2b27faa
1
Parent(s):
17aac59
model output answer extraction
Browse files- utils/llama_utils.py +15 -2
utils/llama_utils.py
CHANGED
@@ -2,6 +2,7 @@ import os
|
|
2 |
import torch
|
3 |
from transformers import BitsAndBytesConfig, AutoModelForCausalLM, PreTrainedTokenizerFast
|
4 |
from peft import PeftModel
|
|
|
5 |
import streamlit as st
|
6 |
|
7 |
# Set the cache directory to persistent storage
|
@@ -128,6 +129,9 @@ def generate_response(
|
|
128 |
return_tensors="pt"
|
129 |
).to(model.device)
|
130 |
|
|
|
|
|
|
|
131 |
generation_params = {
|
132 |
"do_sample": do_sample,
|
133 |
"temperature": temperature if do_sample else None,
|
@@ -139,10 +143,19 @@ def generate_response(
|
|
139 |
|
140 |
output = model.generate(**inputs, **generation_params)
|
141 |
|
142 |
-
#
|
143 |
-
|
|
|
|
|
|
|
144 |
|
145 |
if 'assistant' in response:
|
146 |
response = response.split('assistant')[1].strip()
|
147 |
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
return response
|
|
|
2 |
import torch
|
3 |
from transformers import BitsAndBytesConfig, AutoModelForCausalLM, PreTrainedTokenizerFast
|
4 |
from peft import PeftModel
|
5 |
+
import re
|
6 |
import streamlit as st
|
7 |
|
8 |
# Set the cache directory to persistent storage
|
|
|
129 |
return_tensors="pt"
|
130 |
).to(model.device)
|
131 |
|
132 |
+
# Store the number of input tokens for reference
|
133 |
+
input_token_length = inputs.input_ids.shape[1]
|
134 |
+
|
135 |
generation_params = {
|
136 |
"do_sample": do_sample,
|
137 |
"temperature": temperature if do_sample else None,
|
|
|
143 |
|
144 |
output = model.generate(**inputs, **generation_params)
|
145 |
|
146 |
+
# Extract only the newly generated tokens
|
147 |
+
new_tokens = output[0][input_token_length:]
|
148 |
+
|
149 |
+
# Decode only the new tokens
|
150 |
+
response = tokenizer.decode(new_tokens, skip_special_tokens=True)
|
151 |
|
152 |
if 'assistant' in response:
|
153 |
response = response.split('assistant')[1].strip()
|
154 |
|
155 |
+
# In case there's still any assistant prefix, clean it up
|
156 |
+
if response.startswith("assistant") or response.startswith("<assistant>"):
|
157 |
+
response = re.sub(r"^assistant[:\s]*|^<assistant>[\s]*", "", response, flags=re.IGNORECASE)
|
158 |
+
|
159 |
+
response = re.sub(r'^\s*(?:answer\s*)+:?\s*', '', response, flags=re.IGNORECASE)
|
160 |
+
|
161 |
return response
|