kambris commited on
Commit
b449fa6
·
verified ·
1 Parent(s): 9a7840e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -15
app.py CHANGED
@@ -235,30 +235,38 @@ def classify_emotion(text, classifier):
235
 
236
  def get_embedding_for_text(text, tokenizer, model):
237
  """Get embedding for complete text."""
238
- chunks = split_text(text)
 
 
 
 
239
  chunk_embeddings = []
240
 
241
- for chunk in chunks:
242
- inputs = tokenizer(
243
- chunk,
244
- return_tensors="pt",
245
- padding=True,
 
 
 
 
246
  truncation=True,
247
- max_length=512
248
  )
249
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
250
 
 
 
 
 
251
  with torch.no_grad():
252
- # Access the first element of the tuple directly
253
- outputs = model(**inputs)
254
- embedding = outputs[0][:, 0, :].cpu().numpy()
255
  chunk_embeddings.append(embedding[0])
256
 
 
257
  if chunk_embeddings:
258
- weights = np.array([len(chunk.split()) for chunk in chunks])
259
- weights = weights / weights.sum()
260
- weighted_embedding = np.average(chunk_embeddings, axis=0, weights=weights)
261
- return weighted_embedding
262
  return np.zeros(model.config.hidden_size)
263
 
264
  def format_topics(topic_model, topic_counts):
 
235
 
236
  def get_embedding_for_text(text, tokenizer, model):
237
  """Get embedding for complete text."""
238
+ # First tokenize to get exact count
239
+ tokens = tokenizer.tokenize(text)
240
+
241
+ # Process in chunks of exactly 510 tokens (512 - 2 for CLS and SEP)
242
+ chunk_size = 510
243
  chunk_embeddings = []
244
 
245
+ for i in range(0, len(tokens), chunk_size):
246
+ chunk = tokens[i:i + chunk_size]
247
+ # Convert tokens back to text
248
+ chunk_text = tokenizer.convert_tokens_to_string(chunk)
249
+ # Now encode with special tokens
250
+ encoded = tokenizer(
251
+ chunk_text,
252
+ return_tensors='pt',
253
+ max_length=512,
254
  truncation=True,
255
+ padding='max_length'
256
  )
 
257
 
258
+ # Move to device
259
+ encoded = {k: v.to(model.device) for k, v in encoded.items()}
260
+
261
+ # Get embedding
262
  with torch.no_grad():
263
+ output = model(**encoded)
264
+ embedding = output[0][:, 0, :].cpu().numpy()
 
265
  chunk_embeddings.append(embedding[0])
266
 
267
+ # Combine all chunk embeddings
268
  if chunk_embeddings:
269
+ return np.mean(chunk_embeddings, axis=0)
 
 
 
270
  return np.zeros(model.config.hidden_size)
271
 
272
  def format_topics(topic_model, topic_counts):