Daemontatox commited on
Commit
f9b55bc
·
verified ·
1 Parent(s): 41a5ebf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -55
app.py CHANGED
@@ -16,14 +16,28 @@ model = MllamaForConditionalGeneration.from_pretrained(ckpt,
16
  torch_dtype=torch.bfloat16).to("cuda")
17
  processor = AutoProcessor.from_pretrained(ckpt)
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def process_pdf_file(file_path):
20
  """Convert PDF to images and extract text using PyMuPDF."""
21
  doc = fitz.open(file_path)
22
  images = []
23
  text = ""
24
 
25
- for page in doc:
26
  # Extract text
 
27
  text += page.get_text() + "\n"
28
 
29
  # Convert page to image
@@ -35,43 +49,53 @@ def process_pdf_file(file_path):
35
  doc.close()
36
  return images, text
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  @spaces.GPU()
39
  def bot_streaming(message, history, max_new_tokens=2048):
40
  txt = message["text"]
41
- ext_buffer = f"{txt}"
42
-
43
  messages = []
44
  images = []
45
 
46
- # Process history
 
 
 
 
47
  for i, msg in enumerate(history):
48
  if isinstance(msg[0], tuple):
49
  messages.append({"role": "user", "content": [{"type": "text", "text": history[i+1][0]}, {"type": "image"}]})
50
  messages.append({"role": "assistant", "content": [{"type": "text", "text": history[i+1][1]}]})
51
- images.append(Image.open(msg[0][0]).convert("RGB"))
52
  elif isinstance(history[i-1], tuple) and isinstance(msg[0], str):
53
  pass
54
  elif isinstance(history[i-1][0], str) and isinstance(msg[0], str):
55
  messages.append({"role": "user", "content": [{"type": "text", "text": msg[0]}]})
56
  messages.append({"role": "assistant", "content": [{"type": "text", "text": msg[1]}]})
57
 
58
- # Process current message
59
- if len(message["files"]) == 1:
60
- file_data = message["files"][0]
61
- file_path = file_data["path"] if isinstance(file_data, dict) else file_data
62
-
63
- # Check if file is PDF
64
- if file_path.lower().endswith('.pdf'):
65
- # Process PDF
66
- pdf_images, pdf_text = process_pdf_file(file_path)
67
- images.extend(pdf_images)
68
- txt = f"{txt}\nExtracted text from PDF:\n{pdf_text}"
69
- else:
70
- # Handle regular image
71
- image = Image.open(file_path).convert("RGB")
72
- images.append(image)
73
-
74
- messages.append({"role": "user", "content": [{"type": "text", "text": txt}, {"type": "image"}]})
75
  else:
76
  messages.append({"role": "user", "content": [{"type": "text", "text": txt}]})
77
 
@@ -80,11 +104,13 @@ def bot_streaming(message, history, max_new_tokens=2048):
80
  if not images:
81
  inputs = processor(text=texts, return_tensors="pt").to("cuda")
82
  else:
83
- # Handle multiple images if needed
84
- max_images = 4 # Limit number of images to process
85
  if len(images) > max_images:
86
- images = images[:max_images]
87
- txt += f"\n(Note: Only processing first {max_images} pages of the PDF)"
 
 
88
 
89
  inputs = processor(text=texts, images=images, return_tensors="pt").to("cuda")
90
 
@@ -100,36 +126,47 @@ def bot_streaming(message, history, max_new_tokens=2048):
100
  time.sleep(0.01)
101
  yield buffer
102
 
103
- # Create the Gradio interface
104
- demo = gr.ChatInterface(
105
- fn=bot_streaming,
106
- title="Document Analyzer",
107
- examples=[
108
- [{"text": "Which era does this piece belong to? Give details about the era.", "files":["./examples/rococo.jpg"]}, 200],
109
- [{"text": "Where do the droughts happen according to this diagram?", "files":["./examples/weather_events.png"]}, 250],
110
- [{"text": "What happens when you take out white cat from this chain?", "files":["./examples/ai2d_test.jpg"]}, 250],
111
- [{"text": "How long does it take from invoice date to due date? Be short and concise.", "files":["./examples/invoice.png"]}, 250],
112
- [{"text": "Where to find this monument? Can you give me other recommendations around the area?", "files":["./examples/wat_arun.jpg"]}, 250],
113
- ],
114
- textbox=gr.MultimodalTextbox(),
115
- additional_inputs=[
116
- gr.Slider(
117
- minimum=10,
118
- maximum=500,
119
- value=2048,
120
- step=10,
121
- label="Maximum number of new tokens to generate",
122
- )
123
- ],
124
- cache_examples=False,
125
- description="MllM Document and PDF Analyzer",
126
- stop_btn="Stop Generation",
127
- fill_height=True,
128
- multimodal=True
129
- )
130
 
131
- # Update accepted file types
132
- demo.textbox.file_types = ["image", "pdf"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
  # Launch the interface
135
  demo.launch(debug=True)
 
16
  torch_dtype=torch.bfloat16).to("cuda")
17
  processor = AutoProcessor.from_pretrained(ckpt)
18
 
19
+ class DocumentState:
20
+ def __init__(self):
21
+ self.current_doc_images = []
22
+ self.current_doc_text = ""
23
+ self.doc_type = None # 'pdf' or 'image'
24
+
25
+ def clear(self):
26
+ self.current_doc_images = []
27
+ self.current_doc_text = ""
28
+ self.doc_type = None
29
+
30
+ doc_state = DocumentState()
31
+
32
  def process_pdf_file(file_path):
33
  """Convert PDF to images and extract text using PyMuPDF."""
34
  doc = fitz.open(file_path)
35
  images = []
36
  text = ""
37
 
38
+ for page_num, page in enumerate(doc):
39
  # Extract text
40
+ text += f"\n=== Page {page_num + 1} ===\n"
41
  text += page.get_text() + "\n"
42
 
43
  # Convert page to image
 
49
  doc.close()
50
  return images, text
51
 
52
+ def process_file(file):
53
+ """Process either PDF or image file and update document state."""
54
+ doc_state.clear()
55
+
56
+ if isinstance(file, dict):
57
+ file_path = file["path"]
58
+ else:
59
+ file_path = file
60
+
61
+ if file_path.lower().endswith('.pdf'):
62
+ doc_state.doc_type = 'pdf'
63
+ doc_state.current_doc_images, doc_state.current_doc_text = process_pdf_file(file_path)
64
+ return f"PDF processed successfully. {len(doc_state.current_doc_images)} pages loaded. You can now ask questions about the content."
65
+ else:
66
+ doc_state.doc_type = 'image'
67
+ doc_state.current_doc_images = [Image.open(file_path).convert("RGB")]
68
+ return "Image loaded successfully. You can now ask questions about the content."
69
+
70
  @spaces.GPU()
71
  def bot_streaming(message, history, max_new_tokens=2048):
72
  txt = message["text"]
 
 
73
  messages = []
74
  images = []
75
 
76
+ # Process new file if provided
77
+ if message.get("files") and len(message["files"]) > 0:
78
+ process_file(message["files"][0])
79
+
80
+ # Process history and maintain context
81
  for i, msg in enumerate(history):
82
  if isinstance(msg[0], tuple):
83
  messages.append({"role": "user", "content": [{"type": "text", "text": history[i+1][0]}, {"type": "image"}]})
84
  messages.append({"role": "assistant", "content": [{"type": "text", "text": history[i+1][1]}]})
 
85
  elif isinstance(history[i-1], tuple) and isinstance(msg[0], str):
86
  pass
87
  elif isinstance(history[i-1][0], str) and isinstance(msg[0], str):
88
  messages.append({"role": "user", "content": [{"type": "text", "text": msg[0]}]})
89
  messages.append({"role": "assistant", "content": [{"type": "text", "text": msg[1]}]})
90
 
91
+ # Include document context in the current message
92
+ if doc_state.current_doc_images:
93
+ images.extend(doc_state.current_doc_images)
94
+ context = ""
95
+ if doc_state.doc_type == 'pdf':
96
+ context = f"\nContext from PDF:\n{doc_state.current_doc_text}"
97
+ current_msg = f"{txt}{context}"
98
+ messages.append({"role": "user", "content": [{"type": "text", "text": current_msg}, {"type": "image"}]})
 
 
 
 
 
 
 
 
 
99
  else:
100
  messages.append({"role": "user", "content": [{"type": "text", "text": txt}]})
101
 
 
104
  if not images:
105
  inputs = processor(text=texts, return_tensors="pt").to("cuda")
106
  else:
107
+ # Process images in batches if needed
108
+ max_images = 12 # Increased maximum number of images/pages
109
  if len(images) > max_images:
110
+ # Take evenly spaced samples if we have too many pages
111
+ indices = np.linspace(0, len(images) - 1, max_images, dtype=int)
112
+ images = [images[i] for i in indices]
113
+ txt += f"\n(Note: Analyzing {max_images} evenly distributed pages from the document)"
114
 
115
  inputs = processor(text=texts, images=images, return_tensors="pt").to("cuda")
116
 
 
126
  time.sleep(0.01)
127
  yield buffer
128
 
129
+ def clear_context():
130
+ """Clear the current document context."""
131
+ doc_state.clear()
132
+ return "Document context cleared. You can upload a new document."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
+ # Create the Gradio interface with enhanced features
135
+ with gr.Blocks() as demo:
136
+ gr.Markdown("# Document Analyzer with Chat Support")
137
+ gr.Markdown("Upload a PDF or image and chat about its contents. The context is maintained throughout the conversation.")
138
+
139
+ chatbot = gr.ChatInterface(
140
+ fn=bot_streaming,
141
+ title="Document Chat",
142
+ examples=[
143
+ [{"text": "Which era does this piece belong to? Give details about the era.", "files":["./examples/rococo.jpg"]}, 200],
144
+ [{"text": "Where do the droughts happen according to this diagram?", "files":["./examples/weather_events.png"]}, 250],
145
+ [{"text": "What happens when you take out white cat from this chain?", "files":["./examples/ai2d_test.jpg"]}, 250],
146
+ [{"text": "How long does it take from invoice date to due date? Be short and concise.", "files":["./examples/invoice.png"]}, 250],
147
+ [{"text": "Where to find this monument? Can you give me other recommendations around the area?", "files":["./examples/wat_arun.jpg"]}, 250],
148
+ ],
149
+ textbox=gr.MultimodalTextbox(),
150
+ additional_inputs=[
151
+ gr.Slider(
152
+ minimum=10,
153
+ maximum=2048,
154
+ value=2048,
155
+ step=10,
156
+ label="Maximum number of new tokens to generate",
157
+ )
158
+ ],
159
+ cache_examples=False,
160
+ stop_btn="Stop Generation",
161
+ fill_height=True,
162
+ multimodal=True
163
+ )
164
+
165
+ clear_btn = gr.Button("Clear Document Context")
166
+ clear_btn.click(fn=clear_context)
167
+
168
+ # Update accepted file types
169
+ chatbot.textbox.file_types = ["image", "pdf"]
170
 
171
  # Launch the interface
172
  demo.launch(debug=True)