File size: 7,618 Bytes
87337b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
#
#
# Agora Real Time Engagement
# Created by Wei Hu in 2024-05.
# Copyright (c) 2024 Agora IO. All rights reserved.
#
#
from ten import (
    Extension,
    TenEnv,
    Cmd,
    StatusCode,
    CmdResult,
)
from typing import List, Any
import json
from datetime import datetime
import uuid, math
import queue, threading

CMD_FILE_CHUNK = "file_chunk"
UPSERT_VECTOR_CMD = "upsert_vector"
FILE_CHUNKED_CMD = "file_chunked"

CHUNK_SIZE = 200
CHUNK_OVERLAP = 20
BATCH_SIZE = 5


def batch(nodes, size):
    batch_texts = []
    for n in nodes:
        batch_texts.append(n.text)
        if len(batch_texts) == size:
            yield batch_texts[:]
            batch_texts.clear()
    if batch_texts:
        yield batch_texts


class FileChunkerExtension(Extension):
    def __init__(self, name: str):
        super().__init__(name)

        self.counters = {}
        self.expected = {}
        self.new_collection_name = ""
        self.file_chunked_event = threading.Event()

        self.thread = None
        self.queue = queue.Queue()
        self.stop = False

    def generate_collection_name(self) -> str:
        """
        follow rules: ^[a-z]+[a-z0-9_]*
        """

        return "coll_" + uuid.uuid1().hex.lower()

    def split(self, ten: TenEnv, path: str) -> List[Any]:
        # lazy import packages which requires long time to load
        from llama_index.core import SimpleDirectoryReader
        from llama_index.core.node_parser import SentenceSplitter

        # load pdf file by path
        documents = SimpleDirectoryReader(
            input_files=[path], filename_as_id=True
        ).load_data()

        # split pdf file into chunks
        splitter = SentenceSplitter(
            chunk_size=CHUNK_SIZE,
            chunk_overlap=CHUNK_OVERLAP,
        )
        nodes = splitter.get_nodes_from_documents(documents)
        ten.log_info(f"file {path} pages count {documents}, chunking count {nodes}")
        return nodes

    def create_collection(self, ten: TenEnv, collection_name: str, wait: bool):
        cmd_out = Cmd.create("create_collection")
        cmd_out.set_property_string("collection_name", collection_name)

        wait_event = threading.Event()
        ten.send_cmd(
            cmd_out,
            lambda ten, result, _: wait_event.set(),
        )
        if wait:
            wait_event.wait()

    def embedding(self, ten: TenEnv, path: str, texts: List[str]):
        ten.log_info(
            f"generate embeddings for the file: {path}, with batch size: {len(texts)}"
        )

        cmd_out = Cmd.create("embed_batch")
        cmd_out.set_property_from_json("inputs", json.dumps(texts))
        ten.send_cmd(
            cmd_out, lambda ten, result, _: self.vector_store(ten, path, texts, result)
        )

    def vector_store(self, ten: TenEnv, path: str, texts: List[str], result: CmdResult):
        ten.log_info(f"vector store start for one splitting of the file {path}")
        file_name = path.split("/")[-1]
        embed_output_json = result.get_property_string("embeddings")
        embed_output = json.loads(embed_output_json)
        cmd_out = Cmd.create(UPSERT_VECTOR_CMD)
        cmd_out.set_property_string("collection_name", self.new_collection_name)
        cmd_out.set_property_string("file_name", file_name)
        embeddings = [record["embedding"] for record in embed_output]
        content = []
        for text, embedding in zip(texts, embeddings):
            content.append({"text": text, "embedding": embedding})
        cmd_out.set_property_string("content", json.dumps(content))
        # ten.log_info(json.dumps(content))
        ten.send_cmd(cmd_out, lambda ten, result, _: self.file_chunked(ten, path))

    def file_chunked(self, ten: TenEnv, path: str):
        if path in self.counters and path in self.expected:
            self.counters[path] += 1
            ten.log_info(
                "complete vector store for one splitting of the file: %s, current counter: %i, expected: %i",
                path,
                self.counters[path],
                self.expected[path],
            )
            if self.counters[path] == self.expected[path]:
                chunks_count = self.counters[path]
                del self.counters[path]
                del self.expected[path]
                ten.log_info(
                    f"complete chunk for the file: {path}, chunks_count {chunks_count}"
                )
                cmd_out = Cmd.create(FILE_CHUNKED_CMD)
                cmd_out.set_property_string("path", path)
                cmd_out.set_property_string("collection", self.new_collection_name)
                ten.send_cmd(
                    cmd_out,
                    lambda ten, result, _: ten.log_info("send_cmd done"),
                )
                self.file_chunked_event.set()
        else:
            ten.log_error("missing counter for the file path: %s", path)

    def on_cmd(self, ten: TenEnv, cmd: Cmd) -> None:
        cmd_name = cmd.get_name()
        if cmd_name == CMD_FILE_CHUNK:
            path = cmd.get_property_string("path")

            collection = None
            try:
                collection = cmd.get_property_string("collection")
            except Exception:
                ten.log_warn(f"missing collection property in cmd {cmd_name}")

            self.queue.put((path, collection))  # make sure files are processed in order
        else:
            ten.log_info(f"unknown cmd {cmd_name}")

        cmd_result = CmdResult.create(StatusCode.OK)
        cmd_result.set_property_string("detail", "ok")
        ten.return_result(cmd_result, cmd)

    def async_handler(self, ten: TenEnv) -> None:
        while not self.stop:
            value = self.queue.get()
            if value is None:
                break
            path, collection = value

            # start processing the file
            start_time = datetime.now()
            if collection is None:
                collection = self.generate_collection_name()
                ten.log_info(f"collection {collection} generated")
            ten.log_info(f"start processing {path}, collection {collection}")

            # create collection
            self.create_collection(ten, collection, True)
            ten.log_info(f"collection {collection} created")

            # split
            nodes = self.split(ten, path)

            # reset counters and events
            self.new_collection_name = collection
            self.expected[path] = math.ceil(len(nodes) / BATCH_SIZE)
            self.counters[path] = 0
            self.file_chunked_event.clear()

            # trigger embedding and vector storing in parallel
            for texts in list(batch(nodes, BATCH_SIZE)):
                self.embedding(ten, path, texts)

            # wait for all chunks to be processed
            self.file_chunked_event.wait()

            ten.log_info(
                f"finished processing {path}, collection {collection}, cost {int((datetime.now() - start_time).total_seconds() * 1000)}ms"
            )

    def on_start(self, ten: TenEnv) -> None:
        ten.log_info("on_start")

        self.stop = False
        self.thread = threading.Thread(target=self.async_handler, args=[ten])
        self.thread.start()

        ten.on_start_done()

    def on_stop(self, ten: TenEnv) -> None:
        ten.log_info("on_stop")

        self.stop = True
        if self.thread is not None:
            while not self.queue.empty():
                self.queue.get()
            self.queue.put(None)
            self.thread.join()
            self.thread = None

        ten.on_stop_done()