Description: 技术博客
前几天看到知乎上的文章 FLOPs与模型推理速度 ,文中提到一个比较耗时又占显存的pointwise操作 x * sigmoid(x) ,这实际上是 swish activation ;暂且不提它背后的争议,本文主要想从这个结构入手来优化它的显存占用以及耗时,并讨论更广泛的训练时显存优化技术。
要分析清楚swish activation为什么会比较占显存,我们首先需要搞清楚反向传播是如何工作的,或者更进一步说,现有的自动求导框架是如何求出梯度的。
先明确一点,所谓自动求导框架实际上是“半自动”的:它并非直接求出一个复杂函数导数的解析形式,而是通过构建计算图和预先写好的基础函数的求导规则,结合链式求导法则实现的自动求导。