1import torch 2from torch import nn 3from torch.utils.data import DataLoader 4from torchvision import datasets 5from torchvision.transforms import ToTensor 6 7# Download training data from open datasets. 8training_data = datasets.FashionMNIST( 9 root="data", 10 train=True, 11 download=False, 12 transform=ToTensor(), 13) 14 15# Download test data from open datasets. 16test_data = datasets.FashionMNIST( 17 root="data", 18 train=False, 19 download=False, 20 transform=ToTensor(), 21) 22 23batch_size = 64 24 25# Create data loaders. 26train_dataloader = DataLoader(training_data, batch_size=batch_size) 27test_dataloader = DataLoader(test_data, batch_size=batch_size) 28 29for X, y in test_dataloader: 30 print(f"Shape of X [N, C, H, W]: {X.shape}") 31 print(f"Shape of y: {y.shape} {y.dtype}") 32 break 33 34device = ( 35 torch.accelerator.current_accelerator().type 36 if torch.accelerator.is_available() 37 else "cpu" 38) 39print(f"Using {device} device") 40 41 42# Define model 43class NeuralNetwork(nn.Module): 44 def __init__(self): 45 super().__init__() 46 self.flatten = nn.Flatten() 47 self.linear_relu_stack = nn.Sequential( 48 nn.Linear(28 * 28, 512), 49 nn.ReLU(), 50 nn.Linear(512, 512), 51 nn.ReLU(), 52 nn.Linear(512, 10), 53 ) 54 55 def forward(self, x): 56 x = self.flatten(x) 57 logits = self.linear_relu_stack(x) 58 return logits 59 60 61model = NeuralNetwork().to(device) 62print(model) 63 64loss_fn = nn.CrossEntropyLoss() 65optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) 66 67 68def train( 69 dataloader, 70 model, 71 loss_fn, 72 optimizer, 73): 74 size = len(dataloader.dataset) 75 model.train() 76 for batch, (X, y) in enumerate(dataloader): 77 X, y = ( 78 X.to(device), 79 y.to(device), 80 ) 81 82 # Compute prediction error 83 pred = model(X) 84 loss = loss_fn(pred, y) 85 86 # Backpropagation 87 loss.backward() 88 optimizer.step() 89 optimizer.zero_grad() 90 91 if batch % 100 == 0: 92 loss, current = ( 93 loss.item(), 94 (batch + 1) * len(X), 95 ) 96 print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]") 97 98 99def test(dataloader, model, loss_fn): 100 size = len(dataloader.dataset) 101 num_batches = len(dataloader) 102 model.eval() 103 test_loss, correct = 0, 0 104 with torch.no_grad(): 105 for X, y in dataloader: 106 X, y = ( 107 X.to(device), 108 y.to(device), 109 ) 110 pred = model(X) 111 test_loss += loss_fn(pred, y).item() 112 correct += (pred.argmax(1) == y).type(torch.float).sum().item() 113 test_loss /= num_batches 114 correct /= size 115 print( 116 f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n" 117 ) 118 119 120epochs = 5 121for t in range(epochs): 122 print(f"Epoch {t + 1}\n-------------------------------") 123 train( 124 train_dataloader, 125 model, 126 loss_fn, 127 optimizer, 128 ) 129 test(test_dataloader, model, loss_fn) 130print("Done!") 131 132torch.save(model.state_dict(), "model.pth") 133print("Saved PyTorch Model State to model.pth") 134 135model = NeuralNetwork().to(device) 136model.load_state_dict(torch.load("model.pth", weights_only=True)) 137 138classes = [ 139 "T-shirt/top", 140 "Trouser", 141 "Pullover", 142 "Dress", 143 "Coat", 144 "Sandal", 145 "Shirt", 146 "Sneaker", 147 "Bag", 148 "Ankle boot", 149] 150 151model.eval() 152x, y = test_data[0][0], test_data[0][1] 153with torch.no_grad(): 154 x = x.to(device) 155 pred = model(x) 156 predicted, actual = ( 157 classes[pred[0].argmax(0)], 158 classes[y], 159 ) 160 print(f'Predicted: "{predicted}", Actual: "{actual}"')