VishnuRamDebyez commited on
Commit
dbce286
·
verified ·
1 Parent(s): 00d807e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +179 -0
app.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ import os
3
+ from typing import List, Dict
4
+ from dotenv import load_dotenv
5
+ import logging
6
+ from pathlib import Path
7
+
8
+ from langchain_community.document_loaders import PyPDFLoader
9
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
10
+ from langchain_community.vectorstores import Qdrant as QdrantVectorStore
11
+ from langchain_google_genai import GoogleGenerativeAIEmbeddings
12
+ from langchain_groq import ChatGroq
13
+ from qdrant_client import QdrantClient
14
+ from qdrant_client.http.models import Distance, VectorParams
15
+ from qdrant_client.models import PointIdsList
16
+
17
+ from langgraph.graph import MessagesState, StateGraph
18
+ from langchain_core.messages import SystemMessage, HumanMessage
19
+ from langgraph.prebuilt import ToolNode
20
+ from langgraph.graph import END
21
+ from langgraph.prebuilt import tools_condition
22
+ from langgraph.checkpoint.memory import MemorySaver
23
+
24
+ logging.basicConfig(level=logging.INFO)
25
+ logger = logging.getLogger(__name__)
26
+
27
+ load_dotenv()
28
+ GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY')
29
+ GROQ_API_KEY = os.getenv('GROQ_API_KEY')
30
+
31
+ if not GOOGLE_API_KEY or not GROQ_API_KEY:
32
+ raise ValueError("API keys not set in environment variables")
33
+
34
+ app = FastAPI()
35
+
36
+ class QASystem:
37
+ def __init__(self):
38
+ self.vector_store = None
39
+ self.graph = None
40
+ self.memory = None
41
+ self.embeddings = None
42
+ self.client = None
43
+ self.pdf_dir = "pdfs"
44
+
45
+ def load_pdf_documents(self):
46
+ documents = []
47
+ pdf_dir = Path(self.pdf_dir)
48
+
49
+ if not pdf_dir.exists():
50
+ raise FileNotFoundError(f"PDF directory not found: {self.pdf_dir}")
51
+
52
+ for pdf_path in pdf_dir.glob("*.pdf"):
53
+ try:
54
+ loader = PyPDFLoader(str(pdf_path))
55
+ documents.extend(loader.load())
56
+ logger.info(f"Loaded PDF: {pdf_path}")
57
+ except Exception as e:
58
+ logger.error(f"Error loading PDF {pdf_path}: {str(e)}")
59
+
60
+ text_splitter = RecursiveCharacterTextSplitter(
61
+ chunk_size=1000,
62
+ chunk_overlap=100
63
+ )
64
+ split_docs = text_splitter.split_documents(documents)
65
+ logger.info(f"Split documents into {len(split_docs)} chunks")
66
+ return split_docs
67
+
68
+ def initialize_system(self):
69
+ try:
70
+ self.client = QdrantClient(":memory:")
71
+
72
+ try:
73
+ self.client.get_collection("pdf_data")
74
+ except Exception:
75
+ self.client.create_collection(
76
+ collection_name="pdf_data",
77
+ vectors_config=VectorParams(size=768, distance=Distance.COSINE),
78
+ )
79
+ logger.info("Created new collection: pdf_data")
80
+
81
+ self.embeddings = GoogleGenerativeAIEmbeddings(
82
+ model="models/embedding-001",
83
+ google_api_key=GOOGLE_API_KEY
84
+ )
85
+
86
+ self.vector_store = QdrantVectorStore(
87
+ client=self.client,
88
+ collection_name="pdf_data",
89
+ embeddings=self.embeddings,
90
+ )
91
+
92
+ documents = self.load_pdf_documents()
93
+ if documents:
94
+ try:
95
+ points = self.client.scroll(collection_name="pdf_data", limit=100)[0]
96
+ if points:
97
+ self.client.delete(
98
+ collection_name="pdf_data",
99
+ points_selector=PointIdsList(
100
+ points=[p.id for p in points]
101
+ )
102
+ )
103
+ except Exception as e:
104
+ logger.error(f"Error clearing vectors: {str(e)}")
105
+
106
+ self.vector_store.add_documents(documents)
107
+ logger.info(f"Added {len(documents)} documents to vector store")
108
+
109
+ llm = ChatGroq(
110
+ model="llama3-8b-8192",
111
+ api_key=GROQ_API_KEY,
112
+ temperature=0.7
113
+ )
114
+
115
+ graph_builder = StateGraph(MessagesState)
116
+
117
+ def query_or_respond(state: MessagesState):
118
+ response = llm.invoke(state["messages"])
119
+ return {"messages": [response]}
120
+
121
+ def generate(state: MessagesState):
122
+ recent_tools = [m for m in reversed(state["messages"]) if m.type == "tool"][::-1]
123
+
124
+ system_prompt = (
125
+ "You are a senior legal assistant with knowledge in the Indian legal and judiciary system."
126
+ " Provide direct concise summarized answers in 5 sentences based on the following context:\n\n"
127
+ f"{' '.join(m.content for m in recent_tools)}"
128
+ )
129
+ messages = [SystemMessage(content=system_prompt)] + [
130
+ m for m in state["messages"]
131
+ if m.type in ("human", "system") or (m.type == "ai" and not m.tool_calls)
132
+ ]
133
+
134
+ response = llm.invoke(messages)
135
+ return {"messages": [response]}
136
+
137
+ graph_builder.add_node("query_or_respond", query_or_respond)
138
+ graph_builder.add_node("generate", generate)
139
+
140
+ graph_builder.set_entry_point("query_or_respond")
141
+ graph_builder.add_edge("query_or_respond", "generate")
142
+ graph_builder.add_edge("generate", END)
143
+
144
+ self.memory = MemorySaver()
145
+ self.graph = graph_builder.compile(checkpointer=self.memory)
146
+ return True
147
+
148
+ except Exception as e:
149
+ logger.error(f"System initialization error: {str(e)}")
150
+ return False
151
+
152
+ def process_query(self, query: str) -> List[Dict[str, str]]:
153
+ try:
154
+ responses = []
155
+ for step in self.graph.stream(
156
+ {"messages": [HumanMessage(content=query)]},
157
+ stream_mode="values",
158
+ config={"configurable": {"thread_id": "abc123"}}
159
+ ):
160
+ if step["messages"]:
161
+ responses.append({
162
+ 'content': step["messages"][-1].content,
163
+ 'type': step["messages"][-1].type
164
+ })
165
+ return responses
166
+ except Exception as e:
167
+ logger.error(f"Query processing error: {str(e)}")
168
+ return [{'content': f"Query processing error: {str(e)}", 'type': 'error'}]
169
+
170
+ qa_system = QASystem()
171
+ if qa_system.initialize_system():
172
+ logger.info("QA System Initialized Successfully")
173
+ else:
174
+ raise RuntimeError("Failed to initialize QA System")
175
+
176
+ @app.post("/query")
177
+ async def query_api(query: str):
178
+ responses = qa_system.process_query(query)
179
+ return {"responses": responses}