Spaces:
Sleeping
Sleeping
Alex Hortua
commited on
Commit
·
9a6ea32
1
Parent(s):
2add545
Add min configuration for this.
Browse files- config.yaml +4 -4
- src/train.py +3 -1
config.yaml
CHANGED
@@ -6,8 +6,8 @@ model:
|
|
6 |
epochs: 10
|
7 |
batch_size: 8
|
8 |
optimizer: adam
|
9 |
-
image_sample_size:
|
10 |
-
trainable_backbone_layers:
|
11 |
|
12 |
dataset:
|
13 |
image_dir: datasets/images
|
@@ -16,10 +16,10 @@ dataset:
|
|
16 |
val_split: 0.2
|
17 |
|
18 |
notebooks:
|
19 |
-
visualization:
|
20 |
|
21 |
logs:
|
22 |
-
log_dir:
|
23 |
log_interval: 10
|
24 |
|
25 |
evaluation:
|
|
|
6 |
epochs: 10
|
7 |
batch_size: 8
|
8 |
optimizer: adam
|
9 |
+
image_sample_size: 10000
|
10 |
+
trainable_backbone_layers: 1
|
11 |
|
12 |
dataset:
|
13 |
image_dir: datasets/images
|
|
|
16 |
val_split: 0.2
|
17 |
|
18 |
notebooks:
|
19 |
+
visualization: logs/training_log.json
|
20 |
|
21 |
logs:
|
22 |
+
log_dir: logs/training_log.txt
|
23 |
log_interval: 10
|
24 |
|
25 |
evaluation:
|
src/train.py
CHANGED
@@ -59,6 +59,8 @@ def train_one_epoch(model, optimizer, data_loader, device):
|
|
59 |
|
60 |
loss_dict = model(images, targets)
|
61 |
loss = sum(loss for loss in loss_dict.values())
|
|
|
|
|
62 |
|
63 |
optimizer.zero_grad()
|
64 |
loss.backward()
|
@@ -70,7 +72,7 @@ def train_one_epoch(model, optimizer, data_loader, device):
|
|
70 |
|
71 |
def store_log(loss, mAP):
|
72 |
# Save log in JSON for visualization
|
73 |
-
log_json = config["
|
74 |
if not os.path.exists(log_json):
|
75 |
log_data = {"loss": [], "mAP": []}
|
76 |
else:
|
|
|
59 |
|
60 |
loss_dict = model(images, targets)
|
61 |
loss = sum(loss for loss in loss_dict.values())
|
62 |
+
if(batch_idx % 100 == 0):
|
63 |
+
log_message(f"Iteration {batch_idx+1}, Loss: {loss.item()}")
|
64 |
|
65 |
optimizer.zero_grad()
|
66 |
loss.backward()
|
|
|
72 |
|
73 |
def store_log(loss, mAP):
|
74 |
# Save log in JSON for visualization
|
75 |
+
log_json = config["notebooks"]["visualization"]
|
76 |
if not os.path.exists(log_json):
|
77 |
log_data = {"loss": [], "mAP": []}
|
78 |
else:
|