# remember to run preprocess.py before training # preprocess while training is not as effecient import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import MultiheadAttention import torch.optim as optim from torch.utils.data import Dataset, DataLoader, random_split import json import time import os import h5py import numpy as np from tqdm import tqdm class AttentionBlock(nn.Module): def __init__(self, input_dim, num_heads, key_dim, ff_dim, rate=0.1): super(AttentionBlock, self).__init__() self.multihead_attn = MultiheadAttention(embed_dim=input_dim, num_heads=num_heads) self.dropout1 = nn.Dropout(rate) self.layer_norm1 = nn.LayerNorm(input_dim, eps=1e-6) self.ffn = nn.Sequential( nn.Linear(input_dim, ff_dim), nn.ReLU(), nn.Dropout(rate), nn.Linear(ff_dim, input_dim), nn.Dropout(rate) ) self.layer_norm2 = nn.LayerNorm(input_dim, eps=1e-6) def forward(self, x): attn_output, _ = self.multihead_attn(x, x, x) attn_output = self.dropout1(attn_output) out1 = self.layer_norm1(x + attn_output) ffn_output = self.ffn(out1) out2 = self.layer_norm2(out1 + ffn_output) return out2 class TextureContrastClassifier(nn.Module): def __init__(self, input_shape, num_heads=4, key_dim=64, ff_dim=256, rate=0.5): super(TextureContrastClassifier, self).__init__() input_dim = input_shape[1] # assuming the input shape is (seq_len, feature_dim) self.rich_texture_attention = AttentionBlock(input_dim, num_heads, key_dim, ff_dim, rate) self.poor_texture_attention = AttentionBlock(input_dim, num_heads, key_dim, ff_dim, rate) self.rich_texture_dense = nn.Sequential( nn.Linear(input_dim, 128), nn.ReLU(), nn.Dropout(rate) ) self.poor_texture_dense = nn.Sequential( nn.Linear(input_dim, 128), nn.ReLU(), nn.Dropout(rate) ) self.fc = nn.Sequential( nn.Flatten(), nn.Linear(input_shape[0] * 128, 256), nn.ReLU(), nn.Dropout(rate), nn.Linear(256, 128), nn.ReLU(), nn.Dropout(rate), nn.Linear(128, 64), nn.ReLU(), nn.Dropout(rate), nn.Linear(64, 32), nn.ReLU(), nn.Dropout(rate), nn.Linear(32, 16), nn.ReLU(), nn.Dropout(rate), nn.Linear(16, 1), nn.Sigmoid() ) def forward(self, rich_texture, poor_texture): rich_texture = self.rich_texture_attention(rich_texture) rich_texture = self.rich_texture_dense(rich_texture) poor_texture = self.poor_texture_attention(poor_texture) poor_texture = self.poor_texture_dense(poor_texture) difference = rich_texture - poor_texture output = self.fc(difference) return output import os import h5py import numpy as np from tqdm import tqdm def load_and_split_data(h5_dir, train_ratio=0.8,max_num=40): train_rich, train_poor, train_labels = [], [], [] test_rich, test_poor, test_labels = [], [], [] for file_name in tqdm(os.listdir(h5_dir)[:60]): if file_name.endswith('.h5'): file_path = os.path.join(h5_dir, file_name) try: with h5py.File(file_path, 'r') as h5f: rich = h5f['rich'][:] poor = h5f['poor'][:] labels = h5f['labels'][:] dataset_size = len(labels) train_size = int(train_ratio * dataset_size) indices = np.random.permutation(dataset_size) train_indices = indices[:train_size] test_indices = indices[train_size:] train_rich.append(rich[train_indices]) train_poor.append(poor[train_indices]) train_labels.append(labels[train_indices]) test_rich.append(rich[test_indices]) test_poor.append(poor[test_indices]) test_labels.append(labels[test_indices]) except Exception as e: print(f"Error processing {file_name}: {e}") train_rich = np.concatenate(train_rich, axis=0) train_poor = np.concatenate(train_poor, axis=0) train_labels = np.concatenate(train_labels, axis=0) test_rich = np.concatenate(test_rich, axis=0) test_poor = np.concatenate(test_poor, axis=0) test_labels = np.concatenate(test_labels, axis=0) return train_rich, train_poor, train_labels, test_rich, test_poor, test_labels class TextureDataset(Dataset): def __init__(self, rich, poor, labels): self.rich = rich self.poor = poor self.labels = labels def __len__(self): return len(self.labels) def __getitem__(self, idx): rich = torch.tensor(self.rich[idx], dtype=torch.float32) poor = torch.tensor(self.poor[idx], dtype=torch.float32) label = torch.tensor(self.labels[idx], dtype=torch.float32) return rich, poor, label def validate(model, test_loader, criterion, device): model.eval() val_loss = 0.0 correct = 0 total = 0 with torch.no_grad(): for rich, poor, labels in test_loader: rich, poor, labels = rich.to(device), poor.to(device), labels.to(device) outputs = model(rich, poor) outputs = outputs.squeeze() loss = criterion(outputs, labels) val_loss += loss.item() predicted = (outputs > 0.5).float() total += labels.size(0) correct += (predicted == labels).sum().item() val_loss /= len(test_loader) val_accuracy = correct / total print(f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}') return val_loss, val_accuracy h5_dir = '/content/drive/MyDrive/h5saves' train_rich, train_poor, train_labels, test_rich, test_poor, test_labels = load_and_split_data(h5_dir, train_ratio=0.8) print(f"Training data: {len(train_labels)} samples") print(f"Testing data: {len(test_labels)} samples") train_dataset = TextureDataset(train_rich, train_poor, train_labels) test_dataset = TextureDataset(test_rich, test_poor, test_labels) batch_size = 2048 train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4) input_shape = (128, 256) model = TextureContrastClassifier(input_shape) criterion = nn.BCELoss() optimizer = optim.Adam(model.parameters(), lr=0.0001) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(device) history = {'train_loss': [], 'val_loss': [], 'train_accuracy':[], 'val_accuracy': []} save_dir = '/content/drive/MyDrive/model_checkpoints' if not os.path.exists(save_dir): os.makedirs(save_dir) num_epochs = 100 for epoch in range(num_epochs): model.train() running_loss = 0.0 correct = 0 total = 0 batch_loss = 0.0 for batch_idx, (rich, poor, labels) in enumerate(train_loader): rich, poor, labels = rich.to(device), poor.to(device), labels.to(device) optimizer.zero_grad() outputs = model(rich, poor) outputs = outputs.squeeze() loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() batch_loss += loss.item() predicted = (outputs > 0.5).float() total += labels.size(0) correct += (predicted == labels).sum().item() if (batch_idx + 1) % 5 == 0: print(f'\rEpoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}], Loss: {batch_loss / 5:.4f}, Accuracy: {correct / total:.2f}', end='') batch_loss = 0.0 avg_train_loss = running_loss / len(train_loader) train_accuracy = correct / total val_loss, val_accuracy = validate(model, test_loader, criterion, device) history['train_loss'].append(avg_train_loss) history['val_loss'].append(val_loss) history['val_accuracy'].append(val_accuracy) history['train_accuracy'].append(train_accuracy) scheduler.step(val_loss) checkpoint_path = os.path.join(save_dir, f'model_epoch_{epoch+1}.pth') torch.save(model.state_dict(), checkpoint_path) print(f'\nModel checkpoint saved for epoch {epoch+1}') print(f'Epoch [{epoch+1}/{num_epochs:.4f}], Training Loss: {avg_train_loss:.4f}, Training Accuracy: {train_accuracy:.4f} Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}') history_path = os.path.join(save_dir, 'training_history.json') with open(history_path, 'w') as f: json.dump(history, f) print('Finished Training') print(f'Training history saved at {history_path}')