wirag / rag.py
nurasaki's picture
Added retrieval num chunks options
a880965
import logging
import os
import requests
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from openai import OpenAI
from huggingface_hub import snapshot_download, InferenceClient
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
logging.basicConfig(level=logging.INFO, format='[%(asctime)s][%(levelname)s] - %(message)s')
# logging.getLogger().setLevel(logging.INFO)
class RAG:
NO_ANSWER_MESSAGE: str = "Ho sento, no he pogut respondre la teva pregunta."
#vectorstore = "index-intfloat_multilingual-e5-small-500-100-CA-ES" # mixed
#vectorstore = "vectorestore" # CA only
#vectorstore = "index-BAAI_bge-m3-1500-200-recursive_splitter-CA_ES_UE"
def __init__(self, vs_hf_repo_path, vectorstore_path, hf_token, embeddings_model, model_name, rerank_model, rerank_number_contexts):
self.vs_hf_repo_path = vs_hf_repo_path
self.vectorstore_path=vectorstore_path
self.model_name = model_name
self.hf_token = hf_token
self.rerank_model = rerank_model
self.rerank_number_contexts = rerank_number_contexts
# load vectore store
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model, model_kwargs={'device': 'cpu'})
if vs_hf_repo_path:
hf_vectorstore = snapshot_download(repo_id=vs_hf_repo_path)
self.vectore_store = FAISS.load_local(hf_vectorstore, embeddings, allow_dangerous_deserialization=True)
else:
self.vectore_store = FAISS.load_local(self.vectorstore_path, embeddings, allow_dangerous_deserialization=True)
logging.info("RAG loaded!")
logging.info( self.vectore_store)
def rerank_contexts(self, instruction, contexts, number_of_contexts=1):
"""
Rerank the contexts based on their relevance to the given instruction.
"""
rerank_model = self.rerank_model
tokenizer = AutoTokenizer.from_pretrained(rerank_model)
model = AutoModelForSequenceClassification.from_pretrained(rerank_model)
logging.info("Rerank model loaded!")
def get_score(query, passage):
"""Calculate the relevance score of a passage with respect to a query."""
inputs = tokenizer(query, passage, return_tensors='pt', truncation=True, padding=True, max_length=512)
print("Inputs: ", inputs)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
score = logits.view(-1, ).float()
print("Score: ", score)
return score
scores = [get_score(instruction, c[0].page_content) for c in contexts]
print("Scores: ", scores)
combined = list(zip(contexts, scores))
sorted_combined = sorted(combined, key=lambda x: x[1], reverse=True)
sorted_texts, _ = zip(*sorted_combined)
return sorted_texts[:number_of_contexts]
def get_context(self, instruction, number_of_contexts=3):
"""Retrieve the most relevant contexts for a given instruction."""
logging.info("RETRIEVE DOCUMENTS")
logging.info(f"Instruction: {instruction}")
# Embed the query
# ==============================================================================================================
embedding = self.vectore_store._embed_query(instruction)
logging.info(f"Query embedding generated: {len(embedding)}")
# Retrieve documents
# ==============================================================================================================
documents_retrieved = self.vectore_store.similarity_search_with_score_by_vector(embedding, k=number_of_contexts)
logging.info(f"Documents retrieved: {len(documents_retrieved)}")
# Reranking
# ==============================================================================================================
if self.rerank_model:
logging.info("RERANK DOCUMENTS")
documents_reranked = self.rerank_contexts(instruction, documents_retrieved, number_of_contexts=number_of_contexts)
else:
logging.info("NO RERANKING")
documents_reranked = documents_retrieved[:number_of_contexts]
# ==============================================================================================================
return documents_reranked
def predict_dolly(self, instruction, context, model_parameters):
api_key = os.getenv("HF_TOKEN")
headers = {
"Accept" : "application/json",
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
query = f"### Instruction\n{instruction}\n\n### Context\n{context}\n\n### Answer\n "
#prompt = "You are a helpful assistant. Answer the question using only the context you are provided with. If it is not possible to do it with the context, just say 'I can't answer'. <|endoftext|>"
payload = {
"inputs": query,
"parameters": model_parameters
}
response = requests.post(self.model_name, headers=headers, json=payload)
return response.json()[0]["generated_text"].split("###")[-1][8:]
def predict_completion(self, instruction, context, model_parameters):
client = OpenAI(
base_url=os.getenv("MODEL"),
api_key=os.getenv("HF_TOKEN")
)
query = f"Context:\n{context}\n\nQuestion:\n{instruction}"
chat_completion = client.chat.completions.create(
model="tgi",
messages=[
{"role": "user", "content": query}
],
temperature=model_parameters["temperature"],
max_tokens=model_parameters["max_new_tokens"],
stream=False,
stop=["<|im_end|>"],
extra_body = {
"presence_penalty": model_parameters["repetition_penalty"] - 2,
"do_sample": False
}
)
response = chat_completion.choices[0].message.content
return response
def beautiful_context(self, docs):
text_context = ""
full_context = ""
source_context = []
for doc in docs:
# print("="*100)
# logging.info(doc)
text_context += doc[0].page_content
full_context += doc[0].page_content + "\n"
full_context += doc[0].metadata["title"] + "\n\n"
full_context += doc[0].metadata["url"] + "\n\n"
source_context.append(doc[0].metadata["url"])
return text_context, full_context, source_context
def get_response(self, prompt: str, model_parameters: dict) -> str:
try:
docs = self.get_context(prompt, model_parameters["NUM_CHUNKS"])
response = ""
for i, (doc, score) in enumerate(docs):
response += "\n\n" + "="*100
response += f"\nDocument {i+1}"
response += "\n" + "="*100
response += f"\nScore: {score:.5f}"
response += f"\nTitle: {doc.metadata['title']}"
response += f"\nURL: {doc.metadata['url']}"
response += f"\nID: {doc.metadata['id']}"
response += f"\nStart index: {doc.metadata['start_index']}"
# response += f"\nSource: {doc.metadata['src']}"
# response += f"\nRedirected: {doc.metadata['redirected']}"
# url = doc.metadata['url']
# response += f"\nRevision ID: {url}"
# response += f'\nURL: <a href="{url}" target="_blank">{url}</a><br>'
response += "\n" + "-"*100 + "\n"
response += f"\nContent:\n"
response += doc.page_content
full_context = ""
source = []
if not response:
return self.NO_ANSWER_MESSAGE
return response, full_context, source
except Exception as err:
print(err)