ajalisatgi commited on
Commit
4d16da0
·
verified ·
1 Parent(s): 73ab43d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -112
app.py CHANGED
@@ -1,113 +1,29 @@
1
- import torch
2
  import gradio as gr
3
- from langchain.embeddings import HuggingFaceEmbeddings
4
- from langchain_community.vectorstores import Chroma
5
  import openai
6
- import time
7
- import logging
8
  from datasets import load_dataset
9
- from nltk.tokenize import sent_tokenize
10
- import nltk
11
- from langchain.docstore.document import Document
12
- from tqdm import tqdm
13
- import os
14
 
15
  # Set up logging
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
18
 
19
- # Download NLTK data
20
- nltk.download('punkt')
21
- nltk.download('punkt_tab')
22
- nltk.download('averaged_perceptron_tagger')
23
- nltk.download('stopwords')
24
-
25
  # Initialize OpenAI API key
26
  openai.api_key = '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'
27
 
28
- # Load selected datasets
29
- logger.info("Starting dataset loading...")
30
- ragbench = {}
31
- datasets_to_load = ['covidqa', 'hotpotqa', 'pubmedqa']
32
-
33
- for dataset in datasets_to_load:
34
- try:
35
- ragbench[dataset] = load_dataset("rungalileo/ragbench", dataset, split='train')
36
- logger.info(f"Successfully loaded {dataset}")
37
- except Exception as e:
38
- logger.error(f"Failed to load {dataset}: {e}")
39
- continue
40
-
41
- print(f"Loaded {len(ragbench)} datasets successfully")
42
-
43
- # Initialize embedding model
44
- model_name = 'sentence-transformers/all-mpnet-base-v2'
45
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
- embedding_model = HuggingFaceEmbeddings(model_name=model_name)
47
- embedding_model.client.to(device)
48
 
49
- def chunk_documents_semantic(documents, max_chunk_size=500):
50
- chunks = []
51
- for doc in documents:
52
- if isinstance(doc, list):
53
- for passage in doc:
54
- sentences = sent_tokenize(passage)
55
- current_chunk = ""
56
- for sentence in sentences:
57
- if len(current_chunk) + len(sentence) <= max_chunk_size:
58
- current_chunk += sentence + " "
59
- else:
60
- chunks.append(current_chunk.strip())
61
- current_chunk = sentence + " "
62
- if current_chunk:
63
- chunks.append(current_chunk.strip())
64
- else:
65
- sentences = sent_tokenize(doc)
66
- current_chunk = ""
67
- for sentence in sentences:
68
- if len(current_chunk) + len(sentence) <= max_chunk_size:
69
- current_chunk += sentence + " "
70
- else:
71
- chunks.append(current_chunk.strip())
72
- current_chunk = sentence + " "
73
- if current_chunk:
74
- chunks.append(current_chunk.strip())
75
- return chunks
76
-
77
- # Process documents
78
- documents = []
79
- for dataset_name, dataset in ragbench.items():
80
- logger.info(f"Processing {dataset_name}")
81
- original_documents = dataset['documents']
82
- chunked_documents = chunk_documents_semantic(original_documents)
83
- documents.extend([Document(page_content=chunk) for chunk in chunked_documents])
84
- logger.info(f"Processed {len(chunked_documents)} chunks from {dataset_name}")
85
-
86
- # Initialize vectordb
87
- vectordb = Chroma.from_documents(
88
- documents=documents,
89
- embedding=embedding_model,
90
- persist_directory='./docs/chroma/'
91
- )
92
- vectordb.persist()
93
-
94
- def process_query(query, dataset_choice):
95
  try:
96
- logger.info(f"Processing query for {dataset_choice}: {query}")
97
-
98
- relevant_docs = vectordb.max_marginal_relevance_search(
99
- query,
100
- k=5,
101
- fetch_k=10
102
- )
103
-
104
- context = " ".join([doc.page_content for doc in relevant_docs])
105
 
106
  response = openai.chat.completions.create(
107
  model="gpt-3.5-turbo",
108
  messages=[
109
- {"role": "system", "content": "You are a specialized assistant for the RagBench dataset. Provide precise answers based solely on the given context."},
110
- {"role": "user", "content": f"Dataset: {dataset_choice}\nContext: {context}\nQuestion: {query}\n\nProvide a detailed answer using only the information from the context above."}
111
  ],
112
  max_tokens=300,
113
  temperature=0.7,
@@ -116,28 +32,15 @@ def process_query(query, dataset_choice):
116
  return response.choices[0].message.content.strip()
117
 
118
  except Exception as e:
119
- logger.error(f"Error processing query: {str(e)}")
120
- return f"Error: {str(e)}"
121
 
122
- # Create Gradio interface
123
  demo = gr.Interface(
124
  fn=process_query,
125
- inputs=[
126
- gr.Textbox(label="Question", placeholder="Type your question here...", lines=2),
127
- gr.Dropdown(
128
- choices=list(ragbench.keys()),
129
- label="Select Dataset",
130
- value="hotpotqa"
131
- )
132
- ],
133
- outputs=gr.Textbox(label="Answer", lines=5),
134
- title="RagBench Question Answering System",
135
- description="Ask questions across different RagBench datasets",
136
- examples=[
137
- ["What role does T-cell count play in severe human adenovirus type 55 (HAdV-55) infection?", "covidqa"],
138
- ["In what school district is Governor John R. Rogers High School located?", "hotpotqa"],
139
- ["Is there a functional neural correlate of individual differences in cardiovascular reactivity?", "pubmedqa"]
140
- ]
141
  )
142
 
143
  if __name__ == "__main__":
 
 
1
  import gradio as gr
 
 
2
  import openai
 
 
3
  from datasets import load_dataset
4
+ import logging
 
 
 
 
5
 
6
  # Set up logging
7
  logging.basicConfig(level=logging.INFO)
8
  logger = logging.getLogger(__name__)
9
 
 
 
 
 
 
 
10
  # Initialize OpenAI API key
11
  openai.api_key = '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'
12
 
13
+ # Load just one dataset to start
14
+ dataset = load_dataset("rungalileo/ragbench", "hotpotqa", split='train')
15
+ logger.info("Dataset loaded successfully")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ def process_query(query):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  try:
19
+ # Get a relevant document from the dataset
20
+ context = dataset['documents'][0] # Using first document as example
 
 
 
 
 
 
 
21
 
22
  response = openai.chat.completions.create(
23
  model="gpt-3.5-turbo",
24
  messages=[
25
+ {"role": "system", "content": "You are a helpful assistant for the RagBench dataset."},
26
+ {"role": "user", "content": f"Context: {context}\nQuestion: {query}"}
27
  ],
28
  max_tokens=300,
29
  temperature=0.7,
 
32
  return response.choices[0].message.content.strip()
33
 
34
  except Exception as e:
35
+ return f"Query processing: {str(e)}"
 
36
 
37
+ # Create simple Gradio interface
38
  demo = gr.Interface(
39
  fn=process_query,
40
+ inputs=gr.Textbox(label="Question"),
41
+ outputs=gr.Textbox(label="Answer"),
42
+ title="RagBench QA System",
43
+ description="Ask questions about HotpotQA dataset"
 
 
 
 
 
 
 
 
 
 
 
 
44
  )
45
 
46
  if __name__ == "__main__":