File size: 7,618 Bytes
87337b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 |
#
#
# Agora Real Time Engagement
# Created by Wei Hu in 2024-05.
# Copyright (c) 2024 Agora IO. All rights reserved.
#
#
from ten import (
Extension,
TenEnv,
Cmd,
StatusCode,
CmdResult,
)
from typing import List, Any
import json
from datetime import datetime
import uuid, math
import queue, threading
CMD_FILE_CHUNK = "file_chunk"
UPSERT_VECTOR_CMD = "upsert_vector"
FILE_CHUNKED_CMD = "file_chunked"
CHUNK_SIZE = 200
CHUNK_OVERLAP = 20
BATCH_SIZE = 5
def batch(nodes, size):
batch_texts = []
for n in nodes:
batch_texts.append(n.text)
if len(batch_texts) == size:
yield batch_texts[:]
batch_texts.clear()
if batch_texts:
yield batch_texts
class FileChunkerExtension(Extension):
def __init__(self, name: str):
super().__init__(name)
self.counters = {}
self.expected = {}
self.new_collection_name = ""
self.file_chunked_event = threading.Event()
self.thread = None
self.queue = queue.Queue()
self.stop = False
def generate_collection_name(self) -> str:
"""
follow rules: ^[a-z]+[a-z0-9_]*
"""
return "coll_" + uuid.uuid1().hex.lower()
def split(self, ten: TenEnv, path: str) -> List[Any]:
# lazy import packages which requires long time to load
from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SentenceSplitter
# load pdf file by path
documents = SimpleDirectoryReader(
input_files=[path], filename_as_id=True
).load_data()
# split pdf file into chunks
splitter = SentenceSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=CHUNK_OVERLAP,
)
nodes = splitter.get_nodes_from_documents(documents)
ten.log_info(f"file {path} pages count {documents}, chunking count {nodes}")
return nodes
def create_collection(self, ten: TenEnv, collection_name: str, wait: bool):
cmd_out = Cmd.create("create_collection")
cmd_out.set_property_string("collection_name", collection_name)
wait_event = threading.Event()
ten.send_cmd(
cmd_out,
lambda ten, result, _: wait_event.set(),
)
if wait:
wait_event.wait()
def embedding(self, ten: TenEnv, path: str, texts: List[str]):
ten.log_info(
f"generate embeddings for the file: {path}, with batch size: {len(texts)}"
)
cmd_out = Cmd.create("embed_batch")
cmd_out.set_property_from_json("inputs", json.dumps(texts))
ten.send_cmd(
cmd_out, lambda ten, result, _: self.vector_store(ten, path, texts, result)
)
def vector_store(self, ten: TenEnv, path: str, texts: List[str], result: CmdResult):
ten.log_info(f"vector store start for one splitting of the file {path}")
file_name = path.split("/")[-1]
embed_output_json = result.get_property_string("embeddings")
embed_output = json.loads(embed_output_json)
cmd_out = Cmd.create(UPSERT_VECTOR_CMD)
cmd_out.set_property_string("collection_name", self.new_collection_name)
cmd_out.set_property_string("file_name", file_name)
embeddings = [record["embedding"] for record in embed_output]
content = []
for text, embedding in zip(texts, embeddings):
content.append({"text": text, "embedding": embedding})
cmd_out.set_property_string("content", json.dumps(content))
# ten.log_info(json.dumps(content))
ten.send_cmd(cmd_out, lambda ten, result, _: self.file_chunked(ten, path))
def file_chunked(self, ten: TenEnv, path: str):
if path in self.counters and path in self.expected:
self.counters[path] += 1
ten.log_info(
"complete vector store for one splitting of the file: %s, current counter: %i, expected: %i",
path,
self.counters[path],
self.expected[path],
)
if self.counters[path] == self.expected[path]:
chunks_count = self.counters[path]
del self.counters[path]
del self.expected[path]
ten.log_info(
f"complete chunk for the file: {path}, chunks_count {chunks_count}"
)
cmd_out = Cmd.create(FILE_CHUNKED_CMD)
cmd_out.set_property_string("path", path)
cmd_out.set_property_string("collection", self.new_collection_name)
ten.send_cmd(
cmd_out,
lambda ten, result, _: ten.log_info("send_cmd done"),
)
self.file_chunked_event.set()
else:
ten.log_error("missing counter for the file path: %s", path)
def on_cmd(self, ten: TenEnv, cmd: Cmd) -> None:
cmd_name = cmd.get_name()
if cmd_name == CMD_FILE_CHUNK:
path = cmd.get_property_string("path")
collection = None
try:
collection = cmd.get_property_string("collection")
except Exception:
ten.log_warn(f"missing collection property in cmd {cmd_name}")
self.queue.put((path, collection)) # make sure files are processed in order
else:
ten.log_info(f"unknown cmd {cmd_name}")
cmd_result = CmdResult.create(StatusCode.OK)
cmd_result.set_property_string("detail", "ok")
ten.return_result(cmd_result, cmd)
def async_handler(self, ten: TenEnv) -> None:
while not self.stop:
value = self.queue.get()
if value is None:
break
path, collection = value
# start processing the file
start_time = datetime.now()
if collection is None:
collection = self.generate_collection_name()
ten.log_info(f"collection {collection} generated")
ten.log_info(f"start processing {path}, collection {collection}")
# create collection
self.create_collection(ten, collection, True)
ten.log_info(f"collection {collection} created")
# split
nodes = self.split(ten, path)
# reset counters and events
self.new_collection_name = collection
self.expected[path] = math.ceil(len(nodes) / BATCH_SIZE)
self.counters[path] = 0
self.file_chunked_event.clear()
# trigger embedding and vector storing in parallel
for texts in list(batch(nodes, BATCH_SIZE)):
self.embedding(ten, path, texts)
# wait for all chunks to be processed
self.file_chunked_event.wait()
ten.log_info(
f"finished processing {path}, collection {collection}, cost {int((datetime.now() - start_time).total_seconds() * 1000)}ms"
)
def on_start(self, ten: TenEnv) -> None:
ten.log_info("on_start")
self.stop = False
self.thread = threading.Thread(target=self.async_handler, args=[ten])
self.thread.start()
ten.on_start_done()
def on_stop(self, ten: TenEnv) -> None:
ten.log_info("on_stop")
self.stop = True
if self.thread is not None:
while not self.queue.empty():
self.queue.get()
self.queue.put(None)
self.thread.join()
self.thread = None
ten.on_stop_done()
|