zaldivards's picture
feat: improve text splitting logic
13ce80f
import glob
import os
import pickle
import re
from pathlib import Path
import gradio as gr
import spaces
import numpy as np
from pypdf import PdfReader
from transformers import AutoModel
chunk_size = int(os.environ.get("CHUNK_SIZE", 250))
default_k = int(os.environ.get("DEFAULT_K", 5))
model = AutoModel.from_pretrained("jinaai/jina-embeddings-v2-base-es", trust_remote_code=True)
replace_pairs = [
(r"¢\s+100.00", "$50"),
(r"¢\s+300.00", "$100"),
(r"¢\s+500.00", "$150"),
# Attempt to normalize the current chunk by removing more than one consecutive space,
# while preserving single spaces within words
(r"(?<!\w|[.,;]) +", " "),
# remove more than one line break, multiple underscores and unwanted headers or footers
(r"(?<!\w|[ .:])\n|_+|INDICE LEGISLATIVO|ASAMBLEA LEGISLATIVA \- REPUBLICA DE EL SALVADOR", ""),
]
docs = {}
def extract_text_from_pdf(reader: PdfReader) -> str:
"""Extract text from PDF pages
Parameters
----------
reader : PdfReader
PDF reader
Returns
-------
str
Raw text
"""
content = [page.extract_text().strip() for page in reader.pages]
return "\n\n".join(content).strip()
def convert(filename: str) -> str:
"""Convert file content to raw text
Parameters
----------
filename : str
The filename or path
Returns
-------
str
The raw text
Raises
------
ValueError
If the file type is not supported.
"""
plain_text_filetypes = [".txt", ".csv", ".tsv", ".md", ".yaml", ".toml", ".json", ".json5", ".jsonc"]
# Already a plain text file that wouldn't benefit from pandoc so return the content
if any(filename.endswith(ft) for ft in plain_text_filetypes):
with open(filename, "r", encoding="utf-8") as f:
return f.read()
if filename.endswith(".pdf"):
return extract_text_from_pdf(PdfReader(filename))
raise ValueError(f"Unsupported file type: {filename}")
def add_prefix(chunk: str, art_prefix: str) -> tuple[str, str]:
"""Add prefix to chunks that are continuation of a certain article
Parameters
----------
chunk : str
original chunk
art_prefix : str
current prefix
Returns
-------
tuple[str, str]
The updated chunk and the new prefix
"""
results = re.findall(r"(Articulo \d+)\s+-", chunk)
ignore_results = False
if (len(results) == 1 and chunk.find(results[0]) > 4 and art_prefix) or not results:
results.insert(0, art_prefix)
elif len(results) == 1 and chunk.find(results[0]) <= 4:
ignore_results = True
art_prefix = results[-1]
# if the current chunk is a continuation of a certain article, an identifier prefix will be added to it
return (f"<<{'|'.join(results)}>>{chunk}" if results and not ignore_results else chunk), art_prefix
def generate_chunks(text: str, max_length: int) -> list[str]:
"""Generate chunks from a file's raw text. Chunks are calculated based
on the `max_length` parameter and the split character (.)
Parameters
----------
text : str
The raw text
max_length : int
Maximum number of characters a chunk can have. Note that chunks
may not have this exact length, as another component is also
involved in the splitting process
Returns
-------
list[str]
A list of chunks/nodes
"""
for match_result in re.finditer(r"Art\. (\d+)\.", text):
# replace Art. X. with Articulo X
text = text.replace(match_result.group(), f"Articulo {match_result.group(1)} ")
for regex, new in replace_pairs:
text = re.sub(regex, new, text)
chunks = []
chunk = ""
art_prefix = ""
for current_segment in text.split("\n"):
remaining = ""
if len(chunk) + len(current_segment) + 1 <= max_length:
chunk += f" {current_segment}"
else:
remaining = current_segment
# split using period (.) but ignoring number such as 1.0, 2.000, etc
for idx, little_segment in enumerate(re.split(r"(?<!\d)\.", remaining)):
if len(chunk) + len(little_segment) + 2 <= max_length:
remaining = remaining.removeprefix(f"{little_segment}.")
chunk += f"{'.' if idx > 0 else ''} {little_segment}"
else:
break
if remaining:
chunk, art_prefix = add_prefix(chunk, art_prefix)
chunks.append(chunk.lower())
chunk = remaining
if chunk:
chunk, _ = add_prefix(chunk, art_prefix)
chunks.append(chunk.lower())
return chunks
@spaces.GPU
def predict(query: str, k: int = 5) -> str:
"""Find k most relevant chunks based on the given query
Parameters
----------
query : str
The input query
k : int, optional
Number of relevant chunks to return, by default 5
Returns
-------
str
The k chunks concatenated together as a single string.
Example
-------
If k=2, the returned string might look like:
"CONTEXT:\n\nchunk-1\n\nchunk-2"
"""
# Embed the query
query_embedding = model.encode(query)
# Initialize a list to store all chunks and their similarities across all documents
all_chunks = []
# Iterate through all documents
for filename, doc in docs.items():
# Calculate cosine similarity between the query and the document embeddings
similarities = np.dot(doc["embeddings"], query_embedding) / (
np.linalg.norm(doc["embeddings"]) * np.linalg.norm(query_embedding)
)
# Add chunks and similarities to the all_chunks list
all_chunks.extend([(filename, chunk, sim) for chunk, sim in zip(doc["chunks"], similarities)])
# Sort all chunks by similarity
all_chunks.sort(key=lambda x: x[2], reverse=True)
return "CONTEXT:\n\n" + "\n\n".join(f"{filename}: {chunk}" for filename, chunk, _ in all_chunks[:k])
def init():
"""Init function
It will load or calculate the embeddings
"""
global docs # pylint: disable=W0603
embeddings_file = Path("embeddings-es.pickle")
if embeddings_file.exists():
with open(embeddings_file, "rb") as embeddings_pickle:
docs = pickle.load(embeddings_pickle)
else:
for filename in glob.glob("sources/*"):
converted_doc = convert(filename)
chunks = generate_chunks(converted_doc, chunk_size)
embeddings = model.encode(chunks)
# get the filename and slugify it
docs[filename.rsplit("/", 1)[-1].lower().replace(" ", "-")] = {
"chunks": chunks,
"embeddings": embeddings,
}
with open(embeddings_file, "wb") as pickle_file:
pickle.dump(docs, pickle_file)
init()
gr.Interface(
predict,
inputs=[
gr.Textbox(label="Question asked about the documents"),
gr.Number(label="Number of relevant sources to return (k)", value=default_k),
],
outputs=[gr.Text(label="Relevant chunks")],
title="ContextqaSV",
description="RAG tool enabling questions and answers on legal documents from El Salvador. Legal"
" documents supported:\n- Constitución de la república\n- Reglamento de tránsito y seguridad vial",
).launch()