STAT: Shrinking Transformers After Training
Abstract
We present STAT: a simple algorithm to prune transformer models without any fine-tuning. STAT eliminates both attention heads and neurons from the network, while preserving accuracy by calculating a correction to the weights of the next layer. Each layer block in the network is compressed using a series of principled matrix factorizations that preserve the network structure. Our entire algorithm takes minutes to compress BERT, and less than three hours to compress models with 7B parameters using a single GPU. Using only several hundred data examples, STAT preserves the output of the network and improves upon existing gradient-free pruning methods. It is even competitive with methods that include significant fine-tuning. We demonstrate our method on both encoder and decoder architectures, including BERT, DistilBERT, and Llama-2 using benchmarks such as GLUE, Squad, WikiText2.
- Publication:
-
arXiv e-prints
- Pub Date:
- May 2024
- DOI:
- 10.48550/arXiv.2406.00061
- arXiv:
- arXiv:2406.00061
- Bibcode:
- 2024arXiv240600061F
- Keywords:
-
- Computer Science - Machine Learning;
- Computer Science - Artificial Intelligence;
- Computer Science - Computation and Language