Pavan2k4 commited on
Commit
950ad22
·
verified ·
1 Parent(s): f555a0d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -4
app.py CHANGED
@@ -9,8 +9,7 @@ import cv2
9
  import numpy as np
10
  from PIL import Image
11
  import torch
12
-
13
- # Adjust import paths as needed
14
  sys.path.append('Utils')
15
  sys.path.append('model')
16
  from model.CBAM.reunet_cbam import reunet_cbam
@@ -21,7 +20,7 @@ from split_merge import split, merge
21
  from Utils.convert import read_pansharpened_rgb
22
 
23
  # Define base directory for Hugging Face Spaces
24
- BASE_DIR = "/home/user"
25
 
26
  # Define subdirectories
27
  UPLOAD_DIR = os.path.join(BASE_DIR, "uploaded_images")
@@ -34,6 +33,18 @@ CSV_LOG_PATH = os.path.join(BASE_DIR, "image_log.csv")
34
  for directory in [UPLOAD_DIR, MASK_DIR, PATCHES_DIR, PRED_PATCHES_DIR]:
35
  os.makedirs(directory, exist_ok=True)
36
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  # Load model
38
  @st.cache_resource
39
  def load_model():
@@ -49,6 +60,18 @@ def predict(image):
49
  output = model(image.unsqueeze(0))
50
  return output.squeeze().cpu().numpy()
51
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def log_image_details(image_id, image_filename, mask_filename):
53
  file_exists = os.path.exists(CSV_LOG_PATH)
54
 
@@ -70,6 +93,9 @@ def log_image_details(image_id, image_filename, mask_filename):
70
  sno = 1
71
 
72
  writer.writerow([sno, date, time, image_id, image_filename, mask_filename])
 
 
 
73
 
74
  def upload_page():
75
  if 'file_uploaded' not in st.session_state:
@@ -106,6 +132,9 @@ def upload_page():
106
 
107
  st.success(f"Image saved to {filepath}")
108
 
 
 
 
109
  # Check if the uploaded file is a GeoTIFF
110
  if file_extension in ['.tiff', '.tif']:
111
  st.info('Processing GeoTIFF image...')
@@ -169,6 +198,13 @@ def upload_page():
169
  Image.fromarray(mask).save(mask_filepath)
170
  st.session_state.mask_filename = mask_filename
171
 
 
 
 
 
 
 
 
172
  st.session_state.file_uploaded = True
173
 
174
  except Exception as e:
@@ -222,7 +258,7 @@ def result_page():
222
  original_np = cv2.imread(original_img_path)
223
  mask_np = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
224
 
225
- # Ensure mask is binary
226
  _, mask_np = cv2.threshold(mask_np, 127, 255, cv2.THRESH_BINARY)
227
 
228
  # Resize mask to match original image size if necessary
 
9
  import numpy as np
10
  from PIL import Image
11
  import torch
12
+ from huggingface_hub import HfApi
 
13
  sys.path.append('Utils')
14
  sys.path.append('model')
15
  from model.CBAM.reunet_cbam import reunet_cbam
 
20
  from Utils.convert import read_pansharpened_rgb
21
 
22
  # Define base directory for Hugging Face Spaces
23
+ BASE_DIR = "data"
24
 
25
  # Define subdirectories
26
  UPLOAD_DIR = os.path.join(BASE_DIR, "uploaded_images")
 
33
  for directory in [UPLOAD_DIR, MASK_DIR, PATCHES_DIR, PRED_PATCHES_DIR]:
34
  os.makedirs(directory, exist_ok=True)
35
 
36
+ # Initialize Hugging Face API
37
+ hf_api = HfApi()
38
+
39
+ # Get the token from environment variable
40
+ HF_TOKEN = os.environ.get("HF_TOKEN")
41
+ if not HF_TOKEN:
42
+ st.error("HF_TOKEN not found in environment variables. Please set it in your Space settings.")
43
+ st.stop()
44
+
45
+
46
+ REPO_ID = "Pavan2k4/Building_area"
47
+
48
  # Load model
49
  @st.cache_resource
50
  def load_model():
 
60
  output = model(image.unsqueeze(0))
61
  return output.squeeze().cpu().numpy()
62
 
63
+ def save_to_hf_repo(local_path, repo_path):
64
+ try:
65
+ hf_api.upload_file(
66
+ path_or_fileobj=local_path,
67
+ path_in_repo=repo_path,
68
+ repo_id=REPO_ID,
69
+ token=HF_TOKEN
70
+ )
71
+ st.success(f"File uploaded successfully to {repo_path}")
72
+ except Exception as e:
73
+ st.error(f"Error uploading file: {str(e)}")
74
+
75
  def log_image_details(image_id, image_filename, mask_filename):
76
  file_exists = os.path.exists(CSV_LOG_PATH)
77
 
 
93
  sno = 1
94
 
95
  writer.writerow([sno, date, time, image_id, image_filename, mask_filename])
96
+
97
+ # Save CSV to Hugging Face repo
98
+ save_to_hf_repo(CSV_LOG_PATH, 'image_log.csv')
99
 
100
  def upload_page():
101
  if 'file_uploaded' not in st.session_state:
 
132
 
133
  st.success(f"Image saved to {filepath}")
134
 
135
+ # Save image to Hugging Face repo
136
+ save_to_hf_repo(filepath, f'uploaded_images/{filename}')
137
+
138
  # Check if the uploaded file is a GeoTIFF
139
  if file_extension in ['.tiff', '.tif']:
140
  st.info('Processing GeoTIFF image...')
 
198
  Image.fromarray(mask).save(mask_filepath)
199
  st.session_state.mask_filename = mask_filename
200
 
201
+ # Save mask to Hugging Face repo
202
+ mask_filepath = os.path.join(MASK_DIR, st.session_state.mask_filename)
203
+ save_to_hf_repo(mask_filepath, f'generated_masks/{st.session_state.mask_filename}')
204
+
205
+ # Log image details
206
+ log_image_details(timestamp, converted_filename, st.session_state.mask_filename)
207
+
208
  st.session_state.file_uploaded = True
209
 
210
  except Exception as e:
 
258
  original_np = cv2.imread(original_img_path)
259
  mask_np = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
260
 
261
+ # mask is binary
262
  _, mask_np = cv2.threshold(mask_np, 127, 255, cv2.THRESH_BINARY)
263
 
264
  # Resize mask to match original image size if necessary