安爸-超级家庭

LLM-Pruning Collection:一个基于JAX的结构化与非结构化LLM压缩的仓库

安爸 发布于

Zlab 普林斯顿的研究员发布了 LLM-Pruning Collection,这是一个基于 JAX 的存储库,将大型语言模型的多种剪枝算法整合到一个可复制的框架中。它的目标非常明确,就是在统一训练和评估环境中,方便比较块级、层级和权重级剪枝方法,无论是在 GPU 还是 TPU 上。

LLM-Pruning Collection 包含的内容

它被描述为一个基于 JAX 的 LLM 剪枝仓库。它组织成三个主要目录:

  • pruning 存储了多种剪枝方法的实现:Minitron、ShortGPT、Wanda、SparseGPT、Magnitude、Sheared Llama 和 LLM-Pruner。
  • training 提供了与 FMS-FSDP 集成,支持 GPU 训练,以及与 MaxText 集成,支持 TPU 训练。
  • eval 提供了基于 lm-eval-harness 的 JAX 兼容评估脚本,有基于加速技术的 MaxText 支持,可提供大约 2 到 4 倍的速度提升。

覆盖的剪枝方法

LLM-Pruning Collection 涵盖了几个具有不同粒度级别的剪枝算法系列:

Minitron

Minitron 是 NVIDIA 开发的一个实用剪枝和蒸馏食谱,它可以将 Llama 3.1 8B 和 Mistral NeMo 12B 压缩到 4B 和 8B,同时保持性能。它探讨了深度剪枝和隐藏大小、注意力和 MLP 的联合宽度剪枝,随后进行蒸馏。

在 LLM-Pruning Collection 中,pruning/minitron 目录提供了 prune_llama3.1-8b.sh 等脚本,它们在 Llama 3.1 8B 上执行 Minitron 风格的剪枝。

ShortGPT

ShortGPT 基于这样一个观察:许多 Transformer 层都是冗余的。该方法定义了块影响,这是一个衡量每个层贡献的指标,然后通过直接删除低影响层来移除低影响层。实验表明,ShortGPT 在多选题和生成任务中优于之前的剪枝方法。

在集合中,通过 Minitron 目录实现 ShortGPT,有一个专门脚本来 prune_llama2-7b.sh

Wanda、SparseGPT、Magnitude

Wanda 是一种博士后剪枝方法,它根据每个输出上的权重幅度和相应的输入激活度的乘积来对权重进行评分。它剪除最小的评分,不需要重新训练,并且产生的稀疏度即使在大规模参数上也工作良好。

SparseGPT 是另一种博士后剪枝方法,它使用二阶启发的重建步骤在具有高稀疏比率的 GPT 风格模型上进行剪枝。幅度剪枝是经典的基线,它移除绝对值较小的权重。

在 LLM-Pruning Collection 中,这三个都在 pruning/wanda 下,共享安装路径。README 包括一个关于 Llama 2 7B 结果的密集表格,比较了 Wanda、SparseGPT 和 Magnitude 在 BoolQ、RTE、HellaSwag、Winogrande、ARC E、ARC C 和 OBQA 上的表现,以及在 4:8 和 2:4 这样的无结构和有结构的稀疏模式下的表现。

Sheared LLaMA

Sheared LLaMA 是一种结构化剪枝方法,它学习层的掩码、注意力头和隐藏维度,然后重新训练剪枝的架构。原始发布提供了包括 2.7B 和 1.3B 在内的多个缩放级别的模型。

LLM-Pruning Collection 中的 pruning/llmshearing 目录集成了这个食谱。它使用一个 RedPajama 子集进行校准,通过 Hugging Face 访问,并使用辅助脚本来在 Hugging Face 和 MosaicML Composer 格式间转换。

LLM-Pruner

LLM-Pruner 是一个大语言模型结构化剪枝框架。它使用基于梯度的重要性分数去除非关键耦合结构,如注意力头或 MLP 通道,然后通过约 50K 个样本的短 LoRA 调优阶段恢复性能。集合中包括 LLM-Pruner 在 pruning/LLM-Pruner 下,有 LLaMA、LLaMA 2 和 Llama 3.1 8B 的脚本。

主要结论

  • LLM-Pruning Collection 是一个来自 zlab-princeton 的基于 JAX 的 Apache-2.0 存储库,它统一了现代 LLM 剪枝方法,并提供了共享剪枝、训练和评估管道,适用于 GPU 和 TPU。
  • 代码库实现了块级、层级和权重级剪枝方法,包括 Minitron、ShortGPT、Wanda、SparseGPT、Sheared LLaMA、Magnitude 剪枝和 LLM-Pruner,并为 Llama 系列模型提供了方法专用的脚本。
  • 训练集成了 GPU 上的 FMS-FSDP 和 TPU 上的 MaxText,以及基于 lm-eval-harness 构建的 JAX 兼容评估脚本,通过 accelerate 提供大约 2 到 4 倍的 MaxText 检查点评估速度。
  • 存储库重现了之前剪枝工作的关键结果,为如 Wanda、SparseGPT、Sheared LLaMA 和 LLM-Pruner 等方法发布了“论文与重现”并排表格,工程师可以将其运行与已知的基线对比验证。

查看 GitHub 仓库。也请随意关注我们的 Twitter,别忘了加入我们的 10 万+ 机器学习 SubReddit。订阅我们的 通讯。等等!你在 Telegram 上吗?现在你也可以加入我们

原文链接:LLM-Pruning Collection:面向结构和无结构 LLM 压缩的基于 JAX 的存储库 首次发布于 MarkTechPost


扫描二维码,在手机上阅读