Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import streamlit as st
|
2 |
-
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
3 |
from huggingface_hub import login
|
4 |
import PyPDF2
|
5 |
import pandas as pd
|
@@ -70,7 +70,7 @@ def load_model(hf_token):
|
|
70 |
|
71 |
login(token=hf_token)
|
72 |
|
73 |
-
# Load tokenizer
|
74 |
tokenizer = AutoTokenizer.from_pretrained(
|
75 |
MODEL_NAME,
|
76 |
token=hf_token
|
@@ -91,14 +91,17 @@ def load_model(hf_token):
|
|
91 |
st.error(f"🤖 Model loading failed: {str(e)}")
|
92 |
return None
|
93 |
|
94 |
-
# Generation function for translation
|
95 |
def generate_translation(input_text, model, tokenizer):
|
96 |
try:
|
97 |
# Tokenize the input (no prompt needed for seq2seq translation models)
|
98 |
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
99 |
inputs = inputs.to(DEVICE)
|
100 |
|
101 |
-
#
|
|
|
|
|
|
|
102 |
model.eval()
|
103 |
with torch.no_grad():
|
104 |
outputs = model.generate(
|
@@ -107,12 +110,15 @@ def generate_translation(input_text, model, tokenizer):
|
|
107 |
max_length=512,
|
108 |
num_beams=5,
|
109 |
length_penalty=1.0,
|
110 |
-
early_stopping=True
|
|
|
|
|
|
|
111 |
)
|
112 |
|
113 |
-
# Decode the output
|
114 |
-
translation = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
115 |
-
return translation
|
116 |
|
117 |
except Exception as e:
|
118 |
raise Exception(f"Generation error: {str(e)}")
|
@@ -154,21 +160,30 @@ if prompt := st.chat_input("Enter text to translate into French..."):
|
|
154 |
file_context = process_file(uploaded_file)
|
155 |
input_text = file_context if file_context else prompt
|
156 |
|
157 |
-
# Generate translation
|
158 |
if model and tokenizer:
|
159 |
try:
|
160 |
with st.chat_message("assistant", avatar=BOT_AVATAR):
|
161 |
start_time = time.time()
|
162 |
-
translation = generate_translation(input_text, model, tokenizer)
|
163 |
|
164 |
-
#
|
165 |
-
st.
|
166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
|
168 |
# Calculate performance metrics
|
169 |
end_time = time.time()
|
170 |
input_tokens = len(tokenizer(input_text)["input_ids"])
|
171 |
-
output_tokens = len(tokenizer(
|
172 |
speed = output_tokens / (end_time - start_time)
|
173 |
|
174 |
# Calculate costs (hypothetical pricing model)
|
@@ -184,6 +199,9 @@ if prompt := st.chat_input("Enter text to translate into French..."):
|
|
184 |
f"💵 Cost (AOA): {total_cost_aoa:.4f}"
|
185 |
)
|
186 |
|
|
|
|
|
|
|
187 |
except Exception as e:
|
188 |
st.error(f"⚡ Translation error: {str(e)}")
|
189 |
else:
|
|
|
1 |
import streamlit as st
|
2 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, TextStreamer
|
3 |
from huggingface_hub import login
|
4 |
import PyPDF2
|
5 |
import pandas as pd
|
|
|
70 |
|
71 |
login(token=hf_token)
|
72 |
|
73 |
+
# Load tokenizer
|
74 |
tokenizer = AutoTokenizer.from_pretrained(
|
75 |
MODEL_NAME,
|
76 |
token=hf_token
|
|
|
91 |
st.error(f"🤖 Model loading failed: {str(e)}")
|
92 |
return None
|
93 |
|
94 |
+
# Generation function for translation with streaming
|
95 |
def generate_translation(input_text, model, tokenizer):
|
96 |
try:
|
97 |
# Tokenize the input (no prompt needed for seq2seq translation models)
|
98 |
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
99 |
inputs = inputs.to(DEVICE)
|
100 |
|
101 |
+
# Set up the streamer for real-time output
|
102 |
+
streamer = TextStreamer(tokenizer, skip_special_tokens=True)
|
103 |
+
|
104 |
+
# Generate translation with streaming
|
105 |
model.eval()
|
106 |
with torch.no_grad():
|
107 |
outputs = model.generate(
|
|
|
110 |
max_length=512,
|
111 |
num_beams=5,
|
112 |
length_penalty=1.0,
|
113 |
+
early_stopping=True,
|
114 |
+
streamer=streamer,
|
115 |
+
return_dict_in_generate=True,
|
116 |
+
output_scores=True
|
117 |
)
|
118 |
|
119 |
+
# Decode the full output for storage and metrics
|
120 |
+
translation = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
|
121 |
+
return translation, streamer
|
122 |
|
123 |
except Exception as e:
|
124 |
raise Exception(f"Generation error: {str(e)}")
|
|
|
160 |
file_context = process_file(uploaded_file)
|
161 |
input_text = file_context if file_context else prompt
|
162 |
|
163 |
+
# Generate translation with streaming
|
164 |
if model and tokenizer:
|
165 |
try:
|
166 |
with st.chat_message("assistant", avatar=BOT_AVATAR):
|
167 |
start_time = time.time()
|
|
|
168 |
|
169 |
+
# Create a placeholder for streaming output
|
170 |
+
response_container = st.empty()
|
171 |
+
full_response = ""
|
172 |
+
|
173 |
+
# Generate translation and stream output
|
174 |
+
translation, streamer = generate_translation(input_text, model, tokenizer)
|
175 |
+
|
176 |
+
# Streamlit will automatically display the streamed output via the TextStreamer
|
177 |
+
# Collect the full response for metrics and storage
|
178 |
+
full_response = translation
|
179 |
+
|
180 |
+
# Update the placeholder with the final response
|
181 |
+
response_container.markdown(full_response)
|
182 |
|
183 |
# Calculate performance metrics
|
184 |
end_time = time.time()
|
185 |
input_tokens = len(tokenizer(input_text)["input_ids"])
|
186 |
+
output_tokens = len(tokenizer(full_response)["input_ids"])
|
187 |
speed = output_tokens / (end_time - start_time)
|
188 |
|
189 |
# Calculate costs (hypothetical pricing model)
|
|
|
199 |
f"💵 Cost (AOA): {total_cost_aoa:.4f}"
|
200 |
)
|
201 |
|
202 |
+
# Store the full response in chat history
|
203 |
+
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
204 |
+
|
205 |
except Exception as e:
|
206 |
st.error(f"⚡ Translation error: {str(e)}")
|
207 |
else:
|