Towhidul commited on
Commit
69f835c
·
verified ·
1 Parent(s): 7228177

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +171 -0
app.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import nest_asyncio
4
+ import re
5
+ from pathlib import Path
6
+ import typing as t
7
+ import base64
8
+ from mimetypes import guess_type
9
+ from llama_parse import LlamaParse
10
+ from llama_index.core.schema import TextNode
11
+ from llama_index.core import VectorStoreIndex, StorageContext, load_index_from_storage, Settings
12
+ from llama_index.embeddings.openai import OpenAIEmbedding
13
+ from llama_index.llms.openai import OpenAI
14
+ from llama_index.core.query_engine import CustomQueryEngine
15
+ from llama_index.multi_modal_llms.openai import OpenAIMultiModal
16
+ from llama_index.core.prompts import PromptTemplate
17
+ from llama_index.core.schema import ImageNode
18
+ from llama_index.core.base.response.schema import Response
19
+ from typing import Optional, List
20
+
21
+ nest_asyncio.apply()
22
+
23
+ # Setting API keys
24
+ os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY')
25
+ os.environ["LLAMA_CLOUD_API_KEY"] = os.getenv('LLAMA_CLOUD_API_KEY')
26
+
27
+ # Initialize Streamlit app
28
+ st.title("Medical Knowledge Base & Query System")
29
+ st.sidebar.title("Settings")
30
+
31
+ # User input for file upload
32
+ st.sidebar.subheader("Upload Knowledge Base")
33
+ uploaded_file = st.sidebar.file_uploader("Upload a medical text book (image)", type=["jpg", "png"])
34
+
35
+ # Initialize the parser
36
+ parser = LlamaParse(
37
+ result_type="markdown",
38
+ parsing_instruction="You are given medical text book on medicine",
39
+ use_vendor_multimodal_model=True,
40
+ vendor_multimodal_model_name="gpt-4o-mini-2024-07-18",
41
+ show_progress=True,
42
+ verbose=True,
43
+ invalidate_cache=True,
44
+ do_not_cache=True,
45
+ num_workers=8,
46
+ language="en"
47
+ )
48
+
49
+ # Function to encode image to data URL
50
+ def local_image_to_data_url(image_path):
51
+ mime_type, _ = guess_type(image_path)
52
+ if mime_type is None:
53
+ mime_type = 'image/png'
54
+ with open(image_path, "rb") as image_file:
55
+ base64_encoded_data = base64.b64encode(image_file.read()).decode('utf-8')
56
+ return f"data:{mime_type};base64,{base64_encoded_data}"
57
+
58
+ # Upload and process file
59
+ if uploaded_file:
60
+ st.sidebar.write("Processing file...")
61
+ file_path = f"files/{uploaded_file.name}"
62
+ with open(file_path, "wb") as f:
63
+ f.write(uploaded_file.read())
64
+
65
+ # Parse the uploaded image
66
+ md_json_objs = parser.get_json_result([file_path])
67
+ image_dicts = parser.get_images(md_json_objs, download_path="data_images")
68
+
69
+ # Extract and display parsed information
70
+ st.write("File successfully processed!")
71
+ st.write(f"Processed file: {uploaded_file.name}")
72
+
73
+ # Function to get sorted image files
74
+ def get_page_number(file_name):
75
+ match = re.search(r"-page-(\d+)\.jpg$", str(file_name))
76
+ if match:
77
+ return int(match.group(1))
78
+ return 0
79
+
80
+ def _get_sorted_image_files(image_dir):
81
+ raw_files = [f for f in list(Path(image_dir).iterdir()) if f.is_file()]
82
+ sorted_files = sorted(raw_files, key=get_page_number)
83
+ return sorted_files
84
+
85
+ def get_text_nodes(md_json_objs, image_dir) -> t.List[TextNode]:
86
+ nodes = []
87
+ for result in md_json_objs:
88
+ json_dicts = result["pages"]
89
+ document_name = result["file_path"].split('/')[-1]
90
+ docs = [doc["md"] for doc in json_dicts]
91
+ image_files = _get_sorted_image_files(image_dir)
92
+ for idx, doc in enumerate(docs):
93
+ node = TextNode(
94
+ text=doc,
95
+ metadata={"image_path": str(image_files[idx]), "page_num": idx + 1, "document_name": document_name},
96
+ )
97
+ nodes.append(node)
98
+ return nodes
99
+
100
+ # Load text nodes
101
+ text_nodes = get_text_nodes(md_json_objs, "data_images")
102
+
103
+ # Setup index and LLM
104
+ embed_model = OpenAIEmbedding(model="text-embedding-3-large")
105
+ llm = OpenAI("gpt-4o-mini-2024-07-18")
106
+ Settings.llm = llm
107
+ Settings.embed_model = embed_model
108
+
109
+ if not os.path.exists("storage_manuals"):
110
+ index = VectorStoreIndex(text_nodes, embed_model=embed_model)
111
+ index.storage_context.persist(persist_dir="./storage_manuals")
112
+ else:
113
+ ctx = StorageContext.from_defaults(persist_dir="./storage_manuals")
114
+ index = load_index_from_storage(ctx)
115
+
116
+ retriever = index.as_retriever()
117
+
118
+ # Query input
119
+ st.subheader("Ask a Question")
120
+ query_text = st.text_input("Enter your query:")
121
+ uploaded_query_image = st.file_uploader("Upload a query image (if any):", type=["jpg", "png"])
122
+
123
+ # Encode query image if provided
124
+ encoded_image_url = None
125
+ if uploaded_query_image:
126
+ query_image_path = f"query_images/{uploaded_query_image.name}"
127
+ with open(query_image_path, "wb") as img_file:
128
+ img_file.write(uploaded_query_image.read())
129
+ encoded_image_url = local_image_to_data_url(query_image_path)
130
+
131
+ # Setup query engine
132
+ QA_PROMPT_TMPL = """
133
+ You are a friendly medical chatbot designed to assist users by providing accurate and detailed responses to medical questions based on information from medical books.
134
+
135
+ ### Context:
136
+ ---------------------
137
+ {context_str}
138
+ ---------------------
139
+
140
+ ### Query Text:
141
+ {query_str}
142
+
143
+ ### Query Image:
144
+ ---------------------
145
+ {encoded_image_url}
146
+ ---------------------
147
+
148
+ ### Answer:
149
+ """
150
+ QA_PROMPT = PromptTemplate(QA_PROMPT_TMPL)
151
+ gpt_4o_mm = OpenAIMultiModal(model="gpt-4o-mini-2024-07-18")
152
+
153
+ class MultimodalQueryEngine(CustomQueryEngine):
154
+ def __init__(self, qa_prompt, retriever, multi_modal_llm, node_postprocessors=[]):
155
+ super().__init__(qa_prompt=qa_prompt, retriever=retriever, multi_modal_llm=multi_modal_llm, node_postprocessors=node_postprocessors)
156
+
157
+ def custom_query(self, query_str):
158
+ nodes = self.retriever.retrieve(query_str)
159
+ image_nodes = [NodeWithScore(node=ImageNode(image_path=n.node.metadata["image_path"])) for n in nodes]
160
+ ctx_str = "\n\n".join([r.node.get_content().strip() for r in nodes])
161
+ fmt_prompt = self.qa_prompt.format(context_str=ctx_str, query_str=query_str, encoded_image_url=encoded_image_url)
162
+ llm_response = self.multi_modal_llm.complete(prompt=fmt_prompt, image_documents=[image_node.node for image_node in image_nodes])
163
+ return Response(response=str(llm_response), source_nodes=nodes, metadata={"text_nodes": text_nodes, "image_nodes": image_nodes})
164
+
165
+ query_engine = MultimodalQueryEngine(QA_PROMPT, retriever, gpt_4o_mm)
166
+
167
+ # Handle query
168
+ if query_text:
169
+ st.write("Querying...")
170
+ response = query_engine.custom_query(query_text)
171
+ st.markdown(response.response)