Loading datasets#
A neural net written in Jax+Flax expects its input data as jax.numpy
array instances. Therefore, loading a dataset from any source is as simple as converting it to jax.numpy
types and reshaping it to the appropriate dimensions for your network.
As an example, this guide demonstrates how to import MNIST using the APIs from Torchvision, Tensorflow, and Hugging Face. We’ll load the whole dataset into memory. For datasets that don’t fit into memory the process is analogous but should be done in a batchwise fashion.
The MNIST dataset consists of greyscale images of 28x28 pixels of handwritten digits, and has a designated 60k/10k train/test split. The task is to predict the correct class (digit 0, …, 9) of each image.
Assuming a CNN-based classifier, the input data should have shape (B, 28, 28, 1)
, where the trailing singleton dimension denotes the greyscale image channel.
The labels are simply the integer denoting the digit corresponding to the image. Labels should therefore have shape (B,)
, to enable loss computation with optax.softmax_cross_entropy_with_integer_labels
import numpy as np
import jax.numpy as jnp
Loading from torchvision.datasets
import torchvision
def get_dataset_torch():
mnist = {
'train': torchvision.datasets.MNIST('./data', train=True, download=True),
'test': torchvision.datasets.MNIST('./data', train=False, download=True)
ds = {}
for split in ['train', 'test']:
ds[split] = {
'image': mnist[split].data.numpy(),
'label': mnist[split].targets.numpy()
# cast from np to jnp and rescale the pixel values from [0,255] to [0,1]
ds[split]['image'] = jnp.float32(ds[split]['image']) / 255
ds[split]['label'] = jnp.int16(ds[split]['label'])
# torchvision returns shape (B, 28, 28).
# hence, append the trailing channel dimension.
ds[split]['image'] = jnp.expand_dims(ds[split]['image'], 3)
return ds['train'], ds['test']
train, test = get_dataset_torch()
print(train['image'].shape, train['image'].dtype)
print(train['label'].shape, train['label'].dtype)
print(test['image'].shape, test['image'].dtype)
print(test['label'].shape, test['label'].dtype)
(60000, 28, 28, 1) float32
(60000,) int16
(10000, 28, 28, 1) float32
(10000,) int16
Loading from tensorflow_datasets
import tensorflow_datasets as tfds
def get_dataset_tf():
mnist = tfds.builder('mnist')
ds = {}
for split in ['train', 'test']:
ds[split] = tfds.as_numpy(mnist.as_dataset(split=split, batch_size=-1))
# cast to jnp and rescale pixel values
ds[split]['image'] = jnp.float32(ds[split]['image']) / 255
ds[split]['label'] = jnp.int16(ds[split]['label'])
return ds['train'], ds['test']
train, test = get_dataset_tf()
print(train['image'].shape, train['image'].dtype)
print(train['label'].shape, train['label'].dtype)
print(test['image'].shape, test['image'].dtype)
print(test['label'].shape, test['label'].dtype)
(60000, 28, 28, 1) float32
(60000,) int16
(10000, 28, 28, 1) float32
(10000,) int16
Loading from 🤗 Hugging Face datasets
#!pip install datasets # datasets isn't preinstalled on Colab; uncomment to install
from datasets import load_dataset
def get_dataset_hf():
mnist = load_dataset("mnist")
ds = {}
for split in ['train', 'test']:
ds[split] = {
'image': np.array([np.array(im) for im in mnist[split]['image']]),
'label': np.array(mnist[split]['label'])
# cast to jnp and rescale pixel values
ds[split]['image'] = jnp.float32(ds[split]['image']) / 255
ds[split]['label'] = jnp.int16(ds[split]['label'])
# append trailing channel dimension
ds[split]['image'] = jnp.expand_dims(ds[split]['image'], 3)
return ds['train'], ds['test']
train, test = get_dataset_hf()
print(train['image'].shape, train['image'].dtype)
print(train['label'].shape, train['label'].dtype)
print(test['image'].shape, test['image'].dtype)
print(test['label'].shape, test['label'].dtype)
(60000, 28, 28, 1) float32
(60000,) int16
(10000, 28, 28, 1) float32
(10000,) int16