ymcmy commited on
Commit
77e830c
·
verified ·
1 Parent(s): 845a877

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +263 -0
train.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # remember to run preprocess.py before training
2
+ # preprocess while training is not as effecient
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torch.nn import MultiheadAttention
8
+ import torch.optim as optim
9
+ from torch.utils.data import Dataset, DataLoader, random_split
10
+ import json
11
+ import time
12
+ import os
13
+ import h5py
14
+ import numpy as np
15
+ from tqdm import tqdm
16
+
17
+ class AttentionBlock(nn.Module):
18
+ def __init__(self, input_dim, num_heads, key_dim, ff_dim, rate=0.1):
19
+ super(AttentionBlock, self).__init__()
20
+ self.multihead_attn = MultiheadAttention(embed_dim=input_dim, num_heads=num_heads)
21
+ self.dropout1 = nn.Dropout(rate)
22
+ self.layer_norm1 = nn.LayerNorm(input_dim, eps=1e-6)
23
+
24
+ self.ffn = nn.Sequential(
25
+ nn.Linear(input_dim, ff_dim),
26
+ nn.ReLU(),
27
+ nn.Dropout(rate),
28
+ nn.Linear(ff_dim, input_dim),
29
+ nn.Dropout(rate)
30
+ )
31
+ self.layer_norm2 = nn.LayerNorm(input_dim, eps=1e-6)
32
+
33
+ def forward(self, x):
34
+ attn_output, _ = self.multihead_attn(x, x, x)
35
+ attn_output = self.dropout1(attn_output)
36
+ out1 = self.layer_norm1(x + attn_output)
37
+
38
+ ffn_output = self.ffn(out1)
39
+ out2 = self.layer_norm2(out1 + ffn_output)
40
+
41
+ return out2
42
+
43
+ class TextureContrastClassifier(nn.Module):
44
+ def __init__(self, input_shape, num_heads=4, key_dim=64, ff_dim=256, rate=0.5):
45
+ super(TextureContrastClassifier, self).__init__()
46
+ input_dim = input_shape[1] # assuming the input shape is (seq_len, feature_dim)
47
+
48
+ self.rich_texture_attention = AttentionBlock(input_dim, num_heads, key_dim, ff_dim, rate)
49
+ self.poor_texture_attention = AttentionBlock(input_dim, num_heads, key_dim, ff_dim, rate)
50
+
51
+ self.rich_texture_dense = nn.Sequential(
52
+ nn.Linear(input_dim, 128),
53
+ nn.ReLU(),
54
+ nn.Dropout(rate)
55
+ )
56
+
57
+ self.poor_texture_dense = nn.Sequential(
58
+ nn.Linear(input_dim, 128),
59
+ nn.ReLU(),
60
+ nn.Dropout(rate)
61
+ )
62
+
63
+ self.fc = nn.Sequential(
64
+ nn.Flatten(),
65
+ nn.Linear(input_shape[0] * 128, 256),
66
+ nn.ReLU(),
67
+ nn.Dropout(rate),
68
+ nn.Linear(256, 128),
69
+ nn.ReLU(),
70
+ nn.Dropout(rate),
71
+ nn.Linear(128, 64),
72
+ nn.ReLU(),
73
+ nn.Dropout(rate),
74
+ nn.Linear(64, 32),
75
+ nn.ReLU(),
76
+ nn.Dropout(rate),
77
+ nn.Linear(32, 16),
78
+ nn.ReLU(),
79
+ nn.Dropout(rate),
80
+ nn.Linear(16, 1),
81
+ nn.Sigmoid()
82
+ )
83
+
84
+ def forward(self, rich_texture, poor_texture):
85
+ rich_texture = self.rich_texture_attention(rich_texture)
86
+ rich_texture = self.rich_texture_dense(rich_texture)
87
+
88
+ poor_texture = self.poor_texture_attention(poor_texture)
89
+ poor_texture = self.poor_texture_dense(poor_texture)
90
+
91
+ difference = rich_texture - poor_texture
92
+ output = self.fc(difference)
93
+
94
+ return output
95
+
96
+ import os
97
+ import h5py
98
+ import numpy as np
99
+ from tqdm import tqdm
100
+
101
+ def load_and_split_data(h5_dir, train_ratio=0.8,max_num=40):
102
+ train_rich, train_poor, train_labels = [], [], []
103
+ test_rich, test_poor, test_labels = [], [], []
104
+
105
+ for file_name in tqdm(os.listdir(h5_dir)[:60]):
106
+ if file_name.endswith('.h5'):
107
+ file_path = os.path.join(h5_dir, file_name)
108
+ try:
109
+ with h5py.File(file_path, 'r') as h5f:
110
+ rich = h5f['rich'][:]
111
+ poor = h5f['poor'][:]
112
+ labels = h5f['labels'][:]
113
+
114
+ dataset_size = len(labels)
115
+ train_size = int(train_ratio * dataset_size)
116
+ indices = np.random.permutation(dataset_size)
117
+ train_indices = indices[:train_size]
118
+ test_indices = indices[train_size:]
119
+
120
+ train_rich.append(rich[train_indices])
121
+ train_poor.append(poor[train_indices])
122
+ train_labels.append(labels[train_indices])
123
+
124
+ test_rich.append(rich[test_indices])
125
+ test_poor.append(poor[test_indices])
126
+ test_labels.append(labels[test_indices])
127
+
128
+ except Exception as e:
129
+ print(f"Error processing {file_name}: {e}")
130
+
131
+ train_rich = np.concatenate(train_rich, axis=0)
132
+ train_poor = np.concatenate(train_poor, axis=0)
133
+ train_labels = np.concatenate(train_labels, axis=0)
134
+
135
+ test_rich = np.concatenate(test_rich, axis=0)
136
+ test_poor = np.concatenate(test_poor, axis=0)
137
+ test_labels = np.concatenate(test_labels, axis=0)
138
+
139
+ return train_rich, train_poor, train_labels, test_rich, test_poor, test_labels
140
+
141
+ class TextureDataset(Dataset):
142
+ def __init__(self, rich, poor, labels):
143
+ self.rich = rich
144
+ self.poor = poor
145
+ self.labels = labels
146
+
147
+ def __len__(self):
148
+ return len(self.labels)
149
+
150
+ def __getitem__(self, idx):
151
+ rich = torch.tensor(self.rich[idx], dtype=torch.float32)
152
+ poor = torch.tensor(self.poor[idx], dtype=torch.float32)
153
+ label = torch.tensor(self.labels[idx], dtype=torch.float32)
154
+ return rich, poor, label
155
+
156
+ def validate(model, test_loader, criterion, device):
157
+ model.eval()
158
+ val_loss = 0.0
159
+ correct = 0
160
+ total = 0
161
+
162
+ with torch.no_grad():
163
+ for rich, poor, labels in test_loader:
164
+ rich, poor, labels = rich.to(device), poor.to(device), labels.to(device)
165
+
166
+ outputs = model(rich, poor)
167
+ outputs = outputs.squeeze()
168
+
169
+ loss = criterion(outputs, labels)
170
+ val_loss += loss.item()
171
+
172
+ predicted = (outputs > 0.5).float()
173
+ total += labels.size(0)
174
+ correct += (predicted == labels).sum().item()
175
+
176
+ val_loss /= len(test_loader)
177
+ val_accuracy = correct / total
178
+ print(f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}')
179
+ return val_loss, val_accuracy
180
+
181
+
182
+
183
+ h5_dir = '/content/drive/MyDrive/h5saves'
184
+ train_rich, train_poor, train_labels, test_rich, test_poor, test_labels = load_and_split_data(h5_dir, train_ratio=0.8)
185
+ print(f"Training data: {len(train_labels)} samples")
186
+ print(f"Testing data: {len(test_labels)} samples")
187
+ train_dataset = TextureDataset(train_rich, train_poor, train_labels)
188
+ test_dataset = TextureDataset(test_rich, test_poor, test_labels)
189
+ batch_size = 2048
190
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
191
+ test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
192
+
193
+ input_shape = (128, 256)
194
+ model = TextureContrastClassifier(input_shape)
195
+ criterion = nn.BCELoss()
196
+ optimizer = optim.Adam(model.parameters(), lr=0.0001)
197
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)
198
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
199
+ model.to(device)
200
+
201
+ history = {'train_loss': [], 'val_loss': [], 'train_accuracy':[], 'val_accuracy': []}
202
+ save_dir = '/content/drive/MyDrive/model_checkpoints'
203
+ if not os.path.exists(save_dir):
204
+ os.makedirs(save_dir)
205
+ num_epochs = 100
206
+
207
+
208
+
209
+ for epoch in range(num_epochs):
210
+ model.train()
211
+ running_loss = 0.0
212
+ correct = 0
213
+ total = 0
214
+
215
+ batch_loss = 0.0
216
+
217
+ for batch_idx, (rich, poor, labels) in enumerate(train_loader):
218
+ rich, poor, labels = rich.to(device), poor.to(device), labels.to(device)
219
+
220
+ optimizer.zero_grad()
221
+
222
+ outputs = model(rich, poor)
223
+ outputs = outputs.squeeze()
224
+
225
+ loss = criterion(outputs, labels)
226
+ loss.backward()
227
+ optimizer.step()
228
+
229
+ running_loss += loss.item()
230
+ batch_loss += loss.item()
231
+
232
+ predicted = (outputs > 0.5).float()
233
+ total += labels.size(0)
234
+ correct += (predicted == labels).sum().item()
235
+
236
+ if (batch_idx + 1) % 5 == 0:
237
+ print(f'\rEpoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}], Loss: {batch_loss / 5:.4f}, Accuracy: {correct / total:.2f}', end='')
238
+ batch_loss = 0.0
239
+
240
+ avg_train_loss = running_loss / len(train_loader)
241
+ train_accuracy = correct / total
242
+
243
+ val_loss, val_accuracy = validate(model, test_loader, criterion, device)
244
+
245
+ history['train_loss'].append(avg_train_loss)
246
+ history['val_loss'].append(val_loss)
247
+ history['val_accuracy'].append(val_accuracy)
248
+ history['train_accuracy'].append(train_accuracy)
249
+
250
+ scheduler.step(val_loss)
251
+
252
+ checkpoint_path = os.path.join(save_dir, f'model_epoch_{epoch+1}.pth')
253
+ torch.save(model.state_dict(), checkpoint_path)
254
+ print(f'\nModel checkpoint saved for epoch {epoch+1}')
255
+
256
+ print(f'Epoch [{epoch+1}/{num_epochs:.4f}], Training Loss: {avg_train_loss:.4f}, Training Accuracy: {train_accuracy:.4f} Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}')
257
+
258
+ history_path = os.path.join(save_dir, 'training_history.json')
259
+ with open(history_path, 'w') as f:
260
+ json.dump(history, f)
261
+
262
+ print('Finished Training')
263
+ print(f'Training history saved at {history_path}')