amiguel commited on
Commit
3c9f4cd
·
verified ·
1 Parent(s): c7ff1b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -14
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import streamlit as st
2
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
  from huggingface_hub import login
4
  import PyPDF2
5
  import pandas as pd
@@ -70,7 +70,7 @@ def load_model(hf_token):
70
 
71
  login(token=hf_token)
72
 
73
- # Load tokenizer (requires sentencepiece)
74
  tokenizer = AutoTokenizer.from_pretrained(
75
  MODEL_NAME,
76
  token=hf_token
@@ -91,14 +91,17 @@ def load_model(hf_token):
91
  st.error(f"🤖 Model loading failed: {str(e)}")
92
  return None
93
 
94
- # Generation function for translation
95
  def generate_translation(input_text, model, tokenizer):
96
  try:
97
  # Tokenize the input (no prompt needed for seq2seq translation models)
98
  inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
99
  inputs = inputs.to(DEVICE)
100
 
101
- # Generate translation
 
 
 
102
  model.eval()
103
  with torch.no_grad():
104
  outputs = model.generate(
@@ -107,12 +110,15 @@ def generate_translation(input_text, model, tokenizer):
107
  max_length=512,
108
  num_beams=5,
109
  length_penalty=1.0,
110
- early_stopping=True
 
 
 
111
  )
112
 
113
- # Decode the output
114
- translation = tokenizer.decode(outputs[0], skip_special_tokens=True)
115
- return translation
116
 
117
  except Exception as e:
118
  raise Exception(f"Generation error: {str(e)}")
@@ -154,21 +160,30 @@ if prompt := st.chat_input("Enter text to translate into French..."):
154
  file_context = process_file(uploaded_file)
155
  input_text = file_context if file_context else prompt
156
 
157
- # Generate translation
158
  if model and tokenizer:
159
  try:
160
  with st.chat_message("assistant", avatar=BOT_AVATAR):
161
  start_time = time.time()
162
- translation = generate_translation(input_text, model, tokenizer)
163
 
164
- # Display the translation
165
- st.markdown(translation)
166
- st.session_state.messages.append({"role": "assistant", "content": translation})
 
 
 
 
 
 
 
 
 
 
167
 
168
  # Calculate performance metrics
169
  end_time = time.time()
170
  input_tokens = len(tokenizer(input_text)["input_ids"])
171
- output_tokens = len(tokenizer(translation)["input_ids"])
172
  speed = output_tokens / (end_time - start_time)
173
 
174
  # Calculate costs (hypothetical pricing model)
@@ -184,6 +199,9 @@ if prompt := st.chat_input("Enter text to translate into French..."):
184
  f"💵 Cost (AOA): {total_cost_aoa:.4f}"
185
  )
186
 
 
 
 
187
  except Exception as e:
188
  st.error(f"⚡ Translation error: {str(e)}")
189
  else:
 
1
  import streamlit as st
2
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, TextStreamer
3
  from huggingface_hub import login
4
  import PyPDF2
5
  import pandas as pd
 
70
 
71
  login(token=hf_token)
72
 
73
+ # Load tokenizer
74
  tokenizer = AutoTokenizer.from_pretrained(
75
  MODEL_NAME,
76
  token=hf_token
 
91
  st.error(f"🤖 Model loading failed: {str(e)}")
92
  return None
93
 
94
+ # Generation function for translation with streaming
95
  def generate_translation(input_text, model, tokenizer):
96
  try:
97
  # Tokenize the input (no prompt needed for seq2seq translation models)
98
  inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
99
  inputs = inputs.to(DEVICE)
100
 
101
+ # Set up the streamer for real-time output
102
+ streamer = TextStreamer(tokenizer, skip_special_tokens=True)
103
+
104
+ # Generate translation with streaming
105
  model.eval()
106
  with torch.no_grad():
107
  outputs = model.generate(
 
110
  max_length=512,
111
  num_beams=5,
112
  length_penalty=1.0,
113
+ early_stopping=True,
114
+ streamer=streamer,
115
+ return_dict_in_generate=True,
116
+ output_scores=True
117
  )
118
 
119
+ # Decode the full output for storage and metrics
120
+ translation = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
121
+ return translation, streamer
122
 
123
  except Exception as e:
124
  raise Exception(f"Generation error: {str(e)}")
 
160
  file_context = process_file(uploaded_file)
161
  input_text = file_context if file_context else prompt
162
 
163
+ # Generate translation with streaming
164
  if model and tokenizer:
165
  try:
166
  with st.chat_message("assistant", avatar=BOT_AVATAR):
167
  start_time = time.time()
 
168
 
169
+ # Create a placeholder for streaming output
170
+ response_container = st.empty()
171
+ full_response = ""
172
+
173
+ # Generate translation and stream output
174
+ translation, streamer = generate_translation(input_text, model, tokenizer)
175
+
176
+ # Streamlit will automatically display the streamed output via the TextStreamer
177
+ # Collect the full response for metrics and storage
178
+ full_response = translation
179
+
180
+ # Update the placeholder with the final response
181
+ response_container.markdown(full_response)
182
 
183
  # Calculate performance metrics
184
  end_time = time.time()
185
  input_tokens = len(tokenizer(input_text)["input_ids"])
186
+ output_tokens = len(tokenizer(full_response)["input_ids"])
187
  speed = output_tokens / (end_time - start_time)
188
 
189
  # Calculate costs (hypothetical pricing model)
 
199
  f"💵 Cost (AOA): {total_cost_aoa:.4f}"
200
  )
201
 
202
+ # Store the full response in chat history
203
+ st.session_state.messages.append({"role": "assistant", "content": full_response})
204
+
205
  except Exception as e:
206
  st.error(f"⚡ Translation error: {str(e)}")
207
  else: