论文标题
自适应检查点的伴随方法,用于神经ode的梯度估计
Adaptive Checkpoint Adjoint Method for Gradient Estimation in Neural ODE
论文作者
论文摘要
神经普通微分方程(节点)最近引起了人们越来越多的关注。但是,它们在基准任务(例如图像分类)上的经验性表现明显较低,而不是离散层模型。我们证明了其性能较差的解释是现有梯度估计方法的不准确性:伴随方法在反向模式集成中具有数值错误;幼稚的方法直接通过ode求解器向后传播,但是在搜索最佳步骤尺寸时,会遭受冗余的计算图。我们提出了自适应检查点的伴随(ACA)方法:在自动分化中,ACA应用了轨迹检查点策略,该策略将前向模式轨迹记录为反向模式轨迹以确保准确性; ACA删除浅计算图的冗余组件; ACA支持自适应求解器。在图像分类任务上,与伴随和幼稚的方法相比,ACA在训练时间的一半中达到了一半的错误率。接受ACA培训的节点的精度和重新测试可靠性都优于重新连接。在时间序列建模上,ACA的表现优于竞争方法。最后,在三体问题的示例中,我们显示了带有ACA的节点可以结合物理知识以获得更好的准确性。我们提供ACA的Pytorch实现:\ url {https://github.com/juntang-zhuang/torch-aca}。
Neural ordinary differential equations (NODEs) have recently attracted increasing attention; however, their empirical performance on benchmark tasks (e.g. image classification) are significantly inferior to discrete-layer models. We demonstrate an explanation for their poorer performance is the inaccuracy of existing gradient estimation methods: the adjoint method has numerical errors in reverse-mode integration; the naive method directly back-propagates through ODE solvers, but suffers from a redundantly deep computation graph when searching for the optimal stepsize. We propose the Adaptive Checkpoint Adjoint (ACA) method: in automatic differentiation, ACA applies a trajectory checkpoint strategy which records the forward-mode trajectory as the reverse-mode trajectory to guarantee accuracy; ACA deletes redundant components for shallow computation graphs; and ACA supports adaptive solvers. On image classification tasks, compared with the adjoint and naive method, ACA achieves half the error rate in half the training time; NODE trained with ACA outperforms ResNet in both accuracy and test-retest reliability. On time-series modeling, ACA outperforms competing methods. Finally, in an example of the three-body problem, we show NODE with ACA can incorporate physical knowledge to achieve better accuracy. We provide the PyTorch implementation of ACA: \url{https://github.com/juntang-zhuang/torch-ACA}.