|
"""Module to upsert data into AstraDB""" |
|
import os |
|
import glob |
|
import uuid |
|
import logging |
|
|
|
import boto3 |
|
import pandas as pd |
|
import tiktoken |
|
from astrapy import DataAPIClient |
|
from astrapy.constants import VectorMetric |
|
from astrapy.info import CollectionVectorServiceOptions |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_community.document_loaders import DataFrameLoader |
|
from dotenv import load_dotenv |
|
|
|
load_dotenv() |
|
|
|
ASTRA_DB_APPLICATION_TOKEN = os.environ['ASTRA_DB_APPLICATION_TOKEN'] |
|
ASTRA_DB_API_ENDPOINT = os.environ['ASTRA_DB_API_ENDPOINT'] |
|
|
|
logging.basicConfig( |
|
format='%(asctime)s - %(levelname)s - %(funcName)s - %(message)s', |
|
datefmt="%Y-%m-%d %H:%M:%S", |
|
level=logging.INFO) |
|
|
|
client = DataAPIClient(ASTRA_DB_APPLICATION_TOKEN) |
|
database = client.get_database_by_api_endpoint(ASTRA_DB_API_ENDPOINT) |
|
logging.info("* Database : %s", database.info().name) |
|
|
|
if "finfast_marco_china" not in database.list_collection_names(): |
|
collection = database.create_collection( |
|
"finfast_marco_china", |
|
metric=VectorMetric.COSINE, |
|
service=CollectionVectorServiceOptions( |
|
provider="nvidia", |
|
model_name="NV-Embed-QA", |
|
), |
|
) |
|
else: |
|
collection = database.get_collection("finfast_marco_china") |
|
logging.info("* Collection: %s", collection.full_name) |
|
|
|
|
|
def truncate_tokens(string: str, encoding_name: str, max_length: int = 8192) -> str: |
|
""" |
|
Truncates a string of tokens to a maximum length. |
|
|
|
Args: |
|
string (str): The input string to be truncated. |
|
encoding_name (str): The name of the encoding used for tokenization. |
|
max_length (int, optional): The maximum length of the truncated string. Defaults to 8192. |
|
|
|
Returns: |
|
str: The truncated string. |
|
|
|
""" |
|
encoding = tiktoken.encoding_for_model(encoding_name) |
|
encoded_string = encoding.encode(string) |
|
num_tokens = len(encoded_string) |
|
|
|
if num_tokens > max_length: |
|
string = encoding.decode(encoded_string[:max_length]) |
|
return string |
|
|
|
def upsert(article, db_collection): |
|
""" |
|
Upserts articles into the index. |
|
|
|
Args: |
|
articles (list): A list of articles to be upserted. |
|
|
|
Returns: |
|
None |
|
""" |
|
if article is None or 'content' not in article: |
|
return None |
|
article = {k: v for k, v in article.items() if v is not None} |
|
article["articleid"] = article["id"] |
|
logging.info(article["id"]) |
|
del article["id"] |
|
if len(article['subtitle'].encode('utf-8')) > 8000: |
|
del article['subtitle'] |
|
article['$vectorize'] = article['content'] |
|
_id = uuid.uuid5(uuid.NAMESPACE_URL, article['content']) |
|
del article["content"] |
|
db_collection.update_one( |
|
{"_id": _id}, |
|
{"$set": article}, |
|
upsert=True) |
|
|
|
def split_documents(content, page_content_column="content", chunk_size=800, chunk_overlap=20): |
|
""" |
|
Splits a given content into smaller documents using a recursive character text splitter. |
|
|
|
Args: |
|
content (pandas.DataFrame): The input content to be split. |
|
page_content_column (str, optional): \ |
|
The name of the column in the input content that contains the text to be split. |
|
chunk_size (int, optional): The maximum size of each chunk. Defaults to 800. |
|
chunk_overlap (int, optional): \ |
|
The number of overlapping characters between chunks. Defaults to 20. |
|
|
|
Returns: |
|
list: A list of the split documents. |
|
""" |
|
loader = DataFrameLoader(content, page_content_column=page_content_column) |
|
documents = loader.load() |
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=chunk_size, |
|
chunk_overlap=chunk_overlap, |
|
length_function=len, |
|
is_separator_regex=False, |
|
) |
|
_docs = text_splitter.split_documents(documents) |
|
return _docs |
|
|
|
def documents_to_list_of_dicts(documents): |
|
""" |
|
Converts a list of documents to a list of dictionaries. |
|
|
|
Parameters: |
|
- documents (list): A list of documents. |
|
|
|
Returns: |
|
- doc_dict_list (list): A list of dictionaries, where each dictionary represents a document. |
|
Each dictionary contains the following keys: |
|
- 'content': The page content of the document. |
|
- Other keys represent metadata items of the document. |
|
""" |
|
doc_dict_list = [] |
|
for _doc in documents: |
|
doc_dict = {} |
|
doc_dict['content'] = _doc.page_content |
|
for key, item in _doc.metadata.items(): |
|
doc_dict[key] = item |
|
doc_dict_list.append(doc_dict) |
|
return doc_dict_list |
|
|
|
def vectorize(article): |
|
""" |
|
Process the given article. |
|
|
|
Parameters: |
|
article (DataFrame): The article to be processed. |
|
|
|
Returns: |
|
None |
|
""" |
|
df = pd.DataFrame(article) |
|
df = df[['id','site','title','titleCN','contentCN','category','author','content','subtitle','publishDate','link','attachment','sentimentScore','sentimentLabel']] |
|
docs = split_documents(df) |
|
documents_list = documents_to_list_of_dicts(docs) |
|
for doc in documents_list: |
|
try: |
|
upsert(doc, collection) |
|
except Exception as e: |
|
logging.info(e) |
|
pass |