wjbmattingly commited on
Commit
1596eba
·
verified ·
1 Parent(s): a281612

Upload 5 files

Browse files
Files changed (5) hide show
  1. .gitattributes +2 -0
  2. 32600926.jpg +3 -0
  3. app.py +283 -0
  4. call.py +73 -0
  5. test.jpg +3 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ 32600926.jpg filter=lfs diff=lfs merge=lfs -text
37
+ test.jpg filter=lfs diff=lfs merge=lfs -text
32600926.jpg ADDED

Git LFS Details

  • SHA256: e7b075fa32a57b9a527521ed20b638a901866ac53943b103255808b5269292d7
  • Pointer size: 131 Bytes
  • Size of remote file: 140 kB
app.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, Qwen2_5_VLForConditionalGeneration
4
+ from qwen_vl_utils import process_vision_info
5
+ import torch
6
+ from PIL import Image
7
+ import subprocess
8
+ from datetime import datetime
9
+ import numpy as np
10
+ import os
11
+ from gliner import GLiNER
12
+ import json
13
+ import tempfile
14
+ import zipfile
15
+ import base64
16
+ import io
17
+
18
+ # Initialize GLiNER model
19
+ gliner_model = GLiNER.from_pretrained("knowledgator/modern-gliner-bi-large-v1.0")
20
+
21
+ DEFAULT_NER_LABELS = "person, organization, location, date, event"
22
+
23
+ # subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
24
+
25
+ # models = {
26
+ # "Qwen/Qwen2-VL-7B-Instruct": AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", trust_remote_code=True, torch_dtype="auto", _attn_implementation="flash_attention_2").cuda().eval()
27
+
28
+ # }
29
+
30
+ class TextWithMetadata(list):
31
+ def __init__(self, *args, **kwargs):
32
+ super().__init__(*args)
33
+ self.original_text = kwargs.get('original_text', '')
34
+ self.entities = kwargs.get('entities', [])
35
+
36
+ def array_to_image_path(image_array):
37
+ # Convert numpy array to PIL Image
38
+ img = Image.fromarray(np.uint8(image_array))
39
+ img.thumbnail((1024, 1024))
40
+
41
+ # Generate a unique filename using timestamp
42
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
43
+ filename = f"image_{timestamp}.png"
44
+
45
+ # Save the image
46
+ img.save(filename)
47
+
48
+ # Get the full path of the saved image
49
+ full_path = os.path.abspath(filename)
50
+
51
+ return full_path
52
+
53
+ models = {
54
+ "Qwen/Qwen2.5-VL-7B-Instruct": Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", trust_remote_code=True, torch_dtype="auto").cuda().eval()
55
+
56
+ }
57
+
58
+ processors = {
59
+ "Qwen/Qwen2.5-VL-7B-Instruct": AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", trust_remote_code=True)
60
+ }
61
+
62
+ DESCRIPTION = "This demo uses[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct)"
63
+
64
+ kwargs = {}
65
+ kwargs['torch_dtype'] = torch.bfloat16
66
+
67
+ user_prompt = '<|user|>\n'
68
+ assistant_prompt = '<|assistant|>\n'
69
+ prompt_suffix = "<|end|>\n"
70
+
71
+ @spaces.GPU
72
+ def run_example(image, model_id="Qwen/Qwen2.5-VL-7B-Instruct", run_ner=False, ner_labels=DEFAULT_NER_LABELS):
73
+ # First get the OCR text
74
+ text_input = "Convert the image to text."
75
+
76
+ # Handle various image input formats
77
+ if image is None:
78
+ raise ValueError("Image path is None.")
79
+
80
+ # Case 1: Image is a dictionary with base64 data (from API calls)
81
+ if isinstance(image, dict) and 'data' in image and isinstance(image['data'], str):
82
+ if image['data'].startswith('data:image'):
83
+ # Extract the base64 part after the comma
84
+ base64_data = image['data'].split(',', 1)[1]
85
+ # Convert base64 to bytes and then to PIL Image
86
+ image_bytes = base64.b64decode(base64_data)
87
+ pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
88
+ # Convert to numpy array for further processing
89
+ image = np.array(pil_image)
90
+
91
+ # Convert numpy array to image path
92
+ image_path = array_to_image_path(image)
93
+
94
+ model = models[model_id]
95
+ processor = processors[model_id]
96
+
97
+ prompt = f"{user_prompt}<|image_1|>\n{text_input}{prompt_suffix}{assistant_prompt}"
98
+ image = Image.fromarray(image).convert("RGB")
99
+ messages = [
100
+ {
101
+ "role": "user",
102
+ "content": [
103
+ {
104
+ "type": "image",
105
+ "image": image_path,
106
+ },
107
+ {"type": "text", "text": text_input},
108
+ ],
109
+ }
110
+ ]
111
+
112
+ # Preparation for inference
113
+ text = processor.apply_chat_template(
114
+ messages, tokenize=False, add_generation_prompt=True
115
+ )
116
+ image_inputs, video_inputs = process_vision_info(messages)
117
+ inputs = processor(
118
+ text=[text],
119
+ images=image_inputs,
120
+ videos=video_inputs,
121
+ padding=True,
122
+ return_tensors="pt",
123
+ )
124
+ inputs = inputs.to("cuda")
125
+
126
+ # Inference: Generation of the output
127
+ generated_ids = model.generate(**inputs, max_new_tokens=1024)
128
+ generated_ids_trimmed = [
129
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
130
+ ]
131
+ output_text = processor.batch_decode(
132
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
133
+ )
134
+
135
+ ocr_text = output_text[0]
136
+
137
+ # If NER is enabled, process the OCR text
138
+ if run_ner:
139
+ ner_results = gliner_model.predict_entities(
140
+ ocr_text,
141
+ ner_labels.split(","),
142
+ threshold=0.3
143
+ )
144
+
145
+ # Create a list of tuples (text, label) for highlighting
146
+ highlighted_text = []
147
+ last_end = 0
148
+
149
+ # Sort entities by start position
150
+ sorted_entities = sorted(ner_results, key=lambda x: x["start"])
151
+
152
+ # Process each entity and add non-entity text segments
153
+ for entity in sorted_entities:
154
+ # Add non-entity text before the current entity
155
+ if last_end < entity["start"]:
156
+ highlighted_text.append((ocr_text[last_end:entity["start"]], None))
157
+
158
+ # Add the entity text with its label
159
+ highlighted_text.append((
160
+ ocr_text[entity["start"]:entity["end"]],
161
+ entity["label"]
162
+ ))
163
+ last_end = entity["end"]
164
+
165
+ # Add any remaining text after the last entity
166
+ if last_end < len(ocr_text):
167
+ highlighted_text.append((ocr_text[last_end:], None))
168
+
169
+ # Create TextWithMetadata instance with the highlighted text and metadata
170
+ result = TextWithMetadata(highlighted_text, original_text=ocr_text, entities=ner_results)
171
+ return result, result # Return twice: once for display, once for state
172
+
173
+ # If NER is disabled, return the text without highlighting
174
+ result = TextWithMetadata([(ocr_text, None)], original_text=ocr_text, entities=[])
175
+ return result, result # Return twice: once for display, once for state
176
+
177
+
178
+ with gr.Blocks(css=css) as demo:
179
+ # Add state variables to store OCR results
180
+ ocr_state = gr.State()
181
+
182
+ gr.Image("Caracal.jpg", interactive=False)
183
+ with gr.Tab(label="Image Input", elem_classes="tabs"):
184
+ with gr.Row():
185
+ with gr.Column(elem_classes="input-container"):
186
+ input_img = gr.Image(label="Input Picture", elem_classes="gr-image-input")
187
+ model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value="Qwen/Qwen2.5-VL-7B-Instruct", elem_classes="gr-dropdown")
188
+
189
+ # Add NER controls
190
+ with gr.Row():
191
+ ner_checkbox = gr.Checkbox(label="Run Named Entity Recognition", value=False)
192
+ ner_labels = gr.Textbox(
193
+ label="NER Labels (comma-separated)",
194
+ value=DEFAULT_NER_LABELS,
195
+ visible=False
196
+ )
197
+
198
+ submit_btn = gr.Button(value="Submit", elem_classes="submit-btn")
199
+ with gr.Column(elem_classes="output-container"):
200
+ output_text = gr.HighlightedText(label="Output Text", elem_id="output")
201
+
202
+ # Show/hide NER labels based on checkbox
203
+ ner_checkbox.change(
204
+ lambda x: gr.update(visible=x),
205
+ inputs=[ner_checkbox],
206
+ outputs=[ner_labels]
207
+ )
208
+
209
+ # Modify the submit button click handler to update state
210
+ submit_btn.click(
211
+ run_example,
212
+ inputs=[input_img, model_selector, ner_checkbox, ner_labels],
213
+ outputs=[output_text, ocr_state] # Add ocr_state to outputs
214
+ )
215
+ with gr.Row():
216
+ filename = gr.Textbox(label="Save filename (without extension)", placeholder="Enter filename to save")
217
+ download_btn = gr.Button("Download Image & Text", elem_classes="submit-btn")
218
+ download_output = gr.File(label="Download")
219
+
220
+ # Modify create_zip to use the state data
221
+ def create_zip(image, fname, ocr_result):
222
+ # Validate inputs
223
+ if not fname or image is None: # Changed the validation check
224
+ return None
225
+
226
+ try:
227
+ # Convert numpy array to PIL Image if needed
228
+ if isinstance(image, np.ndarray):
229
+ image = Image.fromarray(image)
230
+ elif not isinstance(image, Image.Image):
231
+ return None
232
+
233
+ with tempfile.TemporaryDirectory() as temp_dir:
234
+ # Save image
235
+ img_path = os.path.join(temp_dir, f"{fname}.png")
236
+ image.save(img_path)
237
+
238
+ # Use the OCR result from state
239
+ original_text = ocr_result.original_text if ocr_result else ""
240
+ entities = ocr_result.entities if ocr_result else []
241
+
242
+ # Save text
243
+ txt_path = os.path.join(temp_dir, f"{fname}.txt")
244
+ with open(txt_path, 'w', encoding='utf-8') as f:
245
+ f.write(original_text)
246
+
247
+ # Create JSON with text and entities
248
+ json_data = {
249
+ "text": original_text,
250
+ "entities": entities,
251
+ "image_file": f"{fname}.png"
252
+ }
253
+
254
+ # Save JSON
255
+ json_path = os.path.join(temp_dir, f"{fname}.json")
256
+ with open(json_path, 'w', encoding='utf-8') as f:
257
+ json.dump(json_data, f, indent=2, ensure_ascii=False)
258
+
259
+ # Create zip file
260
+ output_dir = "downloads"
261
+ os.makedirs(output_dir, exist_ok=True)
262
+ zip_path = os.path.join(output_dir, f"{fname}.zip")
263
+
264
+ with zipfile.ZipFile(zip_path, 'w') as zipf:
265
+ zipf.write(img_path, os.path.basename(img_path))
266
+ zipf.write(txt_path, os.path.basename(txt_path))
267
+ zipf.write(json_path, os.path.basename(json_path))
268
+
269
+ return zip_path
270
+
271
+ except Exception as e:
272
+ print(f"Error creating zip: {str(e)}")
273
+ return None
274
+
275
+ # Update the download button click handler to include state
276
+ download_btn.click(
277
+ create_zip,
278
+ inputs=[input_img, filename, ocr_state],
279
+ outputs=[download_output]
280
+ )
281
+
282
+ demo.queue(api_open=False)
283
+ demo.launch(debug=True)
call.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ import json
4
+ import dotenv
5
+ from PIL import Image
6
+ import io
7
+ import base64
8
+ import numpy as np
9
+ from gradio_client import Client
10
+
11
+ # Load environment variables from .env file
12
+ dotenv.load_dotenv()
13
+
14
+ # Get Hugging Face token from environment
15
+ hf_token = os.getenv("HF_TOKEN")
16
+
17
+ # Create client for the Hugging Face Space with authentication
18
+ client = Client("wjbmattingly/caracal", hf_token=hf_token)
19
+
20
+ # Example usage
21
+ if __name__ == "__main__":
22
+ # Path to input image
23
+ image_path = "test.jpg"
24
+
25
+ # Check if image exists
26
+ if not os.path.exists(image_path):
27
+ print(f"Error: Image file not found: {image_path}")
28
+ exit(1)
29
+
30
+ # Print available API endpoints (for reference)
31
+ print("Available API endpoints:")
32
+ for endpoint in client.endpoints:
33
+ print(f" - {endpoint}")
34
+
35
+ print(f"\nProcessing image: {image_path}")
36
+
37
+ # Convert the image to base64
38
+ with open(image_path, "rb") as image_file:
39
+ image_bytes = image_file.read()
40
+
41
+ # Convert to base64 string
42
+ base64_image = base64.b64encode(image_bytes).decode("utf-8")
43
+
44
+ # Format as expected by Gradio's ImageData model
45
+ image_data = {
46
+ "data": f"data:image/jpeg;base64,{base64_image}",
47
+ "is_file": False
48
+ }
49
+
50
+ # Call the run_example function with the appropriate parameters
51
+ result = client.predict(
52
+ image_data, # Base64 encoded image
53
+ "Qwen/Qwen2.5-VL-7B-Instruct", # Model selection
54
+ False, # NER disabled
55
+ "person, organization, location, date, event", # Default NER labels
56
+ api_name="/run_example" # API endpoint name
57
+ )
58
+
59
+ # Process and display the result
60
+ if result:
61
+ print("\nExtracted Text:")
62
+ if isinstance(result, list):
63
+ full_text = ""
64
+ for segment in result:
65
+ if isinstance(segment, list) and len(segment) > 0:
66
+ text = segment[0]
67
+ full_text += text
68
+ print(text, end="")
69
+ print("\n")
70
+ else:
71
+ print(result)
72
+ else:
73
+ print("No result returned from API")
test.jpg ADDED

Git LFS Details

  • SHA256: 8fd2d36a5bccfb9318fb5bb14df325f0d25a1a44704c7c83b1f634c6e15f1249
  • Pointer size: 131 Bytes
  • Size of remote file: 196 kB