Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a description of Flux vs Lux #819

Merged
merged 2 commits into from
Mar 8, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,33 @@ The following layer functions exist:
Examples of how to build architectures from scratch, with tutorials on things
like Graph Neural ODEs, can be found in the [SciMLSensitivity.jl documentation](https://docs.sciml.ai/SciMLSensitivity/stable/).

WIP:
## Flux.jl vs Lux.jl

- Lagrangian Neural Networks
- Galerkin Neural ODEs
Both Flux and Lux defined neural networks are supported by DiffEqFlux.jl. However, Lux.jl neural networks are greatly preferred for many
correctness reasons. Particularly, a Flux `Chain` does not respect Julia's type promotion rules. This causes major problems in that
the restructuring of a Flux neural network will not respect the chosen types from the solver. Demonstration:

```julia
using Flux, Tracker

x = [0.8; 0.8]
ann = Chain(Dense(2, 10, tanh), Dense(10, 1))
p, re = Flux.destructure(ann)
z = re(Float64(p))
```

While one may thing this recreates the neural network to act in `Float64` precsion, [it does not](/~https://github.com/FluxML/Flux.jl/pull/2156)
ChrisRackauckas marked this conversation as resolved.
Show resolved Hide resolved
and instead its values will silently downgrade everything to `Float32`. This is only fixed by `Chain(Dense(2, 10, tanh), Dense(10, 1)) |> f64`.
Similar cases will [lead to dropped gradients with complex numbers](/~https://github.com/FluxML/Optimisers.jl/issues/95). This is not an issue
with the automatic differentiation library commonly associated with Flux (Zygote.jl) but rather due to choices in the neural network library's
decision for how to approach type handling and precision. Thus when using DiffEqFlux.jl with Flux, the user must be very careful to ensure that
the precision of the arguments are correct, and anything that requires alternative types (like `TrackerAdjoint` tracked values and
`ForwardDiffSensitivity` dual numbers) are suspect.

Lux.jl has none of these issues, is simpler to work with due to the parameters in its function calls being explicit rather than implicit global
references, and achieves higher performance. It is built on the same foundations as Flux.jl, such as Zygote and NNLib, and thus it supports the
same layers underneith and calls the same kernels. The better performance comes from not having the overhead of `restructure` required.
Thus we highly recommend people use Lux instead and only use the Flux fallbacks for legacy code.

## Citation

Expand Down