Spaces:
Sleeping
Sleeping
vaishnav
commited on
Commit
·
23d9a47
1
Parent(s):
eb93a33
add chat history
Browse files- .gitignore +4 -1
- app.py +14 -8
- configs/.env +1 -1
- configs/config.py +6 -7
- llm_setup/llm_setup.py +51 -42
- requirements.txt +153 -152
.gitignore
CHANGED
@@ -1,4 +1,7 @@
|
|
1 |
venv
|
2 |
configs/.env
|
3 |
.idea
|
4 |
-
*__pycache__
|
|
|
|
|
|
|
|
1 |
venv
|
2 |
configs/.env
|
3 |
.idea
|
4 |
+
*__pycache__
|
5 |
+
venv
|
6 |
+
embeddings
|
7 |
+
*.gradio
|
app.py
CHANGED
@@ -20,16 +20,22 @@ service.scrape_and_get_store_vector_retriever(config.URLS)
|
|
20 |
# Initialize the LLMService with logger, prompt, and store vector retriever
|
21 |
llm_svc = LLMService(logger, config.SYSTEM_PROMPT, store.get_chroma_instance().as_retriever())
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
response = llm_svc.conversational_rag_chain().invoke(user_input)
|
29 |
|
30 |
return response
|
31 |
|
32 |
-
|
33 |
if __name__ == '__main__':
|
34 |
logging.info("Starting AIVIz Bot")
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
# Initialize the LLMService with logger, prompt, and store vector retriever
|
21 |
llm_svc = LLMService(logger, config.SYSTEM_PROMPT, store.get_chroma_instance().as_retriever())
|
22 |
|
23 |
+
def respond(user_input):
|
24 |
+
response = llm_svc.conversational_rag_chain().invoke(
|
25 |
+
{"input": user_input},
|
26 |
+
config={"configurable": {"session_id": "abc"}},
|
27 |
+
)["answer"]
|
|
|
28 |
|
29 |
return response
|
30 |
|
|
|
31 |
if __name__ == '__main__':
|
32 |
logging.info("Starting AIVIz Bot")
|
33 |
+
|
34 |
+
# Using ChatInterface to create the chatbot interface
|
35 |
+
chat_interface = gr.ChatInterface(
|
36 |
+
fn=respond,
|
37 |
+
title="AISDb Bot",
|
38 |
+
description="LLM's are prone to hallucinations"
|
39 |
+
)
|
40 |
+
|
41 |
+
chat_interface.launch(share=True)
|
configs/.env
CHANGED
@@ -1 +1 @@
|
|
1 |
-
GOOGLE_API_KEY="
|
|
|
1 |
+
GOOGLE_API_KEY=""
|
configs/config.py
CHANGED
@@ -57,13 +57,12 @@ EMBEDDINGS = HuggingFaceEmbeddings(
|
|
57 |
model_kwargs={"device": "cpu"},
|
58 |
)
|
59 |
|
60 |
-
SYSTEM_PROMPT = """
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
|
68 |
def set_envs():
|
69 |
if "GOOGLE_API_KEY" not in os.environ:
|
|
|
57 |
model_kwargs={"device": "cpu"},
|
58 |
)
|
59 |
|
60 |
+
SYSTEM_PROMPT = """You are an assistant for question-answering tasks. \
|
61 |
+
Use the following pieces of retrieved context to answer the question. \
|
62 |
+
Try to keep the answer concise, unless aksed by the user to be eloborated.\
|
63 |
+
Analyse the question, and provide necessary python code help if necessary, as you will be mainly used for ML research.\
|
64 |
+
If you don't know the answer, just say that you don't know. \
|
65 |
+
Context: {context}"""
|
|
|
66 |
|
67 |
def set_envs():
|
68 |
if "GOOGLE_API_KEY" not in os.environ:
|
llm_setup/llm_setup.py
CHANGED
@@ -1,13 +1,16 @@
|
|
1 |
from langchain_core.output_parsers import StrOutputParser
|
2 |
from langchain_core.prompts import (
|
3 |
ChatPromptTemplate,
|
4 |
-
|
5 |
-
HumanMessagePromptTemplate
|
6 |
)
|
|
|
|
|
7 |
from langchain_core.runnables import RunnablePassthrough
|
8 |
from langchain_core.vectorstores import VectorStoreRetriever
|
9 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
10 |
-
|
|
|
|
|
11 |
from processing.documents import format_documents
|
12 |
|
13 |
|
@@ -15,20 +18,12 @@ def _initialize_llm() -> ChatGoogleGenerativeAI:
|
|
15 |
"""
|
16 |
Initializes the LLM instance.
|
17 |
"""
|
18 |
-
llm = ChatGoogleGenerativeAI(model="gemini-
|
|
|
19 |
return llm
|
20 |
|
21 |
|
22 |
class LLMService:
|
23 |
-
"""
|
24 |
-
Service for managing LLM interactions and conversational RAG chain.
|
25 |
-
|
26 |
-
Args:
|
27 |
-
logger: Logger instance for logging.
|
28 |
-
system_prompt: The prompt for the QA system.
|
29 |
-
web_retriever: A VectorStoreRetriever instance for retrieving web documents.
|
30 |
-
"""
|
31 |
-
|
32 |
def __init__(self, logger, system_prompt: str, web_retriever: VectorStoreRetriever):
|
33 |
self._conversational_rag_chain = None
|
34 |
self._logger = logger
|
@@ -39,27 +34,55 @@ class LLMService:
|
|
39 |
|
40 |
self._initialize_conversational_rag_chain()
|
41 |
|
|
|
|
|
|
|
42 |
def _initialize_conversational_rag_chain(self):
|
43 |
"""
|
44 |
Initializes the conversational RAG chain.
|
45 |
"""
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
|
|
56 |
)
|
57 |
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
def conversational_rag_chain(self):
|
65 |
"""
|
@@ -69,20 +92,6 @@ class LLMService:
|
|
69 |
The conversational RAG chain instance.
|
70 |
"""
|
71 |
return self._conversational_rag_chain
|
72 |
-
|
73 |
-
def update_chat_history(self, user_input: str, llm_response: str):
|
74 |
-
"""
|
75 |
-
Updates the chat history with the latest question and response.
|
76 |
-
"""
|
77 |
-
self.chat_history.append(f"User: {user_input}\nAI: {llm_response}")
|
78 |
-
|
79 |
-
def ask_question(self, question: str):
|
80 |
-
"""
|
81 |
-
Processes a user question using the conversational RAG chain and updates history.
|
82 |
-
"""
|
83 |
-
response = self._conversational_rag_chain.invoke(question)
|
84 |
-
self.update_chat_history(question, response)
|
85 |
-
return response
|
86 |
|
87 |
def get_llm(self) -> ChatGoogleGenerativeAI:
|
88 |
"""
|
|
|
1 |
from langchain_core.output_parsers import StrOutputParser
|
2 |
from langchain_core.prompts import (
|
3 |
ChatPromptTemplate,
|
4 |
+
MessagesPlaceholder,
|
|
|
5 |
)
|
6 |
+
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
|
7 |
+
from langchain.chains.combine_documents import create_stuff_documents_chain
|
8 |
from langchain_core.runnables import RunnablePassthrough
|
9 |
from langchain_core.vectorstores import VectorStoreRetriever
|
10 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
11 |
+
from langchain_core.chat_history import BaseChatMessageHistory
|
12 |
+
from langchain_community.chat_message_histories import ChatMessageHistory
|
13 |
+
from langchain_core.runnables.history import RunnableWithMessageHistory
|
14 |
from processing.documents import format_documents
|
15 |
|
16 |
|
|
|
18 |
"""
|
19 |
Initializes the LLM instance.
|
20 |
"""
|
21 |
+
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash-exp")
|
22 |
+
|
23 |
return llm
|
24 |
|
25 |
|
26 |
class LLMService:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
def __init__(self, logger, system_prompt: str, web_retriever: VectorStoreRetriever):
|
28 |
self._conversational_rag_chain = None
|
29 |
self._logger = logger
|
|
|
34 |
|
35 |
self._initialize_conversational_rag_chain()
|
36 |
|
37 |
+
### Statefully manage chat history ###
|
38 |
+
self.store = {}
|
39 |
+
|
40 |
def _initialize_conversational_rag_chain(self):
|
41 |
"""
|
42 |
Initializes the conversational RAG chain.
|
43 |
"""
|
44 |
+
### Contextualize question ###
|
45 |
+
contextualize_q_system_prompt = """Given a chat history and the latest user question \
|
46 |
+
which might reference context in the chat history, formulate a standalone question \
|
47 |
+
which can be understood without the chat history. Do NOT answer the question, \
|
48 |
+
just reformulate it if needed and otherwise return it as is."""
|
49 |
+
|
50 |
+
contextualize_q_prompt = ChatPromptTemplate.from_messages(
|
51 |
+
[
|
52 |
+
("system", contextualize_q_system_prompt),
|
53 |
+
MessagesPlaceholder("chat_history"),
|
54 |
+
("human", "{input}"),
|
55 |
+
]
|
56 |
)
|
57 |
|
58 |
+
|
59 |
+
|
60 |
+
history_aware_retriever = create_history_aware_retriever(
|
61 |
+
self.llm, self._web_retriever, contextualize_q_prompt)
|
62 |
+
|
63 |
+
qa_prompt = ChatPromptTemplate.from_messages(
|
64 |
+
[
|
65 |
+
("system", self.system_prompt),
|
66 |
+
MessagesPlaceholder("chat_history"),
|
67 |
+
("human", "{input}"),
|
68 |
+
]
|
69 |
+
)
|
70 |
+
|
71 |
+
question_answer_chain = create_stuff_documents_chain(self.llm, qa_prompt)
|
72 |
+
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
|
73 |
+
|
74 |
+
self._conversational_rag_chain = RunnableWithMessageHistory(
|
75 |
+
rag_chain,
|
76 |
+
self._get_session_history,
|
77 |
+
input_messages_key="input",
|
78 |
+
history_messages_key="chat_history",
|
79 |
+
output_messages_key="answer",
|
80 |
+
)
|
81 |
+
|
82 |
+
def _get_session_history(self, session_id: str) -> BaseChatMessageHistory:
|
83 |
+
if session_id not in self.store:
|
84 |
+
self.store[session_id] = ChatMessageHistory()
|
85 |
+
return self.store[session_id]
|
86 |
|
87 |
def conversational_rag_chain(self):
|
88 |
"""
|
|
|
92 |
The conversational RAG chain instance.
|
93 |
"""
|
94 |
return self._conversational_rag_chain
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
def get_llm(self) -> ChatGoogleGenerativeAI:
|
97 |
"""
|
requirements.txt
CHANGED
@@ -1,154 +1,155 @@
|
|
|
|
1 |
huggingface_hub
|
2 |
-
aiofiles
|
3 |
-
aiohappyeyeballs
|
4 |
-
aiohttp
|
5 |
-
aiosignal
|
6 |
-
annotated-types
|
7 |
-
anyio
|
8 |
-
asgiref
|
9 |
-
attrs
|
10 |
-
backoff
|
11 |
-
bcrypt
|
12 |
-
build
|
13 |
-
cachetools
|
14 |
-
certifi
|
15 |
-
charset-normalizer
|
16 |
-
chroma-hnswlib
|
17 |
-
chromadb
|
18 |
-
click
|
19 |
-
colorama
|
20 |
-
coloredlogs
|
21 |
-
contourpy
|
22 |
-
cycler
|
23 |
-
dataclasses-json
|
24 |
-
Deprecated
|
25 |
-
fastapi
|
26 |
-
ffmpy
|
27 |
-
filelock
|
28 |
-
flatbuffers
|
29 |
-
fonttools
|
30 |
-
frozenlist
|
31 |
-
fsspec
|
32 |
-
google-ai-generativelanguage
|
33 |
-
google-api-core
|
34 |
-
google-api-python-client
|
35 |
-
google-auth
|
36 |
-
google-auth-httplib2
|
37 |
-
google-generativeai
|
38 |
-
googleapis-common-protos
|
39 |
-
gradio
|
40 |
-
gradio_client
|
41 |
-
greenlet
|
42 |
-
grpcio
|
43 |
-
grpcio-status
|
44 |
-
h11
|
45 |
-
httpcore
|
46 |
-
httplib2
|
47 |
-
httptools
|
48 |
-
httpx
|
49 |
-
huggingface-hub
|
50 |
-
humanfriendly
|
51 |
-
idna
|
52 |
-
importlib_metadata
|
53 |
-
importlib_resources
|
54 |
-
Jinja2
|
55 |
-
joblib
|
56 |
-
jsonpatch
|
57 |
-
jsonpointer
|
58 |
-
kiwisolver
|
59 |
-
kubernetes
|
60 |
-
langchain
|
61 |
-
langchain-chroma
|
62 |
-
langchain-community
|
63 |
-
langchain-core
|
64 |
-
langchain-google-genai
|
65 |
-
langchain-huggingface
|
66 |
-
langchain-text-splitters
|
67 |
-
langsmith
|
68 |
-
markdown-it-py
|
69 |
-
MarkupSafe
|
70 |
-
marshmallow
|
71 |
-
matplotlib
|
72 |
-
mdurl
|
73 |
-
mmh3
|
74 |
-
monotonic
|
75 |
-
mpmath
|
76 |
-
multidict
|
77 |
-
mypy-extensions
|
78 |
-
networkx
|
79 |
-
numpy
|
80 |
-
oauthlib
|
81 |
-
onnxruntime
|
82 |
-
opentelemetry-api
|
83 |
-
opentelemetry-exporter-otlp-proto-common
|
84 |
-
opentelemetry-exporter-otlp-proto-grpc
|
85 |
-
opentelemetry-instrumentation
|
86 |
-
opentelemetry-instrumentation-asgi
|
87 |
-
opentelemetry-instrumentation-fastapi
|
88 |
-
opentelemetry-proto
|
89 |
-
opentelemetry-sdk
|
90 |
-
opentelemetry-semantic-conventions
|
91 |
-
opentelemetry-util-http
|
92 |
-
orjson
|
93 |
-
overrides
|
94 |
-
packaging
|
95 |
-
pandas
|
96 |
-
pillow
|
97 |
-
posthog
|
98 |
-
proto-plus
|
99 |
-
protobuf
|
100 |
-
pyasn1
|
101 |
-
pyasn1_modules
|
102 |
-
pydantic
|
103 |
-
pydantic-settings
|
104 |
-
pydantic_core
|
105 |
-
pydub
|
106 |
-
Pygments
|
107 |
-
pyparsing
|
108 |
-
PyPika
|
109 |
-
pyproject_hooks
|
110 |
-
pyreadline3
|
111 |
-
python-dateutil
|
112 |
-
python-dotenv
|
113 |
-
python-multipart
|
114 |
-
pytz
|
115 |
-
PyYAML
|
116 |
-
regex
|
117 |
-
requests
|
118 |
-
requests-oauthlib
|
119 |
-
rich
|
120 |
-
rsa
|
121 |
-
ruff
|
122 |
-
safetensors
|
123 |
-
scikit-learn
|
124 |
-
scipy
|
125 |
-
semantic-version
|
126 |
-
sentence-transformers
|
127 |
-
setuptools
|
128 |
-
shellingham
|
129 |
-
six
|
130 |
-
sniffio
|
131 |
-
SQLAlchemy
|
132 |
-
starlette
|
133 |
-
sympy
|
134 |
-
tenacity
|
135 |
-
threadpoolctl
|
136 |
-
tokenizers
|
137 |
-
tomlkit
|
138 |
-
torch
|
139 |
-
tqdm
|
140 |
-
transformers
|
141 |
-
typer
|
142 |
-
typing-inspect
|
143 |
-
typing_extensions
|
144 |
-
tzdata
|
145 |
-
uritemplate
|
146 |
-
urllib3
|
147 |
-
uvicorn
|
148 |
-
watchfiles
|
149 |
-
websocket-client
|
150 |
-
websockets
|
151 |
-
wrapt
|
152 |
-
yarl
|
153 |
-
zipp
|
154 |
bs4
|
|
|
1 |
+
|
2 |
huggingface_hub
|
3 |
+
aiofiles
|
4 |
+
aiohappyeyeballs
|
5 |
+
aiohttp
|
6 |
+
aiosignal
|
7 |
+
annotated-types
|
8 |
+
anyio
|
9 |
+
asgiref
|
10 |
+
attrs
|
11 |
+
backoff
|
12 |
+
bcrypt
|
13 |
+
build
|
14 |
+
cachetools
|
15 |
+
certifi
|
16 |
+
charset-normalizer
|
17 |
+
chroma-hnswlib
|
18 |
+
chromadb
|
19 |
+
click
|
20 |
+
colorama
|
21 |
+
coloredlogs
|
22 |
+
contourpy
|
23 |
+
cycler
|
24 |
+
dataclasses-json
|
25 |
+
Deprecated
|
26 |
+
fastapi
|
27 |
+
ffmpy
|
28 |
+
filelock
|
29 |
+
flatbuffers
|
30 |
+
fonttools
|
31 |
+
frozenlist
|
32 |
+
fsspec
|
33 |
+
google-ai-generativelanguage
|
34 |
+
google-api-core
|
35 |
+
google-api-python-client
|
36 |
+
google-auth
|
37 |
+
google-auth-httplib2
|
38 |
+
google-generativeai
|
39 |
+
googleapis-common-protos
|
40 |
+
gradio
|
41 |
+
gradio_client
|
42 |
+
greenlet
|
43 |
+
grpcio
|
44 |
+
grpcio-status
|
45 |
+
h11
|
46 |
+
httpcore
|
47 |
+
httplib2
|
48 |
+
httptools
|
49 |
+
httpx
|
50 |
+
huggingface-hub
|
51 |
+
humanfriendly
|
52 |
+
idna
|
53 |
+
importlib_metadata
|
54 |
+
importlib_resources
|
55 |
+
Jinja2
|
56 |
+
joblib
|
57 |
+
jsonpatch
|
58 |
+
jsonpointer
|
59 |
+
kiwisolver
|
60 |
+
kubernetes
|
61 |
+
langchain
|
62 |
+
langchain-chroma
|
63 |
+
langchain-community
|
64 |
+
langchain-core
|
65 |
+
langchain-google-genai
|
66 |
+
langchain-huggingface
|
67 |
+
langchain-text-splitters
|
68 |
+
langsmith
|
69 |
+
markdown-it-py
|
70 |
+
MarkupSafe
|
71 |
+
marshmallow
|
72 |
+
matplotlib
|
73 |
+
mdurl
|
74 |
+
mmh3
|
75 |
+
monotonic
|
76 |
+
mpmath
|
77 |
+
multidict
|
78 |
+
mypy-extensions
|
79 |
+
networkx
|
80 |
+
numpy
|
81 |
+
oauthlib
|
82 |
+
onnxruntime
|
83 |
+
opentelemetry-api
|
84 |
+
opentelemetry-exporter-otlp-proto-common
|
85 |
+
opentelemetry-exporter-otlp-proto-grpc
|
86 |
+
opentelemetry-instrumentation
|
87 |
+
opentelemetry-instrumentation-asgi
|
88 |
+
opentelemetry-instrumentation-fastapi
|
89 |
+
opentelemetry-proto
|
90 |
+
opentelemetry-sdk
|
91 |
+
opentelemetry-semantic-conventions
|
92 |
+
opentelemetry-util-http
|
93 |
+
orjson
|
94 |
+
overrides
|
95 |
+
packaging
|
96 |
+
pandas
|
97 |
+
pillow
|
98 |
+
posthog
|
99 |
+
proto-plus
|
100 |
+
protobuf
|
101 |
+
pyasn1
|
102 |
+
pyasn1_modules
|
103 |
+
pydantic
|
104 |
+
pydantic-settings
|
105 |
+
pydantic_core
|
106 |
+
pydub
|
107 |
+
Pygments
|
108 |
+
pyparsing
|
109 |
+
PyPika
|
110 |
+
pyproject_hooks
|
111 |
+
pyreadline3
|
112 |
+
python-dateutil
|
113 |
+
python-dotenv
|
114 |
+
python-multipart
|
115 |
+
pytz
|
116 |
+
PyYAML
|
117 |
+
regex
|
118 |
+
requests
|
119 |
+
requests-oauthlib
|
120 |
+
rich
|
121 |
+
rsa
|
122 |
+
ruff
|
123 |
+
safetensors
|
124 |
+
scikit-learn
|
125 |
+
scipy
|
126 |
+
semantic-version
|
127 |
+
sentence-transformers
|
128 |
+
setuptools
|
129 |
+
shellingham
|
130 |
+
six
|
131 |
+
sniffio
|
132 |
+
SQLAlchemy
|
133 |
+
starlette
|
134 |
+
sympy
|
135 |
+
tenacity
|
136 |
+
threadpoolctl
|
137 |
+
tokenizers
|
138 |
+
tomlkit
|
139 |
+
torch
|
140 |
+
tqdm
|
141 |
+
transformers
|
142 |
+
typer
|
143 |
+
typing-inspect
|
144 |
+
typing_extensions
|
145 |
+
tzdata
|
146 |
+
uritemplate
|
147 |
+
urllib3
|
148 |
+
uvicorn
|
149 |
+
watchfiles
|
150 |
+
websocket-client
|
151 |
+
websockets
|
152 |
+
wrapt
|
153 |
+
yarl
|
154 |
+
zipp
|
155 |
bs4
|