Tzktz commited on
Commit
b816ac7
·
verified ·
1 Parent(s): 6fc683c

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -0
app.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from unilm.dit.object_detection.ditod import add_vit_config
2
+ import torch
3
+ from detectron2.config import get_cfg
4
+ from detectron2.utils.visualizer import ColorMode, Visualizer
5
+ from detectron2.data import MetadataCatalog
6
+ from detectron2.engine import DefaultPredictor
7
+ import gradio as gr
8
+
9
+ cfg = get_cfg()
10
+ add_vit_config(cfg)
11
+ cfg.merge_from_file("cascade_dit_base.yml")
12
+
13
+ cfg.MODEL.WEIGHTS = "publaynet_dit-b_cascade.pth"
14
+
15
+ cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
+
17
+ predictor = DefaultPredictor(cfg)
18
+
19
+
20
+ def analyze_image(img):
21
+ md = MetadataCatalog.get(cfg.DATASETS.TEST[0])
22
+ if cfg.DATASETS.TEST[0] == 'icdar2019_test':
23
+ md.set(thing_classes=["table"])
24
+ else:
25
+ md.set(thing_classes=["text", "title", "list", "table", "figure"])
26
+
27
+ output = predictor(img)["instances"]
28
+ v = Visualizer(img[:, :, ::-1],
29
+ md,
30
+ scale=1.0,
31
+ instance_mode=ColorMode.SEGMENTATION)
32
+ result = v.draw_instance_predictions(output.to("cpu"))
33
+ result_image = result.get_image()[:, :, ::-1]
34
+
35
+ return result_image
36
+
37
+
38
+ title = " Table Detection with DiT"
39
+ css = ".output-image, .input-image, .image-preview {height: 600px !important}"
40
+
41
+ iface = gr.Interface(
42
+ fn=analyze_image,
43
+ inputs=[gr.Image(type="numpy", label="document image")],
44
+ outputs=[gr.Image(type="numpy", label="detected tables")],
45
+ title=title,
46
+
47
+ css=css,
48
+ )
49
+ iface.launch(debug=True, share=True)