This repository provides an introduction to JAX, a high-performance numerical computing library from Google Research.
It is meant for practitioners that are already familiar with numpy and PyTorch.
-
00_from_zero_to_noob.ipynb
: An introductory tutorial notebook that covers the essentials:- Numpy-style array operations
- Automatic differentiation and optimization
- Parallel execution with vmap
- Introduction to PyTrees
- Just-In-Time compilation with jit
- Random numbers
- Introduction to flax
- Recurrent nets with scan
-
01_meta_learning_maml.ipynb
Meta learning with MAML, implementation sketch -
02_meta_learning_hypernet.ipynb
Meta learning with Hypernets, implementation sketch
Worked out examples are available in the gallery folder.
- Make sure you have Python installed
- Install dependencies:
pip install jax jaxlib jupyter
- Launch Jupyter Notebook:
jupyter notebook
- Open
00_from_zero_to_noob.ipynb
to begin learning JAX