Motivation
Improving model performance often requires training bigger models. However, bigger models result in higher computational cost and inference latency, which could prohibit usage in some real-world scenarios. Mistral 7B is a 7-billion parameter language model using several techniques to improve the efficiency of the model.
Method
- Grouped-query Attention (GQA) Architecture: Standard decoder-only Transformer with fewer KV heads than attention heads to speed up decode and raise batch throughput.
- Sliding Window Attention: Restricts each token to attend to the last 4096 tokens in the previous layer. Information then flows upward so the final layer has a theoretical attention span of ~131k tokens (32x4096).
- Rolling Buffer Cache: Limit the cache memory usage by overwriting values more than 4096 tokens away.
- Pre-fill Cache: If prompts are known in advance, the cache is pre-filled with the prompts. If the prompt is long, pre-fill is done on window-sized chunks, where each chunk attends to the cache + current chunk.
Results
Performance wise, Mistral 7B outperforms the LLaMA 2 7B and 13B model in all tested benchmarks (MMLU, Hellaswag, WinoG, PIQA, Arc-e, Arc-c, NQ, TriviaQA, HumanEval, MBPP, MATH, GSM8K) and outperforms Code-Llama 7B in all but one benchmark (MBPP).
Efficiency wise, Mistral 7B claims it is ~2x faster over vanilla using Sliding Window Attention along with FlashAttention and xFormers with sequence length of 16K and windows size of 4096. Mistral 7B also reduced the memory usage by 1/2 for sequence length of 8K, and 1/8 for sequence length of 32K.
Separately, the authors demonstrate system-prompt guardrails and a self-reflection moderation setup.
Contributions / Why It Matters
- Strong small model: Mistral 7B is a small model that consistently surpasses LLaMA 2 13B across multiple benchmarks.
- Practical long-context recipe: The combination of methods allow for good performance at lower latency and higher throughput.
❓ Remaining Questions
- Since each layer only expands the attention span by 4096, we need sufficient number of layers for a sufficient attention span size. Mistral 7B has an attention span of ~131K tokens from 32 layers. When are cases this may not be enough? Summarization? Coding?
- Sliding Window Attention with FlashAttention and xFormers yield 2x speed improvement. How much of this can be attributed to Sliding Window Attention?
- The authors mention a difference in evaluation protocol for MBPP (hand-verified subset) and TriviaQA (no Wikipedia context). Why did they change the protocol?
- The experiments on content moderation has no baseline to compare against. Is ths moderation ability unique to Mistral 7B, or is it general across all LLMs?
- The experiments on content moderation has no description of the dataset used. How was this data collected, and how was generated answers determined as acceptable or unacceptable?
- Why are the sequence lengths different across the paper when discussing efficiency?