File size: 5,170 Bytes
4e18ce3 0f23641 4e18ce3 edb8f2e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
"""Module to upsert data into AstraDB"""
import os
import glob
import uuid
import logging
import boto3
import pandas as pd
import tiktoken
from astrapy import DataAPIClient
from astrapy.constants import VectorMetric
from astrapy.info import CollectionVectorServiceOptions
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import DataFrameLoader
from dotenv import load_dotenv
load_dotenv()
ASTRA_DB_APPLICATION_TOKEN = os.environ['ASTRA_DB_APPLICATION_TOKEN']
ASTRA_DB_API_ENDPOINT = os.environ['ASTRA_DB_API_ENDPOINT']
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(funcName)s - %(message)s',
datefmt="%Y-%m-%d %H:%M:%S",
level=logging.INFO)
client = DataAPIClient(ASTRA_DB_APPLICATION_TOKEN)
database = client.get_database_by_api_endpoint(ASTRA_DB_API_ENDPOINT)
logging.info("* Database : %s", database.info().name)
if "finfast_marco_china" not in database.list_collection_names():
collection = database.create_collection(
"finfast_marco_china",
metric=VectorMetric.COSINE,
service=CollectionVectorServiceOptions(
provider="nvidia",
model_name="NV-Embed-QA",
),
)
else:
collection = database.get_collection("finfast_marco_china")
logging.info("* Collection: %s", collection.full_name)
def truncate_tokens(string: str, encoding_name: str, max_length: int = 8192) -> str:
"""
Truncates a string of tokens to a maximum length.
Args:
string (str): The input string to be truncated.
encoding_name (str): The name of the encoding used for tokenization.
max_length (int, optional): The maximum length of the truncated string. Defaults to 8192.
Returns:
str: The truncated string.
"""
encoding = tiktoken.encoding_for_model(encoding_name)
encoded_string = encoding.encode(string)
num_tokens = len(encoded_string)
if num_tokens > max_length:
string = encoding.decode(encoded_string[:max_length])
return string
def upsert(article, db_collection):
"""
Upserts articles into the index.
Args:
articles (list): A list of articles to be upserted.
Returns:
None
"""
if article is None or 'content' not in article:
return None
article = {k: v for k, v in article.items() if v is not None}
article["articleid"] = article["id"]
logging.info(article["id"])
del article["id"]
if len(article['subtitle'].encode('utf-8')) > 8000:
del article['subtitle']
article['$vectorize'] = article['content']
_id = uuid.uuid5(uuid.NAMESPACE_URL, article['content'])
del article["content"]
db_collection.update_one(
{"_id": _id},
{"$set": article},
upsert=True)
def split_documents(content, page_content_column="content", chunk_size=800, chunk_overlap=20):
"""
Splits a given content into smaller documents using a recursive character text splitter.
Args:
content (pandas.DataFrame): The input content to be split.
page_content_column (str, optional): \
The name of the column in the input content that contains the text to be split.
chunk_size (int, optional): The maximum size of each chunk. Defaults to 800.
chunk_overlap (int, optional): \
The number of overlapping characters between chunks. Defaults to 20.
Returns:
list: A list of the split documents.
"""
loader = DataFrameLoader(content, page_content_column=page_content_column)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len,
is_separator_regex=False,
)
_docs = text_splitter.split_documents(documents)
return _docs
def documents_to_list_of_dicts(documents):
"""
Converts a list of documents to a list of dictionaries.
Parameters:
- documents (list): A list of documents.
Returns:
- doc_dict_list (list): A list of dictionaries, where each dictionary represents a document.
Each dictionary contains the following keys:
- 'content': The page content of the document.
- Other keys represent metadata items of the document.
"""
doc_dict_list = []
for _doc in documents:
doc_dict = {}
doc_dict['content'] = _doc.page_content
for key, item in _doc.metadata.items():
doc_dict[key] = item
doc_dict_list.append(doc_dict)
return doc_dict_list
def vectorize(article):
"""
Process the given article.
Parameters:
article (DataFrame): The article to be processed.
Returns:
None
"""
df = pd.DataFrame(article)
df = df[['id','site','title','titleCN','contentCN','category','author','content','subtitle','publishDate','link','attachment','sentimentScore','sentimentLabel']]
docs = split_documents(df)
documents_list = documents_to_list_of_dicts(docs)
for doc in documents_list:
try:
upsert(doc, collection)
except Exception as e:
logging.info(e)
pass |