File size: 1,257 Bytes
bc01fb2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from sentence_transformers import SentenceTransformer
import numpy as np
import faiss
from datasets import load_dataset
# Load Dataset
dataset = load_dataset("pubmed_qa", "pqa_labeled")
corpus = [entry['context'] for entry in dataset['train']]
# Embedding model
embed_model = SentenceTransformer('pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb')
corpus_embeddings = embed_model.encode(corpus, show_progress_bar=True)
# FAISS index
index = faiss.IndexFlatL2(len(corpus_embeddings[0]))
index.add(np.array(corpus_embeddings))
# Generator model
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large")
# Generate Answer Function
def generate_answer(query, index, embeddings, corpus, embed_model):
query_embedding = embed_model.encode([query])
D, I = index.search(np.array(query_embedding), k=5)
retrieved = [corpus[i] for i in I[0]]
prompt = f"Context: {retrieved}\n\nQuestion: {query}\n\nAnswer:"
inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
outputs = model.generate(**inputs, max_new_tokens=128)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|