OxbridgeEconomics commited on
Commit
62774df
·
unverified ·
1 Parent(s): 2dbc5e6

Update vectorizer.py

Browse files
Files changed (1) hide show
  1. controllers/vectorizer.py +26 -132
controllers/vectorizer.py CHANGED
@@ -1,141 +1,30 @@
1
  """Module to upsert data into AstraDB"""
2
  import os
3
- import glob
4
- import uuid
5
  import logging
6
 
7
- import boto3
8
  import pandas as pd
9
- import tiktoken
10
- from astrapy import DataAPIClient
11
- from astrapy.constants import VectorMetric
12
- from astrapy.info import CollectionVectorServiceOptions
13
  from langchain.text_splitter import RecursiveCharacterTextSplitter
14
  from langchain_community.document_loaders import DataFrameLoader
15
- from dotenv import load_dotenv
16
-
17
- load_dotenv()
18
-
19
- ASTRA_DB_APPLICATION_TOKEN = os.environ['ASTRA_DB_APPLICATION_TOKEN']
20
- ASTRA_DB_API_ENDPOINT = os.environ['ASTRA_DB_API_ENDPOINT']
21
 
22
  logging.basicConfig(
23
  format='%(asctime)s - %(levelname)s - %(funcName)s - %(message)s',
24
  datefmt="%Y-%m-%d %H:%M:%S",
25
  level=logging.INFO)
26
 
27
- client = DataAPIClient(ASTRA_DB_APPLICATION_TOKEN)
28
- database = client.get_database_by_api_endpoint(ASTRA_DB_API_ENDPOINT)
29
- logging.info("* Database : %s", database.info().name)
30
-
31
- if "finfast_marco_china" not in database.list_collection_names():
32
- collection = database.create_collection(
33
- "finfast_marco_china",
34
- metric=VectorMetric.COSINE,
35
- service=CollectionVectorServiceOptions(
36
- provider="nvidia",
37
- model_name="NV-Embed-QA",
38
- ),
39
- )
40
- else:
41
- collection = database.get_collection("finfast_marco_china")
42
- logging.info("* Collection: %s", collection.full_name)
43
-
44
-
45
- def truncate_tokens(string: str, encoding_name: str, max_length: int = 8192) -> str:
46
- """
47
- Truncates a string of tokens to a maximum length.
48
-
49
- Args:
50
- string (str): The input string to be truncated.
51
- encoding_name (str): The name of the encoding used for tokenization.
52
- max_length (int, optional): The maximum length of the truncated string. Defaults to 8192.
53
-
54
- Returns:
55
- str: The truncated string.
56
-
57
- """
58
- encoding = tiktoken.encoding_for_model(encoding_name)
59
- encoded_string = encoding.encode(string)
60
- num_tokens = len(encoded_string)
61
-
62
- if num_tokens > max_length:
63
- string = encoding.decode(encoded_string[:max_length])
64
- return string
65
-
66
- def upsert(article, db_collection):
67
- """
68
- Upserts articles into the index.
69
-
70
- Args:
71
- articles (list): A list of articles to be upserted.
72
-
73
- Returns:
74
- None
75
- """
76
- if article is None or 'content' not in article:
77
- return None
78
- article = {k: v for k, v in article.items() if v is not None}
79
- article["articleid"] = str(article["id"])
80
- logging.info(article["id"])
81
- del article["id"]
82
- if len(article['subtitle'].encode('utf-8')) > 8000:
83
- del article['subtitle']
84
- article['$vectorize'] = article['content']
85
- _id = uuid.uuid5(uuid.NAMESPACE_URL, article['content'])
86
- del article["content"]
87
- db_collection.update_one(
88
- {"_id": _id},
89
- {"$set": article},
90
- upsert=True)
91
-
92
- def split_documents(content, page_content_column="content", chunk_size=800, chunk_overlap=20):
93
- """
94
- Splits a given content into smaller documents using a recursive character text splitter.
95
-
96
- Args:
97
- content (pandas.DataFrame): The input content to be split.
98
- page_content_column (str, optional): \
99
- The name of the column in the input content that contains the text to be split.
100
- chunk_size (int, optional): The maximum size of each chunk. Defaults to 800.
101
- chunk_overlap (int, optional): \
102
- The number of overlapping characters between chunks. Defaults to 20.
103
-
104
- Returns:
105
- list: A list of the split documents.
106
- """
107
- loader = DataFrameLoader(content, page_content_column=page_content_column)
108
- documents = loader.load()
109
- text_splitter = RecursiveCharacterTextSplitter(
110
- chunk_size=chunk_size,
111
- chunk_overlap=chunk_overlap,
112
- length_function=len,
113
- is_separator_regex=False,
114
- )
115
- _docs = text_splitter.split_documents(documents)
116
- return _docs
117
-
118
- def documents_to_list_of_dicts(documents):
119
- """
120
- Converts a list of documents to a list of dictionaries.
121
 
122
- Parameters:
123
- - documents (list): A list of documents.
 
124
 
125
- Returns:
126
- - doc_dict_list (list): A list of dictionaries, where each dictionary represents a document.
127
- Each dictionary contains the following keys:
128
- - 'content': The page content of the document.
129
- - Other keys represent metadata items of the document.
130
- """
131
- doc_dict_list = []
132
- for _doc in documents:
133
- doc_dict = {}
134
- doc_dict['content'] = _doc.page_content
135
- for key, item in _doc.metadata.items():
136
- doc_dict[key] = item
137
- doc_dict_list.append(doc_dict)
138
- return doc_dict_list
139
 
140
  def vectorize(article):
141
  """
@@ -148,13 +37,18 @@ def vectorize(article):
148
  None
149
  """
150
  df = pd.DataFrame(article)
151
- df = df[['id','site','title','titleCN','contentCN','category','author','content','subtitle','publishDate','link','attachment','sentimentScore','sentimentLabel']]
 
152
  df['sentimentScore'] = df['sentimentScore'].round(2)
153
- docs = split_documents(df)
154
- documents_list = documents_to_list_of_dicts(docs)
155
- for doc in documents_list:
156
- try:
157
- upsert(doc, collection)
158
- except Exception as e:
159
- logging.info(e)
160
- pass
 
 
 
 
 
1
  """Module to upsert data into AstraDB"""
2
  import os
 
 
3
  import logging
4
 
 
5
  import pandas as pd
6
+ from langchain_astradb import AstraDBVectorStore
7
+ from langchain_openai import AzureOpenAIEmbeddings
 
 
8
  from langchain.text_splitter import RecursiveCharacterTextSplitter
9
  from langchain_community.document_loaders import DataFrameLoader
 
 
 
 
 
 
10
 
11
  logging.basicConfig(
12
  format='%(asctime)s - %(levelname)s - %(funcName)s - %(message)s',
13
  datefmt="%Y-%m-%d %H:%M:%S",
14
  level=logging.INFO)
15
 
16
+ ASTRA_DB_APPLICATION_TOKEN = os.environ['ASTRA_DB_APPLICATION_TOKEN']
17
+ ASTRA_DB_API_ENDPOINT = os.environ['ASTRA_DB_API_ENDPOINT']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ embedding = AzureOpenAIEmbeddings(
20
+ api_version="2024-07-01-preview",
21
+ azure_endpoint="https://openai-oe.openai.azure.com/")
22
 
23
+ vstore = AstraDBVectorStore(embedding=embedding,
24
+ namespace="default_keyspace",
25
+ collection_name="finfast_china_test",
26
+ token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
27
+ api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"])
 
 
 
 
 
 
 
 
 
28
 
29
  def vectorize(article):
30
  """
 
37
  None
38
  """
39
  df = pd.DataFrame(article)
40
+ df = df[['id','site','title','titleCN','category','author','content',
41
+ 'publishDate','link','attachment','sentimentScore','sentimentLabel']]
42
  df['sentimentScore'] = df['sentimentScore'].round(2)
43
+ loader = DataFrameLoader(df, page_content_column="content")
44
+ documents = loader.load()
45
+ text_splitter = RecursiveCharacterTextSplitter(
46
+ chunk_size=800,
47
+ chunk_overlap=20,
48
+ length_function=len,
49
+ is_separator_regex=False,
50
+ )
51
+
52
+ docs = text_splitter.split_documents(documents)
53
+ inserted_ids = vstore.add_documents(docs)
54
+ logging.info(inserted_ids)