Yoxas commited on
Commit
cfee418
·
verified ·
1 Parent(s): 2e13b19

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -10
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
  from datasets import load_dataset
 
3
  import os
4
  import spaces
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
@@ -15,21 +16,19 @@ dataset = load_dataset("Yoxas/statistical_literacyv2")
15
 
16
  data = dataset["train"]
17
 
18
- # Check the structure of embeddings
19
- example_embedding = data[0]['embedding']
20
- print(f"Example embedding shape: {np.array(example_embedding).shape}")
21
-
22
- # Ensure embeddings are 2-dimensional and of type string
23
- def ensure_2d_embeddings(embeddings):
24
- embeddings = np.array(embeddings, dtype=np.float32)
25
  if embeddings.ndim == 1:
26
  embeddings = embeddings.reshape(1, -1)
27
- return embeddings
28
 
29
  # Apply the function to ensure embeddings are 2-dimensional and of type float32
30
- data = data.map(lambda example: {'embedding': ensure_2d_embeddings(example['embedding'])})
31
 
32
- data = data.add_faiss_index("Abstract")
33
 
34
  model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
35
 
 
1
  import gradio as gr
2
  from datasets import load_dataset
3
+
4
  import os
5
  import spaces
6
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
 
16
 
17
  data = dataset["train"]
18
 
19
+ # Convert the string embeddings to numerical arrays
20
+ def convert_and_ensure_2d_embeddings(example):
21
+ # Convert the string to a numpy array
22
+ embeddings = np.fromstring(example['embedding'].strip("[]"), sep=' ', dtype=np.float32)
23
+ # Ensure the embeddings are 2-dimensional
 
 
24
  if embeddings.ndim == 1:
25
  embeddings = embeddings.reshape(1, -1)
26
+ return {'embedding': embeddings}
27
 
28
  # Apply the function to ensure embeddings are 2-dimensional and of type float32
29
+ data = data.map(convert_and_ensure_2d_embeddings)
30
 
31
+ data = data.add_faiss_index("embedding")
32
 
33
  model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
34