Direct Alignment of Draft Model for Speculative Decoding with Chat-Fine-Tuned LLMs
Abstract
Text generation with Large Language Models (LLMs) is known to be memory bound due to the combination of their auto-regressive nature, huge parameter counts, and limited memory bandwidths, often resulting in low token rates. Speculative decoding has been proposed as a solution for LLM inference acceleration. However, since draft models are often unavailable in the modern open-source LLM families, e.g., for Llama 2 7B, training a high-quality draft model is required to enable inference acceleration via speculative decoding. In this paper, we propose a simple draft model training framework for direct alignment to chat-capable target models. With the proposed framework, we train Llama 2 Chat Drafter 115M, a draft model for Llama 2 Chat 7B or larger, with only 1.64\% of the original size. Our training framework only consists of pretraining, distillation dataset generation, and finetuning with knowledge distillation, with no additional alignment procedure. For the finetuning step, we use instruction-response pairs generated by target model for distillation in plausible data distribution, and propose a new Total Variation Distance++ (TVD++) loss that incorporates variance reduction techniques inspired from the policy gradient method in reinforcement learning. Our empirical results show that Llama 2 Chat Drafter 115M with speculative decoding achieves up to 2.3 block efficiency and 2.4$\times$ speed-up relative to autoregressive decoding on various tasks with no further task-specific fine-tuning.
- Publication:
-
arXiv e-prints
- Pub Date:
- February 2024
- DOI:
- 10.48550/arXiv.2403.00858
- arXiv:
- arXiv:2403.00858
- Bibcode:
- 2024arXiv240300858G
- Keywords:
-
- Computer Science - Machine Learning;
- Computer Science - Artificial Intelligence;
- Computer Science - Computation and Language
- E-Print:
- 8 pages, 3 figures, Published at the ICLR 2024 Workshop on Understanding of Foundation Models (ME-FoMo)