Flax - 专为灵活性设计的JAX神经网络库及生态系统

发布于:2025-05-13 ⋅ 阅读:(9) ⋅ 点赞:(0)

本文翻译整理自:https://github.com/google/flax


一、关于 Flax

Flax NNX 是2024年发布的全新简化版Flax API,旨在简化在JAX中创建、检查、调试和分析神经网络的过程。

它通过原生支持Python引用语义来实现这一目标,允许用户使用常规Python对象表达模型,支持引用共享和可变性。

Flax NNX 由Flax Linen API演进而来,后者是2020年由Google Brain工程师和研究人员与JAX团队紧密合作发布的。

您可以在Flax专属文档站点了解更多关于Flax NNX的信息,特别推荐以下内容:

注:Flax Linen的文档有独立站点

Flax团队的使命是服务于不断增长的JAX神经网络研究生态系统——包括Alphabet内部和更广泛的社区,并探索JAX的闪光用例。我们几乎所有的协调和规划工作都在GitHub上进行,包括讨论即将进行的设计变更。欢迎您在我们的讨论区、问题区和拉取请求线程中提供反馈。

您可以在Flax GitHub讨论区提出功能请求、分享工作内容、报告问题或提问。

我们期待改进Flax,但不预期核心API会有重大破坏性变更。我们会尽可能使用更新日志和弃用警告。

如需直接联系我们,请发邮件至flax-dev@google.com。


相关链接资源


关键功能特性


二、安装

Flax基于JAX,请先查看JAX在CPU、GPU和TPU上的安装说明

需要Python 3.8或更高版本。从PyPi安装Flax:

pip install flax

升级到最新版Flax:

pip install --upgrade git+https://github.com/google/flax.git

安装额外依赖(如matplotlib):

pip install "flax[all]"

三、Flax代码示例

我们提供三个使用Flax API的示例:简单多层感知机、CNN和自动编码器。

要了解Module抽象,请查阅我们的文档Module抽象介绍。更多最佳实践示例,请参考我们的指南开发者笔记

多层感知机示例:

class MLP(nnx.Module):
  def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):
    self.linear1 = Linear(din, dmid, rngs=rngs)
    self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
    self.bn = nnx.BatchNorm(dmid, rngs=rngs)
    self.linear2 = Linear(dmid, dout, rngs=rngs)

  def __call__(self, x: jax.Array):
    x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))
    return self.linear2(x)

CNN示例:

class CNN(nnx.Module):
  def __init__(self, *, rngs: nnx.Rngs):
    self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
    self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
    self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))
    self.linear1 = nnx.Linear(3136, 256, rngs=rngs)
    self.linear2 = nnx.Linear(256, 10, rngs=rngs)

  def __call__(self, x):
    x = self.avg_pool(nnx.relu(self.conv1(x)))
    x = self.avg_pool(nnx.relu(self.conv2(x)))
    x = x.reshape(x.shape[0], -1)  # flatten
    x = nnx.relu(self.linear1(x))
    x = self.linear2(x)
    return x

自动编码器示例:

Encoder = lambda rngs: nnx.Linear(2, 10, rngs=rngs)
Decoder = lambda rngs: nnx.Linear(10, 2, rngs=rngs)

class AutoEncoder(nnx.Module):
  def __init__(self, rngs):
    self.encoder = Encoder(rngs)
    self.decoder = Decoder(rngs)

  def __call__(self, x) -> jax.Array:
    return self.decoder(self.encoder(x))

  def encode(self, x) -> jax.Array:
    return self.encoder(x)


伊织 xAI 2025-04-27(日)