论文标题

Rieoptax:JAX中的Riemannian优化

Rieoptax: Riemannian Optimization in JAX

论文作者

Utpala, Saiteja, Han, Andi, Jawanpuria, Pratik, Mishra, Bamdev

论文摘要

我们提出了Rieoptax,这是一个用于JAX中Riemannian优化的开源Python库。我们表明,在Rieoptax中,许多差异几何原始原始图(例如Riemannian指数和对数图)通常比Python的现有框架更快,无论是在CPU和GPU上。我们支持各种基本和先进的随机优化求解器,例如Riemannian随机梯度,随机方差降低和自适应梯度方法。所提出的工具箱的一个区别特征是,我们还支持对Riemannian歧管的私有优化。

We present Rieoptax, an open source Python library for Riemannian optimization in JAX. We show that many differential geometric primitives, such as Riemannian exponential and logarithm maps, are usually faster in Rieoptax than existing frameworks in Python, both on CPU and GPU. We support various range of basic and advanced stochastic optimization solvers like Riemannian stochastic gradient, stochastic variance reduction, and adaptive gradient methods. A distinguishing feature of the proposed toolbox is that we also support differentially private optimization on Riemannian manifolds.

扫码加入交流群

加入微信交流群

微信交流群二维码

扫码加入学术交流群,获取更多资源