Yoxas commited on
Commit
48dc91c
·
verified ·
1 Parent(s): e5f39fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import gradio as gr
2
  from datasets import load_dataset
3
-
4
  import os
5
  import spaces
6
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
@@ -22,7 +21,10 @@ print(f"Example embedding shape: {np.array(example_embedding).shape}")
22
 
23
  # Ensure embeddings are 2-dimensional
24
  def ensure_2d_embeddings(embeddings):
25
- return [np.atleast_2d(embedding) for embedding in embeddings]
 
 
 
26
 
27
  # Apply the function to ensure embeddings are 2-dimensional
28
  data = data.map(lambda example: {'embedding': ensure_2d_embeddings(example['embedding'])})
@@ -57,7 +59,7 @@ def search(query: str, k: int = 3):
57
  """a function that embeds a new query and returns the most probable results"""
58
  embedded_query = ST.encode(query) # embed new query
59
  scores, retrieved_examples = data.get_nearest_examples( # retrieve results
60
- "Abstract_Embeddings", embedded_query, # compare our new embedded query with the dataset embeddings
61
  k=k # get only top k results
62
  )
63
  return scores, retrieved_examples
 
1
  import gradio as gr
2
  from datasets import load_dataset
 
3
  import os
4
  import spaces
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
 
21
 
22
  # Ensure embeddings are 2-dimensional
23
  def ensure_2d_embeddings(embeddings):
24
+ embeddings = np.array(embeddings)
25
+ if embeddings.ndim == 1:
26
+ embeddings = embeddings.reshape(1, -1)
27
+ return embeddings
28
 
29
  # Apply the function to ensure embeddings are 2-dimensional
30
  data = data.map(lambda example: {'embedding': ensure_2d_embeddings(example['embedding'])})
 
59
  """a function that embeds a new query and returns the most probable results"""
60
  embedded_query = ST.encode(query) # embed new query
61
  scores, retrieved_examples = data.get_nearest_examples( # retrieve results
62
+ "embedding", embedded_query, # compare our new embedded query with the dataset embeddings
63
  k=k # get only top k results
64
  )
65
  return scores, retrieved_examples