File size: 6,087 Bytes
4e18ce3 693e166 4e18ce3 693e166 7a785e1 4e18ce3 62774df 1a8947e 4e18ce3 1a8947e 4e18ce3 c0eabca 4e18ce3 4f365e0 1a8947e 4f365e0 1a8947e 4e18ce3 1a8947e 4f365e0 5fea365 1a8947e 693e166 1a8947e 693e166 5fea365 693e166 c0eabca 693e166 5fea365 693e166 5fea365 693e166 7a785e1 693e166 f313f6f 693e166 5fea365 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
"""Module to upsert data into AstraDB"""
import os
import logging
import time
import tiktoken
from pytz import timezone
import pandas as pd
from langchain_astradb import AstraDBVectorStore
# from langchain_openai import AzureOpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import DataFrameLoader
# from astrapy import DataAPIClient
# from astrapy.info import CollectionVectorServiceOptions
# from astrapy.exceptions import CollectionAlreadyExistsException
# from astrapy.core.api import APIRequestError
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(funcName)s - %(message)s',
datefmt="%Y-%m-%d %H:%M:%S",
level=logging.ERROR)
# from astrapy import AstraClient
# ASTRA_DB_APPLICATION_TOKEN = os.environ['ASTRA_DB_APPLICATION_TOKEN']
# ASTRA_DB_API_ENDPOINT = os.environ['ASTRA_DB_API_ENDPOINT']
# COLLECTION_NAME = "article"
# VECTOR_OPTIONS = CollectionVectorServiceOptions(
# provider="azureOpenAI",
# model_name="text-embedding-3-small",
# authentication={"providerKey": "AZURE_OPENAI_API_KEY"},
# parameters={
# "resourceName": "openai-oe",
# "deploymentId": "text-embedding-3-small",
# },
# )
# client = DataAPIClient(token=ASTRA_DB_APPLICATION_TOKEN)
# database = client.get_database(ASTRA_DB_API_ENDPOINT)
# embedding = AzureOpenAIEmbeddings(
# api_version="2024-07-01-preview",
# azure_endpoint="https://openai-oe.openai.azure.com/")
vstore = AstraDBVectorStore(
# collection_vector_service_options=CollectionVectorServiceOptions(
# provider="azureOpenAI",
# model_name="text-embedding-3-small",
# authentication={
# "providerKey": "AZURE_OPENAI_API_KEY",
# },
# parameters={
# "resourceName": "openai-oe",
# "deploymentId": "text-embedding-3-small",
# },
# ),
namespace="default_keyspace",
collection_name="article",
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
autodetect_collection=True)
def token_length(text):
"""
Calculates length of encoded text using the tokenizer for the "text-embedding-3-small" model.
Args:
text (str): The input text to be tokenized and measured.
Returns:
int: The length of the encoded text.
"""
tokenizer = tiktoken.encoding_for_model("text-embedding-3-small")
return len(tokenizer.encode(text))
def add_documents_with_retry(chunks, ids, max_retries=3):
"""
Attempts to add documents to the vstore with a specified number of retries.
Parameters:
chunks (list): The list of document chunks to be added.
ids (list): The list of document IDs corresponding to the chunks.
max_retries (int, optional): The maximum number of retry attempts. Default is 3.
Raises:
Exception: If the operation fails after the maximum number of retries, the exception is logged.
"""
for attempt in range(max_retries):
try:
vstore.add_documents(chunks, ids=ids)
except (ConnectionError, TimeoutError) as e:
logging.info("Attempt %d failed: %s", attempt + 1, e)
if attempt < max_retries - 1:
time.sleep(10)
else:
logging.error("Max retries reached. Operation failed.")
logging.error(ids)
print(ids)
def vectorize(article):
"""
Process the given article.
Parameters:
article (DataFrame): The article to be processed.
Returns:
None
"""
article['id'] = str(article['id'])
if isinstance(article, dict):
article = [article] # Convert single dictionary to list of dictionaries
df = pd.DataFrame(article)
df = df[['id', 'publishDate', 'author', 'category',
'content', 'referenceid', 'site', 'title', 'link']]
df['publishDate'] = pd.to_datetime(df['publishDate'], errors='coerce')
df['publishDate'] = df['publishDate'].dt.tz_localize('UTC', ambiguous='NaT', nonexistent='NaT')
df['publishDate'] = df['publishDate'].dt.tz_localize(None).dt.tz_localize(timezone('Etc/GMT+8'))
documents = DataFrameLoader(df, page_content_column="content").load()
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
length_function=token_length,
is_separator_regex=False,
separators=["\n\n", "\n", "\t", "\\n"] # Logical separators
)
chunks = text_splitter.split_documents(documents)
ids = []
for index, chunk in enumerate(chunks):
_id = f"{chunk.metadata['id']}-{str(index)}"
ids.append(_id)
try:
add_documents_with_retry(chunks, ids)
except (ConnectionError, TimeoutError, ValueError) as e:
logging.error("Failed to add documents: %s", e)
# def vectorize(article):
# """
# Process the given article.
# Parameters:
# article (DataFrame): The article to be processed.
# Returns:
# None
# """
# article['id'] = str(article['id'])
# if isinstance(article, dict):
# article = [article] # Convert single dictionary to list of dictionaries
# df = pd.DataFrame(article)
# df = df[['id','site','title','titleCN','category','author','content',
# 'publishDate','link']]
# df['publishDate'] = pd.to_datetime(df['publishDate'])
# loader = DataFrameLoader(df, page_content_column="content")
# documents = loader.load()
# text_splitter = RecursiveCharacterTextSplitter(
# chunk_size=800,
# chunk_overlap=20,
# length_function=len,
# is_separator_regex=False,
# )
# chunks = text_splitter.split_documents(documents)
# ids = []
# for chunk in chunks:
# _id = f"{chunk.metadata['id']}-{str(uuid.uuid5(uuid.NAMESPACE_OID,chunk.page_content))}"
# ids.append(_id)
# inserted_ids = vstore.add_documents(chunks, ids=ids)
# print(inserted_ids)
# logging.info(inserted_ids)
|