import logging import os import requests from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch from openai import OpenAI from huggingface_hub import snapshot_download, InferenceClient from langchain_community.vectorstores import FAISS from langchain_community.embeddings import HuggingFaceEmbeddings logging.basicConfig(level=logging.INFO, format='[%(asctime)s][%(levelname)s] - %(message)s') # logging.getLogger().setLevel(logging.INFO) class RAG: NO_ANSWER_MESSAGE: str = "Ho sento, no he pogut respondre la teva pregunta." #vectorstore = "index-intfloat_multilingual-e5-small-500-100-CA-ES" # mixed #vectorstore = "vectorestore" # CA only #vectorstore = "index-BAAI_bge-m3-1500-200-recursive_splitter-CA_ES_UE" def __init__(self, vs_hf_repo_path, vectorstore_path, hf_token, embeddings_model, model_name, rerank_model, rerank_number_contexts): self.vs_hf_repo_path = vs_hf_repo_path self.vectorstore_path=vectorstore_path self.model_name = model_name self.hf_token = hf_token self.rerank_model = rerank_model self.rerank_number_contexts = rerank_number_contexts # load vectore store embeddings = HuggingFaceEmbeddings(model_name=embeddings_model, model_kwargs={'device': 'cpu'}) if vs_hf_repo_path: hf_vectorstore = snapshot_download(repo_id=vs_hf_repo_path) self.vectore_store = FAISS.load_local(hf_vectorstore, embeddings, allow_dangerous_deserialization=True) else: self.vectore_store = FAISS.load_local(self.vectorstore_path, embeddings, allow_dangerous_deserialization=True) logging.info("RAG loaded!") logging.info( self.vectore_store) def rerank_contexts(self, instruction, contexts, number_of_contexts=1): """ Rerank the contexts based on their relevance to the given instruction. """ rerank_model = self.rerank_model tokenizer = AutoTokenizer.from_pretrained(rerank_model) model = AutoModelForSequenceClassification.from_pretrained(rerank_model) logging.info("Rerank model loaded!") def get_score(query, passage): """Calculate the relevance score of a passage with respect to a query.""" inputs = tokenizer(query, passage, return_tensors='pt', truncation=True, padding=True, max_length=512) print("Inputs: ", inputs) with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits score = logits.view(-1, ).float() print("Score: ", score) return score scores = [get_score(instruction, c[0].page_content) for c in contexts] print("Scores: ", scores) combined = list(zip(contexts, scores)) sorted_combined = sorted(combined, key=lambda x: x[1], reverse=True) sorted_texts, _ = zip(*sorted_combined) return sorted_texts[:number_of_contexts] def get_context(self, instruction, number_of_contexts=3): """Retrieve the most relevant contexts for a given instruction.""" logging.info("RETRIEVE DOCUMENTS") logging.info(f"Instruction: {instruction}") # Embed the query # ============================================================================================================== embedding = self.vectore_store._embed_query(instruction) logging.info(f"Query embedding generated: {len(embedding)}") # Retrieve documents # ============================================================================================================== documents_retrieved = self.vectore_store.similarity_search_with_score_by_vector(embedding, k=number_of_contexts) logging.info(f"Documents retrieved: {len(documents_retrieved)}") # Reranking # ============================================================================================================== if self.rerank_model: logging.info("RERANK DOCUMENTS") documents_reranked = self.rerank_contexts(instruction, documents_retrieved, number_of_contexts=number_of_contexts) else: logging.info("NO RERANKING") documents_reranked = documents_retrieved[:number_of_contexts] # ============================================================================================================== return documents_reranked def predict_dolly(self, instruction, context, model_parameters): api_key = os.getenv("HF_TOKEN") headers = { "Accept" : "application/json", "Authorization": f"Bearer {api_key}", "Content-Type": "application/json" } query = f"### Instruction\n{instruction}\n\n### Context\n{context}\n\n### Answer\n " #prompt = "You are a helpful assistant. Answer the question using only the context you are provided with. If it is not possible to do it with the context, just say 'I can't answer'. <|endoftext|>" payload = { "inputs": query, "parameters": model_parameters } response = requests.post(self.model_name, headers=headers, json=payload) return response.json()[0]["generated_text"].split("###")[-1][8:] def predict_completion(self, instruction, context, model_parameters): client = OpenAI( base_url=os.getenv("MODEL"), api_key=os.getenv("HF_TOKEN") ) query = f"Context:\n{context}\n\nQuestion:\n{instruction}" chat_completion = client.chat.completions.create( model="tgi", messages=[ {"role": "user", "content": query} ], temperature=model_parameters["temperature"], max_tokens=model_parameters["max_new_tokens"], stream=False, stop=["<|im_end|>"], extra_body = { "presence_penalty": model_parameters["repetition_penalty"] - 2, "do_sample": False } ) response = chat_completion.choices[0].message.content return response def beautiful_context(self, docs): text_context = "" full_context = "" source_context = [] for doc in docs: # print("="*100) # logging.info(doc) text_context += doc[0].page_content full_context += doc[0].page_content + "\n" full_context += doc[0].metadata["title"] + "\n\n" full_context += doc[0].metadata["url"] + "\n\n" source_context.append(doc[0].metadata["url"]) return text_context, full_context, source_context def get_response(self, prompt: str, model_parameters: dict) -> str: try: docs = self.get_context(prompt, model_parameters["NUM_CHUNKS"]) response = "" for i, (doc, score) in enumerate(docs): response += "\n\n" + "="*100 response += f"\nDocument {i+1}" response += "\n" + "="*100 response += f"\nScore: {score:.5f}" response += f"\nTitle: {doc.metadata['title']}" response += f"\nURL: {doc.metadata['url']}" response += f"\nID: {doc.metadata['id']}" response += f"\nStart index: {doc.metadata['start_index']}" # response += f"\nSource: {doc.metadata['src']}" # response += f"\nRedirected: {doc.metadata['redirected']}" # url = doc.metadata['url'] # response += f"\nRevision ID: {url}" # response += f'\nURL: {url}
' response += "\n" + "-"*100 + "\n" response += f"\nContent:\n" response += doc.page_content full_context = "" source = [] if not response: return self.NO_ANSWER_MESSAGE return response, full_context, source except Exception as err: print(err)