import glob import os import pickle import re from pathlib import Path import gradio as gr import spaces import numpy as np from pypdf import PdfReader from transformers import AutoModel chunk_size = int(os.environ.get("CHUNK_SIZE", 1000)) default_k = int(os.environ.get("DEFAULT_K", 5)) model = AutoModel.from_pretrained("jinaai/jina-embeddings-v2-base-es", trust_remote_code=True) docs = {} def extract_text_from_pdf(reader: PdfReader) -> str: """Extract text from PDF pages Parameters ---------- reader : PdfReader PDF reader Returns ------- str Raw text """ content = [page.extract_text().strip() for page in reader.pages] return "\n\n".join(content).strip() def convert(filename: str) -> str: """Convert file content to raw text Parameters ---------- filename : str The filename or path Returns ------- str The raw text Raises ------ ValueError If the file type is not supported. """ plain_text_filetypes = [ ".txt", ".csv", ".tsv", ".md", ".yaml", ".toml", ".json", ".json5", ".jsonc", ] # Already a plain text file that wouldn't benefit from pandoc so return the content if any(filename.endswith(ft) for ft in plain_text_filetypes): with open(filename, "r", encoding="utf-8") as f: return f.read() if filename.endswith(".pdf"): return extract_text_from_pdf(PdfReader(filename)) raise ValueError(f"Unsupported file type: {filename}") def generate_chunks(text: str, max_length: int) -> list[str]: """Generate chunks from a file's raw text. Chunks are calculated based on the `max_lenght` parameter and the split character (.) Parameters ---------- text : str The raw text max_length : int Maximum number of characters a chunk can have. Note that chunks may not have this exact lenght, as another component is also involved in the splitting process Returns ------- list[str] A list of chunks/nodes """ segments = text.split(".") chunks = [] chunk = "" for current_segment in segments: # try to normalize the current chunk current_segment = re.sub(r"\s+", " ", current_segment).strip() if len(chunk) < max_length: chunk += f". {current_segment}" else: chunks.append(chunk) chunk = current_segment if chunk: chunks.append(chunk) return chunks @spaces.GPU def predict(query: str, k: int = 5) -> str: """Find k most relevant chunks based on the given query Parameters ---------- query : str The input query k : int, optional Number of relevant chunks to return, by default 5 Returns ------- str The k chunks concatenated together as a single string. Example ------- If k=2, the returned string might look like: "CONTEXT:\n\nchunk-1\n\nchunk-2" """ # Embed the query query_embedding = model.encode(query) # Initialize a list to store all chunks and their similarities across all documents all_chunks = [] # Iterate through all documents for filename, doc in docs.items(): # Calculate cosine similarity between the query and the document embeddings similarities = np.dot(doc["embeddings"], query_embedding) / ( np.linalg.norm(doc["embeddings"]) * np.linalg.norm(query_embedding) ) # Add chunks and similarities to the all_chunks list all_chunks.extend([(filename, chunk, sim) for chunk, sim in zip(doc["chunks"], similarities)]) # Sort all chunks by similarity all_chunks.sort(key=lambda x: x[2], reverse=True) return "CONTEXT:\n\n" + "\n\n".join(f"{filename}: {chunk}" for filename, chunk, _ in all_chunks[:k]) def init(): """Init function It will load or calculate the embeddings """ global docs # pylint: disable=W0603 embeddings_file = Path("embeddings-es.pickle") if embeddings_file.exists(): with open(embeddings_file, "rb") as embeddings_pickle: docs = pickle.load(embeddings_pickle) else: for filename in glob.glob("sources/*"): converted_doc = convert(filename) chunks = generate_chunks(converted_doc, chunk_size) embeddings = model.encode(chunks) # get the filename and slugify it docs[filename.rsplit("/", 1)[-1].lower().replace(" ", "-")] = { "chunks": chunks, "embeddings": embeddings, } with open(embeddings_file, "wb") as pickle_file: pickle.dump(docs, pickle_file) init() gr.Interface( predict, inputs=[ gr.Textbox(label="Query asked about the documents"), gr.Number(label="Number of relevant sources returned (k)", value=default_k), ], outputs=[gr.Text(label="Relevant chunks")], title="ContextQA tool - El Salvador", description="Forked and customized RAG tool working with law documents from El Salvador", ).launch()