yzm0034 commited on
Commit
76a5b51
·
verified ·
1 Parent(s): fdfb446

Fixed: tensor dtype fixed.

Browse files
Files changed (1) hide show
  1. src/evaluate.py +10 -16
src/evaluate.py CHANGED
@@ -20,15 +20,6 @@ def read_pertubed_data(filename, task, lang="en"):
20
  raise FileNotFoundError(f"File {filename} not found.")
21
  return pd.read_csv(filename)
22
 
23
- def compute_metrics(emb1, emb2,metric="cosine"):
24
- """Compute all metrics between two sets of embeddings."""
25
- # sim = utils.cosine_similarity(emb1, emb2)
26
- # ned = compute_ned_distance(emb1, emb2)
27
- # ed = np.linalg.norm(emb1 - emb2, axis=1)
28
- # dotp = np.sum(emb1 * emb2, axis=1)
29
- if metric=="cosine":
30
- sim = CosineMetric(emb1,emb2)
31
- return sim
32
 
33
  def run(args_model, dataset_name, target_lang,args_task, default_gpu="cuda", metric="cosine",save=False,batch_size=2):
34
  model = LLMEmbeddings(args_model, device=default_gpu)
@@ -61,9 +52,15 @@ def run(args_model, dataset_name, target_lang,args_task, default_gpu="cuda", met
61
 
62
  # Batch process embeddings
63
  embeddings = model.encode_batch(sentences,batch_size=batch_size)
64
- if args_model != "chatgpt":
65
- embeddings = [emb.cpu().numpy() for emb in embeddings]
66
- embeddings = np.array(embeddings)
 
 
 
 
 
 
67
 
68
  # Process embeddings based on task
69
  if args_task == "anto":
@@ -151,6 +148,7 @@ if __name__ == "__main__":
151
  "batch_size":2
152
  }
153
  else:
 
154
  config = {
155
  "args_model": "llama3",
156
  "dataset_name": "mrpc",
@@ -161,7 +159,3 @@ if __name__ == "__main__":
161
 
162
  }
163
  run(**config)
164
-
165
-
166
- # file_path = "/home/yash/ALIGN-SIM/data/perturbed_dataset/en/anto/mrpc_anto_perturbed_en.csv"
167
- # run("llama3","mrpc_anto_perturbed_en", "anto", "cuda:2", False)
 
20
  raise FileNotFoundError(f"File {filename} not found.")
21
  return pd.read_csv(filename)
22
 
 
 
 
 
 
 
 
 
 
23
 
24
  def run(args_model, dataset_name, target_lang,args_task, default_gpu="cuda", metric="cosine",save=False,batch_size=2):
25
  model = LLMEmbeddings(args_model, device=default_gpu)
 
52
 
53
  # Batch process embeddings
54
  embeddings = model.encode_batch(sentences,batch_size=batch_size)
55
+ # Ensure embeddings are on CPU and in numpy format
56
+ if args_model == "chatgpt":
57
+ # For chatgpt, embeddings is likely a list of torch tensors
58
+ embeddings = [emb.cpu().numpy() if isinstance(emb, torch.Tensor) else emb for emb in embeddings]
59
+ embeddings = np.array(embeddings)
60
+ else:
61
+ # For other models, assume a single torch tensor
62
+ if isinstance(embeddings, torch.Tensor):
63
+ embeddings = embeddings.cpu().numpy()
64
 
65
  # Process embeddings based on task
66
  if args_task == "anto":
 
148
  "batch_size":2
149
  }
150
  else:
151
+ #sentence-transformers/all-MiniLM-L6-v2
152
  config = {
153
  "args_model": "llama3",
154
  "dataset_name": "mrpc",
 
159
 
160
  }
161
  run(**config)