Nx_datasets
Load common machine learning datasets and generate synthetic data for testing.
Overview
Nx_datasets provides two categories of data:
- Real datasets: Downloaded and cached locally (MNIST, CIFAR-10, Iris, etc.)
- Synthetic generators: Create data on-the-fly for testing (blobs, moons, regression problems)
Real datasets are automatically downloaded on first use and cached in your platform's cache directory.
Available Datasets Reference
Real Datasets
| Dataset | Function | Samples | Features | Task | 
|---|---|---|---|---|
| MNIST | load_mnist | 70,000 | 28×28×1 | Classification | 
| Fashion-MNIST | load_fashion_mnist | 70,000 | 28×28×1 | Classification | 
| CIFAR-10 | load_cifar10 | 60,000 | 32×32×3 | Classification | 
| Iris | load_iris | 150 | 4 | Classification | 
| Breast Cancer | load_breast_cancer | 569 | 30 | Classification | 
| Diabetes | load_diabetes | 442 | 10 | Regression | 
| California Housing | load_california_housing | 20,640 | 8 | Regression | 
| Airline Passengers | load_airline_passengers | 144 | 1 | Time Series | 
Synthetic Generators
| Generator | Function | Purpose | Parameters | 
|---|---|---|---|
| Gaussian Blobs | make_blobs | Clustering | centers, cluster_std | 
| Two Moons | make_moons | Non-linear classification | noise, n_samples | 
| Concentric Circles | make_circles | Non-linear classification | noise, factor | 
| Classification | make_classification | Controlled features | n_informative, n_redundant | 
| Regression | make_regression | Linear relationships | noise, n_features | 
| Friedman | make_friedman1/2/3 | Non-linear regression | - | 
| Swiss Roll | make_swiss_roll | Manifold learning | n_samples | 
| S-Curve | make_s_curve | Manifold learning | n_samples | 
Loading Real Datasets
Image Datasets
MNIST
Classic handwritten digits dataset:
let (x_train, y_train), (x_test, y_test) = Nx_datasets.load_mnist () in
Printf.printf "Train: %s, Test: %s\n" 
  (Nx.shape_to_string x_train) 
  (Nx.shape_to_string x_test)
(* Train: [60000, 28, 28, 1], Test: [10000, 28, 28, 1] *)
Images are uint8 arrays with values 0-255. Labels are single digits 0-9.
Fashion-MNIST
Clothing classification with the same format as MNIST:
let (x_train, y_train), (x_test, y_test) = Nx_datasets.load_fashion_mnist ()
(* 10 classes: T-shirt, Trouser, Pullover, Dress, Coat, Sandal, Shirt, Sneaker, Bag, Ankle boot *)
CIFAR-10
Color images in 10 categories:
let (x_train, y_train), (x_test, y_test) = Nx_datasets.load_cifar10 () in
(* x_train shape: [50000, 32, 32, 3] *)
(* Classes: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck *)
Tabular Datasets
Iris
Classic flower classification:
let x, y = Nx_datasets.load_iris () in
(* x shape: [150, 4] - sepal length/width, petal length/width *)
(* y shape: [150, 1] - 0=setosa, 1=versicolor, 2=virginica *)
Breast Cancer
Binary classification for cancer diagnosis:
let x, y = Nx_datasets.load_breast_cancer () in
(* x shape: [569, 30] - 30 features per sample *)
(* y shape: [569, 1] - 0=malignant, 1=benign *)
Regression Datasets
(* Diabetes regression *)
let x, y = Nx_datasets.load_diabetes () in
(* x: [442, 10], y: [442, 1] - diabetes progression *)
(* California housing prices *)
let x, y = Nx_datasets.load_california_housing () in
(* x: [20640, 8], y: [20640, 1] - median house values *)
Time Series
let passengers = Nx_datasets.load_airline_passengers () in
(* Monthly airline passenger counts 1949-1960 *)
(* shape: [144] *)
Generating Synthetic Data
Classification Datasets
Gaussian Blobs
Generate isotropic Gaussian blobs for clustering:
let x, y = Nx_datasets.make_blobs 
  ~n_samples:300 
  ~centers:(`N 3)
  ~cluster_std:0.5 
  () in
(* 3 well-separated clusters *)
Specify exact cluster centers:
let centers = Nx.of_array Nx.float32 ~shape:[|3; 2|] 
  [|-10.; -10.; 0.; 0.; 10.; 10.|] in
let x, y = Nx_datasets.make_blobs ~centers:(`Array centers) ()
Two Moons
Binary classification with interleaving half circles:
let x, y = Nx_datasets.make_moons 
  ~n_samples:200 
  ~noise:0.1 
  () in
(* Ideal for testing non-linear classifiers *)
Concentric Circles
Nested circles for non-linear separation:
let x, y = Nx_datasets.make_circles 
  ~n_samples:200 
  ~noise:0.05 
  ~factor:0.5  (* Inner circle radius ratio *)
  ()
Complex Classification
Control informative/redundant features:
let x, y = Nx_datasets.make_classification
  ~n_samples:1000
  ~n_features:20
  ~n_informative:15  (* Useful features *)
  ~n_redundant:5     (* Linear combinations *)
  ~n_classes:3
  ~n_clusters_per_class:2
  ()
Regression Datasets
Linear Regression
Generate data with controllable properties:
let x, y, coef_opt = Nx_datasets.make_regression
  ~n_samples:100
  ~n_features:5
  ~n_informative:3  (* Only 3 features affect output *)
  ~noise:10.0       (* Gaussian noise std dev *)
  ~coef:true        (* Return true coefficients *)
  ()
Friedman Benchmarks
Standard non-linear regression problems:
(* Friedman #1: y = 10*sin(π*x1*x2) + 20*(x3-0.5)² + 10*x4 + 5*x5 + noise *)
let x, y = Nx_datasets.make_friedman1 ~n_samples:100 ()
Manifold Data
Swiss Roll
3D manifold for dimensionality reduction:
let x, color = Nx_datasets.make_swiss_roll ~n_samples:1000 () in
(* x shape: [1000, 3], color: [1000] - position along roll *)
S-Curve
Another 3D manifold:
let x, color = Nx_datasets.make_s_curve ~n_samples:1000 ()