Sbnos commited on
Commit
69cffa4
·
verified ·
1 Parent(s): 20a674a

cgpt latest check

Browse files
Files changed (1) hide show
  1. app.py +40 -125
app.py CHANGED
@@ -3,66 +3,28 @@ import os
3
  from langchain_community.vectorstores import Chroma
4
  from langchain_community.embeddings import HuggingFaceBgeEmbeddings
5
  from langchain_community.llms import Together
6
- from langchain import hub
7
- from operator import itemgetter
8
- from langchain.schema.runnable import RunnableParallel
9
  from langchain.schema import format_document
10
- from typing import List, Tuple
11
- from langchain.chains import LLMChain
12
- from langchain.chains import RetrievalQA
13
- from langchain.schema.output_parser import StrOutputParser
14
- from langchain_community.chat_message_histories import StreamlitChatMessageHistory
15
  from langchain.memory import ConversationBufferMemory
16
- from langchain.chains import ConversationalRetrievalChain
17
- from langchain.memory import ConversationSummaryMemory
18
- from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder, PromptTemplate
19
- from langchain.schema.runnable import RunnableLambda, RunnablePassthrough
20
  import time
21
 
22
  # Load the embedding function
23
  model_name = "BAAI/bge-base-en"
24
  encode_kwargs = {'normalize_embeddings': True} # set True to compute cosine similarity
25
-
26
- embedding_function = HuggingFaceBgeEmbeddings(
27
- model_name=model_name,
28
- encode_kwargs=encode_kwargs
29
- )
30
 
31
  # Load the LLM
32
- llm = Together(
33
- model="mistralai/Mixtral-8x22B-Instruct-v0.1",
34
- temperature=0.2,
35
- max_tokens=19096,
36
- top_k=10,
37
- together_api_key=os.environ['pilotikval']
38
- )
39
-
40
- # Load the summarizeLLM
41
- llmc = Together(
42
- model="mistralai/Mixtral-8x7B-Instruct-v0.1",
43
- temperature=0.2,
44
- max_tokens=1024,
45
- top_k=1,
46
- together_api_key=os.environ['pilotikval']
47
- )
48
-
49
- # Load the reranking model
50
- reranker = Together(
51
- model="mistralai/Mixtral-8x22B-Instruct-v0.1",
52
- temperature=0.2,
53
- max_tokens=512,
54
- top_k=10,
55
- together_api_key=os.environ['pilotikval']
56
- )
57
 
58
  msgs = StreamlitChatMessageHistory(key="langchain_messages")
59
  memory = ConversationBufferMemory(chat_memory=msgs)
60
 
61
  DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
62
 
63
- def _combine_documents(
64
- docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"
65
- ):
66
  doc_strings = [format_document(doc, document_prompt) for doc in docs]
67
  return document_separator.join(doc_strings)
68
 
@@ -76,7 +38,9 @@ def render_message_with_copy_button(role: str, content: str, key: str):
76
  html_code = f"""
77
  <div class="message" style="position: relative; padding-right: 40px;">
78
  <div class="message-content">{content}</div>
79
- <button onclick="copyToClipboard('{key}')" style="position: absolute; right: 0; top: 0;">Copy</button>
 
 
80
  </div>
81
  <textarea id="{key}" style="display:none;">{content}</textarea>
82
  <script>
@@ -92,16 +56,28 @@ def render_message_with_copy_button(role: str, content: str, key: str):
92
  """
93
  st.write(html_code, unsafe_allow_html=True)
94
 
95
- # Define the Streamlit app
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  def app():
97
  with st.sidebar:
98
  st.title("dochatter")
99
- # Create a dropdown selection box
100
- option = st.selectbox(
101
- 'Which retriever would you like to use?',
102
- ('General Medicine', 'RespiratoryFishman', 'RespiratoryMurray', 'MedMRCP2', 'OldMedicine')
103
- )
104
- # Depending on the selected option, choose the appropriate retriever
105
  if option == 'RespiratoryFishman':
106
  persist_directory = "./respfishmandbcud/"
107
  vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name="fishmannotescud")
@@ -123,91 +99,30 @@ def app():
123
  vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name="mrcppassmednotes")
124
  retriever = vectordb.as_retriever(search_kwargs={"k": 5})
125
 
126
- # Session State
127
  if "messages" not in st.session_state.keys():
128
  st.session_state.messages = [{"role": "assistant", "content": "How may I help you?"}]
129
 
130
- _template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question which contains the themes of the conversation. Do not write the question. Do not write the answer.
131
- Chat History:
132
- {chat_history}
133
- Follow Up Input: {question}
134
- Standalone question:"""
135
- CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
136
-
137
- template = """You are helping a doctor. Answer with what you know from the context provided. Please be as detailed and thorough. Answer the question based on the following context:
138
- {context}
139
- Question: {question}
140
- """
141
- ANSWER_PROMPT = ChatPromptTemplate.from_template(template)
142
-
143
- _inputs = RunnableParallel(
144
- standalone_question=RunnablePassthrough.assign(
145
- chat_history=lambda x: chistory
146
- ) | CONDENSE_QUESTION_PROMPT | llmc | StrOutputParser(),
147
- )
148
- _context = {
149
- "context": itemgetter("standalone_question") | retriever | _combine_documents,
150
- "question": lambda x: x["standalone_question"],
151
- }
152
- conversational_qa_chain = _inputs | _context | ANSWER_PROMPT | llm
153
-
154
  st.header("Ask Away!")
155
  for i, message in enumerate(st.session_state.messages):
156
  with st.chat_message(message["role"]):
157
  render_message_with_copy_button(message["role"], message["content"], key=f"message-{i}")
158
  store_chat_history(message["role"], message["content"])
159
 
160
- prompts2 = st.chat_input("Say something")
161
-
162
- if prompts2:
163
- st.session_state.messages.append({"role": "user", "content": prompts2})
164
  with st.chat_message("user"):
165
- st.write(prompts2)
166
 
167
- if st.session_state.messages[-1]["role"] != "assistant":
168
  with st.chat_message("assistant"):
169
  with st.spinner("Thinking..."):
170
- for _ in range(3): # Retry up to 3 times
171
- try:
172
- responses = generate_multiple_responses(
173
- conversational_qa_chain,
174
- {
175
- "question": prompts2,
176
- "chat_history": "\n".join([f"{msg['role']}: {msg['content']}" for msg in chistory])
177
- },
178
- num_responses=5
179
- )
180
- best_response = rerank_responses(reranker, responses)
181
- st.write(best_response)
182
- message = {"role": "assistant", "content": best_response}
183
- st.session_state.messages.append(message)
184
- break
185
- except Exception as e:
186
- st.error(f"An error occurred: {e}")
187
- time.sleep(2) # Wait 2 seconds before retrying
188
-
189
- def generate_multiple_responses(chain, inputs, num_responses=5):
190
- responses = []
191
- for _ in range(num_responses):
192
- response = chain.invoke(inputs)
193
- responses.append(response)
194
- return responses
195
-
196
- def rerank_responses(reranker, responses):
197
- scores = []
198
- for response in responses:
199
- score = reranker.invoke(response)
200
- scores.append(score)
201
- best_response_idx = scores.index(max(scores))
202
- return responses[best_response_idx]
203
-
204
- def stream_conversational_qa_chain(chain, inputs):
205
- try:
206
- response = chain.invoke(inputs)
207
- for part in response:
208
- yield part
209
- except Exception as e:
210
- raise e
211
 
212
  if __name__ == '__main__':
213
  app()
 
3
  from langchain_community.vectorstores import Chroma
4
  from langchain_community.embeddings import HuggingFaceBgeEmbeddings
5
  from langchain_community.llms import Together
6
+ from langchain.prompts import ChatPromptTemplate, PromptTemplate
 
 
7
  from langchain.schema import format_document
8
+ from typing import List
 
 
 
 
9
  from langchain.memory import ConversationBufferMemory
10
+ from langchain.schema.runnable import RunnableParallel, RunnablePassthrough, StrOutputParser
11
+ from langchain_core.chat_message_histories import StreamlitChatMessageHistory
 
 
12
  import time
13
 
14
  # Load the embedding function
15
  model_name = "BAAI/bge-base-en"
16
  encode_kwargs = {'normalize_embeddings': True} # set True to compute cosine similarity
17
+ embedding_function = HuggingFaceBgeEmbeddings(model_name=model_name, encode_kwargs=encode_kwargs)
 
 
 
 
18
 
19
  # Load the LLM
20
+ llm = Together(model="mistralai/Mixtral-8x22B-Instruct-v0.1", temperature=0.2, max_tokens=19096, top_k=10, together_api_key=os.environ['pilotikval'], streaming=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  msgs = StreamlitChatMessageHistory(key="langchain_messages")
23
  memory = ConversationBufferMemory(chat_memory=msgs)
24
 
25
  DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
26
 
27
+ def _combine_documents(docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"):
 
 
28
  doc_strings = [format_document(doc, document_prompt) for doc in docs]
29
  return document_separator.join(doc_strings)
30
 
 
38
  html_code = f"""
39
  <div class="message" style="position: relative; padding-right: 40px;">
40
  <div class="message-content">{content}</div>
41
+ <button onclick="copyToClipboard('{key}')" style="position: absolute; right: 0; top: 0; background-color: transparent; border: none; cursor: pointer;">
42
+ <img src="https://img.icons8.com/material-outlined/24/grey/copy.png" alt="Copy">
43
+ </button>
44
  </div>
45
  <textarea id="{key}" style="display:none;">{content}</textarea>
46
  <script>
 
56
  """
57
  st.write(html_code, unsafe_allow_html=True)
58
 
59
+ def get_streaming_response(user_query, chat_history):
60
+ template = """
61
+ You are a knowledgeable assistant. Provide a detailed and thorough answer to the question based on the following context:
62
+
63
+ Chat history: {chat_history}
64
+
65
+ User question: {user_question}
66
+ """
67
+ prompt = ChatPromptTemplate.from_template(template)
68
+
69
+ inputs = {
70
+ "chat_history": chat_history,
71
+ "user_question": user_query
72
+ }
73
+
74
+ chain = prompt | llm | StrOutputParser()
75
+ return chain.stream(inputs)
76
+
77
  def app():
78
  with st.sidebar:
79
  st.title("dochatter")
80
+ option = st.selectbox('Which retriever would you like to use?', ('General Medicine', 'RespiratoryFishman', 'RespiratoryMurray', 'MedMRCP2', 'OldMedicine'))
 
 
 
 
 
81
  if option == 'RespiratoryFishman':
82
  persist_directory = "./respfishmandbcud/"
83
  vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name="fishmannotescud")
 
99
  vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding_function, collection_name="mrcppassmednotes")
100
  retriever = vectordb.as_retriever(search_kwargs={"k": 5})
101
 
 
102
  if "messages" not in st.session_state.keys():
103
  st.session_state.messages = [{"role": "assistant", "content": "How may I help you?"}]
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  st.header("Ask Away!")
106
  for i, message in enumerate(st.session_state.messages):
107
  with st.chat_message(message["role"]):
108
  render_message_with_copy_button(message["role"], message["content"], key=f"message-{i}")
109
  store_chat_history(message["role"], message["content"])
110
 
111
+ user_query = st.chat_input("Say something")
112
+ if user_query:
113
+ st.session_state.messages.append({"role": "user", "content": user_query})
 
114
  with st.chat_message("user"):
115
+ st.write(user_query)
116
 
 
117
  with st.chat_message("assistant"):
118
  with st.spinner("Thinking..."):
119
+ chat_history = "\n".join([f"{msg['role']}: {msg['content']}" for msg in chistory])
120
+ response_generator = get_streaming_response(user_query, chat_history)
121
+ response_text = ""
122
+ for response_part in response_generator:
123
+ response_text += response_part
124
+ st.write(response_text)
125
+ st.session_state.messages.append({"role": "assistant", "content": response_text})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
  if __name__ == '__main__':
128
  app()