defallreduce(data): for i inrange(1, len(data)): data[0][:] += data[i].to(data[0].device) for i inrange(1, len(data)): data[i][:] = data[0].to(data[i].device)
deftrain_batch(X, y, device_params, devices, lr): # 将一个batch的数据拆分到多个GPU上 X_shards, y_shards = split_batch(X, y, devices) # 在每个GPU上分别计算损失 ls = [loss(net(X_shard, device_W), y_shard).sum() for X_shard, y_shard, device_W inzip( X_shards, y_shards, device_params)] for l in ls: # 反向传播在每个GPU上分别执行 l.backward() # 将每个GPU的所有梯度相加,并将其广播到所有GPU with torch.no_grad(): for i inrange(len(device_params[0])): allreduce( [device_params[c][i].grad for c inrange(len(devices))]) # 在每个GPU上分别更新模型参数 for param in device_params: sgd(param, lr, X.shape[0]) # 在这里,我们使用全尺寸的小批量