ccm commited on
Commit
3caa7cd
·
verified ·
1 Parent(s): 9dc260f

Some updates but mostly formatting

Browse files
Files changed (1) hide show
  1. main.py +69 -45
main.py CHANGED
@@ -1,29 +1,35 @@
1
- import gradio # for the interface
2
- import transformers # to load an LLM
3
- import sentence_transformers # to load an embedding model
 
 
4
  import faiss # to create an index
 
5
  import numpy # to work with vectors
6
  import pandas # to work with pandas
7
- import json # to work with JSON
8
- import datasets # to load the dataset
9
  import spaces # for GPU
10
- import threading # for threading
11
- import time # for better HCI
12
 
13
  # Constants
14
- GREETING = "Hi there! I'm an AI agent that uses a [retrieval-augmented generation](https://en.wikipedia.org/wiki/Retrieval-augmented_generation) pipeline to answer questions about research by the Design Research Collective. And the best part is that I always cite my ssources! What can I tell you about today?"
 
 
 
 
15
  EMBEDDING_MODEL_NAME = "allenai-specter"
16
- LLM_MODEL_NAME = "Qwen/Qwen2-7B-Instruct"
 
17
 
18
  # Load the dataset and convert to pandas
19
  full_data = datasets.load_dataset("ccm/publications")["train"].to_pandas()
20
 
21
  # Filter out any publications without an abstract
22
- filter = [
23
  '"abstract": null' in json.dumps(bibdict)
24
  for bibdict in full_data["bib_dict"].values
25
  ]
26
- data = full_data[~pandas.Series(filter)]
27
  data.reset_index(inplace=True)
28
 
29
  # Create a FAISS index for fast similarity search
@@ -38,6 +44,7 @@ index.add(vectors)
38
  # Load the model for later use in embeddings
39
  model = sentence_transformers.SentenceTransformer(EMBEDDING_MODEL_NAME)
40
 
 
41
  # Define the search function
42
  def search(query: str, k: int) -> tuple[str]:
43
  query = numpy.expand_dims(model.encode(query), axis=0)
@@ -45,40 +52,66 @@ def search(query: str, k: int) -> tuple[str]:
45
  D, I = index.search(query, k)
46
  top_five = data.loc[I[0]]
47
 
48
- search_results = "You are an AI assistant who delights in helping people" \
49
- + "learn about research from the Design Research Collective. Here are" \
50
- + "several abstracts from really cool, and really relevant, papers:\n\n"
 
 
51
 
52
  references = "\n\n## References\n\n"
53
 
54
  for i in range(k):
55
  search_results += top_five["bib_dict"].values[i]["abstract"] + "\n"
56
- references += str(i+1) + ". " + ", ".join([author.split(" ")[-1] for author in top_five["bib_dict"].values[i]["author"].split(" and ")]) + ". (" + str(int(top_five["bib_dict"].values[i]["pub_year"])) + "). [" + top_five["bib_dict"].values[i]["title"] + "]" \
57
- + "(https://scholar.google.com/citations?view_op=view_citation&citation_for_view=" + top_five["author_pub_id"].values[i] + ").\n"
58
-
59
- search_results += "\nIf these abstract aren't relevant to the followign query, please say that there is not much research in that area. Response to the following query from the perspective of the provided abstracts only:"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  return search_results, references
62
 
63
 
64
  # Create an LLM pipeline that we can send queries to
65
  tokenizer = transformers.AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
66
- streamer = transformers.TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
 
 
67
  chatmodel = transformers.AutoModelForCausalLM.from_pretrained(
68
- LLM_MODEL_NAME,
69
- torch_dtype="auto",
70
- device_map="auto"
71
  )
72
 
 
73
  def preprocess(message: str) -> tuple[str]:
74
  """Applies a preprocessing step to the user's message before the LLM receives it"""
75
  block_search_results, formatted_search_results = search(message, 5)
76
  return block_search_results + message, formatted_search_results
77
 
 
78
  def postprocess(response: str, bypass_from_preprocessing: str) -> str:
79
  """Applies a postprocessing step to the LLM's response before the user receives it"""
80
  return response + bypass_from_preprocessing
81
 
 
82
  @spaces.GPU
83
  def predict(message: str, history: list[str]) -> str:
84
  """This function is responsible for crafting a response"""
@@ -92,29 +125,23 @@ def predict(message: str, history: list[str]) -> str:
92
  history = history[-1]
93
  print(history)
94
  history_transformer_format = [
95
- {"role": "assistant" if idx&1 else "user", "content": msg}
96
  for idx, msg in enumerate(history)
97
  ] + [{"role": "user", "content": message}]
98
 
99
  # Stream a response from pipe
100
  text = tokenizer.apply_chat_template(
101
- history_transformer_format,
102
- tokenize=False,
103
- add_generation_prompt=True
104
  )
105
  model_inputs = tokenizer([text], return_tensors="pt").to("cuda:0")
106
 
107
- generate_kwargs = dict(
108
- model_inputs,
109
- streamer=streamer,
110
- max_new_tokens=512
111
- )
112
  t = threading.Thread(target=chatmodel.generate, kwargs=generate_kwargs)
113
  t.start()
114
 
115
  partial_message = ""
116
  for new_token in streamer:
117
- if new_token != '<':
118
  partial_message += new_token
119
  time.sleep(0.05)
120
  yield partial_message
@@ -124,21 +151,18 @@ def predict(message: str, history: list[str]) -> str:
124
 
125
  # Create and run the gradio interface
126
  gradio.ChatInterface(
127
- predict,
128
  examples=[
129
  "Tell me about new research at the intersection of additive manufacturing and machine learning",
130
  "What is a physics-informed neural network and what can it be used for?",
131
- "What can agent-based models do about climate change?"
132
  ],
133
- chatbot = gradio.Chatbot(
134
- show_label=False,
135
- show_copy_button=True,
136
- value=[["", GREETING]]
137
  ),
138
- retry_btn = None,
139
- undo_btn = None,
140
- clear_btn = None,
141
- theme = "monochrome",
142
- cache_examples = True,
143
- fill_height = True,
144
- ).launch(debug=True)
 
1
+ import json # to work with JSON
2
+ import threading # for threading
3
+ import time # for better HCI
4
+
5
+ import datasets # to load the dataset
6
  import faiss # to create an index
7
+ import gradio # for the interface
8
  import numpy # to work with vectors
9
  import pandas # to work with pandas
10
+ import sentence_transformers # to load an embedding model
 
11
  import spaces # for GPU
12
+ import transformers # to load an LLM
 
13
 
14
  # Constants
15
+ GREETING = (
16
+ "Howdy! I'm an AI agent that uses a [retrieval-augmented generation]("
17
+ "https://en.wikipedia.org/wiki/Retrieval-augmented_generation) pipeline to answer questions about research by the "
18
+ "Design Research Collective. And the best part is that I always cite my ssources! What can I tell you about today?"
19
+ )
20
  EMBEDDING_MODEL_NAME = "allenai-specter"
21
+ # LLM_MODEL_NAME = "Qwen/Qwen2-7B-Instruct"
22
+ LLM_MODEL_NAME = "Qwen/Qwen2-0.5B-Instruct"
23
 
24
  # Load the dataset and convert to pandas
25
  full_data = datasets.load_dataset("ccm/publications")["train"].to_pandas()
26
 
27
  # Filter out any publications without an abstract
28
+ abstract_is_null = [
29
  '"abstract": null' in json.dumps(bibdict)
30
  for bibdict in full_data["bib_dict"].values
31
  ]
32
+ data = full_data[~pandas.Series(abstract_is_null)]
33
  data.reset_index(inplace=True)
34
 
35
  # Create a FAISS index for fast similarity search
 
44
  # Load the model for later use in embeddings
45
  model = sentence_transformers.SentenceTransformer(EMBEDDING_MODEL_NAME)
46
 
47
+
48
  # Define the search function
49
  def search(query: str, k: int) -> tuple[str]:
50
  query = numpy.expand_dims(model.encode(query), axis=0)
 
52
  D, I = index.search(query, k)
53
  top_five = data.loc[I[0]]
54
 
55
+ search_results = (
56
+ "You are an AI assistant who delights in helping people learn about research from the Design "
57
+ "Research Collective. Here are several abstracts from really cool, and really relevant, "
58
+ "papers:\n\n"
59
+ )
60
 
61
  references = "\n\n## References\n\n"
62
 
63
  for i in range(k):
64
  search_results += top_five["bib_dict"].values[i]["abstract"] + "\n"
65
+ references += (
66
+ str(i + 1)
67
+ + ". "
68
+ + ", ".join(
69
+ [
70
+ author.split(" ")[-1]
71
+ for author in top_five["bib_dict"]
72
+ .values[i]["author"]
73
+ .split(" and ")
74
+ ]
75
+ )
76
+ + ". ("
77
+ + str(int(top_five["bib_dict"].values[i]["pub_year"]))
78
+ + "). ["
79
+ + top_five["bib_dict"].values[i]["title"]
80
+ + "]"
81
+ + "(https://scholar.google.com/citations?view_op=view_citation&citation_for_view="
82
+ + top_five["author_pub_id"].values[i]
83
+ + ").\n"
84
+ )
85
+
86
+ search_results += (
87
+ "\nIf these abstract aren't relevant to the following query, please reply 'I am unsure' or "
88
+ "similar. Respond to the following query from the perspective of the provided abstracts only:"
89
+ )
90
 
91
  return search_results, references
92
 
93
 
94
  # Create an LLM pipeline that we can send queries to
95
  tokenizer = transformers.AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
96
+ streamer = transformers.TextIteratorStreamer(
97
+ tokenizer, skip_prompt=True, skip_special_tokens=True
98
+ )
99
  chatmodel = transformers.AutoModelForCausalLM.from_pretrained(
100
+ LLM_MODEL_NAME, torch_dtype="auto", device_map="auto"
 
 
101
  )
102
 
103
+
104
  def preprocess(message: str) -> tuple[str]:
105
  """Applies a preprocessing step to the user's message before the LLM receives it"""
106
  block_search_results, formatted_search_results = search(message, 5)
107
  return block_search_results + message, formatted_search_results
108
 
109
+
110
  def postprocess(response: str, bypass_from_preprocessing: str) -> str:
111
  """Applies a postprocessing step to the LLM's response before the user receives it"""
112
  return response + bypass_from_preprocessing
113
 
114
+
115
  @spaces.GPU
116
  def predict(message: str, history: list[str]) -> str:
117
  """This function is responsible for crafting a response"""
 
125
  history = history[-1]
126
  print(history)
127
  history_transformer_format = [
128
+ {"role": "assistant" if idx & 1 else "user", "content": msg}
129
  for idx, msg in enumerate(history)
130
  ] + [{"role": "user", "content": message}]
131
 
132
  # Stream a response from pipe
133
  text = tokenizer.apply_chat_template(
134
+ history_transformer_format, tokenize=False, add_generation_prompt=True
 
 
135
  )
136
  model_inputs = tokenizer([text], return_tensors="pt").to("cuda:0")
137
 
138
+ generate_kwargs = dict(model_inputs, streamer=streamer, max_new_tokens=512)
 
 
 
 
139
  t = threading.Thread(target=chatmodel.generate, kwargs=generate_kwargs)
140
  t.start()
141
 
142
  partial_message = ""
143
  for new_token in streamer:
144
+ if new_token != "<":
145
  partial_message += new_token
146
  time.sleep(0.05)
147
  yield partial_message
 
151
 
152
  # Create and run the gradio interface
153
  gradio.ChatInterface(
154
+ predict,
155
  examples=[
156
  "Tell me about new research at the intersection of additive manufacturing and machine learning",
157
  "What is a physics-informed neural network and what can it be used for?",
158
+ "What can agent-based models do about climate change?",
159
  ],
160
+ chatbot=gradio.Chatbot(
161
+ show_label=False, show_copy_button=True, value=[["", GREETING]]
 
 
162
  ),
163
+ retry_btn=None,
164
+ undo_btn=None,
165
+ clear_btn=None,
166
+ cache_examples=True,
167
+ fill_height=True,
168
+ ).launch(debug=True)