gavinzli commited on
Commit
c7426d8
·
1 Parent(s): a4e857f

Remove unused binary files and refactor main application structure to integrate FastAPI with new routing and utility functions.

Browse files
chain/__init__.py CHANGED
@@ -4,43 +4,19 @@ import json
4
  from datetime import datetime
5
  from venv import logger
6
 
7
- import torch
8
  from pymongo import errors
9
  from langchain_core.runnables.history import RunnableWithMessageHistory
 
10
  from langchain_core.messages import BaseMessage, message_to_dict
11
  from langchain.chains.combine_documents import create_stuff_documents_chain
12
  from langchain.chains.retrieval import create_retrieval_chain
13
  from langchain.prompts.chat import ChatPromptTemplate, MessagesPlaceholder
14
  from langchain_mongodb import MongoDBChatMessageHistory
15
- from langchain_huggingface import HuggingFacePipeline
16
-
17
- from models.llm import GPTModel, Phi4MiniONNXLLM, HuggingfaceModel
18
-
19
- # llm = GPTModel()
20
- # REPO_ID = "microsoft/Phi-4-mini-instruct-onnx"
21
- # SUBFOLDER = "cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4"
22
- # llm = Phi4MiniONNXLLM(REPO_ID, SUBFOLDER)
23
-
24
- # MODEL_NAME = "openai-community/gpt2"
25
- MODEL_NAME = "microsoft/phi-1_5"
26
- # llm = HuggingfaceModel(MODEL_NAME)
27
-
28
- hf_llm = HuggingFacePipeline.from_model_id(
29
- model_id="microsoft/Phi-4",
30
- task="text-generation",
31
- pipeline_kwargs={
32
- "max_new_tokens": 128,
33
- "temperature": 0.3,
34
- "top_k": 50,
35
- "do_sample": True
36
- },
37
- model_kwargs={
38
- "torch_dtype": "auto",
39
- "device_map": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
40
- "max_memory": {0: "10GB"},
41
- "use_cache": False
42
- }
43
- )
44
 
45
  SYS_PROMPT = """You are a knowledgeable financial professional. You can provide well elaborated and credible answers to user queries in economic and finance by referring to retrieved contexts.
46
  You should answer user queries strictly following the instructions below, and do not provide anything irrelevant. \n
@@ -108,7 +84,7 @@ def get_message_history(
108
  """
109
  return MessageHistory(
110
  session_id = session_id,
111
- connection_string=str(mongo_url), database_name='emails')
112
 
113
  class RAGChain(RunnableWithMessageHistory):
114
  """
@@ -130,3 +106,50 @@ class RAGChain(RunnableWithMessageHistory):
130
  history_messages_key="chat_history",
131
  output_messages_key="answer"
132
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from datetime import datetime
5
  from venv import logger
6
 
 
7
  from pymongo import errors
8
  from langchain_core.runnables.history import RunnableWithMessageHistory
9
+ from langchain_core.output_parsers import PydanticOutputParser
10
  from langchain_core.messages import BaseMessage, message_to_dict
11
  from langchain.chains.combine_documents import create_stuff_documents_chain
12
  from langchain.chains.retrieval import create_retrieval_chain
13
  from langchain.prompts.chat import ChatPromptTemplate, MessagesPlaceholder
14
  from langchain_mongodb import MongoDBChatMessageHistory
15
+
16
+ from schema import FollowUpQ
17
+ from models.llm import GPTModel
18
+
19
+ llm = GPTModel()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  SYS_PROMPT = """You are a knowledgeable financial professional. You can provide well elaborated and credible answers to user queries in economic and finance by referring to retrieved contexts.
22
  You should answer user queries strictly following the instructions below, and do not provide anything irrelevant. \n
 
84
  """
85
  return MessageHistory(
86
  session_id = session_id,
87
+ connection_string=str(mongo_url), database_name='mailbox')
88
 
89
  class RAGChain(RunnableWithMessageHistory):
90
  """
 
106
  history_messages_key="chat_history",
107
  output_messages_key="answer"
108
  )
109
+
110
+ class FollowUpChain():
111
+ """
112
+ FollowUpQChain is a class to generate follow-up questions based on contexts and initial query.
113
+
114
+ Attributes:
115
+ parser (PydanticOutputParser): An instance of PydanticOutputParser to parse the output.
116
+ chain (Chain): A chain of prompts and models to generate follow-up questions.
117
+
118
+ Methods:
119
+ __init__():
120
+ Initializes the FollowUpQChain with a parser and a prompt chain.
121
+
122
+ invoke(contexts, query):
123
+ Invokes the chain with the provided contexts and query to generate follow-up questions.
124
+
125
+ contexts (str): The contexts to be used for generating follow-up questions.
126
+ query (str): The initial query to be used for generating follow-up questions.
127
+ """
128
+ def __init__(self):
129
+ self.parser = PydanticOutputParser(pydantic_object=FollowUpQ)
130
+ prompt = ChatPromptTemplate.from_messages([
131
+ ("system", "You are a professional commentator on current events.Your task\
132
+ is to provide 3 follow-up questions based on contexts and initial query."),
133
+ ("system", "contexts: {contexts}"),
134
+ ("system", "initial query: {query}"),
135
+ ("human", "Format instructions: {format_instructions}"),
136
+ ("placeholder", "{agent_scratchpad}"),
137
+ ])
138
+ self.chain = prompt | llm | self.parser
139
+
140
+ def invoke(self, query, contexts):
141
+ """
142
+ Invokes the chain with the provided content and additional parameters.
143
+
144
+ Args:
145
+ content (str): The article content to be processed.
146
+
147
+ Returns:
148
+ The result of the chain invocation.
149
+ """
150
+ result = self.chain.invoke({
151
+ 'contexts': contexts,
152
+ 'format_instructions': self.parser.get_format_instructions(),
153
+ 'query': query
154
+ })
155
+ return result.questions
collection_data.csv DELETED
The diff for this file is too large to render. See raw diff
 
main.py CHANGED
@@ -1,20 +1,36 @@
1
- """Module to run the mail collection process."""
2
- from dotenv import load_dotenv
3
-
4
- # from controllers import mail
5
- from chain import RAGChain
6
- from retriever import DocRetriever
7
-
8
- load_dotenv()
9
-
10
- if __name__ == "__main__":
11
- # mail.collect()
12
- # mail.get_documents()
13
- req = {
14
- "query": "Just give me an update?",
15
- }
16
- chain = RAGChain(DocRetriever(req=req))
17
- result = chain.invoke({"input": req['query']},
18
- config={"configurable": {"session_id": "20250301"}})
19
- print(result)
20
- print(result.get("answer"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module to handle the main FastAPI application and its endpoints."""
2
+ import logging
3
+ from fastapi import FastAPI
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ from router import main
6
+
7
+
8
+ app = FastAPI(docs_url="/")
9
+
10
+ app.include_router(main.router, tags=["content"])
11
+
12
+ origins = [
13
+ "*"
14
+ ]
15
+
16
+ app.add_middleware(
17
+ CORSMiddleware,
18
+ allow_origins=origins,
19
+ allow_credentials = True,
20
+ allow_methods=["*"],
21
+ allow_headers=["*"],
22
+ )
23
+
24
+ logging.basicConfig(
25
+ format='%(asctime)s - %(levelname)s - %(funcName)s - %(message)s')
26
+ logging.getLogger().setLevel(logging.ERROR)
27
+
28
+
29
+ @app.get("/_health")
30
+ def health():
31
+ """
32
+ Returns the health status of the application.
33
+
34
+ :return: A string "OK" indicating the health status.
35
+ """
36
+ return "OK"
retriever/__init__.py CHANGED
@@ -23,9 +23,9 @@ class DocRetriever(BaseRetriever):
23
  list: A list of Document objects with relevant metadata.
24
  """
25
  retriever: VectorStoreRetriever = None
26
- k: int = 5
27
 
28
- def __init__(self, req, k: int = 2) -> None:
29
  super().__init__()
30
  # _filter={}
31
  # if req.site != []:
@@ -52,7 +52,7 @@ class DocRetriever(BaseRetriever):
52
  metadata = {
53
  "content": doc.page_content,
54
  # "id": doc.metadata['id'],
55
- # "title": doc.metadata['title'],
56
  # "site": doc.metadata['site'],
57
  # "link": doc.metadata['link'],
58
  # "publishDate": doc.metadata['publishDate'].strftime('%Y-%m-%d'),
 
23
  list: A list of Document objects with relevant metadata.
24
  """
25
  retriever: VectorStoreRetriever = None
26
+ k: int = 3
27
 
28
+ def __init__(self, req, k: int = 3) -> None:
29
  super().__init__()
30
  # _filter={}
31
  # if req.site != []:
 
52
  metadata = {
53
  "content": doc.page_content,
54
  # "id": doc.metadata['id'],
55
+ "title": doc.metadata['subject'],
56
  # "site": doc.metadata['site'],
57
  # "link": doc.metadata['link'],
58
  # "publishDate": doc.metadata['publishDate'].strftime('%Y-%m-%d'),
router/__init__.py ADDED
File without changes
router/main.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module for defining the main routes of the API."""
2
+ from fastapi import APIRouter
3
+ from fastapi.responses import StreamingResponse
4
+ from schema import ReqData
5
+ from utils import generate
6
+
7
+ router = APIRouter()
8
+
9
+ @router.post("/stream")
10
+ async def stream(query: ReqData):
11
+ """
12
+ Handles streaming of data based on the provided query.
13
+
14
+ Args:
15
+ query (ReqData): The request data containing the query parameters.
16
+
17
+ Returns:
18
+ StreamingResponse: A streaming response with generated data with type 'text/event-stream'.
19
+ """
20
+ return StreamingResponse(generate(query), media_type='text/event-stream')
21
+
22
+ # # @router.post("/followup")
23
+ # # def follow_up(req: ReqFollowUp):
24
+ # # """
25
+ # # Handles the follow-up POST request.
26
+
27
+ # # Args:
28
+ # # req (ReqFollowUp): The request object containing follow-up data.
29
+
30
+ # # Returns:
31
+ # # Response: The response from the follow-up processing function.
32
+ # # """
33
+ # # return followup(req)
34
+
35
+ # @router.post("/chat/history")
36
+ # def retrieve_history(chat_history: ChatHistory):
37
+ # """
38
+ # Endpoint to retrieve chat history.
39
+
40
+ # This endpoint handles POST requests to the "/chat/history" URL. It accepts a
41
+ # ChatHistory object as input and returns the chat history.
42
+
43
+ # Args:
44
+ # chat_history (ChatHistory): The chat history object containing the details
45
+ # of the chat to be retrieved.
46
+
47
+ # Returns:
48
+ # The chat history retrieved by the retrieve_chat_history function.
49
+ # """
50
+ # return get_chat_history(chat_history)
51
+
52
+ # @router.post("/chat/session")
53
+ # def retrieve_session(chat_session: ChatSession):
54
+ # """
55
+ # Retrieve a chat session.
56
+
57
+ # Args:
58
+ # chat_session (ChatSession): The chat session to retrieve.
59
+
60
+ # Returns:
61
+ # ChatSession: The retrieved chat session.
62
+ # """
63
+ # return get_chat_session(chat_session)
64
+
65
+ # @router.post("/chat/history/clear")
66
+ # def clear_history(chat_history: ChatHistory):
67
+ # """
68
+ # Clears the chat history.
69
+
70
+ # Args:
71
+ # chat_history (ChatHistory): The chat history object to be cleared.
72
+
73
+ # Returns:
74
+ # The result of the clear_chat_history function.
75
+ # """
76
+ # return clear_chat(chat_history)
schema/__init__.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing the data models for the application."""
2
+ from typing import Optional, List
3
+ from pydantic import BaseModel, Field
4
+
5
+ class ReqData(BaseModel):
6
+ """
7
+ RequestData is a Pydantic model that represents the data structure for a request.
8
+
9
+ Attributes:
10
+ query (str): The query string provided by the user.
11
+ chat_id (str): The unique identifier for the chat session.
12
+ user_id (str): The unique identifier for the user.
13
+ web (Optional[bool]): A flag indicating if the request is from the web. Defaults to False.
14
+ """
15
+ query: str
16
+ id: Optional[List[str]] = []
17
+ site: Optional[List[str]] = []
18
+ chat_id: str
19
+ user_id: str
20
+ web: Optional[bool] = False
21
+
22
+ class ReqFollowUp(BaseModel):
23
+ """
24
+ RequestFollowUp is a Pydantic model that represents a request for follow-up.
25
+
26
+ Attributes:
27
+ query (str): The query string that needs follow-up.
28
+ contexts (list[str]): A list of context strings related to the query.
29
+ """
30
+ query: str
31
+ contexts: list[str]
32
+
33
+ class FollowUpQ(BaseModel):
34
+ """
35
+ FollowUpQ model to represent a follow-up question based on context information.
36
+
37
+ Attributes:
38
+ question (list[str]): A list of follow-up questions based on context information.
39
+ """
40
+ questions: list[str] = Field(..., description="3 Follow up questions based on context.")
41
+
42
+ class ChatHistory(BaseModel):
43
+ """
44
+ ChatHistory model representing a chat session.
45
+
46
+ Attributes:
47
+ chat_id (str): The unique identifier for the chat session.
48
+ user_id (str): The unique identifier for the user.
49
+ """
50
+ chat_id: str
51
+ user_id: str
52
+
53
+ class ChatSession(BaseModel):
54
+ """
55
+ ChatSession model representing a chat session.
56
+
57
+ Attributes:
58
+ user_id (str): The unique identifier for the user.
59
+ """
60
+ user_id: str
utils.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing utility functions for the chatbot application."""
2
+ import json
3
+ from chain import RAGChain, FollowUpChain
4
+ from schema import ReqData
5
+ from retriever import DocRetriever
6
+
7
+ followUpChain = FollowUpChain()
8
+
9
+ async def generate(req: ReqData):
10
+ """
11
+ Asynchronously generates responses based on the provided request data.
12
+
13
+ This function uses different processing chains depending on the `web` attribute of the request.
14
+ It streams chunks of data and yields server-sent events (SSE) for answers and contexts.
15
+ Additionally, it generates follow-up questions and updates citations.
16
+
17
+ Args:
18
+ req (ReqData): Request data containing user and chat info, query, and other parameters.
19
+
20
+ Yields:
21
+ str: Server-sent events (SSE) for answers, contexts, and follow-up questions in JSON format.
22
+ """
23
+ chain = RAGChain(DocRetriever(req=req))
24
+ session_id = "/".join([req.user_id, req.chat_id])
25
+ contexts = []
26
+ for chunk in chain.stream({"input": req.query},
27
+ config={"configurable": {"session_id": session_id}}):
28
+ if 'answer' in chunk:
29
+ yield "event: answer\n"
30
+ yield f"data: {json.dumps(chunk)}\n\n"
31
+ elif 'context' in chunk:
32
+ for context in chunk['context']:
33
+ yield "event: context\n"
34
+ yield f"data: {json.dumps({'context': context.metadata})}\n\n"
35
+ yield "event: questions\n"
36
+ yield f"data: {json.dumps({'questions': followUpChain.invoke(req.query, contexts)})}\n\n"