amiguel commited on
Commit
55d3e9f
·
verified ·
1 Parent(s): f149660

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +279 -55
app.py CHANGED
@@ -1,20 +1,16 @@
1
  import streamlit as st
2
- from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
3
  from huggingface_hub import login
4
- from threading import Thread
5
  import PyPDF2
6
  import pandas as pd
7
  import torch
 
 
 
8
  import time
9
 
10
- # Check if 'peft' is installed (though not used here, kept for potential future use)
11
- try:
12
- from peft import PeftModel, PeftConfig
13
- except ImportError:
14
- raise ImportError(
15
- "The 'peft' library is required but not installed. "
16
- "Please install it using: `pip install peft`"
17
- )
18
 
19
  # Set page configuration
20
  st.set_page_config(
@@ -77,6 +73,234 @@ def process_file(uploaded_file):
77
  st.error(f"📄 Error processing file: {str(e)}")
78
  return ""
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  # Model loading function
81
  @st.cache_resource
82
  def load_model(hf_token):
@@ -87,16 +311,15 @@ def load_model(hf_token):
87
 
88
  login(token=hf_token)
89
 
90
- # Load tokenizer
91
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=hf_token)
92
 
93
- # Load the full model (no adapters since we're using the base transformer)
94
- model = AutoModelForCausalLM.from_pretrained(
95
  MODEL_NAME,
96
- torch_dtype=torch.bfloat16,
97
- device_map="auto",
98
  token=hf_token
99
  )
 
100
 
101
  return model, tokenizer
102
 
@@ -104,32 +327,40 @@ def load_model(hf_token):
104
  st.error(f"🤖 Model loading failed: {str(e)}")
105
  return None
106
 
107
- # Generation function with KV caching
108
- def generate_translation(input_text, model, tokenizer, use_cache=True):
109
- full_prompt = TRANSLATION_PROMPT.format(input_text=input_text)
110
-
111
- streamer = TextIteratorStreamer(
112
- tokenizer,
113
- skip_prompt=True,
114
- skip_special_tokens=True
115
- )
116
-
117
- inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
118
-
119
- generation_kwargs = {
120
- "input_ids": inputs["input_ids"],
121
- "attention_mask": inputs["attention_mask"],
122
- "max_new_tokens": 1024,
123
- "temperature": 0.7,
124
- "top_p": 0.9,
125
- "repetition_penalty": 1.1,
126
- "do_sample": True,
127
- "use_cache": use_cache,
128
- "streamer": streamer
129
- }
130
-
131
- Thread(target=model.generate, kwargs=generation_kwargs).start()
132
- return streamer
 
 
 
 
 
 
 
 
133
 
134
  # Display chat messages
135
  for message in st.session_state.messages:
@@ -173,20 +404,16 @@ if prompt := st.chat_input("Enter text to translate into French..."):
173
  try:
174
  with st.chat_message("assistant", avatar=BOT_AVATAR):
175
  start_time = time.time()
176
- streamer = generate_translation(input_text, model, tokenizer, use_cache=True)
177
 
178
- response_container = st.empty()
179
- full_response = ""
 
180
 
181
- for chunk in streamer:
182
- cleaned_chunk = chunk.strip()
183
- full_response += cleaned_chunk + " "
184
- response_container.markdown(full_response + "▌", unsafe_allow_html=True)
185
-
186
- # Calculate performance metrics
187
  end_time = time.time()
188
- input_tokens = len(tokenizer(input_text)["input_ids"])
189
- output_tokens = len(tokenizer(full_response)["input_ids"])
190
  speed = output_tokens / (end_time - start_time)
191
 
192
  # Calculate costs (hypothetical pricing model)
@@ -202,9 +429,6 @@ if prompt := st.chat_input("Enter text to translate into French..."):
202
  f"💵 Cost (AOA): {total_cost_aoa:.4f}"
203
  )
204
 
205
- response_container.markdown(full_response)
206
- st.session_state.messages.append({"role": "assistant", "content": full_response})
207
-
208
  except Exception as e:
209
  st.error(f"⚡ Translation error: {str(e)}")
210
  else:
 
1
  import streamlit as st
2
+ from transformers import PreTrainedModel, PretrainedConfig, AutoTokenizer
3
  from huggingface_hub import login
 
4
  import PyPDF2
5
  import pandas as pd
6
  import torch
7
+ import numpy as np
8
+ from copy import deepcopy
9
+ import math
10
  import time
11
 
12
+ # Device setup
13
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
14
 
15
  # Set page configuration
16
  st.set_page_config(
 
73
  st.error(f"📄 Error processing file: {str(e)}")
74
  return ""
75
 
76
+ # Custom model definition (copied from previous steps)
77
+ # Masking functions
78
+ def subsequent_mask(size):
79
+ attn_shape = (1, size, size)
80
+ subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
81
+ return torch.from_numpy(subsequent_mask) == 0
82
+
83
+ def make_std_mask(tgt, pad):
84
+ tgt_mask = (tgt != pad).unsqueeze(-2)
85
+ return tgt_mask & subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data)
86
+
87
+ # Batch class
88
+ class Batch:
89
+ def __init__(self, src, trg=None, pad=0):
90
+ src = torch.from_numpy(src).to(DEVICE).long()
91
+ self.src = src
92
+ self.src_mask = (src != pad).unsqueeze(-2)
93
+ if trg is not None:
94
+ trg = torch.from_numpy(trg).to(DEVICE).long()
95
+ self.trg = trg[:, :-1]
96
+ self.trg_y = trg[:, 1:]
97
+ self.trg_mask = make_std_mask(self.trg, pad)
98
+ self.ntokens = (self.trg_y != pad).data.sum()
99
+
100
+ # Hugging Face config
101
+ class En2FrConfig(PretrainedConfig):
102
+ model_type = "en2fr_transformer"
103
+ def __init__(self, src_vocab=32000, tgt_vocab=32000, N=6, d_model=512,
104
+ d_ff=2048, h=8, dropout=0.1, **kwargs):
105
+ self.src_vocab = src_vocab
106
+ self.tgt_vocab = tgt_vocab
107
+ self.N = N
108
+ self.d_model = d_model
109
+ self.d_ff = d_ff
110
+ self.h = h
111
+ self.dropout = dropout
112
+ super().__init__(**kwargs)
113
+
114
+ # Transformer components
115
+ class Transformer(nn.Module):
116
+ def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
117
+ super().__init__()
118
+ self.encoder = encoder
119
+ self.decoder = decoder
120
+ self.src_embed = src_embed
121
+ self.tgt_embed = tgt_embed
122
+ self.generator = generator
123
+
124
+ def forward(self, src, tgt, src_mask, tgt_mask):
125
+ memory = self.encoder(self.src_embed(src), src_mask)
126
+ output = self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)
127
+ return output
128
+
129
+ class Encoder(nn.Module):
130
+ def __init__(self, layer, N):
131
+ super().__init__()
132
+ self.layers = nn.ModuleList([deepcopy(layer) for _ in range(N)])
133
+ self.norm = LayerNorm(layer.size)
134
+
135
+ def forward(self, x, mask):
136
+ for layer in self.layers:
137
+ x = layer(x, mask)
138
+ return self.norm(x)
139
+
140
+ class EncoderLayer(nn.Module):
141
+ def __init__(self, size, self_attn, feed_forward, dropout):
142
+ super().__init__()
143
+ self.self_attn = self_attn
144
+ self.feed_forward = feed_forward
145
+ self.sublayer = nn.ModuleList([deepcopy(SublayerConnection(size, dropout)) for _ in range(2)])
146
+ self.size = size
147
+
148
+ def forward(self, x, mask):
149
+ x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
150
+ return self.sublayer[1](x, self.feed_forward)
151
+
152
+ class Decoder(nn.Module):
153
+ def __init__(self, layer, N):
154
+ super().__init__()
155
+ self.layers = nn.ModuleList([deepcopy(layer) for _ in range(N)])
156
+ self.norm = LayerNorm(layer.size)
157
+
158
+ def forward(self, x, memory, src_mask, tgt_mask):
159
+ for layer in self.layers:
160
+ x = layer(x, memory, src_mask, tgt_mask)
161
+ return self.norm(x)
162
+
163
+ class DecoderLayer(nn.Module):
164
+ def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
165
+ super().__init__()
166
+ self.size = size
167
+ self.self_attn = self_attn
168
+ self.src_attn = src_attn
169
+ self.feed_forward = feed_forward
170
+ self.sublayer = nn.ModuleList([deepcopy(SublayerConnection(size, dropout)) for _ in range(3)])
171
+
172
+ def forward(self, x, memory, src_mask, tgt_mask):
173
+ x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
174
+ x = self.sublayer[1](x, lambda x: self.src_attn(x, memory, memory, src_mask))
175
+ return self.sublayer[2](x, self.feed_forward)
176
+
177
+ class SublayerConnection(nn.Module):
178
+ def __init__(self, size, dropout):
179
+ super().__init__()
180
+ self.norm = LayerNorm(size)
181
+ self.dropout = nn.Dropout(dropout)
182
+
183
+ def forward(self, x, sublayer):
184
+ return x + self.dropout(sublayer(self.norm(x)))
185
+
186
+ class LayerNorm(nn.Module):
187
+ def __init__(self, features, eps=1e-6):
188
+ super().__init__()
189
+ self.a_2 = nn.Parameter(torch.ones(features))
190
+ self.b_2 = nn.Parameter(torch.zeros(features))
191
+ self.eps = eps
192
+
193
+ def forward(self, x):
194
+ mean = x.mean(-1, keepdim=True)
195
+ std = x.std(-1, keepdim=True)
196
+ return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
197
+
198
+ class MultiHeadedAttention(nn.Module):
199
+ def __init__(self, h, d_model, dropout=0.1):
200
+ super().__init__()
201
+ assert d_model % h == 0
202
+ self.d_k = d_model // h
203
+ self.h = h
204
+ self.linears = nn.ModuleList([deepcopy(nn.Linear(d_model, d_model)) for _ in range(4)])
205
+ self.attn = None
206
+ self.dropout = nn.Dropout(p=dropout)
207
+
208
+ def forward(self, query, key, value, mask=None):
209
+ if mask is not None:
210
+ mask = mask.unsqueeze(1)
211
+ nbatches = query.size(0)
212
+ query, key, value = [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
213
+ for l, x in zip(self.linears, (query, key, value))]
214
+ x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)
215
+ x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
216
+ return self.linears[-1](x)
217
+
218
+ def attention(query, key, value, mask=None, dropout=None):
219
+ d_k = query.size(-1)
220
+ scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
221
+ if mask is not None:
222
+ scores = scores.masked_fill(mask == 0, -1e9)
223
+ p_attn = nn.functional.softmax(scores, dim=-1)
224
+ if dropout is not None:
225
+ p_attn = dropout(p_attn)
226
+ return torch.matmul(p_attn, value), p_attn
227
+
228
+ class PositionwiseFeedForward(nn.Module):
229
+ def __init__(self, d_model, d_ff, dropout=0.1):
230
+ super().__init__()
231
+ self.w_1 = nn.Linear(d_model, d_ff)
232
+ self.w_2 = nn.Linear(d_ff, d_model)
233
+ self.dropout = nn.Dropout(dropout)
234
+
235
+ def forward(self, x):
236
+ return self.w_2(self.dropout(self.w_1(x)))
237
+
238
+ class Embeddings(nn.Module):
239
+ def __init__(self, d_model, vocab):
240
+ super().__init__()
241
+ self.lut = nn.Embedding(vocab, d_model)
242
+ self.d_model = d_model
243
+
244
+ def forward(self, x):
245
+ return self.lut(x) * math.sqrt(self.d_model)
246
+
247
+ class PositionalEncoding(nn.Module):
248
+ def __init__(self, d_model, dropout, max_len=5000):
249
+ super().__init__()
250
+ self.dropout = nn.Dropout(p=dropout)
251
+ pe = torch.zeros(max_len, d_model, device=DEVICE)
252
+ position = torch.arange(0., max_len, device=DEVICE).unsqueeze(1)
253
+ div_term = torch.exp(torch.arange(0., d_model, 2, device=DEVICE) * -(math.log(10000.0) / d_model))
254
+ pe[:, 0::2] = torch.sin(position * div_term)
255
+ pe[:, 1::2] = torch.cos(position * div_term)
256
+ pe = pe.unsqueeze(0)
257
+ self.register_buffer('pe', pe)
258
+
259
+ def forward(self, x):
260
+ x = x + self.pe[:, :x.size(1)].requires_grad_(False)
261
+ return self.dropout(x)
262
+
263
+ class Generator(nn.Module):
264
+ def __init__(self, d_model, vocab):
265
+ super().__init__()
266
+ self.proj = nn.Linear(d_model, vocab)
267
+
268
+ def forward(self, x):
269
+ return nn.functional.log_softmax(self.proj(x), dim=-1)
270
+
271
+ def create_model(src_vocab, tgt_vocab, N, d_model, d_ff, h, dropout=0.1):
272
+ attn = MultiHeadedAttention(h, d_model).to(DEVICE)
273
+ ff = PositionwiseFeedForward(d_model, d_ff, dropout).to(DEVICE)
274
+ pos = PositionalEncoding(d_model, dropout).to(DEVICE)
275
+ model = Transformer(
276
+ Encoder(EncoderLayer(d_model, deepcopy(attn), deepcopy(ff), dropout).to(DEVICE), N).to(DEVICE),
277
+ Decoder(DecoderLayer(d_model, deepcopy(attn), deepcopy(attn), deepcopy(ff), dropout).to(DEVICE), N).to(DEVICE),
278
+ nn.Sequential(Embeddings(d_model, src_vocab).to(DEVICE), deepcopy(pos)),
279
+ nn.Sequential(Embeddings(d_model, tgt_vocab).to(DEVICE), deepcopy(pos)),
280
+ Generator(d_model, tgt_vocab)).to(DEVICE)
281
+ for p in model.parameters():
282
+ if p.dim() > 1:
283
+ nn.init.xavier_uniform_(p)
284
+ return model
285
+
286
+ class En2FrTransformer(PreTrainedModel):
287
+ config_class = En2FrConfig
288
+
289
+ def __init__(self, config):
290
+ super().__init__(config)
291
+ self.model = create_model(
292
+ src_vocab=config.src_vocab,
293
+ tgt_vocab=config.tgt_vocab,
294
+ N=config.N,
295
+ d_model=config.d_model,
296
+ d_ff=config.d_ff,
297
+ h=config.h,
298
+ dropout=config.dropout
299
+ )
300
+
301
+ def forward(self, src, tgt, src_mask, tgt_mask):
302
+ return self.model(src, tgt, src_mask, tgt_mask)
303
+
304
  # Model loading function
305
  @st.cache_resource
306
  def load_model(hf_token):
 
311
 
312
  login(token=hf_token)
313
 
314
+ # Load tokenizer (assuming a tokenizer was saved with the model)
315
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=hf_token)
316
 
317
+ # Load the custom model
318
+ model = En2FrTransformer.from_pretrained(
319
  MODEL_NAME,
 
 
320
  token=hf_token
321
  )
322
+ model.to(DEVICE) # Ensure model is on the correct device
323
 
324
  return model, tokenizer
325
 
 
327
  st.error(f"🤖 Model loading failed: {str(e)}")
328
  return None
329
 
330
+ # Simple tokenization function (placeholder, since we don't have the actual vocab)
331
+ def tokenize_text(text, tokenizer, max_length=10):
332
+ # This is a placeholder; in a real scenario, you'd use the tokenizer's vocabulary
333
+ # For now, we'll create dummy token IDs (0 for padding, 1 for start, 2 for end, 3+ for words)
334
+ words = text.split()
335
+ token_ids = [1] + [i + 3 for i in range(min(len(words), max_length - 2))] + [2]
336
+ if len(token_ids) < max_length:
337
+ token_ids += [0] * (max_length - len(token_ids))
338
+ return torch.tensor([token_ids], dtype=torch.long, device=DEVICE)
339
+
340
+ # Generation function for translation (custom inference loop)
341
+ def generate_translation(input_text, model, tokenizer):
342
+ model.eval()
343
+ with torch.no_grad():
344
+ # Tokenize input (source) and target (start with a dummy start token)
345
+ src = tokenize_text(input_text, tokenizer)
346
+ tgt = torch.tensor([[1]], dtype=torch.long, device=DEVICE) # Start token
347
+ src_mask = (src != 0).unsqueeze(-2)
348
+ max_length = 10 # Adjust as needed
349
+
350
+ # Generate translation token by token
351
+ for _ in range(max_length - 1):
352
+ tgt_mask = make_std_mask(tgt, pad=0)
353
+ output = model(src, tgt, src_mask, tgt_mask)
354
+ output = model.model.generator(output[:, -1, :]) # Get logits for the last token
355
+ next_token = torch.argmax(output, dim=-1).unsqueeze(0)
356
+ tgt = torch.cat((tgt, next_token), dim=1)
357
+ if next_token.item() == 2: # End token
358
+ break
359
+
360
+ # Convert token IDs back to text (placeholder)
361
+ # In a real scenario, you'd use tokenizer.decode()
362
+ translation = " ".join([f"word{i-3}" if i >= 3 else "<start>" if i == 1 else "<end>" for i in tgt[0].tolist()])
363
+ return translation
364
 
365
  # Display chat messages
366
  for message in st.session_state.messages:
 
404
  try:
405
  with st.chat_message("assistant", avatar=BOT_AVATAR):
406
  start_time = time.time()
407
+ translation = generate_translation(input_text, model, tokenizer)
408
 
409
+ # Display the translation
410
+ st.markdown(translation)
411
+ st.session_state.messages.append({"role": "assistant", "content": translation})
412
 
413
+ # Calculate performance metrics (simplified, since we don't have real token counts)
 
 
 
 
 
414
  end_time = time.time()
415
+ input_tokens = len(input_text.split()) # Approximate
416
+ output_tokens = len(translation.split()) # Approximate
417
  speed = output_tokens / (end_time - start_time)
418
 
419
  # Calculate costs (hypothetical pricing model)
 
429
  f"💵 Cost (AOA): {total_cost_aoa:.4f}"
430
  )
431
 
 
 
 
432
  except Exception as e:
433
  st.error(f"⚡ Translation error: {str(e)}")
434
  else: