vaishnav commited on
Commit
23d9a47
·
1 Parent(s): eb93a33

add chat history

Browse files
Files changed (6) hide show
  1. .gitignore +4 -1
  2. app.py +14 -8
  3. configs/.env +1 -1
  4. configs/config.py +6 -7
  5. llm_setup/llm_setup.py +51 -42
  6. requirements.txt +153 -152
.gitignore CHANGED
@@ -1,4 +1,7 @@
1
  venv
2
  configs/.env
3
  .idea
4
- *__pycache__
 
 
 
 
1
  venv
2
  configs/.env
3
  .idea
4
+ *__pycache__
5
+ venv
6
+ embeddings
7
+ *.gradio
app.py CHANGED
@@ -20,16 +20,22 @@ service.scrape_and_get_store_vector_retriever(config.URLS)
20
  # Initialize the LLMService with logger, prompt, and store vector retriever
21
  llm_svc = LLMService(logger, config.SYSTEM_PROMPT, store.get_chroma_instance().as_retriever())
22
 
23
-
24
-
25
- def respond(user_input, history):
26
- print(f"{user_input}")
27
-
28
- response = llm_svc.conversational_rag_chain().invoke(user_input)
29
 
30
  return response
31
 
32
-
33
  if __name__ == '__main__':
34
  logging.info("Starting AIVIz Bot")
35
- gr.ChatInterface(respond).launch(share=True)
 
 
 
 
 
 
 
 
 
20
  # Initialize the LLMService with logger, prompt, and store vector retriever
21
  llm_svc = LLMService(logger, config.SYSTEM_PROMPT, store.get_chroma_instance().as_retriever())
22
 
23
+ def respond(user_input):
24
+ response = llm_svc.conversational_rag_chain().invoke(
25
+ {"input": user_input},
26
+ config={"configurable": {"session_id": "abc"}},
27
+ )["answer"]
 
28
 
29
  return response
30
 
 
31
  if __name__ == '__main__':
32
  logging.info("Starting AIVIz Bot")
33
+
34
+ # Using ChatInterface to create the chatbot interface
35
+ chat_interface = gr.ChatInterface(
36
+ fn=respond,
37
+ title="AISDb Bot",
38
+ description="LLM's are prone to hallucinations"
39
+ )
40
+
41
+ chat_interface.launch(share=True)
configs/.env CHANGED
@@ -1 +1 @@
1
- GOOGLE_API_KEY="AIzaSyCWQsPEq-D3nJZFdMgsTlxDOweTzPKOTwI"
 
1
+ GOOGLE_API_KEY=""
configs/config.py CHANGED
@@ -57,13 +57,12 @@ EMBEDDINGS = HuggingFaceEmbeddings(
57
  model_kwargs={"device": "cpu"},
58
  )
59
 
60
- SYSTEM_PROMPT = """
61
- You are a chatbot to assist users asking about Automatic Identification systems (AIS) database from the context given to you.
62
- Use this Context: {context}. The users are building great Machine learning models using this Database,
63
- so assist them with code, definitions, summarization and so forth like a tutor.
64
- Give scenario based answer that can clearly explain it to users and explain step by step.
65
- Based on this, now answer for this user's question: {question}."""
66
-
67
 
68
  def set_envs():
69
  if "GOOGLE_API_KEY" not in os.environ:
 
57
  model_kwargs={"device": "cpu"},
58
  )
59
 
60
+ SYSTEM_PROMPT = """You are an assistant for question-answering tasks. \
61
+ Use the following pieces of retrieved context to answer the question. \
62
+ Try to keep the answer concise, unless aksed by the user to be eloborated.\
63
+ Analyse the question, and provide necessary python code help if necessary, as you will be mainly used for ML research.\
64
+ If you don't know the answer, just say that you don't know. \
65
+ Context: {context}"""
 
66
 
67
  def set_envs():
68
  if "GOOGLE_API_KEY" not in os.environ:
llm_setup/llm_setup.py CHANGED
@@ -1,13 +1,16 @@
1
  from langchain_core.output_parsers import StrOutputParser
2
  from langchain_core.prompts import (
3
  ChatPromptTemplate,
4
- PromptTemplate,
5
- HumanMessagePromptTemplate
6
  )
 
 
7
  from langchain_core.runnables import RunnablePassthrough
8
  from langchain_core.vectorstores import VectorStoreRetriever
9
  from langchain_google_genai import ChatGoogleGenerativeAI
10
-
 
 
11
  from processing.documents import format_documents
12
 
13
 
@@ -15,20 +18,12 @@ def _initialize_llm() -> ChatGoogleGenerativeAI:
15
  """
16
  Initializes the LLM instance.
17
  """
18
- llm = ChatGoogleGenerativeAI(model="gemini-pro")
 
19
  return llm
20
 
21
 
22
  class LLMService:
23
- """
24
- Service for managing LLM interactions and conversational RAG chain.
25
-
26
- Args:
27
- logger: Logger instance for logging.
28
- system_prompt: The prompt for the QA system.
29
- web_retriever: A VectorStoreRetriever instance for retrieving web documents.
30
- """
31
-
32
  def __init__(self, logger, system_prompt: str, web_retriever: VectorStoreRetriever):
33
  self._conversational_rag_chain = None
34
  self._logger = logger
@@ -39,27 +34,55 @@ class LLMService:
39
 
40
  self._initialize_conversational_rag_chain()
41
 
 
 
 
42
  def _initialize_conversational_rag_chain(self):
43
  """
44
  Initializes the conversational RAG chain.
45
  """
46
- # Initialize RAG (Retrieval-Augmented Generation) chain
47
- prompt = ChatPromptTemplate(input_variables=['history','context', 'question'], messages=[HumanMessagePromptTemplate(
48
- prompt=PromptTemplate(input_variables=['history','context', 'question'], template="{history}\nContext: {context}\nQuestion: {question}"))])
49
-
50
- # Initialize conversational RAG chain
51
- self._conversational_rag_chain = (
52
- {"context": self._web_retriever | format_documents, "question": RunnablePassthrough(), "history":self.get_chat_history}
53
- | prompt
54
- | self.llm
55
- | StrOutputParser()
 
 
56
  )
57
 
58
- def get_chat_history(self):
59
- """
60
- Retrieves the last 3 chat messages formatted as a string.
61
- """
62
- return "\n".join(self.chat_history) if self.chat_history else "No prior conversation."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  def conversational_rag_chain(self):
65
  """
@@ -69,20 +92,6 @@ class LLMService:
69
  The conversational RAG chain instance.
70
  """
71
  return self._conversational_rag_chain
72
-
73
- def update_chat_history(self, user_input: str, llm_response: str):
74
- """
75
- Updates the chat history with the latest question and response.
76
- """
77
- self.chat_history.append(f"User: {user_input}\nAI: {llm_response}")
78
-
79
- def ask_question(self, question: str):
80
- """
81
- Processes a user question using the conversational RAG chain and updates history.
82
- """
83
- response = self._conversational_rag_chain.invoke(question)
84
- self.update_chat_history(question, response)
85
- return response
86
 
87
  def get_llm(self) -> ChatGoogleGenerativeAI:
88
  """
 
1
  from langchain_core.output_parsers import StrOutputParser
2
  from langchain_core.prompts import (
3
  ChatPromptTemplate,
4
+ MessagesPlaceholder,
 
5
  )
6
+ from langchain.chains import create_history_aware_retriever, create_retrieval_chain
7
+ from langchain.chains.combine_documents import create_stuff_documents_chain
8
  from langchain_core.runnables import RunnablePassthrough
9
  from langchain_core.vectorstores import VectorStoreRetriever
10
  from langchain_google_genai import ChatGoogleGenerativeAI
11
+ from langchain_core.chat_history import BaseChatMessageHistory
12
+ from langchain_community.chat_message_histories import ChatMessageHistory
13
+ from langchain_core.runnables.history import RunnableWithMessageHistory
14
  from processing.documents import format_documents
15
 
16
 
 
18
  """
19
  Initializes the LLM instance.
20
  """
21
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash-exp")
22
+
23
  return llm
24
 
25
 
26
  class LLMService:
 
 
 
 
 
 
 
 
 
27
  def __init__(self, logger, system_prompt: str, web_retriever: VectorStoreRetriever):
28
  self._conversational_rag_chain = None
29
  self._logger = logger
 
34
 
35
  self._initialize_conversational_rag_chain()
36
 
37
+ ### Statefully manage chat history ###
38
+ self.store = {}
39
+
40
  def _initialize_conversational_rag_chain(self):
41
  """
42
  Initializes the conversational RAG chain.
43
  """
44
+ ### Contextualize question ###
45
+ contextualize_q_system_prompt = """Given a chat history and the latest user question \
46
+ which might reference context in the chat history, formulate a standalone question \
47
+ which can be understood without the chat history. Do NOT answer the question, \
48
+ just reformulate it if needed and otherwise return it as is."""
49
+
50
+ contextualize_q_prompt = ChatPromptTemplate.from_messages(
51
+ [
52
+ ("system", contextualize_q_system_prompt),
53
+ MessagesPlaceholder("chat_history"),
54
+ ("human", "{input}"),
55
+ ]
56
  )
57
 
58
+
59
+
60
+ history_aware_retriever = create_history_aware_retriever(
61
+ self.llm, self._web_retriever, contextualize_q_prompt)
62
+
63
+ qa_prompt = ChatPromptTemplate.from_messages(
64
+ [
65
+ ("system", self.system_prompt),
66
+ MessagesPlaceholder("chat_history"),
67
+ ("human", "{input}"),
68
+ ]
69
+ )
70
+
71
+ question_answer_chain = create_stuff_documents_chain(self.llm, qa_prompt)
72
+ rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
73
+
74
+ self._conversational_rag_chain = RunnableWithMessageHistory(
75
+ rag_chain,
76
+ self._get_session_history,
77
+ input_messages_key="input",
78
+ history_messages_key="chat_history",
79
+ output_messages_key="answer",
80
+ )
81
+
82
+ def _get_session_history(self, session_id: str) -> BaseChatMessageHistory:
83
+ if session_id not in self.store:
84
+ self.store[session_id] = ChatMessageHistory()
85
+ return self.store[session_id]
86
 
87
  def conversational_rag_chain(self):
88
  """
 
92
  The conversational RAG chain instance.
93
  """
94
  return self._conversational_rag_chain
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  def get_llm(self) -> ChatGoogleGenerativeAI:
97
  """
requirements.txt CHANGED
@@ -1,154 +1,155 @@
 
1
  huggingface_hub
2
- aiofiles==23.2.1
3
- aiohappyeyeballs==2.4.0
4
- aiohttp==3.10.5
5
- aiosignal==1.3.1
6
- annotated-types==0.7.0
7
- anyio==4.4.0
8
- asgiref==3.8.1
9
- attrs==24.2.0
10
- backoff==2.2.1
11
- bcrypt==4.2.0
12
- build==1.2.2
13
- cachetools==5.5.0
14
- certifi==2024.8.30
15
- charset-normalizer==3.3.2
16
- chroma-hnswlib==0.7.6
17
- chromadb==0.5.7
18
- click==8.1.7
19
- colorama==0.4.6
20
- coloredlogs==15.0.1
21
- contourpy==1.3.0
22
- cycler==0.12.1
23
- dataclasses-json==0.6.7
24
- Deprecated==1.2.14
25
- fastapi==0.115.0
26
- ffmpy==0.4.0
27
- filelock==3.16.1
28
- flatbuffers==24.3.25
29
- fonttools==4.53.1
30
- frozenlist==1.4.1
31
- fsspec==2024.9.0
32
- google-ai-generativelanguage==0.6.6
33
- google-api-core==2.19.2
34
- google-api-python-client==2.146.0
35
- google-auth==2.34.0
36
- google-auth-httplib2==0.2.0
37
- google-generativeai==0.7.2
38
- googleapis-common-protos==1.65.0
39
- gradio==4.44.0
40
- gradio_client==1.3.0
41
- greenlet==3.1.0
42
- grpcio==1.66.1
43
- grpcio-status==1.62.3
44
- h11==0.14.0
45
- httpcore==1.0.5
46
- httplib2==0.22.0
47
- httptools==0.6.1
48
- httpx==0.27.2
49
- huggingface-hub==0.25.0
50
- humanfriendly==10.0
51
- idna==3.10
52
- importlib_metadata==8.4.0
53
- importlib_resources==6.4.5
54
- Jinja2==3.1.4
55
- joblib==1.4.2
56
- jsonpatch==1.33
57
- jsonpointer==3.0.0
58
- kiwisolver==1.4.7
59
- kubernetes==30.1.0
60
- langchain==0.3.0
61
- langchain-chroma==0.1.4
62
- langchain-community==0.3.0
63
- langchain-core==0.3.1
64
- langchain-google-genai==2.0.0
65
- langchain-huggingface==0.1.0
66
- langchain-text-splitters==0.3.0
67
- langsmith==0.1.121
68
- markdown-it-py==3.0.0
69
- MarkupSafe==2.1.5
70
- marshmallow==3.22.0
71
- matplotlib==3.9.2
72
- mdurl==0.1.2
73
- mmh3==5.0.0
74
- monotonic==1.6
75
- mpmath==1.3.0
76
- multidict==6.1.0
77
- mypy-extensions==1.0.0
78
- networkx==3.3
79
- numpy==1.26.4
80
- oauthlib==3.2.2
81
- onnxruntime==1.19.2
82
- opentelemetry-api==1.27.0
83
- opentelemetry-exporter-otlp-proto-common==1.27.0
84
- opentelemetry-exporter-otlp-proto-grpc==1.27.0
85
- opentelemetry-instrumentation==0.48b0
86
- opentelemetry-instrumentation-asgi==0.48b0
87
- opentelemetry-instrumentation-fastapi==0.48b0
88
- opentelemetry-proto==1.27.0
89
- opentelemetry-sdk==1.27.0
90
- opentelemetry-semantic-conventions==0.48b0
91
- opentelemetry-util-http==0.48b0
92
- orjson==3.10.7
93
- overrides==7.7.0
94
- packaging==24.1
95
- pandas==2.2.2
96
- pillow==10.4.0
97
- posthog==3.6.6
98
- proto-plus==1.24.0
99
- protobuf==4.25.4
100
- pyasn1==0.6.1
101
- pyasn1_modules==0.4.1
102
- pydantic==2.9.2
103
- pydantic-settings==2.5.2
104
- pydantic_core==2.23.4
105
- pydub==0.25.1
106
- Pygments==2.18.0
107
- pyparsing==3.1.4
108
- PyPika==0.48.9
109
- pyproject_hooks==1.1.0
110
- pyreadline3==3.5.4
111
- python-dateutil==2.9.0.post0
112
- python-dotenv==1.0.1
113
- python-multipart==0.0.9
114
- pytz==2024.2
115
- PyYAML==6.0.2
116
- regex==2024.9.11
117
- requests==2.32.3
118
- requests-oauthlib==2.0.0
119
- rich==13.8.1
120
- rsa==4.9
121
- ruff==0.6.5
122
- safetensors==0.4.5
123
- scikit-learn==1.5.2
124
- scipy==1.14.1
125
- semantic-version==2.10.0
126
- sentence-transformers==3.1.0
127
- setuptools==75.1.0
128
- shellingham==1.5.4
129
- six==1.16.0
130
- sniffio==1.3.1
131
- SQLAlchemy==2.0.35
132
- starlette==0.38.5
133
- sympy==1.13.2
134
- tenacity==8.5.0
135
- threadpoolctl==3.5.0
136
- tokenizers==0.19.1
137
- tomlkit==0.12.0
138
- torch==2.4.1
139
- tqdm==4.66.5
140
- transformers==4.44.2
141
- typer==0.12.5
142
- typing-inspect==0.9.0
143
- typing_extensions==4.12.2
144
- tzdata==2024.1
145
- uritemplate==4.1.1
146
- urllib3==2.2.3
147
- uvicorn==0.30.6
148
- watchfiles==0.24.0
149
- websocket-client==1.8.0
150
- websockets==12.0
151
- wrapt==1.16.0
152
- yarl==1.11.1
153
- zipp==3.20.2
154
  bs4
 
1
+
2
  huggingface_hub
3
+ aiofiles
4
+ aiohappyeyeballs
5
+ aiohttp
6
+ aiosignal
7
+ annotated-types
8
+ anyio
9
+ asgiref
10
+ attrs
11
+ backoff
12
+ bcrypt
13
+ build
14
+ cachetools
15
+ certifi
16
+ charset-normalizer
17
+ chroma-hnswlib
18
+ chromadb
19
+ click
20
+ colorama
21
+ coloredlogs
22
+ contourpy
23
+ cycler
24
+ dataclasses-json
25
+ Deprecated
26
+ fastapi
27
+ ffmpy
28
+ filelock
29
+ flatbuffers
30
+ fonttools
31
+ frozenlist
32
+ fsspec
33
+ google-ai-generativelanguage
34
+ google-api-core
35
+ google-api-python-client
36
+ google-auth
37
+ google-auth-httplib2
38
+ google-generativeai
39
+ googleapis-common-protos
40
+ gradio
41
+ gradio_client
42
+ greenlet
43
+ grpcio
44
+ grpcio-status
45
+ h11
46
+ httpcore
47
+ httplib2
48
+ httptools
49
+ httpx
50
+ huggingface-hub
51
+ humanfriendly
52
+ idna
53
+ importlib_metadata
54
+ importlib_resources
55
+ Jinja2
56
+ joblib
57
+ jsonpatch
58
+ jsonpointer
59
+ kiwisolver
60
+ kubernetes
61
+ langchain
62
+ langchain-chroma
63
+ langchain-community
64
+ langchain-core
65
+ langchain-google-genai
66
+ langchain-huggingface
67
+ langchain-text-splitters
68
+ langsmith
69
+ markdown-it-py
70
+ MarkupSafe
71
+ marshmallow
72
+ matplotlib
73
+ mdurl
74
+ mmh3
75
+ monotonic
76
+ mpmath
77
+ multidict
78
+ mypy-extensions
79
+ networkx
80
+ numpy
81
+ oauthlib
82
+ onnxruntime
83
+ opentelemetry-api
84
+ opentelemetry-exporter-otlp-proto-common
85
+ opentelemetry-exporter-otlp-proto-grpc
86
+ opentelemetry-instrumentation
87
+ opentelemetry-instrumentation-asgi
88
+ opentelemetry-instrumentation-fastapi
89
+ opentelemetry-proto
90
+ opentelemetry-sdk
91
+ opentelemetry-semantic-conventions
92
+ opentelemetry-util-http
93
+ orjson
94
+ overrides
95
+ packaging
96
+ pandas
97
+ pillow
98
+ posthog
99
+ proto-plus
100
+ protobuf
101
+ pyasn1
102
+ pyasn1_modules
103
+ pydantic
104
+ pydantic-settings
105
+ pydantic_core
106
+ pydub
107
+ Pygments
108
+ pyparsing
109
+ PyPika
110
+ pyproject_hooks
111
+ pyreadline3
112
+ python-dateutil
113
+ python-dotenv
114
+ python-multipart
115
+ pytz
116
+ PyYAML
117
+ regex
118
+ requests
119
+ requests-oauthlib
120
+ rich
121
+ rsa
122
+ ruff
123
+ safetensors
124
+ scikit-learn
125
+ scipy
126
+ semantic-version
127
+ sentence-transformers
128
+ setuptools
129
+ shellingham
130
+ six
131
+ sniffio
132
+ SQLAlchemy
133
+ starlette
134
+ sympy
135
+ tenacity
136
+ threadpoolctl
137
+ tokenizers
138
+ tomlkit
139
+ torch
140
+ tqdm
141
+ transformers
142
+ typer
143
+ typing-inspect
144
+ typing_extensions
145
+ tzdata
146
+ uritemplate
147
+ urllib3
148
+ uvicorn
149
+ watchfiles
150
+ websocket-client
151
+ websockets
152
+ wrapt
153
+ yarl
154
+ zipp
155
  bs4