论文标题
Rieoptax:JAX中的Riemannian优化
Rieoptax: Riemannian Optimization in JAX
论文作者
论文摘要
我们提出了Rieoptax,这是一个用于JAX中Riemannian优化的开源Python库。我们表明,在Rieoptax中,许多差异几何原始原始图(例如Riemannian指数和对数图)通常比Python的现有框架更快,无论是在CPU和GPU上。我们支持各种基本和先进的随机优化求解器,例如Riemannian随机梯度,随机方差降低和自适应梯度方法。所提出的工具箱的一个区别特征是,我们还支持对Riemannian歧管的私有优化。
We present Rieoptax, an open source Python library for Riemannian optimization in JAX. We show that many differential geometric primitives, such as Riemannian exponential and logarithm maps, are usually faster in Rieoptax than existing frameworks in Python, both on CPU and GPU. We support various range of basic and advanced stochastic optimization solvers like Riemannian stochastic gradient, stochastic variance reduction, and adaptive gradient methods. A distinguishing feature of the proposed toolbox is that we also support differentially private optimization on Riemannian manifolds.