|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from ten import ( |
|
Extension, |
|
TenEnv, |
|
Cmd, |
|
StatusCode, |
|
CmdResult, |
|
) |
|
from typing import List, Any |
|
import json |
|
from datetime import datetime |
|
import uuid, math |
|
import queue, threading |
|
|
|
CMD_FILE_CHUNK = "file_chunk" |
|
UPSERT_VECTOR_CMD = "upsert_vector" |
|
FILE_CHUNKED_CMD = "file_chunked" |
|
|
|
CHUNK_SIZE = 200 |
|
CHUNK_OVERLAP = 20 |
|
BATCH_SIZE = 5 |
|
|
|
|
|
def batch(nodes, size): |
|
batch_texts = [] |
|
for n in nodes: |
|
batch_texts.append(n.text) |
|
if len(batch_texts) == size: |
|
yield batch_texts[:] |
|
batch_texts.clear() |
|
if batch_texts: |
|
yield batch_texts |
|
|
|
|
|
class FileChunkerExtension(Extension): |
|
def __init__(self, name: str): |
|
super().__init__(name) |
|
|
|
self.counters = {} |
|
self.expected = {} |
|
self.new_collection_name = "" |
|
self.file_chunked_event = threading.Event() |
|
|
|
self.thread = None |
|
self.queue = queue.Queue() |
|
self.stop = False |
|
|
|
def generate_collection_name(self) -> str: |
|
""" |
|
follow rules: ^[a-z]+[a-z0-9_]* |
|
""" |
|
|
|
return "coll_" + uuid.uuid1().hex.lower() |
|
|
|
def split(self, ten: TenEnv, path: str) -> List[Any]: |
|
|
|
from llama_index.core import SimpleDirectoryReader |
|
from llama_index.core.node_parser import SentenceSplitter |
|
|
|
|
|
documents = SimpleDirectoryReader( |
|
input_files=[path], filename_as_id=True |
|
).load_data() |
|
|
|
|
|
splitter = SentenceSplitter( |
|
chunk_size=CHUNK_SIZE, |
|
chunk_overlap=CHUNK_OVERLAP, |
|
) |
|
nodes = splitter.get_nodes_from_documents(documents) |
|
ten.log_info(f"file {path} pages count {documents}, chunking count {nodes}") |
|
return nodes |
|
|
|
def create_collection(self, ten: TenEnv, collection_name: str, wait: bool): |
|
cmd_out = Cmd.create("create_collection") |
|
cmd_out.set_property_string("collection_name", collection_name) |
|
|
|
wait_event = threading.Event() |
|
ten.send_cmd( |
|
cmd_out, |
|
lambda ten, result, _: wait_event.set(), |
|
) |
|
if wait: |
|
wait_event.wait() |
|
|
|
def embedding(self, ten: TenEnv, path: str, texts: List[str]): |
|
ten.log_info( |
|
f"generate embeddings for the file: {path}, with batch size: {len(texts)}" |
|
) |
|
|
|
cmd_out = Cmd.create("embed_batch") |
|
cmd_out.set_property_from_json("inputs", json.dumps(texts)) |
|
ten.send_cmd( |
|
cmd_out, lambda ten, result, _: self.vector_store(ten, path, texts, result) |
|
) |
|
|
|
def vector_store(self, ten: TenEnv, path: str, texts: List[str], result: CmdResult): |
|
ten.log_info(f"vector store start for one splitting of the file {path}") |
|
file_name = path.split("/")[-1] |
|
embed_output_json = result.get_property_string("embeddings") |
|
embed_output = json.loads(embed_output_json) |
|
cmd_out = Cmd.create(UPSERT_VECTOR_CMD) |
|
cmd_out.set_property_string("collection_name", self.new_collection_name) |
|
cmd_out.set_property_string("file_name", file_name) |
|
embeddings = [record["embedding"] for record in embed_output] |
|
content = [] |
|
for text, embedding in zip(texts, embeddings): |
|
content.append({"text": text, "embedding": embedding}) |
|
cmd_out.set_property_string("content", json.dumps(content)) |
|
|
|
ten.send_cmd(cmd_out, lambda ten, result, _: self.file_chunked(ten, path)) |
|
|
|
def file_chunked(self, ten: TenEnv, path: str): |
|
if path in self.counters and path in self.expected: |
|
self.counters[path] += 1 |
|
ten.log_info( |
|
"complete vector store for one splitting of the file: %s, current counter: %i, expected: %i", |
|
path, |
|
self.counters[path], |
|
self.expected[path], |
|
) |
|
if self.counters[path] == self.expected[path]: |
|
chunks_count = self.counters[path] |
|
del self.counters[path] |
|
del self.expected[path] |
|
ten.log_info( |
|
f"complete chunk for the file: {path}, chunks_count {chunks_count}" |
|
) |
|
cmd_out = Cmd.create(FILE_CHUNKED_CMD) |
|
cmd_out.set_property_string("path", path) |
|
cmd_out.set_property_string("collection", self.new_collection_name) |
|
ten.send_cmd( |
|
cmd_out, |
|
lambda ten, result, _: ten.log_info("send_cmd done"), |
|
) |
|
self.file_chunked_event.set() |
|
else: |
|
ten.log_error("missing counter for the file path: %s", path) |
|
|
|
def on_cmd(self, ten: TenEnv, cmd: Cmd) -> None: |
|
cmd_name = cmd.get_name() |
|
if cmd_name == CMD_FILE_CHUNK: |
|
path = cmd.get_property_string("path") |
|
|
|
collection = None |
|
try: |
|
collection = cmd.get_property_string("collection") |
|
except Exception: |
|
ten.log_warn(f"missing collection property in cmd {cmd_name}") |
|
|
|
self.queue.put((path, collection)) |
|
else: |
|
ten.log_info(f"unknown cmd {cmd_name}") |
|
|
|
cmd_result = CmdResult.create(StatusCode.OK) |
|
cmd_result.set_property_string("detail", "ok") |
|
ten.return_result(cmd_result, cmd) |
|
|
|
def async_handler(self, ten: TenEnv) -> None: |
|
while not self.stop: |
|
value = self.queue.get() |
|
if value is None: |
|
break |
|
path, collection = value |
|
|
|
|
|
start_time = datetime.now() |
|
if collection is None: |
|
collection = self.generate_collection_name() |
|
ten.log_info(f"collection {collection} generated") |
|
ten.log_info(f"start processing {path}, collection {collection}") |
|
|
|
|
|
self.create_collection(ten, collection, True) |
|
ten.log_info(f"collection {collection} created") |
|
|
|
|
|
nodes = self.split(ten, path) |
|
|
|
|
|
self.new_collection_name = collection |
|
self.expected[path] = math.ceil(len(nodes) / BATCH_SIZE) |
|
self.counters[path] = 0 |
|
self.file_chunked_event.clear() |
|
|
|
|
|
for texts in list(batch(nodes, BATCH_SIZE)): |
|
self.embedding(ten, path, texts) |
|
|
|
|
|
self.file_chunked_event.wait() |
|
|
|
ten.log_info( |
|
f"finished processing {path}, collection {collection}, cost {int((datetime.now() - start_time).total_seconds() * 1000)}ms" |
|
) |
|
|
|
def on_start(self, ten: TenEnv) -> None: |
|
ten.log_info("on_start") |
|
|
|
self.stop = False |
|
self.thread = threading.Thread(target=self.async_handler, args=[ten]) |
|
self.thread.start() |
|
|
|
ten.on_start_done() |
|
|
|
def on_stop(self, ten: TenEnv) -> None: |
|
ten.log_info("on_stop") |
|
|
|
self.stop = True |
|
if self.thread is not None: |
|
while not self.queue.empty(): |
|
self.queue.get() |
|
self.queue.put(None) |
|
self.thread.join() |
|
self.thread = None |
|
|
|
ten.on_stop_done() |
|
|