Byzantine Fault-Tolerant Distributed Machine Learning Using Stochastic Gradient Descent (SGD) and Norm-Based Comparative Gradient Elimination (CGE)
Abstract
This paper considers the Byzantine fault-tolerance problem in distributed stochastic gradient descent (D-SGD) method - a popular algorithm for distributed multi-agent machine learning. In this problem, each agent samples data points independently from a certain data-generating distribution. In the fault-free case, the D-SGD method allows all the agents to learn a mathematical model best fitting the data collectively sampled by all agents. We consider the case when a fraction of agents may be Byzantine faulty. Such faulty agents may not follow a prescribed algorithm correctly, and may render traditional D-SGD method ineffective by sharing arbitrary incorrect stochastic gradients. We propose a norm-based gradient-filter, named comparative gradient elimination (CGE), that robustifies the D-SGD method against Byzantine agents. We show that the CGE gradient-filter guarantees fault-tolerance against a bounded fraction of Byzantine agents under standard stochastic assumptions, and is computationally simpler compared to many existing gradient-filters such as multi-KRUM, geometric median-of-means, and the spectral filters. We empirically show, by simulating distributed learning on neural networks, that the fault-tolerance of CGE is comparable to that of existing gradient-filters. We also empirically show that exponential averaging of stochastic gradients improves the fault-tolerance of a generic gradient-filter.
- Publication:
-
arXiv e-prints
- Pub Date:
- August 2020
- DOI:
- arXiv:
- arXiv:2008.04699
- Bibcode:
- 2020arXiv200804699G
- Keywords:
-
- Computer Science - Machine Learning;
- Computer Science - Distributed;
- Parallel;
- and Cluster Computing;
- Statistics - Machine Learning
- E-Print:
- The report includes 52 pages, and 16 figures. Extension of our prior work on Byzantine fault-tolerant distribution optimization (arXiv:1903.08752 and doi:10.1145/3382734.3405748) to Byzantine fault-tolerant distributed machine learning