Yoxas commited on
Commit
befa614
·
verified ·
1 Parent(s): 3c00fbf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -2
app.py CHANGED
@@ -16,7 +16,7 @@ dataset = load_dataset("Yoxas/statistical_literacyv2")
16
 
17
  data = dataset["train"]
18
 
19
- # Convert the string embeddings to numerical arrays
20
  def convert_and_ensure_2d_embeddings(example):
21
  # Clean the embedding string
22
  embedding_str = example['embedding']
@@ -31,7 +31,20 @@ def convert_and_ensure_2d_embeddings(example):
31
  # Apply the function to ensure embeddings are 2-dimensional and of type float32
32
  data = data.map(convert_and_ensure_2d_embeddings)
33
 
34
- data = data.add_faiss_index("embedding")
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
37
 
 
16
 
17
  data = dataset["train"]
18
 
19
+ # Convert the string embeddings to numerical arrays and ensure they are 2D
20
  def convert_and_ensure_2d_embeddings(example):
21
  # Clean the embedding string
22
  embedding_str = example['embedding']
 
31
  # Apply the function to ensure embeddings are 2-dimensional and of type float32
32
  data = data.map(convert_and_ensure_2d_embeddings)
33
 
34
+ # Flatten embeddings if they are nested 2D arrays
35
+ def flatten_embeddings(example):
36
+ embedding = example['embedding']
37
+ if embedding.ndim == 2 and embedding.shape[0] == 1:
38
+ embedding = embedding.flatten()
39
+ return {'embedding': embedding}
40
+
41
+ data = data.map(flatten_embeddings)
42
+
43
+ # Ensure embeddings are in the correct shape for FAISS
44
+ embeddings = np.array(data['embedding'].tolist(), dtype=np.float32)
45
+
46
+ # Add FAISS index
47
+ data.add_faiss_index_from_external_arrays(embeddings)
48
 
49
  model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
50