Aranwer commited on
Commit
088c109
Β·
verified Β·
1 Parent(s): 8b24483

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -19
app.py CHANGED
@@ -5,46 +5,99 @@ import faiss
5
  import numpy as np
6
  from transformers import pipeline
7
 
 
8
  dataset = load_dataset("lex_glue", "scotus")
9
- corpus_data = dataset['train'].select(range(200))
10
- corpus = [doc['text'] for doc in corpus_data]
11
 
12
  embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
13
  corpus_embeddings = embedder.encode(corpus, convert_to_numpy=True)
14
 
 
15
  dimension = corpus_embeddings.shape[1]
16
  index = faiss.IndexFlatL2(dimension)
17
  index.add(corpus_embeddings)
18
 
 
19
  gen_pipeline = pipeline("text2text-generation", model="facebook/bart-large-cnn")
20
 
 
21
  def rag_query(user_question):
22
  question_embedding = embedder.encode([user_question])
23
- _, indices = index.search(np.array(question_embedding), k=3)
24
- valid_indices = [i for i in indices[0] if i < len(corpus)]
25
- context = " ".join([corpus[i] for i in valid_indices])
 
 
 
 
 
 
 
26
  prompt = f"Question: {user_question}\nContext: {context}\nAnswer:"
27
  result = gen_pipeline(prompt, max_length=250, do_sample=False)[0]['generated_text']
28
  return result
29
 
30
- def chatbot_interface(query, history):
31
- response = rag_query(query)
32
- history.append((query, response))
33
- chat_history = "\n\n".join([f"πŸ‘€ You: {q}\nπŸ§‘β€βš–οΈ Bot: {a}" for q, a in history])
34
- return chat_history, history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  iface = gr.Interface(
37
  fn=chatbot_interface,
38
- inputs=[
39
- gr.Textbox(lines=2, placeholder="Enter your legal question here...", label="Your Question"),
40
- gr.State([]) # Session state to store history
41
- ],
42
- outputs=[
43
- gr.Textbox(label="Chat History", lines=20, interactive=False),
44
- gr.State()
45
- ],
46
  title="πŸ§‘β€βš–οΈ Legal Assistant Chatbot",
47
- description="Ask legal questions based on case data (LexGLUE - SCOTUS subset). The bot retrieves context and generates an answer."
 
 
48
  )
49
 
 
50
  iface.launch()
 
5
  import numpy as np
6
  from transformers import pipeline
7
 
8
+
9
  dataset = load_dataset("lex_glue", "scotus")
10
+ corpus = [doc['text'] for doc in dataset['train'].select(range(200))]
11
+
12
 
13
  embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
14
  corpus_embeddings = embedder.encode(corpus, convert_to_numpy=True)
15
 
16
+
17
  dimension = corpus_embeddings.shape[1]
18
  index = faiss.IndexFlatL2(dimension)
19
  index.add(corpus_embeddings)
20
 
21
+
22
  gen_pipeline = pipeline("text2text-generation", model="facebook/bart-large-cnn")
23
 
24
+
25
  def rag_query(user_question):
26
  question_embedding = embedder.encode([user_question])
27
+ k = 3
28
+ if index.ntotal < k:
29
+ k = index.ntotal
30
+ _, indices = index.search(np.array(question_embedding), k=k)
31
+
32
+ if len(indices[0]) == 0:
33
+ return "Sorry, no relevant documents were found."
34
+
35
+ context = " ".join([corpus[i] for i in indices[0] if i < len(corpus)])
36
+
37
  prompt = f"Question: {user_question}\nContext: {context}\nAnswer:"
38
  result = gen_pipeline(prompt, max_length=250, do_sample=False)[0]['generated_text']
39
  return result
40
 
41
+
42
+ def chatbot_interface(query):
43
+ return rag_query(query)
44
+
45
+
46
+ css = """
47
+ .gradio-container {
48
+ background-color: #f0f4f8;
49
+ font-family: Arial, sans-serif;
50
+ }
51
+ .gradio-input {
52
+ background-color: #ffffff;
53
+ border-radius: 5px;
54
+ border: 1px solid #d1d1d1;
55
+ font-size: 16px;
56
+ padding: 10px;
57
+ }
58
+ .gradio-button {
59
+ background-color: #4CAF50;
60
+ color: white;
61
+ border-radius: 5px;
62
+ border: none;
63
+ padding: 10px 20px;
64
+ font-size: 16px;
65
+ }
66
+ .gradio-button:hover {
67
+ background-color: #45a049;
68
+ }
69
+ .gradio-output {
70
+ background-color: #ffffff;
71
+ border-radius: 5px;
72
+ padding: 15px;
73
+ font-size: 16px;
74
+ border: 1px solid #d1d1d1;
75
+ }
76
+ .gradio-title {
77
+ font-size: 28px;
78
+ font-weight: bold;
79
+ color: #333333;
80
+ text-align: center;
81
+ margin-bottom: 20px;
82
+ }
83
+ .gradio-description {
84
+ font-size: 16px;
85
+ color: #666666;
86
+ text-align: center;
87
+ margin-bottom: 30px;
88
+ }
89
+ """
90
+
91
 
92
  iface = gr.Interface(
93
  fn=chatbot_interface,
94
+ inputs="text",
95
+ outputs="text",
 
 
 
 
 
 
96
  title="πŸ§‘β€βš–οΈ Legal Assistant Chatbot",
97
+ description="Ask legal questions based on case data (LexGLUE - SCOTUS subset). Get answers derived from relevant court case texts.",
98
+ theme="compact",
99
+ css=css
100
  )
101
 
102
+
103
  iface.launch()