File size: 30,578 Bytes
639aec2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
# -*- coding: utf-8 -*-
"""Good.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1AkM6wLyspo4q2ScK_pIkTAFDV-Q-JCgh
"""

!pip install torchinfo

files.upload()

from google.colab import files
import matplotlib.pyplot as plt
import torch
import torchvision

from torch import nn
from torchvision import transforms
from Helperfunction import set_seeds

device = "cuda" if torch.cuda.is_available() else "cpu"
device

# Commented out IPython magic to ensure Python compatibility.
# %%writefile predict.py
# 
# #predict
# 
# 
# """
# Utility functions to make predictions.
# 
# Main reference for code creation: https://www.learnpytorch.io/06_pytorch_transfer_learning/#6-make-predictions-on-images-from-the-test-set
# """
# import torch
# import torchvision
# from torchvision import transforms
# import matplotlib.pyplot as plt
# 
# from typing import List, Tuple
# 
# from PIL import Image
# 
# # Set device
# device = "cuda" if torch.cuda.is_available() else "cpu"
# 
# # Predict on a target image with a target model
# # Function created in: https://www.learnpytorch.io/06_pytorch_transfer_learning/#6-make-predictions-on-images-from-the-test-set
# def pred_and_plot_image(
#     model: torch.nn.Module,
#     class_names: List[str],
#     image_path: str,
#     image_size: Tuple[int, int] = (224, 224),
#     transform: torchvision.transforms = None,
#     device: torch.device = device,
# ):
#     """Predicts on a target image with a target model.
# 
#     Args:
#         model (torch.nn.Module): A trained (or untrained) PyTorch model to predict on an image.
#         class_names (List[str]): A list of target classes to map predictions to.
#         image_path (str): Filepath to target image to predict on.
#         image_size (Tuple[int, int], optional): Size to transform target image to. Defaults to (224, 224).
#         transform (torchvision.transforms, optional): Transform to perform on image. Defaults to None which uses ImageNet normalization.
#         device (torch.device, optional): Target device to perform prediction on. Defaults to device.
#     """
# 
#     # Open image
#     img = Image.open(image_path)
# 
#     # Create transformation for image (if one doesn't exist)
#     if transform is not None:
#         image_transform = transform
#     else:
#         image_transform = transforms.Compose(
#             [
#                 transforms.Resize(image_size),
#                 transforms.ToTensor(),
#                 transforms.Normalize(
#                     mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
#                 ),
#             ]
#         )
# 
#     ### Predict on image ###
# 
#     # Make sure the model is on the target device
#     model.to(device)
# 
#     # Turn on model evaluation mode and inference mode
#     model.eval()
#     with torch.inference_mode():
#         # Transform and add an extra dimension to image (model requires samples in [batch_size, color_channels, height, width])
#         transformed_image = image_transform(img).unsqueeze(dim=0)
# 
#         # Make a prediction on image with an extra dimension and send it to the target device
#         target_image_pred = model(transformed_image.to(device))
# 
#     # Convert logits -> prediction probabilities (using torch.softmax() for multi-class classification)
#     target_image_pred_probs = torch.softmax(target_image_pred, dim=1)
# 
#     # Convert prediction probabilities -> prediction labels
#     target_image_pred_label = torch.argmax(target_image_pred_probs, dim=1)
# 
#     # Plot image with predicted label and probability
#     plt.figure()
#     plt.imshow(img)
#     plt.title(
#         f"Pred: {class_names[target_image_pred_label]} | Prob: {target_image_pred_probs.max():.3f}"
#     )
#     plt.axis(False)
#

from google.colab import drive
drive.mount('/content/drive')

# Commented out IPython magic to ensure Python compatibility.
# %%writefile model_builder.py
# 
# #model_builder
# 
# """
# Contains PyTorch model code to instantiate a TinyVGG model.
# """
# import torch
# from torch import nn
# 
# class TinyVGG(nn.Module):
#     """Creates the TinyVGG architecture.
# 
#     Replicates the TinyVGG architecture from the CNN explainer website in PyTorch.
#     See the original architecture here: https://poloclub.github.io/cnn-explainer/
# 
#     Args:
#     input_shape: An integer indicating number of input channels.
#     hidden_units: An integer indicating number of hidden units between layers.
#     output_shape: An integer indicating number of output units.
#     """
#     def __init__(self, input_shape: int, hidden_units: int, output_shape: int) -> None:
#         super().__init__()
#         self.conv_block_1 = nn.Sequential(
#           nn.Conv2d(in_channels=input_shape,
#                     out_channels=hidden_units,
#                     kernel_size=3,
#                     stride=1,
#                     padding=0),
#           nn.ReLU(),
#           nn.Conv2d(in_channels=hidden_units,
#                     out_channels=hidden_units,
#                     kernel_size=3,
#                     stride=1,
#                     padding=0),
#           nn.ReLU(),
#           nn.MaxPool2d(kernel_size=2,
#                         stride=2)
#         )
#         self.conv_block_2 = nn.Sequential(
#           nn.Conv2d(hidden_units, hidden_units, kernel_size=3, padding=0),
#           nn.ReLU(),
#           nn.Conv2d(hidden_units, hidden_units, kernel_size=3, padding=0),
#           nn.ReLU(),
#           nn.MaxPool2d(2)
#         )
#         self.classifier = nn.Sequential(
#           nn.Flatten(),
#           # Where did this in_features shape come from?
#           # It's because each layer of our network compresses and changes the shape of our inputs data.
#           nn.Linear(in_features=hidden_units*13*13,
#                     out_features=output_shape)
#         )
# 
#     def forward(self, x: torch.Tensor):
#         x = self.conv_block_1(x)
#         x = self.conv_block_2(x)
#         x = self.classifier(x)
#         return x
#         # return self.classifier(self.block_2(self.block_1(x))) # <- leverage the benefits of operator fusion

# Commented out IPython magic to ensure Python compatibility.
# %%writefile utils.py
# 
# #utils.py
# 
# """
# Contains various utility functions for PyTorch model training and saving.
# """
# import torch
# from pathlib import Path
# 
# def save_model(model: torch.nn.Module,
#                target_dir: str,
#                model_name: str):
#     """Saves a PyTorch model to a target directory.
# 
#     Args:
#     model: A target PyTorch model to save.
#     target_dir: A directory for saving the model to.
#     model_name: A filename for the saved model. Should include
#       either ".pth" or ".pt" as the file extension.
# 
#     Example usage:
#     save_model(model=model_0,
#                target_dir="models",
#                model_name="05_going_modular_tingvgg_model.pth")
#     """
#     # Create target directory
#     target_dir_path = Path(target_dir)
#     target_dir_path.mkdir(parents=True,
#                         exist_ok=True)
# 
#     # Create model save path
#     assert model_name.endswith(".pth") or model_name.endswith(".pt"), "model_name should end with '.pt' or '.pth'"
#     model_save_path = target_dir_path / model_name
# 
#     # Save the model state_dict()
#     print(f"[INFO] Saving model to: {model_save_path}")
#     torch.save(obj=model.state_dict(),
#              f=model_save_path)

# Commented out IPython magic to ensure Python compatibility.
# %%writefile engine.py
# #engine.py
# 
# """
# Contains functions for training and testing a PyTorch model.
# """
# import torch
# 
# from tqdm.auto import tqdm
# from typing import Dict, List, Tuple
# 
# def train_step(model: torch.nn.Module,
#                dataloader: torch.utils.data.DataLoader,
#                loss_fn: torch.nn.Module,
#                optimizer: torch.optim.Optimizer,
#                device: torch.device) -> Tuple[float, float]:
#     """Trains a PyTorch model for a single epoch.
# 
#     Turns a target PyTorch model to training mode and then
#     runs through all of the required training steps (forward
#     pass, loss calculation, optimizer step).
# 
#     Args:
#     model: A PyTorch model to be trained.
#     dataloader: A DataLoader instance for the model to be trained on.
#     loss_fn: A PyTorch loss function to minimize.
#     optimizer: A PyTorch optimizer to help minimize the loss function.
#     device: A target device to compute on (e.g. "cuda" or "cpu").
# 
#     Returns:
#     A tuple of training loss and training accuracy metrics.
#     In the form (train_loss, train_accuracy). For example:
# 
#     (0.1112, 0.8743)
#     """
#     # Put model in train mode
#     model.train()
# 
#     # Setup train loss and train accuracy values
#     train_loss, train_acc = 0, 0
# 
#     # Loop through data loader data batches
#     for batch, (X, y) in enumerate(dataloader):
#         # Send data to target device
#         X, y = X.to(device), y.to(device)
# 
#         # 1. Forward pass
#         y_pred = model(X)
# 
#         # 2. Calculate  and accumulate loss
#         loss = loss_fn(y_pred, y)
#         train_loss += loss.item()
# 
#         # 3. Optimizer zero grad
#         optimizer.zero_grad()
# 
#         # 4. Loss backward
#         loss.backward()
# 
#         # 5. Optimizer step
#         optimizer.step()
# 
#         # Calculate and accumulate accuracy metric across all batches
#         y_pred_class = torch.argmax(torch.softmax(y_pred, dim=1), dim=1)
#         train_acc += (y_pred_class == y).sum().item()/len(y_pred)
# 
#     # Adjust metrics to get average loss and accuracy per batch
#     train_loss = train_loss / len(dataloader)
#     train_acc = train_acc / len(dataloader)
#     return train_loss, train_acc
# 
# def test_step(model: torch.nn.Module,
#               dataloader: torch.utils.data.DataLoader,
#               loss_fn: torch.nn.Module,
#               device: torch.device) -> Tuple[float, float]:
#     """Tests a PyTorch model for a single epoch.
# 
#     Turns a target PyTorch model to "eval" mode and then performs
#     a forward pass on a testing dataset.
# 
#     Args:
#     model: A PyTorch model to be tested.
#     dataloader: A DataLoader instance for the model to be tested on.
#     loss_fn: A PyTorch loss function to calculate loss on the test data.
#     device: A target device to compute on (e.g. "cuda" or "cpu").
# 
#     Returns:
#     A tuple of testing loss and testing accuracy metrics.
#     In the form (test_loss, test_accuracy). For example:
# 
#     (0.0223, 0.8985)
#     """
#     # Put model in eval mode
#     model.eval()
# 
#     # Setup test loss and test accuracy values
#     test_loss, test_acc = 0, 0
# 
#     # Turn on inference context manager
#     with torch.inference_mode():
#         # Loop through DataLoader batches
#         for batch, (X, y) in enumerate(dataloader):
#             # Send data to target device
#             X, y = X.to(device), y.to(device)
# 
#             # 1. Forward pass
#             test_pred_logits = model(X)
# 
#             # 2. Calculate and accumulate loss
#             loss = loss_fn(test_pred_logits, y)
#             test_loss += loss.item()
# 
#             # Calculate and accumulate accuracy
#             test_pred_labels = test_pred_logits.argmax(dim=1)
#             test_acc += ((test_pred_labels == y).sum().item()/len(test_pred_labels))
# 
#     # Adjust metrics to get average loss and accuracy per batch
#     test_loss = test_loss / len(dataloader)
#     test_acc = test_acc / len(dataloader)
#     return test_loss, test_acc
# 
# def train(model: torch.nn.Module,
#           train_dataloader: torch.utils.data.DataLoader,
#           test_dataloader: torch.utils.data.DataLoader,
#           optimizer: torch.optim.Optimizer,
#           loss_fn: torch.nn.Module,
#           epochs: int,
#           device: torch.device) -> Dict[str, List]:
#     """Trains and tests a PyTorch model.
# 
#     Passes a target PyTorch models through train_step() and test_step()
#     functions for a number of epochs, training and testing the model
#     in the same epoch loop.
# 
#     Calculates, prints and stores evaluation metrics throughout.
# 
#     Args:
#     model: A PyTorch model to be trained and tested.
#     train_dataloader: A DataLoader instance for the model to be trained on.
#     test_dataloader: A DataLoader instance for the model to be tested on.
#     optimizer: A PyTorch optimizer to help minimize the loss function.
#     loss_fn: A PyTorch loss function to calculate loss on both datasets.
#     epochs: An integer indicating how many epochs to train for.
#     device: A target device to compute on (e.g. "cuda" or "cpu").
# 
#     Returns:
#     A dictionary of training and testing loss as well as training and
#     testing accuracy metrics. Each metric has a value in a list for
#     each epoch.
#     In the form: {train_loss: [...],
#               train_acc: [...],
#               test_loss: [...],
#               test_acc: [...]}
#     For example if training for epochs=2:
#              {train_loss: [2.0616, 1.0537],
#               train_acc: [0.3945, 0.3945],
#               test_loss: [1.2641, 1.5706],
#               test_acc: [0.3400, 0.2973]}
#     """
#     # Create empty results dictionary
#     results = {"train_loss": [],
#                "train_acc": [],
#                "test_loss": [],
#                "test_acc": []
#     }
# 
#     # Make sure model on target device
#     model.to(device)
# 
#     # Loop through training and testing steps for a number of epochs
#     for epoch in tqdm(range(epochs)):
#         train_loss, train_acc = train_step(model=model,
#                                           dataloader=train_dataloader,
#                                           loss_fn=loss_fn,
#                                           optimizer=optimizer,
#                                           device=device)
#         test_loss, test_acc = test_step(model=model,
#           dataloader=test_dataloader,
#           loss_fn=loss_fn,
#           device=device)
# 
#         # Print out what's happening
#         print(
#           f"Epoch: {epoch+1} | "
#           f"train_loss: {train_loss:.4f} | "
#           f"train_acc: {train_acc:.4f} | "
#           f"test_loss: {test_loss:.4f} | "
#           f"test_acc: {test_acc:.4f}"
#         )
# 
#         # Update results dictionary
#         results["train_loss"].append(train_loss)
#         results["train_acc"].append(train_acc)
#         results["test_loss"].append(test_loss)
#         results["test_acc"].append(test_acc)
# 
#     # Return the filled results at the end of the epochs
#     return results

# Commented out IPython magic to ensure Python compatibility.
# %%writefile data_setup.py
# #data_setup.py
# """
# Contains functionality for creating PyTorch DataLoaders for
# image classification data.
# """
# import os
# 
# from torchvision import datasets, transforms
# from torch.utils.data import DataLoader
# 
# NUM_WORKERS = os.cpu_count()
# 
# def create_dataloaders(
#     train_dir: str,
#     test_dir: str,
#     transform: transforms.Compose,
#     batch_size: int,
#     num_workers: int=NUM_WORKERS
# ):
#   """Creates training and testing DataLoaders.
# 
#   Takes in a training directory and testing directory path and turns
#   them into PyTorch Datasets and then into PyTorch DataLoaders.
# 
#   Args:
#     train_dir: Path to training directory.
#     test_dir: Path to testing directory.
#     transform: torchvision transforms to perform on training and testing data.
#     batch_size: Number of samples per batch in each of the DataLoaders.
#     num_workers: An integer for number of workers per DataLoader.
# 
#   Returns:
#     A tuple of (train_dataloader, test_dataloader, class_names).
#     Where class_names is a list of the target classes.
#     Example usage:
#       train_dataloader, test_dataloader, class_names = \
#         = create_dataloaders(train_dir=path/to/train_dir,
#                              test_dir=path/to/test_dir,
#                              transform=some_transform,
#                              batch_size=32,
#                              num_workers=4)
#   """
#   # Use ImageFolder to create dataset(s)
#   train_data = datasets.ImageFolder(train_dir, transform=transform)
#   test_data = datasets.ImageFolder(test_dir, transform=transform)
# 
#   # Get class names
#   class_names = train_data.classes
# 
#   # Turn images into data loaders
#   train_dataloader = DataLoader(
#       train_data,
#       batch_size=batch_size,
#       shuffle=True,
#       num_workers=num_workers,
#       pin_memory=True,
#   )
#   test_dataloader = DataLoader(
#       test_data,
#       batch_size=batch_size,
#       shuffle=False,
#       num_workers=num_workers,
#       pin_memory=True,
#   )
# 
#   return train_dataloader, test_dataloader, class_names

# Commented out IPython magic to ensure Python compatibility.
# %%writefile train.py
# #train.py only in this cell
# 
# """
# Trains a PyTorch image classification model using device-agnostic code.
# """
# 
# import os
# import torch
# #import data_setup, engine, model_builder, utils
# 
# from torchvision import transforms
# 
# # Setup hyperparameters
# NUM_EPOCHS = 5
# BATCH_SIZE = 32
# HIDDEN_UNITS = 10
# LEARNING_RATE = 0.001
# 
# # Setup directories
# train_dir = "data/pizza_steak_sushi/train"
# test_dir = "data/pizza_steak_sushi/test"
# 
# # Setup target device
# device = "cuda" if torch.cuda.is_available() else "cpu"
# 
# # Create transforms
# data_transform = transforms.Compose([
#   transforms.Resize((64, 64)),
#   transforms.ToTensor()
# ])
# 
# # Create DataLoaders with help from data_setup.py
# train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(
#     train_dir=train_dir,
#     test_dir=test_dir,
#     transform=data_transform,
#     batch_size=BATCH_SIZE
# )
# 
# # Create model with help from model_builder.py
# model = model_builder.TinyVGG(
#     input_shape=3,
#     hidden_units=HIDDEN_UNITS,
#     output_shape=len(class_names)
# ).to(device)
# 
# # Set loss and optimizer
# loss_fn = torch.nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(model.parameters(),
#                              lr=LEARNING_RATE)
# 
# # Start training with help from engine.py
# engine.train(model=model,
#              train_dataloader=train_dataloader,
#              test_dataloader=test_dataloader,
#              loss_fn=loss_fn,
#              optimizer=optimizer,
#              epochs=NUM_EPOCHS,
#              device=device)
# 
# # Save the model with help from utils.py
# utils.save_model(model=model,
#                  target_dir="models",
#                  model_name="05_going_modular_script_mode_tinyvgg_model.pth")
# 
# 
# 
#

!python /content/data_setup.py/train.py --batch_size 64 --learning_rate 0.001 --num_epochs 25

# 1. Get pretrained weights for ViT-Base
pretrained_vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT

# 2. Setup a ViT model instance with pretrained weights
pretrained_vit = torchvision.models.vit_b_16(weights=pretrained_vit_weights).to(device)

# 3. Freeze the base parameters
for parameter in pretrained_vit.parameters():
    parameter.requires_grad = False

# 4. Change the classifier head
class_names = ['Bad_tire','Good_tire']

set_seeds()
pretrained_vit.heads = nn.Linear(in_features=768, out_features=len(class_names)).to(device)
# pretrained_vit # uncomment for model output

from torchinfo import summary

# Print a summary using torchinfo (uncomment for actual output)
summary(model=pretrained_vit,
        input_size=(32, 3, 224, 224), # (batch_size, color_channels, height, width)
        #col_names=["input_size"], # uncomment for smaller output
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
)

# Setup directory paths to train and test images
train_dir = '/content/drive/MyDrive/Test/test'
test_dir = '/content/drive/MyDrive/Train/train'

# Get automatic transforms from pretrained ViT weights
pretrained_vit_transforms = pretrained_vit_weights.transforms()
print(pretrained_vit_transforms)

import os

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

NUM_WORKERS = os.cpu_count()

def create_dataloaders(
    train_dir: str,
    test_dir: str,
    transform: transforms.Compose,
    batch_size: int,
    num_workers: int=NUM_WORKERS
):

  # Use ImageFolder to create dataset(s)
  train_data = datasets.ImageFolder(train_dir, transform=transform)
  test_data = datasets.ImageFolder(test_dir, transform=transform)

  # Get class names
  class_names = train_data.classes

  # Turn images into data loaders
  train_dataloader = DataLoader(
      train_data,
      batch_size=batch_size,
      shuffle=True,
      num_workers=num_workers,
      pin_memory=True,
  )
  test_dataloader = DataLoader(
      test_data,
      batch_size=batch_size,
      shuffle=False,
      num_workers=num_workers,
      pin_memory=True,
  )

  return train_dataloader, test_dataloader, class_names

# Setup dataloaders
train_dataloader_pretrained, test_dataloader_pretrained, class_names = create_dataloaders(
                                                                                            train_dir=train_dir,
                                                                                            test_dir=test_dir,
                                                                                            transform=pretrained_vit_transforms,
                                                                                            batch_size=32) # Could increase if we had more samples, such as here: https://arxiv.org/abs/2205.01580 (there are other improvements there too...)

#import data_setup.py

!python train.py --batch_size 64 --learning_rate 0.001 --num_epochs 25

import engine

# Create optimizer and loss function
optimizer = torch.optim.Adam(params=pretrained_vit.parameters(),
                             lr=1e-3)
loss_fn = torch.nn.CrossEntropyLoss()

# Train the classifier head of the pretrained ViT feature extractor model
set_seeds()
pretrained_vit_results = engine.train(model=pretrained_vit,
                                      train_dataloader=train_dataloader_pretrained,
                                      test_dataloader=test_dataloader_pretrained,
                                      optimizer=optimizer,
                                      loss_fn=loss_fn,
                                      epochs=10,
                                      device=device)

# Commented out IPython magic to ensure Python compatibility.
# %%writefile helper_functions.py
# 
# # helper_functions.py
# 
# """
# A series of helper functions used throughout the course.
# 
# If a function gets defined once and could be used over and over, it'll go in here.
# """
# import torch
# import matplotlib.pyplot as plt
# import numpy as np
# 
# from torch import nn
# import os
# import zipfile
# from pathlib import Path
# import requests
# import os
# 
# 
# 
# # Plot linear data or training and test and predictions (optional)
# def plot_predictions(
#     train_data, train_labels, test_data, test_labels, predictions=None
# ):
#     """
#   Plots linear training data and test data and compares predictions.
#   """
#     plt.figure(figsize=(10, 7))
# 
#     # Plot training data in blue
#     plt.scatter(train_data, train_labels, c="b", s=4, label="Training data")
# 
#     # Plot test data in green
#     plt.scatter(test_data, test_labels, c="g", s=4, label="Testing data")
# 
#     if predictions is not None:
#         # Plot the predictions in red (predictions were made on the test data)
#         plt.scatter(test_data, predictions, c="r", s=4, label="Predictions")
# 
#     # Show the legend
#     plt.legend(prop={"size": 14})
# 
# 
# # Calculate accuracy (a classification metric)
# def accuracy_fn(y_true, y_pred):
#     """Calculates accuracy between truth labels and predictions.
# 
#     Args:
#         y_true (torch.Tensor): Truth labels for predictions.
#         y_pred (torch.Tensor): Predictions to be compared to predictions.
# 
#     Returns:
#         [torch.float]: Accuracy value between y_true and y_pred, e.g. 78.45
#     """
#     correct = torch.eq(y_true, y_pred).sum().item()
#     acc = (correct / len(y_pred)) * 100
#     return acc
# 
# 
# def print_train_time(start, end, device=None):
#     """Prints difference between start and end time.
# 
#     Args:
#         start (float): Start time of computation (preferred in timeit format).
#         end (float): End time of computation.
#         device ([type], optional): Device that compute is running on. Defaults to None.
# 
#     Returns:
#         float: time between start and end in seconds (higher is longer).
#     """
#     total_time = end - start
#     print(f"\nTrain time on {device}: {total_time:.3f} seconds")
#     return total_time
# 
# 
# # Plot loss curves of a model
# def plot_loss_curves(results):
#     """Plots training curves of a results dictionary.
# 
#     Args:
#         results (dict): dictionary containing list of values, e.g.
#             {"train_loss": [...],
#              "train_acc": [...],
#              "test_loss": [...],
#              "test_acc": [...]}
#     """
#     loss = results["train_loss"]
#     test_loss = results["test_loss"]
# 
#     accuracy = results["train_acc"]
#     test_accuracy = results["test_acc"]
# 
#     epochs = range(len(results["train_loss"]))
# 
#     plt.figure(figsize=(15, 7))
# 
#     # Plot loss
#     plt.subplot(1, 2, 1)
#     plt.plot(epochs, loss, label="train_loss")
#     plt.plot(epochs, test_loss, label="test_loss")
#     plt.title("Loss")
#     plt.xlabel("Epochs")
#     plt.legend()
# 
#     # Plot accuracy
#     plt.subplot(1, 2, 2)
#     plt.plot(epochs, accuracy, label="train_accuracy")
#     plt.plot(epochs, test_accuracy, label="test_accuracy")
#     plt.title("Accuracy")
#     plt.xlabel("Epochs")
#     plt.legend()
# 
# 
# # Pred and plot image function from notebook 04
# # See creation: https://www.learnpytorch.io/04_pytorch_custom_datasets/#113-putting-custom-image-prediction-together-building-a-function
# from typing import List
# import torchvision
# 
# 
# def pred_and_plot_image(
#     model: torch.nn.Module,
#     image_path: str,
#     class_names: List[str] = None,
#     transform=None,
#     device: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
# ):
#     """Makes a prediction on a target image with a trained model and plots the image.
# 
#     Args:
#         model (torch.nn.Module): trained PyTorch image classification model.
#         image_path (str): filepath to target image.
#         class_names (List[str], optional): different class names for target image. Defaults to None.
#         transform (_type_, optional): transform of target image. Defaults to None.
#         device (torch.device, optional): target device to compute on. Defaults to "cuda" if torch.cuda.is_available() else "cpu".
# 
#     Returns:
#         Matplotlib plot of target image and model prediction as title.
# 
#     Example usage:
#         pred_and_plot_image(model=model,
#                             image="some_image.jpeg",
#                             class_names=["class_1", "class_2", "class_3"],
#                             transform=torchvision.transforms.ToTensor(),
#                             device=device)
#     """
# 
#     # 1. Load in image and convert the tensor values to float32
#     target_image = torchvision.io.read_image(str(image_path)).type(torch.float32)
# 
#     # 2. Divide the image pixel values by 255 to get them between [0, 1]
#     target_image = target_image / 255.0
# 
#     # 3. Transform if necessary
#     if transform:
#         target_image = transform(target_image)
# 
#     # 4. Make sure the model is on the target device
#     model.to(device)
# 
#     # 5. Turn on model evaluation mode and inference mode
#     model.eval()
#     with torch.inference_mode():
#         # Add an extra dimension to the image
#         target_image = target_image.unsqueeze(dim=0)
# 
#         # Make a prediction on image with an extra dimension and send it to the target device
#         target_image_pred = model(target_image.to(device))
# 
#     # 6. Convert logits -> prediction probabilities (using torch.softmax() for multi-class classification)
#     target_image_pred_probs = torch.softmax(target_image_pred, dim=1)
# 
#     # 7. Convert prediction probabilities -> prediction labels
#     target_image_pred_label = torch.argmax(target_image_pred_probs, dim=1)
# 
#     # 8. Plot the image alongside the prediction and prediction probability
#     plt.imshow(
#         target_image.squeeze().permute(1, 2, 0)
#     )  # make sure it's the right size for matplotlib
#     if class_names:
#         title = f"Pred: {class_names[target_image_pred_label.cpu()]} | Prob: {target_image_pred_probs.max().cpu():.3f}"
#     else:
#         title = f"Pred: {target_image_pred_label} | Prob: {target_image_pred_probs.max().cpu():.3f}"
#     plt.title(title)
#     plt.axis(False)
# 
# def set_seeds(seed: int=42):
#     """Sets random sets for torch operations.
# 
#     Args:
#         seed (int, optional): Random seed to set. Defaults to 42.
#     """
#     # Set the seed for general torch operations
#     torch.manual_seed(seed)
#     # Set the seed for CUDA torch operations (ones that happen on the GPU)
#     torch.cuda.manual_seed(seed)
#

# Plot the loss curves
from helper_functions import plot_loss_curves

plot_loss_curves(pretrained_vit_results)

import requests

# Import function to make predictions on images and plot them
from predict import pred_and_plot_image

# Setup custom image path
custom_image_path = "/content/drive/MyDrive/validation/Bad_Tire (3).jpg"

# Predict on custom image
pred_and_plot_image(model=pretrained_vit,
                    image_path=custom_image_path,
                    class_names=class_names)

# Import function to make predictions on images and plot them
from predict import pred_and_plot_image

# Setup custom image path
custom_image_path = "/content/drive/MyDrive/validation/Good_Tire (4).jpg"

# Predict on custom image
pred_and_plot_image(model=pretrained_vit,
                    image_path=custom_image_path,
                    class_names=class_names)