Aranwer commited on
Commit
dbe8ae7
·
verified ·
1 Parent(s): 088c109

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -16
app.py CHANGED
@@ -5,44 +5,54 @@ import faiss
5
  import numpy as np
6
  from transformers import pipeline
7
 
8
-
9
  dataset = load_dataset("lex_glue", "scotus")
10
- corpus = [doc['text'] for doc in dataset['train'].select(range(200))]
11
-
12
 
 
13
  embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
14
  corpus_embeddings = embedder.encode(corpus, convert_to_numpy=True)
15
 
16
-
17
  dimension = corpus_embeddings.shape[1]
18
  index = faiss.IndexFlatL2(dimension)
19
  index.add(corpus_embeddings)
20
 
21
-
22
  gen_pipeline = pipeline("text2text-generation", model="facebook/bart-large-cnn")
23
 
24
-
25
  def rag_query(user_question):
 
26
  question_embedding = embedder.encode([user_question])
27
- k = 3
 
28
  if index.ntotal < k:
29
- k = index.ntotal
 
 
30
  _, indices = index.search(np.array(question_embedding), k=k)
31
 
32
- if len(indices[0]) == 0:
 
 
 
33
  return "Sorry, no relevant documents were found."
 
 
 
34
 
35
- context = " ".join([corpus[i] for i in indices[0] if i < len(corpus)])
36
-
37
  prompt = f"Question: {user_question}\nContext: {context}\nAnswer:"
38
  result = gen_pipeline(prompt, max_length=250, do_sample=False)[0]['generated_text']
 
39
  return result
40
 
41
-
42
  def chatbot_interface(query):
43
  return rag_query(query)
44
 
45
-
46
  css = """
47
  .gradio-container {
48
  background-color: #f0f4f8;
@@ -88,7 +98,7 @@ css = """
88
  }
89
  """
90
 
91
-
92
  iface = gr.Interface(
93
  fn=chatbot_interface,
94
  inputs="text",
@@ -99,5 +109,5 @@ iface = gr.Interface(
99
  css=css
100
  )
101
 
102
-
103
- iface.launch()
 
5
  import numpy as np
6
  from transformers import pipeline
7
 
8
+ # Load dataset
9
  dataset = load_dataset("lex_glue", "scotus")
10
+ corpus = [doc['text'] for doc in dataset['train'].select(range(200))] # just 200 to keep it light
 
11
 
12
+ # Embedding model
13
  embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
14
  corpus_embeddings = embedder.encode(corpus, convert_to_numpy=True)
15
 
16
+ # Build FAISS index
17
  dimension = corpus_embeddings.shape[1]
18
  index = faiss.IndexFlatL2(dimension)
19
  index.add(corpus_embeddings)
20
 
21
+ # Text generation model
22
  gen_pipeline = pipeline("text2text-generation", model="facebook/bart-large-cnn")
23
 
24
+ # RAG-like query function
25
  def rag_query(user_question):
26
+ # Encode the user question
27
  question_embedding = embedder.encode([user_question])
28
+
29
+ k = 3 # top 3 documents
30
  if index.ntotal < k:
31
+ k = index.ntotal # Adjust if there are fewer documents than requested
32
+
33
+ # Perform the search in the FAISS index
34
  _, indices = index.search(np.array(question_embedding), k=k)
35
 
36
+ # Ensure indices are valid (within range of the corpus)
37
+ valid_indices = [i for i in indices[0] if i < len(corpus)]
38
+
39
+ if len(valid_indices) == 0:
40
  return "Sorry, no relevant documents were found."
41
+
42
+ # Extract relevant context from the corpus based on valid indices
43
+ context = " ".join([corpus[i] for i in valid_indices])
44
 
45
+ # Prepare the prompt and generate the response
 
46
  prompt = f"Question: {user_question}\nContext: {context}\nAnswer:"
47
  result = gen_pipeline(prompt, max_length=250, do_sample=False)[0]['generated_text']
48
+
49
  return result
50
 
51
+ # Gradio UI
52
  def chatbot_interface(query):
53
  return rag_query(query)
54
 
55
+ # Styling for the interface
56
  css = """
57
  .gradio-container {
58
  background-color: #f0f4f8;
 
98
  }
99
  """
100
 
101
+ # Create the Gradio interface
102
  iface = gr.Interface(
103
  fn=chatbot_interface,
104
  inputs="text",
 
109
  css=css
110
  )
111
 
112
+ # Launch the Gradio interface
113
+ iface.launch()