Spaces:
Running
Running
Commit
·
e1cda2e
1
Parent(s):
ef67833
revert
Browse files- .gitignore +3 -0
- README.md +2 -1
- app.py +30 -62
- configs/.env +1 -0
- configs/config.py +45 -0
- llm_setup/llm_setup.py +75 -0
- processing/documents.py +45 -0
- processing/texts.py +6 -0
- requirements.txt +153 -1
- services/scraper.py +26 -0
- stores/chroma.py +23 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
venv
|
2 |
+
configs/.env
|
3 |
+
.idea
|
README.md
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
emoji: 💬
|
4 |
colorFrom: yellow
|
5 |
colorTo: purple
|
@@ -7,6 +7,7 @@ sdk: gradio
|
|
7 |
sdk_version: 4.36.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
10 |
---
|
11 |
|
12 |
An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
|
|
|
1 |
---
|
2 |
+
title: AIVI Bot
|
3 |
emoji: 💬
|
4 |
colorFrom: yellow
|
5 |
colorTo: purple
|
|
|
7 |
sdk_version: 4.36.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
license: mit
|
11 |
---
|
12 |
|
13 |
An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
|
app.py
CHANGED
@@ -1,63 +1,31 @@
|
|
|
|
1 |
import gradio as gr
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
messages,
|
32 |
-
max_tokens=max_tokens,
|
33 |
-
stream=True,
|
34 |
-
temperature=temperature,
|
35 |
-
top_p=top_p,
|
36 |
-
):
|
37 |
-
token = message.choices[0].delta.content
|
38 |
-
|
39 |
-
response += token
|
40 |
-
yield response
|
41 |
-
|
42 |
-
"""
|
43 |
-
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
|
44 |
-
"""
|
45 |
-
demo = gr.ChatInterface(
|
46 |
-
respond,
|
47 |
-
additional_inputs=[
|
48 |
-
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
|
49 |
-
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
|
50 |
-
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
|
51 |
-
gr.Slider(
|
52 |
-
minimum=0.1,
|
53 |
-
maximum=1.0,
|
54 |
-
value=0.95,
|
55 |
-
step=0.05,
|
56 |
-
label="Top-p (nucleus sampling)",
|
57 |
-
),
|
58 |
-
],
|
59 |
-
)
|
60 |
-
|
61 |
-
|
62 |
-
if __name__ == "__main__":
|
63 |
-
demo.launch()
|
|
|
1 |
+
import logging
|
2 |
import gradio as gr
|
3 |
+
import configs.config as config
|
4 |
+
import services.scraper
|
5 |
+
import stores.chroma
|
6 |
+
from llm_setup.llm_setup import LLMService
|
7 |
+
|
8 |
+
logger = logging.getLogger() # Create a logger object
|
9 |
+
logger.setLevel(logging.INFO) # Set the logging level to INFO
|
10 |
+
|
11 |
+
config.set_envs() # Set environment variables using the config module
|
12 |
+
|
13 |
+
store = stores.chroma.ChromaDB(config.EMBEDDINGS)
|
14 |
+
service = services.scraper.Service(store)
|
15 |
+
|
16 |
+
# Scrape data and get the store vector retriever
|
17 |
+
service.scrape_and_get_store_vector_retriever(config.URLS)
|
18 |
+
|
19 |
+
# Initialize the LLMService with logger, prompt, and store vector retriever
|
20 |
+
llm_svc = LLMService(logger, config.SYSTEM_PROMPT, store.get_chroma_instance().as_retriever())
|
21 |
+
|
22 |
+
|
23 |
+
def respond(user_input, history):
|
24 |
+
response = llm_svc.conversational_rag_chain().invoke(user_input)
|
25 |
+
|
26 |
+
return response
|
27 |
+
|
28 |
+
|
29 |
+
if __name__ == '__main__':
|
30 |
+
logging.info("Starting AIVIz Bot")
|
31 |
+
gr.ChatInterface(respond).launch(share=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/.env
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
GOOGLE_API_KEY=""
|
configs/config.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import getpass as getpass
|
2 |
+
import os
|
3 |
+
|
4 |
+
from dotenv import load_dotenv
|
5 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
6 |
+
|
7 |
+
load_dotenv()
|
8 |
+
|
9 |
+
URLS = ["https://aisviz.gitbook.io/documentation", "https://aisviz.gitbook.io/documentation/default-start/quick-start",
|
10 |
+
"https://aisviz.gitbook.io/documentation/default-start/sql-database",
|
11 |
+
"https://aisviz.gitbook.io/documentation/default-start/ais-hardware",
|
12 |
+
"https://aisviz.gitbook.io/documentation/default-start/compile-aisdb",
|
13 |
+
"https://aisviz.gitbook.io/documentation/tutorials/database-loading",
|
14 |
+
"https://aisviz.gitbook.io/documentation/tutorials/data-querying",
|
15 |
+
"https://aisviz.gitbook.io/documentation/tutorials/data-cleaning",
|
16 |
+
"https://aisviz.gitbook.io/documentation/tutorials/data-visualization",
|
17 |
+
"https://aisviz.gitbook.io/documentation/tutorials/track-interpolation",
|
18 |
+
"https://aisviz.gitbook.io/documentation/tutorials/haversine-distance",
|
19 |
+
"https://aisviz.gitbook.io/documentation/tutorials/vessel-speed",
|
20 |
+
"https://aisviz.gitbook.io/documentation/tutorials/coast-shore-and-ports",
|
21 |
+
"https://aisviz.gitbook.io/documentation/tutorials/vessel-metadata",
|
22 |
+
"https://aisviz.gitbook.io/documentation/tutorials/using-your-ais-data",
|
23 |
+
"https://aisviz.gitbook.io/documentation/tutorials/ais-data-to-csv",
|
24 |
+
"https://aisviz.gitbook.io/documentation/tutorials/bathymetric-data",
|
25 |
+
"https://aisviz.gitbook.io/documentation/machine-learning/seq2seq-in-pytorch",
|
26 |
+
"https://aisviz.gitbook.io/documentation/machine-learning/autoencoders-in-keras"]
|
27 |
+
CHUNK_SIZE = 2400
|
28 |
+
CHUNK_OVERLAP = 200
|
29 |
+
TOTAL_RESULTS = 2389
|
30 |
+
MAX_SIZE = 100
|
31 |
+
EMBEDDINGS = HuggingFaceEmbeddings(
|
32 |
+
model_name="sentence-transformers/all-mpnet-base-v2",
|
33 |
+
model_kwargs={"device": "cpu"},
|
34 |
+
)
|
35 |
+
|
36 |
+
SYSTEM_PROMPT = """
|
37 |
+
You are a chatbot to assist users asking about Automatic Identification systems (AIS) from the context given to you.
|
38 |
+
Use this Context: {context}. If the question is beyond the context, just tell you don't know the answer.
|
39 |
+
Give scenario based answer that can clearly explain it to humans.
|
40 |
+
Now, answer for this user's question in a descriptive manner: {question}."""
|
41 |
+
|
42 |
+
|
43 |
+
def set_envs():
|
44 |
+
if "GOOGLE_API_KEY" not in os.environ:
|
45 |
+
os.environ["GOOGLE_API_KEY"] = getpass.getpass(os.getenv("GOOGLE_API_KEY"))
|
llm_setup/llm_setup.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.output_parsers import StrOutputParser
|
2 |
+
from langchain_core.prompts import (
|
3 |
+
ChatPromptTemplate,
|
4 |
+
PromptTemplate,
|
5 |
+
HumanMessagePromptTemplate
|
6 |
+
)
|
7 |
+
from langchain_core.runnables import RunnablePassthrough
|
8 |
+
from langchain_core.vectorstores import VectorStoreRetriever
|
9 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
10 |
+
|
11 |
+
from processing.documents import format_documents
|
12 |
+
|
13 |
+
|
14 |
+
def _initialize_llm() -> ChatGoogleGenerativeAI:
|
15 |
+
"""
|
16 |
+
Initializes the LLM instance.
|
17 |
+
"""
|
18 |
+
llm = ChatGoogleGenerativeAI(model="gemini-pro")
|
19 |
+
return llm
|
20 |
+
|
21 |
+
|
22 |
+
class LLMService:
|
23 |
+
"""
|
24 |
+
Service for managing LLM interactions and conversational RAG chain.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
logger: Logger instance for logging.
|
28 |
+
system_prompt: The prompt for the QA system.
|
29 |
+
web_retriever: A VectorStoreRetriever instance for retrieving web documents.
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(self, logger, system_prompt: str, web_retriever: VectorStoreRetriever):
|
33 |
+
self._conversational_rag_chain = None
|
34 |
+
self._logger = logger
|
35 |
+
self.system_prompt = system_prompt
|
36 |
+
self._web_retriever = web_retriever
|
37 |
+
|
38 |
+
self.llm = _initialize_llm()
|
39 |
+
|
40 |
+
self._initialize_conversational_rag_chain()
|
41 |
+
|
42 |
+
def _initialize_conversational_rag_chain(self):
|
43 |
+
"""
|
44 |
+
Initializes the conversational RAG chain.
|
45 |
+
"""
|
46 |
+
# Initialize RAG (Retrieval-Augmented Generation) chain
|
47 |
+
prompt = ChatPromptTemplate(input_variables=['context', 'question'], messages=[HumanMessagePromptTemplate(
|
48 |
+
prompt=PromptTemplate(input_variables=['context', 'question'], template=self.system_prompt))])
|
49 |
+
|
50 |
+
# Initialize conversational RAG chain
|
51 |
+
self._conversational_rag_chain = (
|
52 |
+
{"context": self._web_retriever | format_documents, "question": RunnablePassthrough()}
|
53 |
+
| prompt
|
54 |
+
| self.llm
|
55 |
+
| StrOutputParser()
|
56 |
+
)
|
57 |
+
|
58 |
+
def conversational_rag_chain(self):
|
59 |
+
"""
|
60 |
+
Returns the initialized conversational RAG chain.
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
The conversational RAG chain instance.
|
64 |
+
"""
|
65 |
+
return self._conversational_rag_chain
|
66 |
+
|
67 |
+
def get_llm(self) -> ChatGoogleGenerativeAI:
|
68 |
+
"""
|
69 |
+
Returns the LLM instance.
|
70 |
+
"""
|
71 |
+
|
72 |
+
if self.llm is None:
|
73 |
+
raise Exception("llm is not initialized")
|
74 |
+
|
75 |
+
return self.llm
|
processing/documents.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_community.document_loaders import WebBaseLoader
|
2 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
3 |
+
from langchain_core.documents import Document
|
4 |
+
from typing import Iterable
|
5 |
+
|
6 |
+
|
7 |
+
def load_documents(website: str) -> list[Document]:
|
8 |
+
"""
|
9 |
+
Loads documents from a given website.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
website (str): The URL of the website to load documents from.
|
13 |
+
|
14 |
+
Returns:
|
15 |
+
list[Document]: A list of loaded documents.
|
16 |
+
"""
|
17 |
+
loader = WebBaseLoader(website)
|
18 |
+
return loader.load()
|
19 |
+
|
20 |
+
|
21 |
+
def format_documents(docs: list[Document]) -> str:
|
22 |
+
"""
|
23 |
+
Formats a list of documents into a single string.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
docs (list[Document]): The list of documents to format.
|
27 |
+
|
28 |
+
Returns:
|
29 |
+
str: The formatted documents as a single string.
|
30 |
+
"""
|
31 |
+
return "\n\n".join(doc.page_content for doc in docs)
|
32 |
+
|
33 |
+
|
34 |
+
def split_documents(documents: Iterable[Document]) -> list[Document]:
|
35 |
+
"""
|
36 |
+
Splits documents into smaller chunks.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
documents (Iterable[Document]): The documents to split.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
list[Document]: A list of split documents.
|
43 |
+
"""
|
44 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
|
45 |
+
return text_splitter.split_documents(documents)
|
processing/texts.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def clean_text(text: str) -> str:
|
2 |
+
"""
|
3 |
+
Clean the text by removing unwanted characters and formatting.
|
4 |
+
"""
|
5 |
+
cleaned_text = text.replace("\n", " ").strip()
|
6 |
+
return cleaned_text
|
requirements.txt
CHANGED
@@ -1 +1,153 @@
|
|
1 |
-
huggingface_hub
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
huggingface_hub
|
2 |
+
aiofiles
|
3 |
+
aiohappyeyeballs
|
4 |
+
aiohttp
|
5 |
+
aiosignal
|
6 |
+
annotated-types
|
7 |
+
anyio
|
8 |
+
asgiref
|
9 |
+
attrs
|
10 |
+
backoff
|
11 |
+
bcrypt
|
12 |
+
build
|
13 |
+
cachetools
|
14 |
+
certifi
|
15 |
+
charset-normalizer
|
16 |
+
chroma-hnswlib
|
17 |
+
chromadb
|
18 |
+
click
|
19 |
+
colorama
|
20 |
+
coloredlogs
|
21 |
+
contourpy
|
22 |
+
cycler
|
23 |
+
dataclasses-json
|
24 |
+
Deprecated
|
25 |
+
fastapi
|
26 |
+
ffmpy
|
27 |
+
filelock
|
28 |
+
flatbuffers
|
29 |
+
fonttools
|
30 |
+
frozenlist
|
31 |
+
fsspec
|
32 |
+
google-ai-generativelanguage
|
33 |
+
google-api-core
|
34 |
+
google-api-python-client
|
35 |
+
google-auth
|
36 |
+
google-auth-httplib2
|
37 |
+
google-generativeai
|
38 |
+
googleapis-common-protos
|
39 |
+
gradio
|
40 |
+
gradio_client
|
41 |
+
greenlet
|
42 |
+
grpcio
|
43 |
+
grpcio-status
|
44 |
+
h11
|
45 |
+
httpcore
|
46 |
+
httplib2
|
47 |
+
httptools
|
48 |
+
httpx
|
49 |
+
huggingface-hub
|
50 |
+
humanfriendly
|
51 |
+
idna
|
52 |
+
importlib_metadata
|
53 |
+
importlib_resources
|
54 |
+
Jinja2
|
55 |
+
joblib
|
56 |
+
jsonpatch
|
57 |
+
jsonpointer
|
58 |
+
kiwisolver
|
59 |
+
kubernetes
|
60 |
+
langchain
|
61 |
+
langchain-chroma
|
62 |
+
langchain-community
|
63 |
+
langchain-core
|
64 |
+
langchain-google-genai
|
65 |
+
langchain-huggingface
|
66 |
+
langchain-text-splitters
|
67 |
+
langsmith
|
68 |
+
markdown-it-py
|
69 |
+
MarkupSafe
|
70 |
+
marshmallow
|
71 |
+
matplotlib
|
72 |
+
mdurl
|
73 |
+
mmh3
|
74 |
+
monotonic
|
75 |
+
mpmath
|
76 |
+
multidict
|
77 |
+
mypy-extensions
|
78 |
+
networkx
|
79 |
+
numpy
|
80 |
+
oauthlib
|
81 |
+
onnxruntime
|
82 |
+
opentelemetry-api
|
83 |
+
opentelemetry-exporter-otlp-proto-common
|
84 |
+
opentelemetry-exporter-otlp-proto-grpc
|
85 |
+
opentelemetry-instrumentation
|
86 |
+
opentelemetry-instrumentation-asgi
|
87 |
+
opentelemetry-instrumentation-fastapi
|
88 |
+
opentelemetry-proto
|
89 |
+
opentelemetry-sdk
|
90 |
+
opentelemetry-semantic-conventions
|
91 |
+
opentelemetry-util-http
|
92 |
+
orjson
|
93 |
+
overrides
|
94 |
+
packaging
|
95 |
+
pandas
|
96 |
+
pillow
|
97 |
+
posthog
|
98 |
+
proto-plus
|
99 |
+
protobuf
|
100 |
+
pyasn1
|
101 |
+
pyasn1_modules
|
102 |
+
pydantic
|
103 |
+
pydantic-settings
|
104 |
+
pydantic_core
|
105 |
+
pydub
|
106 |
+
Pygments
|
107 |
+
pyparsing
|
108 |
+
PyPika
|
109 |
+
pyproject_hooks
|
110 |
+
pyreadline3
|
111 |
+
python-dateutil
|
112 |
+
python-dotenv
|
113 |
+
python-multipart
|
114 |
+
pytz
|
115 |
+
PyYAML
|
116 |
+
regex
|
117 |
+
requests
|
118 |
+
requests-oauthlib
|
119 |
+
rich
|
120 |
+
rsa
|
121 |
+
ruff
|
122 |
+
safetensors
|
123 |
+
scikit-learn
|
124 |
+
scipy
|
125 |
+
semantic-version
|
126 |
+
sentence-transformers
|
127 |
+
setuptools
|
128 |
+
shellingham
|
129 |
+
six
|
130 |
+
sniffio
|
131 |
+
SQLAlchemy
|
132 |
+
starlette
|
133 |
+
sympy
|
134 |
+
tenacity
|
135 |
+
threadpoolctl
|
136 |
+
tokenizers
|
137 |
+
tomlkit
|
138 |
+
torch
|
139 |
+
tqdm
|
140 |
+
transformers
|
141 |
+
typer
|
142 |
+
typing-inspect
|
143 |
+
typing_extensions
|
144 |
+
tzdata
|
145 |
+
uritemplate
|
146 |
+
urllib3
|
147 |
+
uvicorn
|
148 |
+
watchfiles
|
149 |
+
websocket-client
|
150 |
+
websockets
|
151 |
+
wrapt
|
152 |
+
yarl
|
153 |
+
zipp
|
services/scraper.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.schema import Document
|
2 |
+
|
3 |
+
from processing.documents import load_documents, format_documents, split_documents
|
4 |
+
from processing.texts import clean_text
|
5 |
+
|
6 |
+
|
7 |
+
class Service:
|
8 |
+
def __init__(self, store):
|
9 |
+
self.store = store
|
10 |
+
|
11 |
+
def scrape_and_get_store_vector_retriever(self, urls: list[str]):
|
12 |
+
"""
|
13 |
+
Scrapes website content from fetched schemes and creates a VectorStore retriever.
|
14 |
+
"""
|
15 |
+
documents: list[Document] = []
|
16 |
+
|
17 |
+
for url in urls:
|
18 |
+
try:
|
19 |
+
website_documents = load_documents(url)
|
20 |
+
formatted_content = format_documents(website_documents)
|
21 |
+
cleaned_content = clean_text(formatted_content)
|
22 |
+
documents.append(Document(page_content=cleaned_content))
|
23 |
+
except Exception as e:
|
24 |
+
raise Exception(f"Error processing {url}: {e}")
|
25 |
+
|
26 |
+
self.store.store_embeddings(split_documents(documents))
|
stores/chroma.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.schema import Document
|
2 |
+
from langchain_chroma import Chroma
|
3 |
+
|
4 |
+
|
5 |
+
class ChromaDB:
|
6 |
+
def __init__(self, embeddings):
|
7 |
+
self._persistent_directory = "embeddings"
|
8 |
+
model_name = "sentence-transformers/all-mpnet-base-v2"
|
9 |
+
model_kwargs = {'device': 'cpu'}
|
10 |
+
encode_kwargs = {'normalize_embeddings': False}
|
11 |
+
self.embeddings = embeddings
|
12 |
+
|
13 |
+
self.chroma = Chroma(persist_directory=self._persistent_directory, embedding_function=self.embeddings)
|
14 |
+
|
15 |
+
def get_chroma_instance(self) -> Chroma:
|
16 |
+
return self.chroma
|
17 |
+
|
18 |
+
def store_embeddings(self, documents: list[Document]):
|
19 |
+
"""
|
20 |
+
Store embeddings for the documents using HuggingFace embeddings and Chroma vectorstore.
|
21 |
+
"""
|
22 |
+
self.chroma.add_documents(documents=documents, embeddings=self.embeddings,
|
23 |
+
persist_directory=self._persistent_directory)
|