Spaces:
Running
on
Zero
Running
on
Zero
Some updates but mostly formatting
Browse files
main.py
CHANGED
@@ -1,29 +1,35 @@
|
|
1 |
-
import
|
2 |
-
import
|
3 |
-
import
|
|
|
|
|
4 |
import faiss # to create an index
|
|
|
5 |
import numpy # to work with vectors
|
6 |
import pandas # to work with pandas
|
7 |
-
import
|
8 |
-
import datasets # to load the dataset
|
9 |
import spaces # for GPU
|
10 |
-
import
|
11 |
-
import time # for better HCI
|
12 |
|
13 |
# Constants
|
14 |
-
GREETING =
|
|
|
|
|
|
|
|
|
15 |
EMBEDDING_MODEL_NAME = "allenai-specter"
|
16 |
-
LLM_MODEL_NAME = "Qwen/Qwen2-7B-Instruct"
|
|
|
17 |
|
18 |
# Load the dataset and convert to pandas
|
19 |
full_data = datasets.load_dataset("ccm/publications")["train"].to_pandas()
|
20 |
|
21 |
# Filter out any publications without an abstract
|
22 |
-
|
23 |
'"abstract": null' in json.dumps(bibdict)
|
24 |
for bibdict in full_data["bib_dict"].values
|
25 |
]
|
26 |
-
data = full_data[~pandas.Series(
|
27 |
data.reset_index(inplace=True)
|
28 |
|
29 |
# Create a FAISS index for fast similarity search
|
@@ -38,6 +44,7 @@ index.add(vectors)
|
|
38 |
# Load the model for later use in embeddings
|
39 |
model = sentence_transformers.SentenceTransformer(EMBEDDING_MODEL_NAME)
|
40 |
|
|
|
41 |
# Define the search function
|
42 |
def search(query: str, k: int) -> tuple[str]:
|
43 |
query = numpy.expand_dims(model.encode(query), axis=0)
|
@@ -45,40 +52,66 @@ def search(query: str, k: int) -> tuple[str]:
|
|
45 |
D, I = index.search(query, k)
|
46 |
top_five = data.loc[I[0]]
|
47 |
|
48 |
-
search_results =
|
49 |
-
|
50 |
-
|
|
|
|
|
51 |
|
52 |
references = "\n\n## References\n\n"
|
53 |
|
54 |
for i in range(k):
|
55 |
search_results += top_five["bib_dict"].values[i]["abstract"] + "\n"
|
56 |
-
references +=
|
57 |
-
|
58 |
-
|
59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
return search_results, references
|
62 |
|
63 |
|
64 |
# Create an LLM pipeline that we can send queries to
|
65 |
tokenizer = transformers.AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
|
66 |
-
streamer = transformers.TextIteratorStreamer(
|
|
|
|
|
67 |
chatmodel = transformers.AutoModelForCausalLM.from_pretrained(
|
68 |
-
LLM_MODEL_NAME,
|
69 |
-
torch_dtype="auto",
|
70 |
-
device_map="auto"
|
71 |
)
|
72 |
|
|
|
73 |
def preprocess(message: str) -> tuple[str]:
|
74 |
"""Applies a preprocessing step to the user's message before the LLM receives it"""
|
75 |
block_search_results, formatted_search_results = search(message, 5)
|
76 |
return block_search_results + message, formatted_search_results
|
77 |
|
|
|
78 |
def postprocess(response: str, bypass_from_preprocessing: str) -> str:
|
79 |
"""Applies a postprocessing step to the LLM's response before the user receives it"""
|
80 |
return response + bypass_from_preprocessing
|
81 |
|
|
|
82 |
@spaces.GPU
|
83 |
def predict(message: str, history: list[str]) -> str:
|
84 |
"""This function is responsible for crafting a response"""
|
@@ -92,29 +125,23 @@ def predict(message: str, history: list[str]) -> str:
|
|
92 |
history = history[-1]
|
93 |
print(history)
|
94 |
history_transformer_format = [
|
95 |
-
{"role": "assistant" if idx&1 else "user", "content": msg}
|
96 |
for idx, msg in enumerate(history)
|
97 |
] + [{"role": "user", "content": message}]
|
98 |
|
99 |
# Stream a response from pipe
|
100 |
text = tokenizer.apply_chat_template(
|
101 |
-
history_transformer_format,
|
102 |
-
tokenize=False,
|
103 |
-
add_generation_prompt=True
|
104 |
)
|
105 |
model_inputs = tokenizer([text], return_tensors="pt").to("cuda:0")
|
106 |
|
107 |
-
generate_kwargs = dict(
|
108 |
-
model_inputs,
|
109 |
-
streamer=streamer,
|
110 |
-
max_new_tokens=512
|
111 |
-
)
|
112 |
t = threading.Thread(target=chatmodel.generate, kwargs=generate_kwargs)
|
113 |
t.start()
|
114 |
|
115 |
partial_message = ""
|
116 |
for new_token in streamer:
|
117 |
-
if new_token !=
|
118 |
partial_message += new_token
|
119 |
time.sleep(0.05)
|
120 |
yield partial_message
|
@@ -124,21 +151,18 @@ def predict(message: str, history: list[str]) -> str:
|
|
124 |
|
125 |
# Create and run the gradio interface
|
126 |
gradio.ChatInterface(
|
127 |
-
predict,
|
128 |
examples=[
|
129 |
"Tell me about new research at the intersection of additive manufacturing and machine learning",
|
130 |
"What is a physics-informed neural network and what can it be used for?",
|
131 |
-
"What can agent-based models do about climate change?"
|
132 |
],
|
133 |
-
chatbot
|
134 |
-
show_label=False,
|
135 |
-
show_copy_button=True,
|
136 |
-
value=[["", GREETING]]
|
137 |
),
|
138 |
-
retry_btn
|
139 |
-
undo_btn
|
140 |
-
clear_btn
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
).launch(debug=True)
|
|
|
1 |
+
import json # to work with JSON
|
2 |
+
import threading # for threading
|
3 |
+
import time # for better HCI
|
4 |
+
|
5 |
+
import datasets # to load the dataset
|
6 |
import faiss # to create an index
|
7 |
+
import gradio # for the interface
|
8 |
import numpy # to work with vectors
|
9 |
import pandas # to work with pandas
|
10 |
+
import sentence_transformers # to load an embedding model
|
|
|
11 |
import spaces # for GPU
|
12 |
+
import transformers # to load an LLM
|
|
|
13 |
|
14 |
# Constants
|
15 |
+
GREETING = (
|
16 |
+
"Howdy! I'm an AI agent that uses a [retrieval-augmented generation]("
|
17 |
+
"https://en.wikipedia.org/wiki/Retrieval-augmented_generation) pipeline to answer questions about research by the "
|
18 |
+
"Design Research Collective. And the best part is that I always cite my ssources! What can I tell you about today?"
|
19 |
+
)
|
20 |
EMBEDDING_MODEL_NAME = "allenai-specter"
|
21 |
+
# LLM_MODEL_NAME = "Qwen/Qwen2-7B-Instruct"
|
22 |
+
LLM_MODEL_NAME = "Qwen/Qwen2-0.5B-Instruct"
|
23 |
|
24 |
# Load the dataset and convert to pandas
|
25 |
full_data = datasets.load_dataset("ccm/publications")["train"].to_pandas()
|
26 |
|
27 |
# Filter out any publications without an abstract
|
28 |
+
abstract_is_null = [
|
29 |
'"abstract": null' in json.dumps(bibdict)
|
30 |
for bibdict in full_data["bib_dict"].values
|
31 |
]
|
32 |
+
data = full_data[~pandas.Series(abstract_is_null)]
|
33 |
data.reset_index(inplace=True)
|
34 |
|
35 |
# Create a FAISS index for fast similarity search
|
|
|
44 |
# Load the model for later use in embeddings
|
45 |
model = sentence_transformers.SentenceTransformer(EMBEDDING_MODEL_NAME)
|
46 |
|
47 |
+
|
48 |
# Define the search function
|
49 |
def search(query: str, k: int) -> tuple[str]:
|
50 |
query = numpy.expand_dims(model.encode(query), axis=0)
|
|
|
52 |
D, I = index.search(query, k)
|
53 |
top_five = data.loc[I[0]]
|
54 |
|
55 |
+
search_results = (
|
56 |
+
"You are an AI assistant who delights in helping people learn about research from the Design "
|
57 |
+
"Research Collective. Here are several abstracts from really cool, and really relevant, "
|
58 |
+
"papers:\n\n"
|
59 |
+
)
|
60 |
|
61 |
references = "\n\n## References\n\n"
|
62 |
|
63 |
for i in range(k):
|
64 |
search_results += top_five["bib_dict"].values[i]["abstract"] + "\n"
|
65 |
+
references += (
|
66 |
+
str(i + 1)
|
67 |
+
+ ". "
|
68 |
+
+ ", ".join(
|
69 |
+
[
|
70 |
+
author.split(" ")[-1]
|
71 |
+
for author in top_five["bib_dict"]
|
72 |
+
.values[i]["author"]
|
73 |
+
.split(" and ")
|
74 |
+
]
|
75 |
+
)
|
76 |
+
+ ". ("
|
77 |
+
+ str(int(top_five["bib_dict"].values[i]["pub_year"]))
|
78 |
+
+ "). ["
|
79 |
+
+ top_five["bib_dict"].values[i]["title"]
|
80 |
+
+ "]"
|
81 |
+
+ "(https://scholar.google.com/citations?view_op=view_citation&citation_for_view="
|
82 |
+
+ top_five["author_pub_id"].values[i]
|
83 |
+
+ ").\n"
|
84 |
+
)
|
85 |
+
|
86 |
+
search_results += (
|
87 |
+
"\nIf these abstract aren't relevant to the following query, please reply 'I am unsure' or "
|
88 |
+
"similar. Respond to the following query from the perspective of the provided abstracts only:"
|
89 |
+
)
|
90 |
|
91 |
return search_results, references
|
92 |
|
93 |
|
94 |
# Create an LLM pipeline that we can send queries to
|
95 |
tokenizer = transformers.AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
|
96 |
+
streamer = transformers.TextIteratorStreamer(
|
97 |
+
tokenizer, skip_prompt=True, skip_special_tokens=True
|
98 |
+
)
|
99 |
chatmodel = transformers.AutoModelForCausalLM.from_pretrained(
|
100 |
+
LLM_MODEL_NAME, torch_dtype="auto", device_map="auto"
|
|
|
|
|
101 |
)
|
102 |
|
103 |
+
|
104 |
def preprocess(message: str) -> tuple[str]:
|
105 |
"""Applies a preprocessing step to the user's message before the LLM receives it"""
|
106 |
block_search_results, formatted_search_results = search(message, 5)
|
107 |
return block_search_results + message, formatted_search_results
|
108 |
|
109 |
+
|
110 |
def postprocess(response: str, bypass_from_preprocessing: str) -> str:
|
111 |
"""Applies a postprocessing step to the LLM's response before the user receives it"""
|
112 |
return response + bypass_from_preprocessing
|
113 |
|
114 |
+
|
115 |
@spaces.GPU
|
116 |
def predict(message: str, history: list[str]) -> str:
|
117 |
"""This function is responsible for crafting a response"""
|
|
|
125 |
history = history[-1]
|
126 |
print(history)
|
127 |
history_transformer_format = [
|
128 |
+
{"role": "assistant" if idx & 1 else "user", "content": msg}
|
129 |
for idx, msg in enumerate(history)
|
130 |
] + [{"role": "user", "content": message}]
|
131 |
|
132 |
# Stream a response from pipe
|
133 |
text = tokenizer.apply_chat_template(
|
134 |
+
history_transformer_format, tokenize=False, add_generation_prompt=True
|
|
|
|
|
135 |
)
|
136 |
model_inputs = tokenizer([text], return_tensors="pt").to("cuda:0")
|
137 |
|
138 |
+
generate_kwargs = dict(model_inputs, streamer=streamer, max_new_tokens=512)
|
|
|
|
|
|
|
|
|
139 |
t = threading.Thread(target=chatmodel.generate, kwargs=generate_kwargs)
|
140 |
t.start()
|
141 |
|
142 |
partial_message = ""
|
143 |
for new_token in streamer:
|
144 |
+
if new_token != "<":
|
145 |
partial_message += new_token
|
146 |
time.sleep(0.05)
|
147 |
yield partial_message
|
|
|
151 |
|
152 |
# Create and run the gradio interface
|
153 |
gradio.ChatInterface(
|
154 |
+
predict,
|
155 |
examples=[
|
156 |
"Tell me about new research at the intersection of additive manufacturing and machine learning",
|
157 |
"What is a physics-informed neural network and what can it be used for?",
|
158 |
+
"What can agent-based models do about climate change?",
|
159 |
],
|
160 |
+
chatbot=gradio.Chatbot(
|
161 |
+
show_label=False, show_copy_button=True, value=[["", GREETING]]
|
|
|
|
|
162 |
),
|
163 |
+
retry_btn=None,
|
164 |
+
undo_btn=None,
|
165 |
+
clear_btn=None,
|
166 |
+
cache_examples=True,
|
167 |
+
fill_height=True,
|
168 |
+
).launch(debug=True)
|
|