Kadi-IAM commited on
Commit
f8bb00c
·
verified ·
1 Parent(s): 71632fa

Upload evaluation_example.ipynb

Browse files
Files changed (1) hide show
  1. evaluation_example.ipynb +316 -0
evaluation_example.ipynb ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from tqdm.auto import tqdm\n",
10
+ "import pandas as pd\n",
11
+ "import time\n",
12
+ "\n",
13
+ "from langchain.document_loaders import PyMuPDFLoader\n",
14
+ "from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
15
+ "\n",
16
+ "pd.set_option(\"display.max_colwidth\", None)\n",
17
+ "\n",
18
+ "# Set ChatMistralAI API KEY\n",
19
+ "# e.g., export MISTRAL_API_KEY==your_api_key_here\n",
20
+ "# or save apy key in .env file\n",
21
+ "from dotenv import load_dotenv\n",
22
+ "load_dotenv()"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": null,
28
+ "metadata": {},
29
+ "outputs": [],
30
+ "source": [
31
+ "# Load pdf file\n",
32
+ "filepath = \"data/documents/Brandt et al_2024_Kadi_info_page.pdf\"\n",
33
+ "loader_module = PyMuPDFLoader\n",
34
+ "loader = loader_module(filepath)\n",
35
+ "document = loader.load()"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": null,
41
+ "metadata": {},
42
+ "outputs": [],
43
+ "source": [
44
+ "# Split docs into chunks\n",
45
+ "text_splitter = RecursiveCharacterTextSplitter(\n",
46
+ " chunk_size=2000,\n",
47
+ " chunk_overlap=200,\n",
48
+ " add_start_index=True,\n",
49
+ " separators=[\"\\n\\n\", \"\\n\", \".\", \" \", \"\"],\n",
50
+ ")\n",
51
+ "\n",
52
+ "docs_processed = []\n",
53
+ "for doc in document:\n",
54
+ " docs_processed += text_splitter.split_documents([doc])\n",
55
+ "\n"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": null,
61
+ "metadata": {},
62
+ "outputs": [],
63
+ "source": [
64
+ "# Create LLM, here we use MistralAI\n",
65
+ "from langchain_mistralai.chat_models import ChatMistralAI\n",
66
+ "\n",
67
+ "llm = ChatMistralAI(\n",
68
+ " model=\"mistral-large-latest\"\n",
69
+ ")\n",
70
+ "\n",
71
+ "llm.invoke(\"hello\") # test llm"
72
+ ]
73
+ },
74
+ {
75
+ "cell_type": "code",
76
+ "execution_count": null,
77
+ "metadata": {},
78
+ "outputs": [],
79
+ "source": [
80
+ "QA_generation_prompt = \"\"\"\n",
81
+ "Your task is to write a factoid question and an answer given a context.\n",
82
+ "Your factoid question should be answerable with a specific, concise piece of factual information from the context.\n",
83
+ "Your factoid question should be formulated in the same style as questions users could ask in a search engine. Users are usually scientific researchers in the field of materials science.\n",
84
+ "This means that your factoid question MUST NOT mention something like \"according to the passage\" or \"context\".\n",
85
+ "Please ask the specific question instead of the general question, like 'What is the key information in the given paragraph?'.\n",
86
+ "\n",
87
+ "Provide your answer as follows:\n",
88
+ "\n",
89
+ "Output:::\n",
90
+ "Factoid question: (your factoid question)\n",
91
+ "Answer: (your answer to the factoid question)\n",
92
+ "\n",
93
+ "Now here is the context.\n",
94
+ "\n",
95
+ "Context: {context}\\n\n",
96
+ "Output:::\"\"\"\n",
97
+ "\n",
98
+ "# Or\n",
99
+ "# Ref: https://mlflow.org/docs/latest/llms/rag/notebooks/question-generation-retrieval-evaluation.html\n",
100
+ "# QA_generation_prompt = \"\"\"\n",
101
+ "# Please generate a question asking for the key information in the given paragraph.\n",
102
+ "# Also answer the questions using the information in the given paragraph.\n",
103
+ "# Please ask the specific question instead of the general question, like\n",
104
+ "# 'What is the key information in the given paragraph?'.\n",
105
+ "# Please generate the answer using as much information as possible.\n",
106
+ "# If you are unable to answer it, please generate the answer as 'I don't know.'\n",
107
+ "\n",
108
+ "# Provide your answer as follows:\n",
109
+ "\n",
110
+ "# Output:::\n",
111
+ "# Factoid question: (your factoid question)\n",
112
+ "# Answer: (your answer to the factoid question)\n",
113
+ "\n",
114
+ "# Now here is the context.\n",
115
+ "\n",
116
+ "# Context: {context}\\n\n",
117
+ "# Output:::\"\"\""
118
+ ]
119
+ },
120
+ {
121
+ "cell_type": "code",
122
+ "execution_count": null,
123
+ "metadata": {},
124
+ "outputs": [],
125
+ "source": [
126
+ "# Generate QA pairs\n",
127
+ "\n",
128
+ "import random\n",
129
+ "\n",
130
+ "N_GENERATIONS = 5 # generate only 5 QA couples here for cost and time considerations\n",
131
+ "\n",
132
+ "print(f\"Generating {N_GENERATIONS} QA couples...\")\n",
133
+ "\n",
134
+ "outputs = []\n",
135
+ "for sampled_context in tqdm(random.choices(docs_processed, k=N_GENERATIONS)):\n",
136
+ " # Generate QA pairs\n",
137
+ " output_QA_couple = llm.invoke(QA_generation_prompt.format(context=sampled_context.page_content)).content\n",
138
+ " try:\n",
139
+ " question = output_QA_couple.split(\"Factoid question: \")[-1].split(\"Answer: \")[0]\n",
140
+ " answer = output_QA_couple.split(\"Answer: \")[-1]\n",
141
+ " assert len(answer) < 500, \"Answer is too long\"\n",
142
+ " outputs.append(\n",
143
+ " {\n",
144
+ " \"context\": sampled_context.page_content,\n",
145
+ " \"question\": question,\n",
146
+ " \"answer\": answer,\n",
147
+ " \"source_doc\": sampled_context.metadata[\"source\"],\n",
148
+ " }\n",
149
+ " )\n",
150
+ " time.sleep(3) # sleep for llm rate limitation\n",
151
+ " except:\n",
152
+ " time.sleep(3) # sleep for llm rate limitation\n",
153
+ " continue"
154
+ ]
155
+ },
156
+ {
157
+ "cell_type": "code",
158
+ "execution_count": null,
159
+ "metadata": {},
160
+ "outputs": [],
161
+ "source": [
162
+ "reference_df = pd.DataFrame(outputs)\n",
163
+ "display(reference_df.head(1))"
164
+ ]
165
+ },
166
+ {
167
+ "cell_type": "code",
168
+ "execution_count": null,
169
+ "metadata": {},
170
+ "outputs": [],
171
+ "source": [
172
+ "# build a simple rag chain\n",
173
+ "from langchain_huggingface import HuggingFaceEmbeddings\n",
174
+ "from langchain.vectorstores import FAISS\n",
175
+ "\n",
176
+ "chunk_size=1024\n",
177
+ "chunk_overlap=256\n",
178
+ "splitter = RecursiveCharacterTextSplitter(\n",
179
+ " separators=[\"\\n\\n\", \"\\n\"], chunk_size=chunk_size, chunk_overlap=chunk_overlap\n",
180
+ ")\n",
181
+ "doc_chunks = splitter.split_documents(document)\n",
182
+ "\n",
183
+ "embeddings = HuggingFaceEmbeddings(model_name=\"all-mpnet-base-v2\")\n",
184
+ "\n",
185
+ "vectorstore = FAISS.from_documents(doc_chunks, embedding=embeddings)\n",
186
+ "\n",
187
+ "retriever = vectorstore.as_retriever()\n",
188
+ "\n",
189
+ "from langchain.chains import RetrievalQA\n",
190
+ "\n",
191
+ "rag_chain = RetrievalQA.from_llm(\n",
192
+ " llm=llm, retriever=retriever, return_source_documents=True\n",
193
+ " )"
194
+ ]
195
+ },
196
+ {
197
+ "cell_type": "code",
198
+ "execution_count": null,
199
+ "metadata": {},
200
+ "outputs": [],
201
+ "source": [
202
+ "# Prepare evaluation data set\n",
203
+ "def prepare_eval_dataset(reference_df, rag_chain):\n",
204
+ " \n",
205
+ " print(\"now loading evaluation dataset...\")\n",
206
+ " from datasets import Dataset\n",
207
+ " # Read reference file\n",
208
+ " df = reference_df\n",
209
+ "\n",
210
+ " # Add anwsers from rag_chain\n",
211
+ " questions = df[\"question\"].values\n",
212
+ " ground_truth = []\n",
213
+ " for a in df[\"answer\"].values:\n",
214
+ " ground_truth.append(a) # [a] for older version of ragas\n",
215
+ " answers = []\n",
216
+ " contexts = []\n",
217
+ "\n",
218
+ " # Get anwswers from rag_chain\n",
219
+ " print(\"now getting anwsers from QA llm...\")\n",
220
+ " for query in questions:\n",
221
+ " results = rag_chain({\"query\": query})\n",
222
+ " answers.append(results[\"result\"])\n",
223
+ " contexts.append([docs.page_content for docs in results[\"source_documents\"]])\n",
224
+ " time.sleep(3) # sleep for llm rate limitation\n",
225
+ "\n",
226
+ " # To dict\n",
227
+ " data = {\n",
228
+ " \"question\": questions,\n",
229
+ " \"answer\": answers,\n",
230
+ " \"contexts\": contexts,\n",
231
+ " \"ground_truth\": ground_truth,\n",
232
+ " }\n",
233
+ "\n",
234
+ " # Convert dict to dataset\n",
235
+ " dataset = Dataset.from_dict(data)\n",
236
+ " return dataset\n",
237
+ "\n",
238
+ "eval_dataset = prepare_eval_dataset(reference_df, rag_chain)\n",
239
+ "eval_dataset\n"
240
+ ]
241
+ },
242
+ {
243
+ "cell_type": "code",
244
+ "execution_count": null,
245
+ "metadata": {},
246
+ "outputs": [],
247
+ "source": [
248
+ "# Ragas evaluation\n",
249
+ "from ragas.llms import LangchainLLMWrapper\n",
250
+ "eval_llm = LangchainLLMWrapper(llm)\n",
251
+ "\n",
252
+ "from ragas import evaluate\n",
253
+ "from ragas.metrics import (\n",
254
+ " faithfulness,\n",
255
+ " answer_relevancy,\n",
256
+ " context_recall,\n",
257
+ " context_precision,\n",
258
+ " answer_correctness,\n",
259
+ ")\n",
260
+ "result_eval_df = evaluate(\n",
261
+ " dataset=eval_dataset,\n",
262
+ " metrics=[\n",
263
+ " context_precision,\n",
264
+ " context_recall,\n",
265
+ " faithfulness,\n",
266
+ " answer_relevancy,\n",
267
+ " answer_correctness,\n",
268
+ " ],\n",
269
+ " llm=eval_llm, embeddings=embeddings,\n",
270
+ " raise_exceptions=False,\n",
271
+ ")\n",
272
+ "\n",
273
+ "result_eval_df = result_eval_df.to_pandas() # can take a while"
274
+ ]
275
+ },
276
+ {
277
+ "cell_type": "code",
278
+ "execution_count": null,
279
+ "metadata": {},
280
+ "outputs": [],
281
+ "source": [
282
+ "# Check results\n",
283
+ "result_eval_df\n",
284
+ "# if you get NaN in results, check \"Frequently Asked Questions\" in Ragas for help"
285
+ ]
286
+ },
287
+ {
288
+ "cell_type": "code",
289
+ "execution_count": null,
290
+ "metadata": {},
291
+ "outputs": [],
292
+ "source": []
293
+ }
294
+ ],
295
+ "metadata": {
296
+ "kernelspec": {
297
+ "display_name": "Python 3",
298
+ "language": "python",
299
+ "name": "python3"
300
+ },
301
+ "language_info": {
302
+ "codemirror_mode": {
303
+ "name": "ipython",
304
+ "version": 3
305
+ },
306
+ "file_extension": ".py",
307
+ "mimetype": "text/x-python",
308
+ "name": "python",
309
+ "nbconvert_exporter": "python",
310
+ "pygments_lexer": "ipython3",
311
+ "version": "3.12.1"
312
+ }
313
+ },
314
+ "nbformat": 4,
315
+ "nbformat_minor": 2
316
+ }