zaldivards commited on
Commit
0918d3a
·
1 Parent(s): 2d2dc23

refactor: cosine similarity and text splitting

Browse files
app.py CHANGED
@@ -1,46 +1,58 @@
1
- import gradio as gr
2
- import spaces
3
- import subprocess
4
  import os
5
- import shutil
6
- import string
7
- import random
8
  import glob
 
 
 
 
 
 
9
  from pypdf import PdfReader
10
  from sentence_transformers import SentenceTransformer
11
 
 
12
  model_name = os.environ.get("MODEL", "Snowflake/snowflake-arctic-embed-m")
13
- chunk_size = int(os.environ.get("CHUNK_SIZE", 128))
14
- default_max_characters = int(os.environ.get("DEFAULT_MAX_CHARACTERS", 258))
15
 
16
  model = SentenceTransformer(model_name)
17
- # model.to(device="cuda")
18
 
19
- @spaces.GPU
20
- def embed(queries, chunks) -> dict[str, list[tuple[str, float]]]:
21
- query_embeddings = model.encode(queries, prompt_name="query")
22
- document_embeddings = model.encode(chunks)
23
 
24
- scores = query_embeddings @ document_embeddings.T
25
- results = {}
26
- for query, query_scores in zip(queries, scores):
27
- chunk_idxs = [i for i in range(len(chunks))]
28
- # Get a structure like {query: [(chunk_idx, score), (chunk_idx, score), ...]}
29
- results[query] = list(zip(chunk_idxs, query_scores))
 
 
 
 
 
 
 
 
 
30
 
31
- return results
32
 
 
 
33
 
34
- def extract_text_from_pdf(reader):
35
- full_text = ""
36
- for idx, page in enumerate(reader.pages):
37
- text = page.extract_text()
38
- if len(text) > 0:
39
- full_text += f"---- Page {idx} ----\n" + page.extract_text() + "\n\n"
40
 
41
- return full_text.strip()
 
 
 
42
 
43
- def convert(filename) -> str:
 
 
 
 
44
  plain_text_filetypes = [
45
  ".txt",
46
  ".csv",
@@ -54,7 +66,7 @@ def convert(filename) -> str:
54
  ]
55
  # Already a plain text file that wouldn't benefit from pandoc so return the content
56
  if any(filename.endswith(ft) for ft in plain_text_filetypes):
57
- with open(filename, "r") as f:
58
  return f.read()
59
 
60
  if filename.endswith(".pdf"):
@@ -63,75 +75,116 @@ def convert(filename) -> str:
63
  raise ValueError(f"Unsupported file type: {filename}")
64
 
65
 
66
- def chunk_to_length(text, max_length=512):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  chunks = []
68
- while len(text) > max_length:
69
- chunks.append(text[:max_length])
70
- text = text[max_length:]
71
- chunks.append(text)
 
 
 
 
 
 
72
  return chunks
73
 
 
74
  @spaces.GPU
75
- def predict(query, max_characters) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  # Embed the query
77
  query_embedding = model.encode(query, prompt_name="query")
78
 
79
  # Initialize a list to store all chunks and their similarities across all documents
80
  all_chunks = []
81
-
82
  # Iterate through all documents
83
- for filename, doc in docs.items():
84
  # Calculate dot product between query and document embeddings
85
- similarities = doc["embeddings"] @ query_embedding.T
86
-
 
87
  # Add chunks and similarities to the all_chunks list
88
- all_chunks.extend([(filename, chunk, sim) for chunk, sim in zip(doc["chunks"], similarities)])
89
 
90
  # Sort all chunks by similarity
91
- all_chunks.sort(key=lambda x: x[2], reverse=True)
92
-
93
- # Initialize a dictionary to store relevant chunks for each document
94
- relevant_chunks = {}
95
-
96
- # Add most relevant chunks until max_characters is reached
97
- total_chars = 0
98
- for filename, chunk, _ in all_chunks:
99
- if total_chars + len(chunk) <= max_characters:
100
- if filename not in relevant_chunks:
101
- relevant_chunks[filename] = []
102
- relevant_chunks[filename].append(chunk)
103
- total_chars += len(chunk)
104
- else:
105
- break
106
 
107
- return relevant_chunks
108
 
109
 
 
 
110
 
111
- docs = {}
112
-
113
- for filename in glob.glob("sources/*"):
114
- if filename.endswith("add_your_files_here"):
115
- continue
116
-
117
- converted_doc = convert(filename)
 
 
 
 
 
 
 
 
 
 
 
118
 
119
- chunks = chunk_to_length(converted_doc, chunk_size)
120
- embeddings = model.encode(chunks)
121
 
122
- docs[filename] = {
123
- "chunks": chunks,
124
- "embeddings": embeddings,
125
- }
126
 
127
 
128
  gr.Interface(
129
  predict,
130
  inputs=[
131
  gr.Textbox(label="Query asked about the documents"),
132
- gr.Number(label="Max output characters", value=default_max_characters),
133
  ],
134
- outputs=[gr.JSON(label="Relevant chunks")],
135
- title="RAG Community Tool Template demo",
136
- description="This is a demo of the RAG Community Tool Template. To use RAG in HuggingChat with your own documents, start by cloning this space, add your documents to the `sources` folder, and then create a community tool with this space!",
137
- ).launch()
 
 
 
 
1
  import os
 
 
 
2
  import glob
3
+ import pickle
4
+ from pathlib import Path
5
+
6
+ import gradio as gr
7
+ import spaces
8
+ import numpy as np
9
  from pypdf import PdfReader
10
  from sentence_transformers import SentenceTransformer
11
 
12
+
13
  model_name = os.environ.get("MODEL", "Snowflake/snowflake-arctic-embed-m")
14
+ chunk_size = int(os.environ.get("CHUNK_SIZE", 1000))
15
+ default_k = int(os.environ.get("DEFAULT_K", 5))
16
 
17
  model = SentenceTransformer(model_name)
18
+ docs = {}
19
 
 
 
 
 
20
 
21
+ def extract_text_from_pdf(reader: PdfReader) -> str:
22
+ """Extract text from PDF pages
23
+
24
+ Parameters
25
+ ----------
26
+ reader : PdfReader
27
+ PDF reader
28
+
29
+ Returns
30
+ -------
31
+ str
32
+ Raw text
33
+ """
34
+ content = [page.extract_text().strip() for page in reader.pages]
35
+ return "\n\n".join(content).strip()
36
 
 
37
 
38
+ def convert(filename: str) -> str:
39
+ """Convert file content to raw text
40
 
41
+ Parameters
42
+ ----------
43
+ filename : str
44
+ The filename or path
 
 
45
 
46
+ Returns
47
+ -------
48
+ str
49
+ The raw text
50
 
51
+ Raises
52
+ ------
53
+ ValueError
54
+ If the file type is not supported.
55
+ """
56
  plain_text_filetypes = [
57
  ".txt",
58
  ".csv",
 
66
  ]
67
  # Already a plain text file that wouldn't benefit from pandoc so return the content
68
  if any(filename.endswith(ft) for ft in plain_text_filetypes):
69
+ with open(filename, "r", encoding="utf-8") as f:
70
  return f.read()
71
 
72
  if filename.endswith(".pdf"):
 
75
  raise ValueError(f"Unsupported file type: {filename}")
76
 
77
 
78
+ def generate_chunks(text: str, max_length: int) -> list[str]:
79
+ """Generate chunks from a file's raw text. Chunks are calculated based
80
+ on the `max_lenght` parameter and the split character (.)
81
+
82
+ Parameters
83
+ ----------
84
+ text : str
85
+ The raw text
86
+ max_length : int
87
+ Maximum number of characters a chunk can have. Note that chunks
88
+ may not have this exact lenght, as another component is also
89
+ involved in the splitting process
90
+
91
+ Returns
92
+ -------
93
+ list[str]
94
+ A list of chunks/nodes
95
+ """
96
+
97
+ segments = text.split(".")
98
  chunks = []
99
+ chunk = ""
100
+
101
+ for current_segment in segments:
102
+ if len(chunk) < max_length:
103
+ chunk += current_segment
104
+ else:
105
+ chunks.append(chunk)
106
+ chunk = current_segment
107
+ if chunk:
108
+ chunks.append(chunk)
109
  return chunks
110
 
111
+
112
  @spaces.GPU
113
+ def predict(query: str, k: int = 5) -> str:
114
+ """Find k most relevant chunks based on the given query
115
+
116
+ Parameters
117
+ ----------
118
+ query : str
119
+ The input query
120
+ k : int, optional
121
+ Number of relevant chunks to return, by default 5
122
+
123
+ Returns
124
+ -------
125
+ str
126
+ The k chunks concatenated together as a single string.
127
+
128
+ Example
129
+ -------
130
+ If k=2, the returned string might look like:
131
+
132
+ "CONTEXT:\n\nchunk-1\n\nchunk-2"
133
+
134
+ """
135
  # Embed the query
136
  query_embedding = model.encode(query, prompt_name="query")
137
 
138
  # Initialize a list to store all chunks and their similarities across all documents
139
  all_chunks = []
 
140
  # Iterate through all documents
141
+ for doc in docs.values():
142
  # Calculate dot product between query and document embeddings
143
+ similarities = np.dot(doc["embeddings"], query_embedding) / (
144
+ np.linalg.norm(doc["embeddings"]) * np.linalg.norm(query_embedding)
145
+ )
146
  # Add chunks and similarities to the all_chunks list
147
+ all_chunks.extend(list(zip(doc["chunks"], similarities)))
148
 
149
  # Sort all chunks by similarity
150
+ all_chunks.sort(key=lambda x: x[1], reverse=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
+ return "CONTEXT:\n\n" + "\n\n".join(chunk for chunk, _ in all_chunks[:k])
153
 
154
 
155
+ def init():
156
+ """Init function
157
 
158
+ It will load or calculate the embeddings
159
+ """
160
+ global docs # pylint: disable=W0603
161
+ embeddings_file = Path("embeddings.pickle")
162
+ if embeddings_file.exists():
163
+ with open(embeddings_file, "rb") as embeddings_pickle:
164
+ docs = pickle.load(embeddings_pickle)
165
+ else:
166
+ for filename in glob.glob("sources/*"):
167
+ converted_doc = convert(filename)
168
+ chunks = generate_chunks(converted_doc, chunk_size)
169
+ embeddings = model.encode(chunks)
170
+ docs[filename] = {
171
+ "chunks": chunks,
172
+ "embeddings": embeddings,
173
+ }
174
+ with open(embeddings_file, "wb") as pickle_file:
175
+ pickle.dump(docs, pickle_file)
176
 
 
 
177
 
178
+ init()
 
 
 
179
 
180
 
181
  gr.Interface(
182
  predict,
183
  inputs=[
184
  gr.Textbox(label="Query asked about the documents"),
185
+ gr.Number(label="Number of relevant sources returned (k)", value=default_k),
186
  ],
187
+ outputs=[gr.Text(label="Relevant chunks")],
188
+ title="ContextQA tool - El Salvador",
189
+ description="Forked and customized RAG tool working with law documents from El Salvador",
190
+ ).launch()
sources/Constitucion de la Republica.pdf ADDED
Binary file (321 kB). View file
 
sources/GeForce-RTX-4090-GAMING-X-TRIO-24G.pdf DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:96cb2dd9797ac7dca9df67a7fd499bb45eecb15219c617bb2d73a3eec19649e6
3
- size 1519838
 
 
 
 
sources/Reglamento General de Transito y Seguridad Vial correcto.pdf ADDED
Binary file (387 kB). View file
 
sources/add_your_files_here DELETED
File without changes
sources/march19newarmouriessamplemenu.pdf DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:886365911dc9cea7d983108b532729e1a895388b27c096bc6554535073ca351a
3
- size 52843