1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
| import torch from torch import nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from torchvision import datasets from torchvision.transforms import ToTensor
import torch.multiprocessing as mp from torch.utils.data.distributed import DistributedSampler from torch.nn.parallel import DistributedDataParallel as DDP from torch.distributed import init_process_group, destroy_process_group import os
def ddp_setup(rank, world_size): """ Args: rank: Unique identifier of each process world_size: Total number of processes """ os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "12355" init_process_group(backend="gloo", rank=rank, world_size=world_size) torch.cuda.set_device(rank)
class Trainer: def __init__( self, model: torch.nn.Module, train_data: DataLoader, optimizer: torch.optim.Optimizer, gpu_id: int, save_every: int, ) -> None: self.gpu_id = gpu_id self.model = model.to(gpu_id) self.train_data = train_data self.optimizer = optimizer self.save_every = save_every self.model = DDP(model, device_ids=[gpu_id])
def _run_batch(self, source, targets): self.optimizer.zero_grad() output = self.model(source) loss = F.cross_entropy(output, targets) loss.backward() self.optimizer.step()
def _run_epoch(self, epoch): b_sz = len(next(iter(self.train_data))[0]) print(f"[GPU{self.gpu_id}] Epoch {epoch} | Batchsize: {b_sz} | Steps: {len(self.train_data)}") self.train_data.sampler.set_epoch(epoch) for source, targets in self.train_data: source = source.to(self.gpu_id) targets = targets.to(self.gpu_id) self._run_batch(source, targets)
def _save_checkpoint(self, epoch): ckp = self.model.module.state_dict() PATH = "checkpoint.pt" torch.save(ckp, PATH) print(f"Epoch {epoch} | Training checkpoint saved at {PATH}")
def train(self, max_epochs: int): for epoch in range(max_epochs): self._run_epoch(epoch) if self.gpu_id == 0 and epoch % self.save_every == 0: self._save_checkpoint(epoch)
class NeuralNetwork(nn.Module): def __init__(self): super().__init__() self.flatten = nn.Flatten() self.linear_relu_stack = nn.Sequential( nn.Linear(28*28, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, 10) )
def forward(self, x): x = self.flatten(x) logits = self.linear_relu_stack(x) return logits
train_set = datasets.FashionMNIST( root="data", train=True, download=True, transform=ToTensor(), )
def load_train_objs(): model = NeuralNetwork() optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) return train_set, model, optimizer
def prepare_dataloader(dataset: Dataset, batch_size: int): return DataLoader( dataset, batch_size=batch_size, pin_memory=True, shuffle=False, sampler=DistributedSampler(dataset) )
def main(rank: int, world_size: int, save_every: int, total_epochs: int, batch_size: int): ddp_setup(rank, world_size) dataset, model, optimizer = load_train_objs() train_data = prepare_dataloader(dataset, batch_size) trainer = Trainer(model, train_data, optimizer, rank, save_every) trainer.train(total_epochs) destroy_process_group()
if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description='simple distributed training job') parser.add_argument('total_epochs', type=int, help='Total epochs to train the model') parser.add_argument('save_every', type=int, help='How often to save a snapshot') parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)') args = parser.parse_args() print(f"total_epochs: {args.total_epochs}, save_every: {args.save_every}, batch_size: {args.batch_size}") world_size = torch.cuda.device_count() mp.spawn(main, args=(world_size, args.save_every, args.total_epochs, args.batch_size), nprocs=world_size)
|