import gradio as gr from datasets import load_dataset import os import spaces from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig import torch from threading import Thread from sentence_transformers import SentenceTransformer import numpy as np token = os.environ["HF_TOKEN"] ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1") dataset = load_dataset("Yoxas/statistical_literacyv2") data = dataset["train"] # Convert the string embeddings to numerical arrays def convert_and_ensure_2d_embeddings(example): # Convert the string to a numpy array embedding_str = example['embedding'] embedding_str = embedding_str.replace('\n', ' ') embedding_list = list(map(float, embedding_str.strip("[]").split())) embeddings = np.array(embedding_list, dtype=np.float32) # Ensure the embeddings are 2-dimensional if embeddings.ndim == 1: embeddings = embeddings.reshape(1, -1) return {'embedding': embeddings} # Apply the function to ensure embeddings are 2-dimensional and of type float32 data = data.map(convert_and_ensure_2d_embeddings) data = data.add_faiss_index("embedding") model_id = "meta-llama/Meta-Llama-3-8B-Instruct" # use quantization to lower GPU usage bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 ) tokenizer = AutoTokenizer.from_pretrained(model_id, token=token) model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.bfloat16, device_map="auto", quantization_config=bnb_config, token=token ) terminators = [ tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("") ] SYS_PROMPT = """You are an assistant for answering questions. You are given the extracted parts of a long document and a question. Provide a conversational answer. If you don't know the answer, just say "I do not know." Don't make up an answer.""" def search(query: str, k: int = 3): """a function that embeds a new query and returns the most probable results""" embedded_query = ST.encode(query) # embed new query scores, retrieved_examples = data.get_nearest_examples( # retrieve results "embedding", embedded_query, # compare our new embedded query with the dataset embeddings k=k # get only top k results ) return scores, retrieved_examples def format_prompt(prompt, retrieved_documents, k): """using the retrieved documents we will prompt the model to generate our responses""" PROMPT = f"Question:{prompt}\nContext:" for idx in range(k): PROMPT += f"{retrieved_documents['text'][idx]}\n" return PROMPT @spaces.GPU(duration=150) def talk(prompt, history): k = 1 # number of retrieved documents scores, retrieved_documents = search(prompt, k) formatted_prompt = format_prompt(prompt, retrieved_documents, k) formatted_prompt = formatted_prompt[:2000] # to avoid GPU OOM messages = [{"role": "system", "content": SYS_PROMPT}, {"role": "user", "content": formatted_prompt}] # tell the model to generate input_ids = tokenizer.apply_chat_template( messages, add_generation_prompt=True, return_tensors="pt" ).to(model.device) outputs = model.generate( input_ids, max_new_tokens=1024, eos_token_id=terminators, do_sample=True, temperature=0.6, top_p=0.9, ) streamer = TextIteratorStreamer( tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True ) generate_kwargs = dict( input_ids=input_ids, streamer=streamer, max_new_tokens=1024, do_sample=True, top_p=0.95, temperature=0.75, eos_token_id=terminators, ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() outputs = [] for text in streamer: outputs.append(text) print(outputs) yield "".join(outputs) TITLE = "# RAG" DESCRIPTION = """ A rag pipeline with a chatbot feature Resources used to build this project : * embedding model : https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1 * dataset : https://huggingface.co/datasets/not-lain/wikipedia * faiss docs : https://huggingface.co/docs/datasets/v2.18.0/en/package_reference/main_classes#datasets.Dataset.add_faiss_index * chatbot : https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct """ demo = gr.ChatInterface( fn=talk, chatbot=gr.Chatbot( show_label=True, show_share_button=True, show_copy_button=True, likeable=True, layout="bubble", bubble_full_width=False, ), theme="Soft", examples=[["what's anarchy ? "]], title=TITLE, description=DESCRIPTION, ) demo.launch(debug=True)