Skip to content

BaljinderHothi/Tiny-Llama-Optimization

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

tinyllama inference optimization

advanced ml optimization framework for tinyllama-1.1b demonstrating production-ready inference acceleration techniques with measurable performance gains.

overview

this project implements multiple state-of-the-art optimization techniques to accelerate inference and reduce memory consumption for tinyllama-1.1b while maintaining model quality. targets: 60%+ speedup, 35%+ memory reduction, 97%+ accuracy retention.

optimization techniques

1. mixed precision (fp16/fp32)

  • converts majority of model to fp16 for faster computation
  • keeps numerically sensitive layers (layernorm, embeddings) in fp32
  • benefits: 40-50% speedup, 50% memory reduction, no quality loss
  • tradeoffs: requires gpu/mps support for optimal performance

2. post-training quantization (int8)

  • applies pytorch dynamic quantization to linear layers
  • symmetric int8 quantization: maps [-max, max] → [-128, 127]
  • benefits: 60-75% model size reduction, lower memory footprint
  • tradeoffs: 5-10% slower on some hardware, slight quality degradation (<2%)

3. attention head pruning

  • identifies low-importance heads using l2 norm of output weights
  • zeros out least important 25% of attention heads
  • benefits: 15-25% speedup, reduced computation
  • tradeoffs: requires careful tuning, ~1-3% quality loss

4. kv-cache quantization

  • quantizes key-value cache to int8 during generation
  • reduces cache memory by 50-75%
  • benefits: enables longer context windows, lower peak memory
  • tradeoffs: minimal (~5%) compute overhead for quant/dequant

project structure

tinyllama-optimization/
├── README.md
├── requirements.txt
├── baseline.py
├── optimizations/
│   ├── mixed_precision.py
│   ├── quantization.py
│   ├── head_pruning.py
│   └── kv_cache_quant.py
├── benchmark.py
├── run_optimization.py
└── results/
    ├── baseline_results.json
    ├── consolidated_results.json
    └── visualizations/

usage

run baseline

python baseline.py

run individual optimizations

python optimizations/mixed_precision.py
python optimizations/quantization.py
python optimizations/head_pruning.py
python optimizations/kv_cache_quant.py

run complete pipeline

python run_optimization.py

Actual Results (M2 MacBook Pro, 16GB RAM)

Method Tokens/sec Memory (MB) Model Size (MB) Perplexity
Baseline 24.3 4521 4200 15.2
Mixed FP16 35.7 (+47%) 2340 (-48%) 2100 (-50%) 15.2 (100%)
INT8 Quant 21.8 (-10%) 2850 (-37%) 1260 (-70%) 15.6 (98%)
Head Prune 29.1 (+20%) 3900 (-14%) 4200 (0%) 15.7 (97%)

Real-World Impact

  • Edge Deployment: 70% size reduction enables on-device inference
  • Cost Savings: 50% memory reduction = 2x throughput per GPU
  • Latency: 45% speedup improves user experience in chatbots

implementation details

mixed precision

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16
)

for module in model.modules():
    if isinstance(module, (nn.LayerNorm, nn.Embedding)):
        module.to(torch.float32)

int8 quantization

scale = 127.0 / max_val
quantized = torch.clamp(torch.round(tensor * scale), -128, 127)
dequantized = quantized.float() / scale

head pruning

importance = torch.norm(attention.o_proj.weight, p=2, dim=0)
threshold = torch.quantile(importance, 0.25)
mask = importance > threshold

relevant papers

  1. mixed precision training - micikevicius et al. (2017)
  2. quantization and training of neural networks - jacob et al. (2018)
  3. are sixteen heads really better than one? - michel et al. (2019)
  4. kv-cache quantization - sheng et al. (2023)

troubleshooting

out of memory errors

  • reduce max_new_tokens in generation
  • use quantization first before other optimizations
  • close other applications

slow generation on m2

  • ensure model uses mps device: device="mps"
  • check activity monitor for thermal throttling
  • mixed precision has best m2 performance

quality degradation

  • reduce pruning ratio (try 0.15 instead of 0.25)
  • use fp16 instead of int8 quantization
  • combine fewer optimizations

customization

adjust pruning ratio

pruned = PrunedInference(prune_ratio=0.15)

modify benchmark prompts

test_prompts = [
    "your custom prompt 1:",
    "your custom prompt 2:",
]

license

mit license

About

Uploading it to GitHub to make it easier to share with others!

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages