os.environ[“CUDA_VISIBLE_DEVICES”] = “0,1”
print(os.environ[“CUDA_VISIBLE_DEVICES”])
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, Sampler
import argparse
import torch.optim as optim
import numpy as np
import random
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
from torch.distributed import init_process_group
import torch.distributed as dist
class data_set(Dataset):
def __init__(self, df):
self.df = df
def __len__(self):
return len(self.df)
def __getitem__(self, index):
sample = self.df[index]
return index, sample
class NeuralNetwork(nn.Module):
def __init__(self, dsize):
super().__init__()
self.linear = nn.Linear(dsize, 1, bias=False)
self.linear.weight.data[:] = 1.
def forward(self, x):
x = self.linear(x)
loss = x.sum()
return loss
class DummySampler(Sampler):
def __init__(self, data, batch_size, n_gpus=2):
self.num_samples = len(data)
self.b_size = batch_size
self.n_gpus = n_gpus
def __iter__(self):
ids = []
for i in range(0, self.num_samples, self.b_size * self.n_gpus):
ids.append(np.arange(self.num_samples)[i: i + self.b_size*self.n_gpus :self.n_gpus])
ids.append(np.arange(self.num_samples)[i+1: (i+1) + self.b_size*self.n_gpus :self.n_gpus])
return iter(np.concatenate(ids))
def __len__(self):
# print (‘\tcalling Sampler:__len__’)
return self.num_samples
def main(args=None):
d_size = args.data_size
if args.distributed:
init_process_group(backend=”nccl”)
device = int(os.environ[“LOCAL_RANK”])
torch.cuda.set_device(device)
else:
device = “cuda:0”
# fix the seed for reproducibility
seed = args.seed
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
cudnn.benchmark = True
# generate data
data = torch.rand(d_size, d_size)
model = NeuralNetwork(args.data_size)
model = model.to(device)
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device])
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
dataset = data_set(data)
if args.distributed:
sampler = torch.utils.data.DistributedSampler(dataset, shuffle=False)
else:
# we define `DummySampler` for exact reproducibility with `DistributedSampler`
# which splits the data as described in the article.
sampler = DummySampler(dataset, args.batch_size)
loader = DataLoader(
dataset,
batch_size=args.batch_size,
num_workers=0,
pin_memory=True,
sampler=sampler,
shuffle=False,
collate_fn=None,
)
if not args.distributed:
grads = []
# ACC_STEPS same as GPU as we need to divide the loss by this number
# to obtain the same gradient as from multiple GPUs that are
# averaged together
ACC_STEPS = args.acc_steps
optimizer.zero_grad()
for epoch in range(args.epochs):
if args.distributed:
loader.sampler.set_epoch(epoch)
for i, (idxs, row) in enumerate(loader):
if args.distributed:
optimizer.zero_grad()
row = row.to(device, non_blocking=True)
if args.distributed:
rank = dist.get_rank() == 0
else:
rank = True
loss = model(row)
if args.distributed:
# does average gradients automatically thanks to model wrapper into
# `DistributedDataParallel`
loss.backward()
else:
# scale loss according to accumulation steps
loss = loss/ACC_STEPS
loss.backward()
if i == 0 and rank:
print(f”Epoch {epoch} {100 * ‘=’}”)
if not args.distributed:
if (i + 1) % ACC_STEPS == 0: # only step when we have done ACC_STEPS
# acumulate grads for entire epoch
optimizer.step()
optimizer.zero_grad()
else:
optimizer.step()
if not args.distributed and args.verbose:
print(100 * “=”)
print(“Model weights : “, model.linear.weight)
print(100 * “=”)
elif args.distributed and args.verbose and rank:
print(100 * “=”)
print(“Model weights : “, model.module.linear.weight)
print(100 * “=”)
if __name__ == “__main__”:
parser = argparse.ArgumentParser()
parser.add_argument(‘–distributed’, action=’store_true’,)
parser.add_argument(‘–seed’, default=0, type=int)
parser.add_argument(‘–epochs’, default=2, type=int)
parser.add_argument(‘–batch_size’, default=4, type=int)
parser.add_argument(‘–data_size’, default=16, type=int)
parser.add_argument(‘–acc_steps’, default=3, type=int)
parser.add_argument(‘–verbose’, action=’store_true’,)
args = parser.parse_args()
print(args)
main(args)