rahideer commited on
Commit
bc01fb2
·
verified ·
1 Parent(s): 9805430

Create rag_pipeline.py

Browse files
Files changed (1) hide show
  1. rag_pipeline.py +31 -0
rag_pipeline.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
+ from sentence_transformers import SentenceTransformer
3
+ import numpy as np
4
+ import faiss
5
+ from datasets import load_dataset
6
+
7
+ # Load Dataset
8
+ dataset = load_dataset("pubmed_qa", "pqa_labeled")
9
+ corpus = [entry['context'] for entry in dataset['train']]
10
+
11
+ # Embedding model
12
+ embed_model = SentenceTransformer('pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb')
13
+ corpus_embeddings = embed_model.encode(corpus, show_progress_bar=True)
14
+
15
+ # FAISS index
16
+ index = faiss.IndexFlatL2(len(corpus_embeddings[0]))
17
+ index.add(np.array(corpus_embeddings))
18
+
19
+ # Generator model
20
+ tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")
21
+ model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large")
22
+
23
+ # Generate Answer Function
24
+ def generate_answer(query, index, embeddings, corpus, embed_model):
25
+ query_embedding = embed_model.encode([query])
26
+ D, I = index.search(np.array(query_embedding), k=5)
27
+ retrieved = [corpus[i] for i in I[0]]
28
+ prompt = f"Context: {retrieved}\n\nQuestion: {query}\n\nAnswer:"
29
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
30
+ outputs = model.generate(**inputs, max_new_tokens=128)
31
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)