vaishnaveswar commited on
Commit
e1cda2e
·
1 Parent(s): ef67833
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ venv
2
+ configs/.env
3
+ .idea
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Gradio Chatbot
3
  emoji: 💬
4
  colorFrom: yellow
5
  colorTo: purple
@@ -7,6 +7,7 @@ sdk: gradio
7
  sdk_version: 4.36.1
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
 
1
  ---
2
+ title: AIVI Bot
3
  emoji: 💬
4
  colorFrom: yellow
5
  colorTo: purple
 
7
  sdk_version: 4.36.1
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
  An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
app.py CHANGED
@@ -1,63 +1,31 @@
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
- """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
- """
45
- demo = gr.ChatInterface(
46
- respond,
47
- additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
57
- ),
58
- ],
59
- )
60
-
61
-
62
- if __name__ == "__main__":
63
- demo.launch()
 
1
+ import logging
2
  import gradio as gr
3
+ import configs.config as config
4
+ import services.scraper
5
+ import stores.chroma
6
+ from llm_setup.llm_setup import LLMService
7
+
8
+ logger = logging.getLogger() # Create a logger object
9
+ logger.setLevel(logging.INFO) # Set the logging level to INFO
10
+
11
+ config.set_envs() # Set environment variables using the config module
12
+
13
+ store = stores.chroma.ChromaDB(config.EMBEDDINGS)
14
+ service = services.scraper.Service(store)
15
+
16
+ # Scrape data and get the store vector retriever
17
+ service.scrape_and_get_store_vector_retriever(config.URLS)
18
+
19
+ # Initialize the LLMService with logger, prompt, and store vector retriever
20
+ llm_svc = LLMService(logger, config.SYSTEM_PROMPT, store.get_chroma_instance().as_retriever())
21
+
22
+
23
+ def respond(user_input, history):
24
+ response = llm_svc.conversational_rag_chain().invoke(user_input)
25
+
26
+ return response
27
+
28
+
29
+ if __name__ == '__main__':
30
+ logging.info("Starting AIVIz Bot")
31
+ gr.ChatInterface(respond).launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/.env ADDED
@@ -0,0 +1 @@
 
 
1
+ GOOGLE_API_KEY=""
configs/config.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import getpass as getpass
2
+ import os
3
+
4
+ from dotenv import load_dotenv
5
+ from langchain_huggingface import HuggingFaceEmbeddings
6
+
7
+ load_dotenv()
8
+
9
+ URLS = ["https://aisviz.gitbook.io/documentation", "https://aisviz.gitbook.io/documentation/default-start/quick-start",
10
+ "https://aisviz.gitbook.io/documentation/default-start/sql-database",
11
+ "https://aisviz.gitbook.io/documentation/default-start/ais-hardware",
12
+ "https://aisviz.gitbook.io/documentation/default-start/compile-aisdb",
13
+ "https://aisviz.gitbook.io/documentation/tutorials/database-loading",
14
+ "https://aisviz.gitbook.io/documentation/tutorials/data-querying",
15
+ "https://aisviz.gitbook.io/documentation/tutorials/data-cleaning",
16
+ "https://aisviz.gitbook.io/documentation/tutorials/data-visualization",
17
+ "https://aisviz.gitbook.io/documentation/tutorials/track-interpolation",
18
+ "https://aisviz.gitbook.io/documentation/tutorials/haversine-distance",
19
+ "https://aisviz.gitbook.io/documentation/tutorials/vessel-speed",
20
+ "https://aisviz.gitbook.io/documentation/tutorials/coast-shore-and-ports",
21
+ "https://aisviz.gitbook.io/documentation/tutorials/vessel-metadata",
22
+ "https://aisviz.gitbook.io/documentation/tutorials/using-your-ais-data",
23
+ "https://aisviz.gitbook.io/documentation/tutorials/ais-data-to-csv",
24
+ "https://aisviz.gitbook.io/documentation/tutorials/bathymetric-data",
25
+ "https://aisviz.gitbook.io/documentation/machine-learning/seq2seq-in-pytorch",
26
+ "https://aisviz.gitbook.io/documentation/machine-learning/autoencoders-in-keras"]
27
+ CHUNK_SIZE = 2400
28
+ CHUNK_OVERLAP = 200
29
+ TOTAL_RESULTS = 2389
30
+ MAX_SIZE = 100
31
+ EMBEDDINGS = HuggingFaceEmbeddings(
32
+ model_name="sentence-transformers/all-mpnet-base-v2",
33
+ model_kwargs={"device": "cpu"},
34
+ )
35
+
36
+ SYSTEM_PROMPT = """
37
+ You are a chatbot to assist users asking about Automatic Identification systems (AIS) from the context given to you.
38
+ Use this Context: {context}. If the question is beyond the context, just tell you don't know the answer.
39
+ Give scenario based answer that can clearly explain it to humans.
40
+ Now, answer for this user's question in a descriptive manner: {question}."""
41
+
42
+
43
+ def set_envs():
44
+ if "GOOGLE_API_KEY" not in os.environ:
45
+ os.environ["GOOGLE_API_KEY"] = getpass.getpass(os.getenv("GOOGLE_API_KEY"))
llm_setup/llm_setup.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.output_parsers import StrOutputParser
2
+ from langchain_core.prompts import (
3
+ ChatPromptTemplate,
4
+ PromptTemplate,
5
+ HumanMessagePromptTemplate
6
+ )
7
+ from langchain_core.runnables import RunnablePassthrough
8
+ from langchain_core.vectorstores import VectorStoreRetriever
9
+ from langchain_google_genai import ChatGoogleGenerativeAI
10
+
11
+ from processing.documents import format_documents
12
+
13
+
14
+ def _initialize_llm() -> ChatGoogleGenerativeAI:
15
+ """
16
+ Initializes the LLM instance.
17
+ """
18
+ llm = ChatGoogleGenerativeAI(model="gemini-pro")
19
+ return llm
20
+
21
+
22
+ class LLMService:
23
+ """
24
+ Service for managing LLM interactions and conversational RAG chain.
25
+
26
+ Args:
27
+ logger: Logger instance for logging.
28
+ system_prompt: The prompt for the QA system.
29
+ web_retriever: A VectorStoreRetriever instance for retrieving web documents.
30
+ """
31
+
32
+ def __init__(self, logger, system_prompt: str, web_retriever: VectorStoreRetriever):
33
+ self._conversational_rag_chain = None
34
+ self._logger = logger
35
+ self.system_prompt = system_prompt
36
+ self._web_retriever = web_retriever
37
+
38
+ self.llm = _initialize_llm()
39
+
40
+ self._initialize_conversational_rag_chain()
41
+
42
+ def _initialize_conversational_rag_chain(self):
43
+ """
44
+ Initializes the conversational RAG chain.
45
+ """
46
+ # Initialize RAG (Retrieval-Augmented Generation) chain
47
+ prompt = ChatPromptTemplate(input_variables=['context', 'question'], messages=[HumanMessagePromptTemplate(
48
+ prompt=PromptTemplate(input_variables=['context', 'question'], template=self.system_prompt))])
49
+
50
+ # Initialize conversational RAG chain
51
+ self._conversational_rag_chain = (
52
+ {"context": self._web_retriever | format_documents, "question": RunnablePassthrough()}
53
+ | prompt
54
+ | self.llm
55
+ | StrOutputParser()
56
+ )
57
+
58
+ def conversational_rag_chain(self):
59
+ """
60
+ Returns the initialized conversational RAG chain.
61
+
62
+ Returns:
63
+ The conversational RAG chain instance.
64
+ """
65
+ return self._conversational_rag_chain
66
+
67
+ def get_llm(self) -> ChatGoogleGenerativeAI:
68
+ """
69
+ Returns the LLM instance.
70
+ """
71
+
72
+ if self.llm is None:
73
+ raise Exception("llm is not initialized")
74
+
75
+ return self.llm
processing/documents.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.document_loaders import WebBaseLoader
2
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
3
+ from langchain_core.documents import Document
4
+ from typing import Iterable
5
+
6
+
7
+ def load_documents(website: str) -> list[Document]:
8
+ """
9
+ Loads documents from a given website.
10
+
11
+ Args:
12
+ website (str): The URL of the website to load documents from.
13
+
14
+ Returns:
15
+ list[Document]: A list of loaded documents.
16
+ """
17
+ loader = WebBaseLoader(website)
18
+ return loader.load()
19
+
20
+
21
+ def format_documents(docs: list[Document]) -> str:
22
+ """
23
+ Formats a list of documents into a single string.
24
+
25
+ Args:
26
+ docs (list[Document]): The list of documents to format.
27
+
28
+ Returns:
29
+ str: The formatted documents as a single string.
30
+ """
31
+ return "\n\n".join(doc.page_content for doc in docs)
32
+
33
+
34
+ def split_documents(documents: Iterable[Document]) -> list[Document]:
35
+ """
36
+ Splits documents into smaller chunks.
37
+
38
+ Args:
39
+ documents (Iterable[Document]): The documents to split.
40
+
41
+ Returns:
42
+ list[Document]: A list of split documents.
43
+ """
44
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
45
+ return text_splitter.split_documents(documents)
processing/texts.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ def clean_text(text: str) -> str:
2
+ """
3
+ Clean the text by removing unwanted characters and formatting.
4
+ """
5
+ cleaned_text = text.replace("\n", " ").strip()
6
+ return cleaned_text
requirements.txt CHANGED
@@ -1 +1,153 @@
1
- huggingface_hub==0.22.2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ huggingface_hub
2
+ aiofiles
3
+ aiohappyeyeballs
4
+ aiohttp
5
+ aiosignal
6
+ annotated-types
7
+ anyio
8
+ asgiref
9
+ attrs
10
+ backoff
11
+ bcrypt
12
+ build
13
+ cachetools
14
+ certifi
15
+ charset-normalizer
16
+ chroma-hnswlib
17
+ chromadb
18
+ click
19
+ colorama
20
+ coloredlogs
21
+ contourpy
22
+ cycler
23
+ dataclasses-json
24
+ Deprecated
25
+ fastapi
26
+ ffmpy
27
+ filelock
28
+ flatbuffers
29
+ fonttools
30
+ frozenlist
31
+ fsspec
32
+ google-ai-generativelanguage
33
+ google-api-core
34
+ google-api-python-client
35
+ google-auth
36
+ google-auth-httplib2
37
+ google-generativeai
38
+ googleapis-common-protos
39
+ gradio
40
+ gradio_client
41
+ greenlet
42
+ grpcio
43
+ grpcio-status
44
+ h11
45
+ httpcore
46
+ httplib2
47
+ httptools
48
+ httpx
49
+ huggingface-hub
50
+ humanfriendly
51
+ idna
52
+ importlib_metadata
53
+ importlib_resources
54
+ Jinja2
55
+ joblib
56
+ jsonpatch
57
+ jsonpointer
58
+ kiwisolver
59
+ kubernetes
60
+ langchain
61
+ langchain-chroma
62
+ langchain-community
63
+ langchain-core
64
+ langchain-google-genai
65
+ langchain-huggingface
66
+ langchain-text-splitters
67
+ langsmith
68
+ markdown-it-py
69
+ MarkupSafe
70
+ marshmallow
71
+ matplotlib
72
+ mdurl
73
+ mmh3
74
+ monotonic
75
+ mpmath
76
+ multidict
77
+ mypy-extensions
78
+ networkx
79
+ numpy
80
+ oauthlib
81
+ onnxruntime
82
+ opentelemetry-api
83
+ opentelemetry-exporter-otlp-proto-common
84
+ opentelemetry-exporter-otlp-proto-grpc
85
+ opentelemetry-instrumentation
86
+ opentelemetry-instrumentation-asgi
87
+ opentelemetry-instrumentation-fastapi
88
+ opentelemetry-proto
89
+ opentelemetry-sdk
90
+ opentelemetry-semantic-conventions
91
+ opentelemetry-util-http
92
+ orjson
93
+ overrides
94
+ packaging
95
+ pandas
96
+ pillow
97
+ posthog
98
+ proto-plus
99
+ protobuf
100
+ pyasn1
101
+ pyasn1_modules
102
+ pydantic
103
+ pydantic-settings
104
+ pydantic_core
105
+ pydub
106
+ Pygments
107
+ pyparsing
108
+ PyPika
109
+ pyproject_hooks
110
+ pyreadline3
111
+ python-dateutil
112
+ python-dotenv
113
+ python-multipart
114
+ pytz
115
+ PyYAML
116
+ regex
117
+ requests
118
+ requests-oauthlib
119
+ rich
120
+ rsa
121
+ ruff
122
+ safetensors
123
+ scikit-learn
124
+ scipy
125
+ semantic-version
126
+ sentence-transformers
127
+ setuptools
128
+ shellingham
129
+ six
130
+ sniffio
131
+ SQLAlchemy
132
+ starlette
133
+ sympy
134
+ tenacity
135
+ threadpoolctl
136
+ tokenizers
137
+ tomlkit
138
+ torch
139
+ tqdm
140
+ transformers
141
+ typer
142
+ typing-inspect
143
+ typing_extensions
144
+ tzdata
145
+ uritemplate
146
+ urllib3
147
+ uvicorn
148
+ watchfiles
149
+ websocket-client
150
+ websockets
151
+ wrapt
152
+ yarl
153
+ zipp
services/scraper.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.schema import Document
2
+
3
+ from processing.documents import load_documents, format_documents, split_documents
4
+ from processing.texts import clean_text
5
+
6
+
7
+ class Service:
8
+ def __init__(self, store):
9
+ self.store = store
10
+
11
+ def scrape_and_get_store_vector_retriever(self, urls: list[str]):
12
+ """
13
+ Scrapes website content from fetched schemes and creates a VectorStore retriever.
14
+ """
15
+ documents: list[Document] = []
16
+
17
+ for url in urls:
18
+ try:
19
+ website_documents = load_documents(url)
20
+ formatted_content = format_documents(website_documents)
21
+ cleaned_content = clean_text(formatted_content)
22
+ documents.append(Document(page_content=cleaned_content))
23
+ except Exception as e:
24
+ raise Exception(f"Error processing {url}: {e}")
25
+
26
+ self.store.store_embeddings(split_documents(documents))
stores/chroma.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.schema import Document
2
+ from langchain_chroma import Chroma
3
+
4
+
5
+ class ChromaDB:
6
+ def __init__(self, embeddings):
7
+ self._persistent_directory = "embeddings"
8
+ model_name = "sentence-transformers/all-mpnet-base-v2"
9
+ model_kwargs = {'device': 'cpu'}
10
+ encode_kwargs = {'normalize_embeddings': False}
11
+ self.embeddings = embeddings
12
+
13
+ self.chroma = Chroma(persist_directory=self._persistent_directory, embedding_function=self.embeddings)
14
+
15
+ def get_chroma_instance(self) -> Chroma:
16
+ return self.chroma
17
+
18
+ def store_embeddings(self, documents: list[Document]):
19
+ """
20
+ Store embeddings for the documents using HuggingFace embeddings and Chroma vectorstore.
21
+ """
22
+ self.chroma.add_documents(documents=documents, embeddings=self.embeddings,
23
+ persist_directory=self._persistent_directory)