dh-mc commited on
Commit
de3c294
·
1 Parent(s): 0b2ef31

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -165
app.py CHANGED
@@ -1,18 +1,18 @@
1
  """Main entrypoint for the app."""
 
2
  import os
 
3
  import time
4
  from queue import Queue
5
  from timeit import default_timer as timer
6
 
7
  import gradio as gr
8
- from anyio.from_thread import start_blocking_portal
9
 
10
  from app_modules.init import app_init
11
- from app_modules.utils import print_llm_response, remove_extra_spaces
12
 
13
  llm_loader, qa_chain = app_init()
14
 
15
- show_param_settings = os.environ.get("SHOW_PARAM_SETTINGS") == "true"
16
  share_gradio_app = os.environ.get("SHARE_GRADIO_APP") == "true"
17
  using_openai = os.environ.get("LLM_MODEL_TYPE") == "openai"
18
  chat_history_enabled = os.environ.get("CHAT_HISTORY_ENABLED") == "true"
@@ -28,176 +28,84 @@ href = (
28
  else f"https://huggingface.co/{model}"
29
  )
30
 
31
- name = "PCI DSS v4"
32
-
33
- title = f"""<h1 align="left" style="min-width:200px; margin-top:0;"> Chat with {name} </h1>"""
 
 
34
 
35
- description_top = f"""\
36
  <div align="left">
37
  <p> Currently Running: <a href="{href}">{model}</a></p>
38
  </div>
39
  """
40
 
41
- description = """\
42
- <div align="center" style="margin:16px 0">
43
- The demo is built on <a href="https://github.com/hwchase17/langchain">LangChain</a>.
44
- </div>
45
- """
46
 
47
- CONCURRENT_COUNT = 1
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
 
 
 
 
 
49
 
50
- def qa(chatbot):
51
- user_msg = chatbot[-1][0]
52
  q = Queue()
53
  result = Queue()
54
- job_done = object()
55
-
56
- def task(question, chat_history):
57
- start = timer()
58
- inputs = {"question": question}
59
- inputs["chat_history"] = chat_history
60
- ret = qa_chain.call_chain(inputs, None, q)
61
- end = timer()
62
-
63
- print(f"Completed in {end - start:.3f}s")
64
- print_llm_response(ret)
65
-
66
- q.put(job_done)
67
- result.put(ret)
68
-
69
- with start_blocking_portal() as portal:
70
- chat_history = []
71
- if chat_history_enabled:
72
- for i in range(len(chatbot) - 1):
73
- element = chatbot[i]
74
- item = (element[0] or "", element[1] or "")
75
- chat_history.append(item)
76
-
77
- portal.start_task_soon(task, user_msg, chat_history)
78
-
79
- content = ""
80
- count = 2 if len(chat_history) > 0 else 1
81
-
82
- while count > 0:
83
- while q.empty():
84
- print("nothing generated yet - retry in 0.5s")
85
- time.sleep(0.5)
86
-
87
- for next_token in llm_loader.streamer:
88
- if next_token is job_done:
89
- break
90
- content += next_token or ""
91
- chatbot[-1][1] = remove_extra_spaces(content)
92
-
93
- if count == 1:
94
- yield chatbot
95
-
96
- count -= 1
97
-
98
- chatbot[-1][1] += "\n\nSources:\n"
99
- ret = result.get()
100
- titles = []
101
- for doc in ret["source_documents"]:
102
- page = doc.metadata["page"] + 1
103
- url = f"{doc.metadata['url']}#page={page}"
104
- file_name = doc.metadata["source"].split("/")[-1]
105
- title = f"{file_name} Page: {page}"
106
- if title not in titles:
107
- titles.append(title)
108
- chatbot[-1][1] += f"1. [{title}]({url})\n"
109
-
110
- yield chatbot
111
-
112
-
113
- with open("assets/custom.css", "r", encoding="utf-8") as f:
114
- customCSS = f.read()
115
-
116
- with gr.Blocks(css=customCSS) as demo:
117
- user_question = gr.State("")
118
- with gr.Row():
119
- gr.HTML(title)
120
- gr.Markdown(description_top)
121
- with gr.Row(equal_height=True):
122
- with gr.Column(scale=5):
123
- with gr.Row():
124
- chatbot = gr.Chatbot(elem_id="inflaton_chatbot", height="100%")
125
- with gr.Row():
126
- with gr.Column(scale=2):
127
- user_input = gr.Textbox(
128
- show_label=False,
129
- placeholder="Enter your question here",
130
- container=False,
131
- )
132
- with gr.Column(
133
- min_width=70,
134
- ):
135
- submitBtn = gr.Button("Send")
136
- with gr.Column(
137
- min_width=70,
138
- ):
139
- clearBtn = gr.Button("Clear")
140
- if show_param_settings:
141
- with gr.Column():
142
- with gr.Column(
143
- min_width=50,
144
- ):
145
- with gr.Tab(label="Parameter Setting"):
146
- gr.Markdown("# Parameters")
147
- top_p = gr.Slider(
148
- minimum=-0,
149
- maximum=1.0,
150
- value=0.95,
151
- step=0.05,
152
- # interactive=True,
153
- label="Top-p",
154
- )
155
- temperature = gr.Slider(
156
- minimum=0.1,
157
- maximum=2.0,
158
- value=0,
159
- step=0.1,
160
- # interactive=True,
161
- label="Temperature",
162
- )
163
- max_new_tokens = gr.Slider(
164
- minimum=0,
165
- maximum=2048,
166
- value=2048,
167
- step=8,
168
- # interactive=True,
169
- label="Max Generation Tokens",
170
- )
171
- max_context_length_tokens = gr.Slider(
172
- minimum=0,
173
- maximum=4096,
174
- value=4096,
175
- step=128,
176
- # interactive=True,
177
- label="Max Context Tokens",
178
- )
179
- gr.Markdown(description)
180
-
181
- def chat(user_message, history):
182
- return "", history + [[user_message, None]]
183
-
184
- user_input.submit(
185
- chat, [user_input, chatbot], [user_input, chatbot], queue=True
186
- ).then(qa, chatbot, chatbot)
187
-
188
- submitBtn.click(
189
- chat, [user_input, chatbot], [user_input, chatbot], queue=True, api_name="chat"
190
- ).then(qa, chatbot, chatbot)
191
-
192
- def reset():
193
- return "", []
194
-
195
- clearBtn.click(
196
- reset,
197
- outputs=[user_input, chatbot],
198
- show_progress=True,
199
- api_name="reset",
200
- )
201
-
202
- demo.title = "Chat with PCI DSS v4"
203
- demo.queue().launch(share=share_gradio_app)
 
1
  """Main entrypoint for the app."""
2
+
3
  import os
4
+ from threading import Thread
5
  import time
6
  from queue import Queue
7
  from timeit import default_timer as timer
8
 
9
  import gradio as gr
 
10
 
11
  from app_modules.init import app_init
12
+ from app_modules.utils import print_llm_response
13
 
14
  llm_loader, qa_chain = app_init()
15
 
 
16
  share_gradio_app = os.environ.get("SHARE_GRADIO_APP") == "true"
17
  using_openai = os.environ.get("LLM_MODEL_TYPE") == "openai"
18
  chat_history_enabled = os.environ.get("CHAT_HISTORY_ENABLED") == "true"
 
28
  else f"https://huggingface.co/{model}"
29
  )
30
 
31
+ title = "Chat with PCI DSS v4"
32
+ examples = [
33
+ "What's PCI DSS?",
34
+ "Can you summarize the changes made from PCI DSS version 3.2.1 to version 4.0?",
35
+ ]
36
 
37
+ description = f"""\
38
  <div align="left">
39
  <p> Currently Running: <a href="{href}">{model}</a></p>
40
  </div>
41
  """
42
 
 
 
 
 
 
43
 
44
+ def task(question, chat_history, q, result):
45
+ start = timer()
46
+ inputs = {"question": question, "chat_history": chat_history}
47
+ ret = qa_chain.call_chain(inputs, None, q)
48
+ end = timer()
49
+
50
+ print(f"Completed in {end - start:.3f}s")
51
+ print_llm_response(ret)
52
+
53
+ result.put(ret)
54
+
55
+
56
+ def predict(message, history):
57
+ print("predict:", message, history)
58
 
59
+ chat_history = []
60
+ if chat_history_enabled:
61
+ for element in history:
62
+ item = (element[0] or "", element[1] or "")
63
+ chat_history.append(item)
64
 
 
 
65
  q = Queue()
66
  result = Queue()
67
+ t = Thread(target=task, args=(message, chat_history, q, result))
68
+ t.start() # Starting the generation in a separate thread.
69
+
70
+ partial_message = ""
71
+ count = 2 if len(chat_history) > 0 else 1
72
+
73
+ while count > 0:
74
+ while q.empty():
75
+ print("nothing generated yet - retry in 0.5s")
76
+ time.sleep(0.5)
77
+
78
+ for next_token in llm_loader.streamer:
79
+ partial_message += next_token or ""
80
+ # partial_message = remove_extra_spaces(partial_message)
81
+ yield partial_message
82
+
83
+ if count == 2:
84
+ partial_message += "\n\n"
85
+
86
+ count -= 1
87
+
88
+ partial_message += "\n\nSources:\n"
89
+ ret = result.get()
90
+ titles = []
91
+ for doc in ret["source_documents"]:
92
+ page = doc.metadata["page"] + 1
93
+ url = f"{doc.metadata['url']}#page={page}"
94
+ file_name = doc.metadata["source"].split("/")[-1]
95
+ title = f"{file_name} Page: {page}"
96
+ if title not in titles:
97
+ titles.append(title)
98
+ partial_message += f"1. [{title}]({url})\n"
99
+
100
+ yield partial_message
101
+
102
+
103
+ # Setting up the Gradio chat interface.
104
+ gr.ChatInterface(
105
+ predict,
106
+ title=title,
107
+ description=description,
108
+ examples=examples,
109
+ ).launch(
110
+ share=share_gradio_app
111
+ ) # Launching the web interface.