# # # Agora Real Time Engagement # Created by Wei Hu in 2024-05. # Copyright (c) 2024 Agora IO. All rights reserved. # # from ten import ( Extension, TenEnv, Cmd, StatusCode, CmdResult, ) from typing import List, Any import json from datetime import datetime import uuid, math import queue, threading CMD_FILE_CHUNK = "file_chunk" UPSERT_VECTOR_CMD = "upsert_vector" FILE_CHUNKED_CMD = "file_chunked" CHUNK_SIZE = 200 CHUNK_OVERLAP = 20 BATCH_SIZE = 5 def batch(nodes, size): batch_texts = [] for n in nodes: batch_texts.append(n.text) if len(batch_texts) == size: yield batch_texts[:] batch_texts.clear() if batch_texts: yield batch_texts class FileChunkerExtension(Extension): def __init__(self, name: str): super().__init__(name) self.counters = {} self.expected = {} self.new_collection_name = "" self.file_chunked_event = threading.Event() self.thread = None self.queue = queue.Queue() self.stop = False def generate_collection_name(self) -> str: """ follow rules: ^[a-z]+[a-z0-9_]* """ return "coll_" + uuid.uuid1().hex.lower() def split(self, ten: TenEnv, path: str) -> List[Any]: # lazy import packages which requires long time to load from llama_index.core import SimpleDirectoryReader from llama_index.core.node_parser import SentenceSplitter # load pdf file by path documents = SimpleDirectoryReader( input_files=[path], filename_as_id=True ).load_data() # split pdf file into chunks splitter = SentenceSplitter( chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP, ) nodes = splitter.get_nodes_from_documents(documents) ten.log_info(f"file {path} pages count {documents}, chunking count {nodes}") return nodes def create_collection(self, ten: TenEnv, collection_name: str, wait: bool): cmd_out = Cmd.create("create_collection") cmd_out.set_property_string("collection_name", collection_name) wait_event = threading.Event() ten.send_cmd( cmd_out, lambda ten, result, _: wait_event.set(), ) if wait: wait_event.wait() def embedding(self, ten: TenEnv, path: str, texts: List[str]): ten.log_info( f"generate embeddings for the file: {path}, with batch size: {len(texts)}" ) cmd_out = Cmd.create("embed_batch") cmd_out.set_property_from_json("inputs", json.dumps(texts)) ten.send_cmd( cmd_out, lambda ten, result, _: self.vector_store(ten, path, texts, result) ) def vector_store(self, ten: TenEnv, path: str, texts: List[str], result: CmdResult): ten.log_info(f"vector store start for one splitting of the file {path}") file_name = path.split("/")[-1] embed_output_json = result.get_property_string("embeddings") embed_output = json.loads(embed_output_json) cmd_out = Cmd.create(UPSERT_VECTOR_CMD) cmd_out.set_property_string("collection_name", self.new_collection_name) cmd_out.set_property_string("file_name", file_name) embeddings = [record["embedding"] for record in embed_output] content = [] for text, embedding in zip(texts, embeddings): content.append({"text": text, "embedding": embedding}) cmd_out.set_property_string("content", json.dumps(content)) # ten.log_info(json.dumps(content)) ten.send_cmd(cmd_out, lambda ten, result, _: self.file_chunked(ten, path)) def file_chunked(self, ten: TenEnv, path: str): if path in self.counters and path in self.expected: self.counters[path] += 1 ten.log_info( "complete vector store for one splitting of the file: %s, current counter: %i, expected: %i", path, self.counters[path], self.expected[path], ) if self.counters[path] == self.expected[path]: chunks_count = self.counters[path] del self.counters[path] del self.expected[path] ten.log_info( f"complete chunk for the file: {path}, chunks_count {chunks_count}" ) cmd_out = Cmd.create(FILE_CHUNKED_CMD) cmd_out.set_property_string("path", path) cmd_out.set_property_string("collection", self.new_collection_name) ten.send_cmd( cmd_out, lambda ten, result, _: ten.log_info("send_cmd done"), ) self.file_chunked_event.set() else: ten.log_error("missing counter for the file path: %s", path) def on_cmd(self, ten: TenEnv, cmd: Cmd) -> None: cmd_name = cmd.get_name() if cmd_name == CMD_FILE_CHUNK: path = cmd.get_property_string("path") collection = None try: collection = cmd.get_property_string("collection") except Exception: ten.log_warn(f"missing collection property in cmd {cmd_name}") self.queue.put((path, collection)) # make sure files are processed in order else: ten.log_info(f"unknown cmd {cmd_name}") cmd_result = CmdResult.create(StatusCode.OK) cmd_result.set_property_string("detail", "ok") ten.return_result(cmd_result, cmd) def async_handler(self, ten: TenEnv) -> None: while not self.stop: value = self.queue.get() if value is None: break path, collection = value # start processing the file start_time = datetime.now() if collection is None: collection = self.generate_collection_name() ten.log_info(f"collection {collection} generated") ten.log_info(f"start processing {path}, collection {collection}") # create collection self.create_collection(ten, collection, True) ten.log_info(f"collection {collection} created") # split nodes = self.split(ten, path) # reset counters and events self.new_collection_name = collection self.expected[path] = math.ceil(len(nodes) / BATCH_SIZE) self.counters[path] = 0 self.file_chunked_event.clear() # trigger embedding and vector storing in parallel for texts in list(batch(nodes, BATCH_SIZE)): self.embedding(ten, path, texts) # wait for all chunks to be processed self.file_chunked_event.wait() ten.log_info( f"finished processing {path}, collection {collection}, cost {int((datetime.now() - start_time).total_seconds() * 1000)}ms" ) def on_start(self, ten: TenEnv) -> None: ten.log_info("on_start") self.stop = False self.thread = threading.Thread(target=self.async_handler, args=[ten]) self.thread.start() ten.on_start_done() def on_stop(self, ten: TenEnv) -> None: ten.log_info("on_stop") self.stop = True if self.thread is not None: while not self.queue.empty(): self.queue.get() self.queue.put(None) self.thread.join() self.thread = None ten.on_stop_done()