Remove unused binary files and refactor main application structure to integrate FastAPI with new routing and utility functions.
Browse files- chain/__init__.py +54 -31
- collection_data.csv +0 -0
- main.py +36 -20
- retriever/__init__.py +3 -3
- router/__init__.py +0 -0
- router/main.py +76 -0
- schema/__init__.py +60 -0
- utils.py +36 -0
chain/__init__.py
CHANGED
@@ -4,43 +4,19 @@ import json
|
|
4 |
from datetime import datetime
|
5 |
from venv import logger
|
6 |
|
7 |
-
import torch
|
8 |
from pymongo import errors
|
9 |
from langchain_core.runnables.history import RunnableWithMessageHistory
|
|
|
10 |
from langchain_core.messages import BaseMessage, message_to_dict
|
11 |
from langchain.chains.combine_documents import create_stuff_documents_chain
|
12 |
from langchain.chains.retrieval import create_retrieval_chain
|
13 |
from langchain.prompts.chat import ChatPromptTemplate, MessagesPlaceholder
|
14 |
from langchain_mongodb import MongoDBChatMessageHistory
|
15 |
-
|
16 |
-
|
17 |
-
from models.llm import GPTModel
|
18 |
-
|
19 |
-
|
20 |
-
# REPO_ID = "microsoft/Phi-4-mini-instruct-onnx"
|
21 |
-
# SUBFOLDER = "cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4"
|
22 |
-
# llm = Phi4MiniONNXLLM(REPO_ID, SUBFOLDER)
|
23 |
-
|
24 |
-
# MODEL_NAME = "openai-community/gpt2"
|
25 |
-
MODEL_NAME = "microsoft/phi-1_5"
|
26 |
-
# llm = HuggingfaceModel(MODEL_NAME)
|
27 |
-
|
28 |
-
hf_llm = HuggingFacePipeline.from_model_id(
|
29 |
-
model_id="microsoft/Phi-4",
|
30 |
-
task="text-generation",
|
31 |
-
pipeline_kwargs={
|
32 |
-
"max_new_tokens": 128,
|
33 |
-
"temperature": 0.3,
|
34 |
-
"top_k": 50,
|
35 |
-
"do_sample": True
|
36 |
-
},
|
37 |
-
model_kwargs={
|
38 |
-
"torch_dtype": "auto",
|
39 |
-
"device_map": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
40 |
-
"max_memory": {0: "10GB"},
|
41 |
-
"use_cache": False
|
42 |
-
}
|
43 |
-
)
|
44 |
|
45 |
SYS_PROMPT = """You are a knowledgeable financial professional. You can provide well elaborated and credible answers to user queries in economic and finance by referring to retrieved contexts.
|
46 |
You should answer user queries strictly following the instructions below, and do not provide anything irrelevant. \n
|
@@ -108,7 +84,7 @@ def get_message_history(
|
|
108 |
"""
|
109 |
return MessageHistory(
|
110 |
session_id = session_id,
|
111 |
-
connection_string=str(mongo_url), database_name='
|
112 |
|
113 |
class RAGChain(RunnableWithMessageHistory):
|
114 |
"""
|
@@ -130,3 +106,50 @@ class RAGChain(RunnableWithMessageHistory):
|
|
130 |
history_messages_key="chat_history",
|
131 |
output_messages_key="answer"
|
132 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
from datetime import datetime
|
5 |
from venv import logger
|
6 |
|
|
|
7 |
from pymongo import errors
|
8 |
from langchain_core.runnables.history import RunnableWithMessageHistory
|
9 |
+
from langchain_core.output_parsers import PydanticOutputParser
|
10 |
from langchain_core.messages import BaseMessage, message_to_dict
|
11 |
from langchain.chains.combine_documents import create_stuff_documents_chain
|
12 |
from langchain.chains.retrieval import create_retrieval_chain
|
13 |
from langchain.prompts.chat import ChatPromptTemplate, MessagesPlaceholder
|
14 |
from langchain_mongodb import MongoDBChatMessageHistory
|
15 |
+
|
16 |
+
from schema import FollowUpQ
|
17 |
+
from models.llm import GPTModel
|
18 |
+
|
19 |
+
llm = GPTModel()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
SYS_PROMPT = """You are a knowledgeable financial professional. You can provide well elaborated and credible answers to user queries in economic and finance by referring to retrieved contexts.
|
22 |
You should answer user queries strictly following the instructions below, and do not provide anything irrelevant. \n
|
|
|
84 |
"""
|
85 |
return MessageHistory(
|
86 |
session_id = session_id,
|
87 |
+
connection_string=str(mongo_url), database_name='mailbox')
|
88 |
|
89 |
class RAGChain(RunnableWithMessageHistory):
|
90 |
"""
|
|
|
106 |
history_messages_key="chat_history",
|
107 |
output_messages_key="answer"
|
108 |
)
|
109 |
+
|
110 |
+
class FollowUpChain():
|
111 |
+
"""
|
112 |
+
FollowUpQChain is a class to generate follow-up questions based on contexts and initial query.
|
113 |
+
|
114 |
+
Attributes:
|
115 |
+
parser (PydanticOutputParser): An instance of PydanticOutputParser to parse the output.
|
116 |
+
chain (Chain): A chain of prompts and models to generate follow-up questions.
|
117 |
+
|
118 |
+
Methods:
|
119 |
+
__init__():
|
120 |
+
Initializes the FollowUpQChain with a parser and a prompt chain.
|
121 |
+
|
122 |
+
invoke(contexts, query):
|
123 |
+
Invokes the chain with the provided contexts and query to generate follow-up questions.
|
124 |
+
|
125 |
+
contexts (str): The contexts to be used for generating follow-up questions.
|
126 |
+
query (str): The initial query to be used for generating follow-up questions.
|
127 |
+
"""
|
128 |
+
def __init__(self):
|
129 |
+
self.parser = PydanticOutputParser(pydantic_object=FollowUpQ)
|
130 |
+
prompt = ChatPromptTemplate.from_messages([
|
131 |
+
("system", "You are a professional commentator on current events.Your task\
|
132 |
+
is to provide 3 follow-up questions based on contexts and initial query."),
|
133 |
+
("system", "contexts: {contexts}"),
|
134 |
+
("system", "initial query: {query}"),
|
135 |
+
("human", "Format instructions: {format_instructions}"),
|
136 |
+
("placeholder", "{agent_scratchpad}"),
|
137 |
+
])
|
138 |
+
self.chain = prompt | llm | self.parser
|
139 |
+
|
140 |
+
def invoke(self, query, contexts):
|
141 |
+
"""
|
142 |
+
Invokes the chain with the provided content and additional parameters.
|
143 |
+
|
144 |
+
Args:
|
145 |
+
content (str): The article content to be processed.
|
146 |
+
|
147 |
+
Returns:
|
148 |
+
The result of the chain invocation.
|
149 |
+
"""
|
150 |
+
result = self.chain.invoke({
|
151 |
+
'contexts': contexts,
|
152 |
+
'format_instructions': self.parser.get_format_instructions(),
|
153 |
+
'query': query
|
154 |
+
})
|
155 |
+
return result.questions
|
collection_data.csv
DELETED
The diff for this file is too large to render.
See raw diff
|
|
main.py
CHANGED
@@ -1,20 +1,36 @@
|
|
1 |
-
"""Module to
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
from
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Module to handle the main FastAPI application and its endpoints."""
|
2 |
+
import logging
|
3 |
+
from fastapi import FastAPI
|
4 |
+
from fastapi.middleware.cors import CORSMiddleware
|
5 |
+
from router import main
|
6 |
+
|
7 |
+
|
8 |
+
app = FastAPI(docs_url="/")
|
9 |
+
|
10 |
+
app.include_router(main.router, tags=["content"])
|
11 |
+
|
12 |
+
origins = [
|
13 |
+
"*"
|
14 |
+
]
|
15 |
+
|
16 |
+
app.add_middleware(
|
17 |
+
CORSMiddleware,
|
18 |
+
allow_origins=origins,
|
19 |
+
allow_credentials = True,
|
20 |
+
allow_methods=["*"],
|
21 |
+
allow_headers=["*"],
|
22 |
+
)
|
23 |
+
|
24 |
+
logging.basicConfig(
|
25 |
+
format='%(asctime)s - %(levelname)s - %(funcName)s - %(message)s')
|
26 |
+
logging.getLogger().setLevel(logging.ERROR)
|
27 |
+
|
28 |
+
|
29 |
+
@app.get("/_health")
|
30 |
+
def health():
|
31 |
+
"""
|
32 |
+
Returns the health status of the application.
|
33 |
+
|
34 |
+
:return: A string "OK" indicating the health status.
|
35 |
+
"""
|
36 |
+
return "OK"
|
retriever/__init__.py
CHANGED
@@ -23,9 +23,9 @@ class DocRetriever(BaseRetriever):
|
|
23 |
list: A list of Document objects with relevant metadata.
|
24 |
"""
|
25 |
retriever: VectorStoreRetriever = None
|
26 |
-
k: int =
|
27 |
|
28 |
-
def __init__(self, req, k: int =
|
29 |
super().__init__()
|
30 |
# _filter={}
|
31 |
# if req.site != []:
|
@@ -52,7 +52,7 @@ class DocRetriever(BaseRetriever):
|
|
52 |
metadata = {
|
53 |
"content": doc.page_content,
|
54 |
# "id": doc.metadata['id'],
|
55 |
-
|
56 |
# "site": doc.metadata['site'],
|
57 |
# "link": doc.metadata['link'],
|
58 |
# "publishDate": doc.metadata['publishDate'].strftime('%Y-%m-%d'),
|
|
|
23 |
list: A list of Document objects with relevant metadata.
|
24 |
"""
|
25 |
retriever: VectorStoreRetriever = None
|
26 |
+
k: int = 3
|
27 |
|
28 |
+
def __init__(self, req, k: int = 3) -> None:
|
29 |
super().__init__()
|
30 |
# _filter={}
|
31 |
# if req.site != []:
|
|
|
52 |
metadata = {
|
53 |
"content": doc.page_content,
|
54 |
# "id": doc.metadata['id'],
|
55 |
+
"title": doc.metadata['subject'],
|
56 |
# "site": doc.metadata['site'],
|
57 |
# "link": doc.metadata['link'],
|
58 |
# "publishDate": doc.metadata['publishDate'].strftime('%Y-%m-%d'),
|
router/__init__.py
ADDED
File without changes
|
router/main.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Module for defining the main routes of the API."""
|
2 |
+
from fastapi import APIRouter
|
3 |
+
from fastapi.responses import StreamingResponse
|
4 |
+
from schema import ReqData
|
5 |
+
from utils import generate
|
6 |
+
|
7 |
+
router = APIRouter()
|
8 |
+
|
9 |
+
@router.post("/stream")
|
10 |
+
async def stream(query: ReqData):
|
11 |
+
"""
|
12 |
+
Handles streaming of data based on the provided query.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
query (ReqData): The request data containing the query parameters.
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
StreamingResponse: A streaming response with generated data with type 'text/event-stream'.
|
19 |
+
"""
|
20 |
+
return StreamingResponse(generate(query), media_type='text/event-stream')
|
21 |
+
|
22 |
+
# # @router.post("/followup")
|
23 |
+
# # def follow_up(req: ReqFollowUp):
|
24 |
+
# # """
|
25 |
+
# # Handles the follow-up POST request.
|
26 |
+
|
27 |
+
# # Args:
|
28 |
+
# # req (ReqFollowUp): The request object containing follow-up data.
|
29 |
+
|
30 |
+
# # Returns:
|
31 |
+
# # Response: The response from the follow-up processing function.
|
32 |
+
# # """
|
33 |
+
# # return followup(req)
|
34 |
+
|
35 |
+
# @router.post("/chat/history")
|
36 |
+
# def retrieve_history(chat_history: ChatHistory):
|
37 |
+
# """
|
38 |
+
# Endpoint to retrieve chat history.
|
39 |
+
|
40 |
+
# This endpoint handles POST requests to the "/chat/history" URL. It accepts a
|
41 |
+
# ChatHistory object as input and returns the chat history.
|
42 |
+
|
43 |
+
# Args:
|
44 |
+
# chat_history (ChatHistory): The chat history object containing the details
|
45 |
+
# of the chat to be retrieved.
|
46 |
+
|
47 |
+
# Returns:
|
48 |
+
# The chat history retrieved by the retrieve_chat_history function.
|
49 |
+
# """
|
50 |
+
# return get_chat_history(chat_history)
|
51 |
+
|
52 |
+
# @router.post("/chat/session")
|
53 |
+
# def retrieve_session(chat_session: ChatSession):
|
54 |
+
# """
|
55 |
+
# Retrieve a chat session.
|
56 |
+
|
57 |
+
# Args:
|
58 |
+
# chat_session (ChatSession): The chat session to retrieve.
|
59 |
+
|
60 |
+
# Returns:
|
61 |
+
# ChatSession: The retrieved chat session.
|
62 |
+
# """
|
63 |
+
# return get_chat_session(chat_session)
|
64 |
+
|
65 |
+
# @router.post("/chat/history/clear")
|
66 |
+
# def clear_history(chat_history: ChatHistory):
|
67 |
+
# """
|
68 |
+
# Clears the chat history.
|
69 |
+
|
70 |
+
# Args:
|
71 |
+
# chat_history (ChatHistory): The chat history object to be cleared.
|
72 |
+
|
73 |
+
# Returns:
|
74 |
+
# The result of the clear_chat_history function.
|
75 |
+
# """
|
76 |
+
# return clear_chat(chat_history)
|
schema/__init__.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Module containing the data models for the application."""
|
2 |
+
from typing import Optional, List
|
3 |
+
from pydantic import BaseModel, Field
|
4 |
+
|
5 |
+
class ReqData(BaseModel):
|
6 |
+
"""
|
7 |
+
RequestData is a Pydantic model that represents the data structure for a request.
|
8 |
+
|
9 |
+
Attributes:
|
10 |
+
query (str): The query string provided by the user.
|
11 |
+
chat_id (str): The unique identifier for the chat session.
|
12 |
+
user_id (str): The unique identifier for the user.
|
13 |
+
web (Optional[bool]): A flag indicating if the request is from the web. Defaults to False.
|
14 |
+
"""
|
15 |
+
query: str
|
16 |
+
id: Optional[List[str]] = []
|
17 |
+
site: Optional[List[str]] = []
|
18 |
+
chat_id: str
|
19 |
+
user_id: str
|
20 |
+
web: Optional[bool] = False
|
21 |
+
|
22 |
+
class ReqFollowUp(BaseModel):
|
23 |
+
"""
|
24 |
+
RequestFollowUp is a Pydantic model that represents a request for follow-up.
|
25 |
+
|
26 |
+
Attributes:
|
27 |
+
query (str): The query string that needs follow-up.
|
28 |
+
contexts (list[str]): A list of context strings related to the query.
|
29 |
+
"""
|
30 |
+
query: str
|
31 |
+
contexts: list[str]
|
32 |
+
|
33 |
+
class FollowUpQ(BaseModel):
|
34 |
+
"""
|
35 |
+
FollowUpQ model to represent a follow-up question based on context information.
|
36 |
+
|
37 |
+
Attributes:
|
38 |
+
question (list[str]): A list of follow-up questions based on context information.
|
39 |
+
"""
|
40 |
+
questions: list[str] = Field(..., description="3 Follow up questions based on context.")
|
41 |
+
|
42 |
+
class ChatHistory(BaseModel):
|
43 |
+
"""
|
44 |
+
ChatHistory model representing a chat session.
|
45 |
+
|
46 |
+
Attributes:
|
47 |
+
chat_id (str): The unique identifier for the chat session.
|
48 |
+
user_id (str): The unique identifier for the user.
|
49 |
+
"""
|
50 |
+
chat_id: str
|
51 |
+
user_id: str
|
52 |
+
|
53 |
+
class ChatSession(BaseModel):
|
54 |
+
"""
|
55 |
+
ChatSession model representing a chat session.
|
56 |
+
|
57 |
+
Attributes:
|
58 |
+
user_id (str): The unique identifier for the user.
|
59 |
+
"""
|
60 |
+
user_id: str
|
utils.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Module containing utility functions for the chatbot application."""
|
2 |
+
import json
|
3 |
+
from chain import RAGChain, FollowUpChain
|
4 |
+
from schema import ReqData
|
5 |
+
from retriever import DocRetriever
|
6 |
+
|
7 |
+
followUpChain = FollowUpChain()
|
8 |
+
|
9 |
+
async def generate(req: ReqData):
|
10 |
+
"""
|
11 |
+
Asynchronously generates responses based on the provided request data.
|
12 |
+
|
13 |
+
This function uses different processing chains depending on the `web` attribute of the request.
|
14 |
+
It streams chunks of data and yields server-sent events (SSE) for answers and contexts.
|
15 |
+
Additionally, it generates follow-up questions and updates citations.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
req (ReqData): Request data containing user and chat info, query, and other parameters.
|
19 |
+
|
20 |
+
Yields:
|
21 |
+
str: Server-sent events (SSE) for answers, contexts, and follow-up questions in JSON format.
|
22 |
+
"""
|
23 |
+
chain = RAGChain(DocRetriever(req=req))
|
24 |
+
session_id = "/".join([req.user_id, req.chat_id])
|
25 |
+
contexts = []
|
26 |
+
for chunk in chain.stream({"input": req.query},
|
27 |
+
config={"configurable": {"session_id": session_id}}):
|
28 |
+
if 'answer' in chunk:
|
29 |
+
yield "event: answer\n"
|
30 |
+
yield f"data: {json.dumps(chunk)}\n\n"
|
31 |
+
elif 'context' in chunk:
|
32 |
+
for context in chunk['context']:
|
33 |
+
yield "event: context\n"
|
34 |
+
yield f"data: {json.dumps({'context': context.metadata})}\n\n"
|
35 |
+
yield "event: questions\n"
|
36 |
+
yield f"data: {json.dumps({'questions': followUpChain.invoke(req.query, contexts)})}\n\n"
|