Tahir5 commited on
Commit
f34f75c
·
verified ·
1 Parent(s): 5c08674

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -0
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import matplotlib.pyplot as plt
3
+ import torch
4
+ import numpy as np
5
+ from transformers import DPTImageProcessor, DPTForDepthEstimation
6
+ from PIL import Image
7
+ import requests
8
+
9
+ # Load model and processor
10
+ st.title("Depth Estimation using DPT")
11
+ st.write("Upload an image to estimate its depth map.")
12
+
13
+ @st.cache_resource
14
+ def load_model():
15
+ processor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
16
+ model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large")
17
+ return processor, model
18
+
19
+ processor, model = load_model()
20
+
21
+ # File uploader
22
+ uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
23
+
24
+ if uploaded_file is not None:
25
+ image = Image.open(uploaded_file)
26
+ st.image(image, caption="Uploaded Image", use_column_width=True)
27
+
28
+ # Process image
29
+ inputs = processor(images=image, return_tensors="pt")
30
+ with torch.no_grad():
31
+ outputs = model(**inputs)
32
+ predicted_depth = outputs.predicted_depth
33
+
34
+ # Interpolate to original size
35
+ prediction = torch.nn.functional.interpolate(
36
+ predicted_depth.unsqueeze(1),
37
+ size=image.size[::-1],
38
+ mode="bicubic",
39
+ align_corners=False,
40
+ )
41
+
42
+ # Convert to NumPy array
43
+ output = prediction.squeeze().cpu().numpy()
44
+ normalized_depth = (output - output.min()) / (output.max() - output.min()) # Normalize to [0, 1]
45
+
46
+ # Plot the results
47
+ fig, ax = plt.subplots(1, 2, figsize=(12, 6))
48
+ ax[0].imshow(image)
49
+ ax[0].set_title("Original Image")
50
+ ax[0].axis("off")
51
+ ax[1].imshow(normalized_depth, cmap="inferno")
52
+ ax[1].set_title("Predicted Depth Map")
53
+ ax[1].axis("off")
54
+
55
+ # Display result
56
+ st.pyplot(fig)