zaldivards's picture
refactor: cosine similarity and text splitting
0918d3a
raw
history blame
4.93 kB
import os
import glob
import pickle
from pathlib import Path
import gradio as gr
import spaces
import numpy as np
from pypdf import PdfReader
from sentence_transformers import SentenceTransformer
model_name = os.environ.get("MODEL", "Snowflake/snowflake-arctic-embed-m")
chunk_size = int(os.environ.get("CHUNK_SIZE", 1000))
default_k = int(os.environ.get("DEFAULT_K", 5))
model = SentenceTransformer(model_name)
docs = {}
def extract_text_from_pdf(reader: PdfReader) -> str:
"""Extract text from PDF pages
Parameters
----------
reader : PdfReader
PDF reader
Returns
-------
str
Raw text
"""
content = [page.extract_text().strip() for page in reader.pages]
return "\n\n".join(content).strip()
def convert(filename: str) -> str:
"""Convert file content to raw text
Parameters
----------
filename : str
The filename or path
Returns
-------
str
The raw text
Raises
------
ValueError
If the file type is not supported.
"""
plain_text_filetypes = [
".txt",
".csv",
".tsv",
".md",
".yaml",
".toml",
".json",
".json5",
".jsonc",
]
# Already a plain text file that wouldn't benefit from pandoc so return the content
if any(filename.endswith(ft) for ft in plain_text_filetypes):
with open(filename, "r", encoding="utf-8") as f:
return f.read()
if filename.endswith(".pdf"):
return extract_text_from_pdf(PdfReader(filename))
raise ValueError(f"Unsupported file type: {filename}")
def generate_chunks(text: str, max_length: int) -> list[str]:
"""Generate chunks from a file's raw text. Chunks are calculated based
on the `max_lenght` parameter and the split character (.)
Parameters
----------
text : str
The raw text
max_length : int
Maximum number of characters a chunk can have. Note that chunks
may not have this exact lenght, as another component is also
involved in the splitting process
Returns
-------
list[str]
A list of chunks/nodes
"""
segments = text.split(".")
chunks = []
chunk = ""
for current_segment in segments:
if len(chunk) < max_length:
chunk += current_segment
else:
chunks.append(chunk)
chunk = current_segment
if chunk:
chunks.append(chunk)
return chunks
@spaces.GPU
def predict(query: str, k: int = 5) -> str:
"""Find k most relevant chunks based on the given query
Parameters
----------
query : str
The input query
k : int, optional
Number of relevant chunks to return, by default 5
Returns
-------
str
The k chunks concatenated together as a single string.
Example
-------
If k=2, the returned string might look like:
"CONTEXT:\n\nchunk-1\n\nchunk-2"
"""
# Embed the query
query_embedding = model.encode(query, prompt_name="query")
# Initialize a list to store all chunks and their similarities across all documents
all_chunks = []
# Iterate through all documents
for doc in docs.values():
# Calculate dot product between query and document embeddings
similarities = np.dot(doc["embeddings"], query_embedding) / (
np.linalg.norm(doc["embeddings"]) * np.linalg.norm(query_embedding)
)
# Add chunks and similarities to the all_chunks list
all_chunks.extend(list(zip(doc["chunks"], similarities)))
# Sort all chunks by similarity
all_chunks.sort(key=lambda x: x[1], reverse=True)
return "CONTEXT:\n\n" + "\n\n".join(chunk for chunk, _ in all_chunks[:k])
def init():
"""Init function
It will load or calculate the embeddings
"""
global docs # pylint: disable=W0603
embeddings_file = Path("embeddings.pickle")
if embeddings_file.exists():
with open(embeddings_file, "rb") as embeddings_pickle:
docs = pickle.load(embeddings_pickle)
else:
for filename in glob.glob("sources/*"):
converted_doc = convert(filename)
chunks = generate_chunks(converted_doc, chunk_size)
embeddings = model.encode(chunks)
docs[filename] = {
"chunks": chunks,
"embeddings": embeddings,
}
with open(embeddings_file, "wb") as pickle_file:
pickle.dump(docs, pickle_file)
init()
gr.Interface(
predict,
inputs=[
gr.Textbox(label="Query asked about the documents"),
gr.Number(label="Number of relevant sources returned (k)", value=default_k),
],
outputs=[gr.Text(label="Relevant chunks")],
title="ContextQA tool - El Salvador",
description="Forked and customized RAG tool working with law documents from El Salvador",
).launch()