paugar commited on
Commit
2ea737d
·
verified ·
1 Parent(s): 8a22c18

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -4
app.py CHANGED
@@ -5,8 +5,7 @@ import torch
5
  import torch.nn.functional as F
6
  import torchvision.transforms as transforms
7
  import torchvision
8
- from torchkeras.data import get_url_img
9
- from torchkeras import plots
10
  import numpy as np
11
  import yaml
12
  from huggingface_hub import hf_hub_download
@@ -28,8 +27,26 @@ def process_img(image):
28
  with torch.no_grad():
29
  result = model(source=image)
30
  if len(result[0].boxes)>0:
31
- vis = plots.plot_detection(image,boxes=result[0].boxes,
32
- class_names=list(result[0].names.values()), min_score=0.2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  else:
34
  vis = img
35
  return vis
 
5
  import torch.nn.functional as F
6
  import torchvision.transforms as transforms
7
  import torchvision
8
+ from ultralytics.utils.plotting import Annotator, colors
 
9
  import numpy as np
10
  import yaml
11
  from huggingface_hub import hf_hub_download
 
27
  with torch.no_grad():
28
  result = model(source=image)
29
  if len(result[0].boxes)>0:
30
+ ann=Annotator(im=img)
31
+ boxes=result[0].boxes.xyxy
32
+ for element in boxes:
33
+ box=np.array(element.cpu()).flatten()
34
+ if result[0].boxes.cls[0].cpu().numpy()==2:
35
+ label='car'
36
+ if result[0].boxes.cls[0].cpu().numpy()==0:
37
+ label='bicycle'
38
+ if result[0].boxes.cls[0].cpu().numpy()==1:
39
+ label='bus'
40
+ if result[0].boxes.cls[0].cpu().numpy()==3:
41
+ label='motorcycle'
42
+ if result[0].boxes.cls[0].cpu().numpy()==4:
43
+ label='person'
44
+ if result[0].boxes.cls[0].cpu().numpy()==5:
45
+ label='train'
46
+ if result[0].boxes.cls[0].cpu().numpy()==6:
47
+ label='truck'
48
+ ann.box_label(box=box, label=label, color=(0,128,0))
49
+ vis=ann.result()
50
  else:
51
  vis = img
52
  return vis