Ankerkraut commited on
Commit
d3077b7
·
1 Parent(s): 4c9753f

remove unneccesary functions

Browse files
Files changed (1) hide show
  1. app.py +8 -108
app.py CHANGED
@@ -1,42 +1,12 @@
1
  import spaces
2
  import gradio as gr
3
  from huggingface_hub import InferenceClient
4
- from qdrant_client import QdrantClient, models
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
6
  from sentence_transformers import SentenceTransformer
7
  from huggingface_hub import login
8
  import torch
9
- import json
10
- import bs4
11
  import os
12
  os.environ["USE_FLASH_ATTENTION"] = "0"
13
-
14
- product_strings = []
15
- recipe_strings = []
16
- with open('./Data/product_strings.json', 'r', encoding='utf-8') as f:
17
- product_strings = [product for product in json.load(f)["product_strings"]]
18
- with open('./Data/recipe_strings.json', 'r', encoding='utf-8') as f:
19
- recipe_strings = [recipe for recipe in json.load(f)["recipe_strings"]]
20
-
21
- client = QdrantClient(":memory:") #QdrantClient("localhost:6333")
22
- client.set_model("sentence-transformers/all-MiniLM-L6-v2")
23
- client.set_sparse_model("prithivida/Splade_PP_en_v1")
24
- client.delete_collection(collection_name="products")
25
- client.create_collection(
26
- collection_name="products",
27
- vectors_config=client.get_fastembed_vector_params(),
28
- sparse_vectors_config=client.get_fastembed_sparse_vector_params(),
29
- )
30
- client.delete_collection(collection_name="recipes")
31
- client.create_collection(
32
- collection_name="recipes",
33
- vectors_config=client.get_fastembed_vector_params(),
34
- sparse_vectors_config=client.get_fastembed_sparse_vector_params(),
35
- )
36
- client.add(collection_name="products",
37
- documents=product_strings)
38
- client.add(collection_name="recipes",
39
- documents=recipe_strings)
40
  model_name = "LeoLM/leo-hessianai-13b-chat"
41
 
42
  last_messages = []
@@ -84,7 +54,7 @@ def generate_response(query, context, prompts, max_tokens, temperature, top_p, g
84
  <|im_end|>
85
  {"".join(last_messages)}
86
  <|im_start|>user
87
- Frage: {query}
88
  <|im_end|>
89
  <|im_start|>assistant"""
90
 
@@ -100,7 +70,7 @@ def generate_response(query, context, prompts, max_tokens, temperature, top_p, g
100
  <|im_end|>
101
  {"".join(last_messages)}
102
  <|im_start|>user
103
- Frage: {query}
104
  <|im_end|>
105
  <|im_start|>assistant"""
106
 
@@ -116,98 +86,28 @@ def generate_response(query, context, prompts, max_tokens, temperature, top_p, g
116
  <|im_end|>
117
  {"".join(last_messages)}
118
  <|im_start|>user
119
- Frage: {query}
120
  <|im_end|>
121
  <|im_start|>assistant"""
122
 
123
  system_message = system_message_products
124
-
125
- if collection_name =="recipes":
126
  system_message = system_message_recipes
127
- elif collection_name =="service":
128
  system_message = system_message_support
129
-
130
 
131
  print("Prompt: ", system_message)
132
 
133
  response = generator(system_message, do_sample=True, top_p=top_p, max_new_tokens=max_tokens, temperature=temperature)[0]["generated_text"]
134
-
135
  # Extract only the assistant's response
136
  if "assistant" in response:
137
  response = response.split("assistant").pop().strip()
138
 
139
  return response
140
 
141
- def search_qdrant_with_context(query_text, collection_name, top_k=3):
142
- """Search Qdrant using a GPT-2 generated embedding."""
143
- print(collection_name)
144
- # print(query_embedding)
145
- search_results = client.query(
146
- collection_name=collection_name,
147
- query_text=query_text,
148
- query_filter=None,
149
- limit=top_k # Number of top results to return
150
- )
151
- retrieved_texts = [result.metadata for result in search_results if result.score > 0.3]
152
-
153
- if not retrieved_texts:
154
- retrieved_texts = "Keinen passenden Kontext gefunden."
155
- print("Retrieved Text ", retrieved_texts)
156
-
157
- return retrieved_texts
158
-
159
- @spaces.GPU
160
- def interactive_chat(query):
161
- generator = get_model()
162
- collection_name = "products"
163
- if "rezept" in query.lower() or "gericht" in query.lower():
164
- collection_name = "recipes"
165
- elif "bestellung" in query.lower() or "order" in query.lower():
166
- collection_name = "products"
167
- print(collection_name)
168
- print(query)
169
- if len(query.split()) < 3:
170
- return generate_response(query, "Der Kunde muss womöglich detailliertere Angaben machen, entscheide, was du sagst.", last_messages, 512, 0.2, 0.95, generator[0])
171
- context = [document["document"] for document in search_qdrant_with_context(query, collection_name)]
172
-
173
- system_message = f"""<|im_start|>system Rolle: Du bist ein KI-Assistent der die Informationen in Relation zum Kontext bewertet.
174
- Oberstes Ziel: Bewerte die die Ergebnisse und stufe sie nach Relevanz in Bezug auf die Konversation ein.
175
- Meta-Anweisung: Analysiere die Konversation und mache Vorschläge für Suchbegriffe in Stichpunkten.
176
- Suchergebnisse: {context}
177
- <|im_end|>
178
- {"".join(last_messages)}
179
- <|im_start|>user
180
- {query}
181
- <|im_end|>
182
- <|im_start|>assistant"""
183
- refined_context = generator[1](system_message, do_sample=True, padding=True, truncation=True, top_p=0.95, max_new_tokens=100)
184
- # Retrieve relevant context from Qdrant
185
- print(f"""Refined context: {refined_context[0]["generated_text"].split("assistant").pop()}""")
186
-
187
- context = [document["document"] for document in search_qdrant_with_context(query + " " + refined_context[0]["generated_text"].split("assistant\n").pop(), collection_name)]
188
- answer = generate_response(query, context, last_messages, 512, 0.2, 0.95, generator)
189
- full_conv = f"<|im_start|>user {query}<|im_end|><|im_start|>assistent {answer}<|im_end|>"
190
- # if len(last_messages) > 5:
191
- # last_messages.pop(0)
192
- # last_messages.append(full_conv)
193
- print(f"last messages: {last_messages}")
194
- print()
195
- return answer
196
- @spaces.GPU(duration=1500)
197
- def get_answers():
198
- answers = []
199
- last_messages = []
200
- with open("./Data/questions.json", "r", encoding="utf-8")as f:
201
- json_data = json.load(f)["questions"]
202
- for (index, question) in enumerate(json_data):
203
- if index <= 5:
204
- continue
205
- answer = interactive_chat(question)
206
- answers.append(answer)
207
- with open("./Data/answers.json", "w", encoding="utf-8") as file:
208
- json.dump({"answers": answers}, file, ensure_ascii=False, indent=4)
209
-
210
- @spaces.GPU
211
  def respond(
212
  query,
213
  history: list[tuple[str, str]],
 
1
  import spaces
2
  import gradio as gr
3
  from huggingface_hub import InferenceClient
 
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
5
  from sentence_transformers import SentenceTransformer
6
  from huggingface_hub import login
7
  import torch
 
 
8
  import os
9
  os.environ["USE_FLASH_ATTENTION"] = "0"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  model_name = "LeoLM/leo-hessianai-13b-chat"
11
 
12
  last_messages = []
 
54
  <|im_end|>
55
  {"".join(last_messages)}
56
  <|im_start|>user
57
+ {query}
58
  <|im_end|>
59
  <|im_start|>assistant"""
60
 
 
70
  <|im_end|>
71
  {"".join(last_messages)}
72
  <|im_start|>user
73
+ {query}
74
  <|im_end|>
75
  <|im_start|>assistant"""
76
 
 
86
  <|im_end|>
87
  {"".join(last_messages)}
88
  <|im_start|>user
89
+ {query}
90
  <|im_end|>
91
  <|im_start|>assistant"""
92
 
93
  system_message = system_message_products
94
+
95
+ if collection_name == "recipes":
96
  system_message = system_message_recipes
97
+ elif collection_name == "services":
98
  system_message = system_message_support
99
+
100
 
101
  print("Prompt: ", system_message)
102
 
103
  response = generator(system_message, do_sample=True, top_p=top_p, max_new_tokens=max_tokens, temperature=temperature)[0]["generated_text"]
104
+ print(f"""-----Response: {response}-----""")
105
  # Extract only the assistant's response
106
  if "assistant" in response:
107
  response = response.split("assistant").pop().strip()
108
 
109
  return response
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  def respond(
112
  query,
113
  history: list[tuple[str, str]],