Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,42 +1,32 @@
|
|
1 |
import gradio as gr
|
2 |
from datasets import load_dataset
|
3 |
-
|
4 |
import os
|
5 |
import spaces
|
6 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
|
7 |
import torch
|
8 |
from threading import Thread
|
9 |
from sentence_transformers import SentenceTransformer
|
10 |
-
|
11 |
-
import time
|
12 |
|
13 |
token = os.environ["HF_TOKEN"]
|
14 |
ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
|
15 |
|
16 |
dataset = load_dataset("Yoxas/statistical_literacyv2")
|
17 |
-
data = dataset["train"]
|
18 |
-
|
19 |
-
# Convert the list to a numpy array
|
20 |
-
embeddings_array = np.array(data["Abstract_Embeddings"])
|
21 |
|
22 |
-
|
23 |
-
print(embeddings_array.shape)
|
24 |
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
if len(embeddings_array.shape) == 1:
|
29 |
-
embeddings_array = embeddings_array.reshape(-1, 1)
|
30 |
-
embeddings_array = np.ascontiguousarray(embeddings_array)
|
31 |
-
new_data = new_data.add_column("Abstract_Embeddings", embeddings_array.tolist())
|
32 |
|
33 |
-
#
|
34 |
-
|
|
|
35 |
|
36 |
-
#
|
37 |
-
|
38 |
|
39 |
-
# Now you can use the Dataset with the Faiss index
|
40 |
data = data.add_faiss_index("Abstract_Embeddings")
|
41 |
|
42 |
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
|
@@ -46,7 +36,7 @@ bnb_config = BitsAndBytesConfig(
|
|
46 |
load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
|
47 |
)
|
48 |
|
49 |
-
tokenizer = AutoTokenizer.from_pretrained(model_id,token=token)
|
50 |
model = AutoModelForCausalLM.from_pretrained(
|
51 |
model_id,
|
52 |
torch_dtype=torch.bfloat16,
|
@@ -56,58 +46,55 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
56 |
)
|
57 |
terminators = [
|
58 |
tokenizer.eos_token_id,
|
59 |
-
tokenizer.convert_tokens_to_ids("
|
60 |
]
|
61 |
|
62 |
SYS_PROMPT = """You are an assistant for answering questions.
|
63 |
You are given the extracted parts of a long document and a question. Provide a conversational answer.
|
64 |
If you don't know the answer, just say "I do not know." Don't make up an answer."""
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
def search(query: str, k: int = 3 ):
|
69 |
"""a function that embeds a new query and returns the most probable results"""
|
70 |
-
embedded_query = ST.encode(query)
|
71 |
-
scores, retrieved_examples = data.get_nearest_examples(
|
72 |
-
"
|
73 |
-
k=k
|
74 |
)
|
75 |
return scores, retrieved_examples
|
76 |
|
77 |
-
def format_prompt(prompt,retrieved_documents,k):
|
78 |
"""using the retrieved documents we will prompt the model to generate our responses"""
|
79 |
PROMPT = f"Question:{prompt}\nContext:"
|
80 |
-
for idx in range(k)
|
81 |
-
PROMPT+= f"{retrieved_documents['text'][idx]}\n"
|
82 |
return PROMPT
|
83 |
|
84 |
-
|
85 |
@spaces.GPU(duration=150)
|
86 |
-
def talk(prompt,history):
|
87 |
-
k = 1
|
88 |
-
scores
|
89 |
-
formatted_prompt = format_prompt(prompt,retrieved_documents,k)
|
90 |
-
formatted_prompt = formatted_prompt[:2000]
|
91 |
-
messages = [{"role":"system","content":SYS_PROMPT},{"role":"user","content":formatted_prompt}]
|
92 |
# tell the model to generate
|
93 |
input_ids = tokenizer.apply_chat_template(
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
).to(model.device)
|
98 |
outputs = model.generate(
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
)
|
106 |
streamer = TextIteratorStreamer(
|
107 |
-
|
108 |
-
|
109 |
generate_kwargs = dict(
|
110 |
-
input_ids=
|
111 |
streamer=streamer,
|
112 |
max_new_tokens=1024,
|
113 |
do_sample=True,
|
@@ -124,7 +111,6 @@ def talk(prompt,history):
|
|
124 |
print(outputs)
|
125 |
yield "".join(outputs)
|
126 |
|
127 |
-
|
128 |
TITLE = "# RAG"
|
129 |
|
130 |
DESCRIPTION = """
|
@@ -136,7 +122,6 @@ Resources used to build this project :
|
|
136 |
* chatbot : https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct
|
137 |
"""
|
138 |
|
139 |
-
|
140 |
demo = gr.ChatInterface(
|
141 |
fn=talk,
|
142 |
chatbot=gr.Chatbot(
|
@@ -151,6 +136,5 @@ demo = gr.ChatInterface(
|
|
151 |
examples=[["what's anarchy ? "]],
|
152 |
title=TITLE,
|
153 |
description=DESCRIPTION,
|
154 |
-
|
155 |
)
|
156 |
demo.launch(debug=True)
|
|
|
1 |
import gradio as gr
|
2 |
from datasets import load_dataset
|
3 |
+
|
4 |
import os
|
5 |
import spaces
|
6 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
|
7 |
import torch
|
8 |
from threading import Thread
|
9 |
from sentence_transformers import SentenceTransformer
|
10 |
+
import numpy as np
|
|
|
11 |
|
12 |
token = os.environ["HF_TOKEN"]
|
13 |
ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
|
14 |
|
15 |
dataset = load_dataset("Yoxas/statistical_literacyv2")
|
|
|
|
|
|
|
|
|
16 |
|
17 |
+
data = dataset["train"]
|
|
|
18 |
|
19 |
+
# Check the structure of embeddings
|
20 |
+
example_embedding = data[0]['Abstract_Embeddings']
|
21 |
+
print(f"Example embedding shape: {np.array(example_embedding).shape}")
|
|
|
|
|
|
|
|
|
22 |
|
23 |
+
# Ensure embeddings are 2-dimensional
|
24 |
+
def ensure_2d_embeddings(embeddings):
|
25 |
+
return [np.atleast_2d(embedding) for embedding in embeddings]
|
26 |
|
27 |
+
# Apply the function to ensure embeddings are 2-dimensional
|
28 |
+
data = data.map(lambda example: {'Abstract_Embeddings': ensure_2d_embeddings(example['Abstract_Embeddings'])})
|
29 |
|
|
|
30 |
data = data.add_faiss_index("Abstract_Embeddings")
|
31 |
|
32 |
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
|
|
|
36 |
load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
|
37 |
)
|
38 |
|
39 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
|
40 |
model = AutoModelForCausalLM.from_pretrained(
|
41 |
model_id,
|
42 |
torch_dtype=torch.bfloat16,
|
|
|
46 |
)
|
47 |
terminators = [
|
48 |
tokenizer.eos_token_id,
|
49 |
+
tokenizer.convert_tokens_to_ids("")
|
50 |
]
|
51 |
|
52 |
SYS_PROMPT = """You are an assistant for answering questions.
|
53 |
You are given the extracted parts of a long document and a question. Provide a conversational answer.
|
54 |
If you don't know the answer, just say "I do not know." Don't make up an answer."""
|
55 |
|
56 |
+
def search(query: str, k: int = 3):
|
|
|
|
|
57 |
"""a function that embeds a new query and returns the most probable results"""
|
58 |
+
embedded_query = ST.encode(query) # embed new query
|
59 |
+
scores, retrieved_examples = data.get_nearest_examples( # retrieve results
|
60 |
+
"Abstract_Embeddings", embedded_query, # compare our new embedded query with the dataset embeddings
|
61 |
+
k=k # get only top k results
|
62 |
)
|
63 |
return scores, retrieved_examples
|
64 |
|
65 |
+
def format_prompt(prompt, retrieved_documents, k):
|
66 |
"""using the retrieved documents we will prompt the model to generate our responses"""
|
67 |
PROMPT = f"Question:{prompt}\nContext:"
|
68 |
+
for idx in range(k):
|
69 |
+
PROMPT += f"{retrieved_documents['text'][idx]}\n"
|
70 |
return PROMPT
|
71 |
|
|
|
72 |
@spaces.GPU(duration=150)
|
73 |
+
def talk(prompt, history):
|
74 |
+
k = 1 # number of retrieved documents
|
75 |
+
scores, retrieved_documents = search(prompt, k)
|
76 |
+
formatted_prompt = format_prompt(prompt, retrieved_documents, k)
|
77 |
+
formatted_prompt = formatted_prompt[:2000] # to avoid GPU OOM
|
78 |
+
messages = [{"role": "system", "content": SYS_PROMPT}, {"role": "user", "content": formatted_prompt}]
|
79 |
# tell the model to generate
|
80 |
input_ids = tokenizer.apply_chat_template(
|
81 |
+
messages,
|
82 |
+
add_generation_prompt=True,
|
83 |
+
return_tensors="pt"
|
84 |
).to(model.device)
|
85 |
outputs = model.generate(
|
86 |
+
input_ids,
|
87 |
+
max_new_tokens=1024,
|
88 |
+
eos_token_id=terminators,
|
89 |
+
do_sample=True,
|
90 |
+
temperature=0.6,
|
91 |
+
top_p=0.9,
|
92 |
)
|
93 |
streamer = TextIteratorStreamer(
|
94 |
+
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
|
95 |
+
)
|
96 |
generate_kwargs = dict(
|
97 |
+
input_ids=input_ids,
|
98 |
streamer=streamer,
|
99 |
max_new_tokens=1024,
|
100 |
do_sample=True,
|
|
|
111 |
print(outputs)
|
112 |
yield "".join(outputs)
|
113 |
|
|
|
114 |
TITLE = "# RAG"
|
115 |
|
116 |
DESCRIPTION = """
|
|
|
122 |
* chatbot : https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct
|
123 |
"""
|
124 |
|
|
|
125 |
demo = gr.ChatInterface(
|
126 |
fn=talk,
|
127 |
chatbot=gr.Chatbot(
|
|
|
136 |
examples=[["what's anarchy ? "]],
|
137 |
title=TITLE,
|
138 |
description=DESCRIPTION,
|
|
|
139 |
)
|
140 |
demo.launch(debug=True)
|