AminFaraji commited on
Commit
ffb24a6
·
verified ·
1 Parent(s): 1c77e84

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +298 -0
app.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ print(55877)
2
+ import argparse
3
+ # from dataclasses import dataclass
4
+ from langchain.prompts import ChatPromptTemplate
5
+ try:
6
+ from langchain_community.vectorstores import Chroma
7
+ except:
8
+ from langchain_community.vectorstores import Chroma
9
+ #from langchain_openai import OpenAIEmbeddings
10
+ #from langchain_openai import ChatOpenAI
11
+
12
+ # from langchain.document_loaders import DirectoryLoader
13
+ from langchain_community.document_loaders import DirectoryLoader
14
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
15
+ from langchain.schema import Document
16
+ # from langchain.embeddings import OpenAIEmbeddings
17
+ #from langchain_openai import OpenAIEmbeddings
18
+ from langchain_community.vectorstores import Chroma
19
+ import openai
20
+ from dotenv import load_dotenv
21
+ import os
22
+ import shutil
23
+
24
+
25
+ import re
26
+ import warnings
27
+ from typing import List
28
+
29
+ import torch
30
+ from langchain import PromptTemplate
31
+ from langchain.chains import ConversationChain
32
+ from langchain.chains.conversation.memory import ConversationBufferWindowMemory
33
+ from langchain.llms import HuggingFacePipeline
34
+ from langchain.schema import BaseOutputParser
35
+ from transformers import (
36
+ AutoModelForCausalLM,
37
+ AutoTokenizer,
38
+ StoppingCriteria,
39
+ StoppingCriteriaList,
40
+ pipeline,
41
+ )
42
+
43
+ warnings.filterwarnings("ignore", category=UserWarning)
44
+
45
+ MODEL_NAME = "tiiuae/falcon-7b-instruct"
46
+
47
+ model = AutoModelForCausalLM.from_pretrained(
48
+ MODEL_NAME, trust_remote_code=True
49
+ )
50
+ model = model.eval()
51
+ print('model loadeddddddddddddddddddddddd')
52
+
53
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
54
+ print(f"Model device: {model.device}")
55
+
56
+ # a custom embedding
57
+ from sentence_transformers import SentenceTransformer
58
+ from langchain_experimental.text_splitter import SemanticChunker
59
+ from typing import List
60
+
61
+
62
+ class MyEmbeddings:
63
+ def __init__(self):
64
+ self.model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
65
+ #self.model=model
66
+
67
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
68
+ return [self.model.encode(t).tolist() for t in texts]
69
+ def embed_query(self, query: str) -> List[float]:
70
+ return [self.model.encode([query])][0][0].tolist()
71
+
72
+
73
+ embeddings = MyEmbeddings()
74
+
75
+ splitter = SemanticChunker(embeddings)
76
+
77
+ PROMPT_TEMPLATE = """
78
+ Answer the question based only on the following context:
79
+
80
+ {context}
81
+
82
+ ---
83
+
84
+ Answer the question based on the above context: {question}
85
+ """
86
+
87
+
88
+ # Create CLI.
89
+ #parser = argparse.ArgumentParser()
90
+ #parser.add_argument("query_text", type=str, help="The query text.")
91
+ #args = parser.parse_args()
92
+ #query_text = args.query_text
93
+
94
+ # a sample query to be asked from the bot and it is expected to be answered based on the template
95
+ query_text="what did alice say to rabbit"
96
+
97
+ # Prepare the DB.
98
+ #embedding_function = OpenAIEmbeddings() # main
99
+
100
+ CHROMA_PATH = "chroma8"
101
+ # call the chroma generated in a directory
102
+ db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embeddings)
103
+
104
+ # Search the DB for similar documents to the query.
105
+ results = db.similarity_search_with_relevance_scores(query_text, k=2)
106
+ if len(results) == 0 or results[0][1] < 0.5:
107
+ print(f"Unable to find matching results.")
108
+
109
+
110
+ context_text = "\n\n---\n\n".join([doc.page_content for doc, _score in results])
111
+ prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE)
112
+ prompt = prompt_template.format(context=context_text, question=query_text)
113
+ print(prompt)
114
+
115
+
116
+
117
+
118
+ generation_config = model.generation_config
119
+ generation_config.temperature = 0
120
+ generation_config.num_return_sequences = 1
121
+ generation_config.max_new_tokens = 256
122
+ generation_config.use_cache = False
123
+ generation_config.repetition_penalty = 1.7
124
+ generation_config.pad_token_id = tokenizer.eos_token_id
125
+ generation_config.eos_token_id = tokenizer.eos_token_id
126
+ generation_config
127
+
128
+ prompt = """
129
+ The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context.
130
+
131
+ Current conversation:
132
+
133
+ Human: Who is Dwight K Schrute?
134
+ AI:
135
+ """.strip()
136
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
137
+ input_ids = input_ids.to(model.device)
138
+
139
+ class StopGenerationCriteria(StoppingCriteria):
140
+ def __init__(
141
+ self, tokens: List[List[str]], tokenizer: AutoTokenizer, device: torch.device
142
+ ):
143
+ stop_token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens]
144
+ self.stop_token_ids = [
145
+ torch.tensor(x, dtype=torch.long, device=device) for x in stop_token_ids
146
+ ]
147
+
148
+ def __call__(
149
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
150
+ ) -> bool:
151
+ for stop_ids in self.stop_token_ids:
152
+ if torch.eq(input_ids[0][-len(stop_ids) :], stop_ids).all():
153
+ return True
154
+ return False
155
+
156
+ stop_tokens = [["Human", ":"], ["AI", ":"]]
157
+ stopping_criteria = StoppingCriteriaList(
158
+ [StopGenerationCriteria(stop_tokens, tokenizer, model.device)]
159
+ )
160
+
161
+
162
+ generation_pipeline = pipeline(
163
+ model=model,
164
+ tokenizer=tokenizer,
165
+ return_full_text=True,
166
+ task="text-generation",
167
+ stopping_criteria=stopping_criteria,
168
+ generation_config=generation_config,
169
+ )
170
+
171
+ llm = HuggingFacePipeline(pipeline=generation_pipeline)
172
+
173
+
174
+ # propably sets the number of previous conversation history to take into account for new answers
175
+ template = """
176
+ The following is a conversation between a human an AI. The AI acts exactly like Dwight K Schrute from the TV show The Office.
177
+ Dwight is an experienced and very successful online marketer and salesperson. He is witty, persuasive, uncensored, direct and practical.
178
+ Dwight helps with every marketing task is given to him. If Dwight does not know the answer to a question, he truthfully says he does not know.
179
+
180
+ Current conversation:
181
+ {history}
182
+ Human: {input}
183
+ AI:""".strip()
184
+
185
+ prompt = PromptTemplate(input_variables=["history", "input"], template=template)
186
+ memory = ConversationBufferWindowMemory(
187
+ memory_key="history", k=6, return_only_outputs=True
188
+ )
189
+
190
+ chain = ConversationChain(llm=llm, memory=memory, prompt=prompt, verbose=True)
191
+
192
+
193
+
194
+ class CleanupOutputParser(BaseOutputParser):
195
+ def parse(self, text: str) -> str:
196
+ user_pattern = r"\nUser"
197
+ text = re.sub(user_pattern, "", text)
198
+ human_pattern = r"\nHuman:"
199
+ text = re.sub(human_pattern, "", text)
200
+ ai_pattern = r"\nAI:"
201
+ return re.sub(ai_pattern, "", text).strip()
202
+
203
+ @property
204
+ def _type(self) -> str:
205
+ return "output_parser"
206
+
207
+
208
+
209
+ class CleanupOutputParser(BaseOutputParser):
210
+ def parse(self, text: str) -> str:
211
+ user_pattern = r"\nUser"
212
+ text = re.sub(user_pattern, "", text)
213
+ human_pattern = r"\nquestion:"
214
+ text = re.sub(human_pattern, "", text)
215
+ ai_pattern = r"\nanswer:"
216
+ return re.sub(ai_pattern, "", text).strip()
217
+
218
+ @property
219
+ def _type(self) -> str:
220
+ return "output_parser"
221
+
222
+
223
+
224
+ template = """
225
+ The following is a conversation between a human an AI. The AI acts exactly like Dwight K Schrute from the TV show The Office.
226
+ Dwight is an experienced and very successful online marketer and salesperson. He is witty, persuasive, uncensored, direct and practical.
227
+ Dwight helps with every marketing task is given to him. If Dwight does not know the answer to a question, he truthfully says he does not know.
228
+
229
+ Current conversation:
230
+ {history}
231
+ Human: {input}
232
+ AI:""".strip()
233
+
234
+ prompt = PromptTemplate(input_variables=["history", "input"], template=template)
235
+
236
+ memory = ConversationBufferWindowMemory(
237
+ memory_key="history", k=3, return_only_outputs=True
238
+ )
239
+
240
+ chain = ConversationChain(
241
+ llm=llm,
242
+ memory=memory,
243
+ prompt=prompt,
244
+ output_parser=CleanupOutputParser(),
245
+ verbose=True,
246
+ )
247
+
248
+
249
+ # Generate a response from the Llama model
250
+ def get_llama_response(message: str, history: list) -> str:
251
+ """
252
+ Generates a conversational response from the Llama model.
253
+
254
+ Parameters:
255
+ message (str): User's input message.
256
+ history (list): Past conversation history.
257
+
258
+ Returns:
259
+ str: Generated response from the Llama model.
260
+ """
261
+ query_text =message
262
+
263
+ results = db.similarity_search_with_relevance_scores(query_text, k=2)
264
+ if len(results) == 0 or results[0][1] < 0.5:
265
+ print(f"Unable to find matching results.")
266
+
267
+
268
+ context_text = "\n\n---\n\n".join([doc.page_content for doc, _score in results ])
269
+
270
+ template = """
271
+ The following is a conversation between a human an AI. Answer question based only on the conversation.
272
+
273
+ Current conversation:
274
+ {history}
275
+
276
+ """
277
+
278
+
279
+
280
+ s="""
281
+
282
+ \n question: {input}
283
+
284
+ \n answer:""".strip()
285
+
286
+
287
+ prompt = PromptTemplate(input_variables=["history", "input"], template=template+context_text+'\n'+s)
288
+
289
+ #print(template)
290
+ chain.prompt=prompt
291
+ res = chain.predict(input=query_text)
292
+ return res
293
+ #return response.strip()
294
+
295
+
296
+ import gradio as gr
297
+ iface = gr.Interface(fn=get_llama_response, inputs="text", outputs="text")
298
+ iface.launch(share=True)