mirror of https://github.com/kubeflow/trainer.git
716 lines
118 KiB
Plaintext
716 lines
118 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "629c902e-6cd0-4475-b6ce-5d6e37d7e2f3",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Distributed MNIST with MLX and Kubeflow Trainer\n",
|
|
"\n",
|
|
"This Notebook will show how to run distributed MLX on Kubernetes with Kubeflow Trainer.\n",
|
|
"\n",
|
|
"We will use the MLX Runtime to created distributed training using OpenMPI and local minikube cluster.\n",
|
|
"\n",
|
|
"MLX Distributed: https://ml-explore.github.io/mlx/build/html/usage/distributed.html\n",
|
|
"\n",
|
|
"Minikube: https://minikube.sigs.k8s.io/docs/\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "34b60ad6-36d6-430f-992e-51440b0bed8f",
|
|
"metadata": {},
|
|
"source": [
|
|
"## (Optional) Create minikube cluster with shared volume\n",
|
|
"\n",
|
|
"This notebook exports the trained model to your local volume for evaluation.\n",
|
|
"\n",
|
|
"To do this, follow these steps:\n",
|
|
"\n",
|
|
"Create a Minikube cluster with a mounted path. The `mlx-model` folder needs to be accessible by the `mpiuser`:\n",
|
|
"\n",
|
|
"```sh\n",
|
|
"mkdir mlx-model\n",
|
|
"chmod 777 mlx-model\n",
|
|
"minikube start --cpus=8 --mount --mount-string=\"$(pwd):/mnt/data\"\n",
|
|
"```\n",
|
|
"\n",
|
|
"After that you can patch the ClusterTrainingRuntime using the following command:\n",
|
|
"\n",
|
|
"```yaml\n",
|
|
"kubectl patch clustertrainingruntime mlx-distributed --type='json' -p='[\n",
|
|
" {\n",
|
|
" \"op\": \"add\",\n",
|
|
" \"path\": \"/spec/template/spec/replicatedJobs/0/template/spec/template/spec/containers/0/volumeMounts\",\n",
|
|
" \"value\": [\n",
|
|
" {\n",
|
|
" \"name\": \"mlx-model\",\n",
|
|
" \"mountPath\": \"/home/mpiuser/mlx-model\"\n",
|
|
" }\n",
|
|
" ]\n",
|
|
" },\n",
|
|
" {\n",
|
|
" \"op\": \"add\",\n",
|
|
" \"path\": \"/spec/template/spec/replicatedJobs/0/template/spec/template/spec/volumes\",\n",
|
|
" \"value\": [\n",
|
|
" {\n",
|
|
" \"name\": \"mlx-model\",\n",
|
|
" \"hostPath\": {\n",
|
|
" \"path\": \"/mnt/data/mlx-model\"\n",
|
|
" }\n",
|
|
" }\n",
|
|
" ]\n",
|
|
" }\n",
|
|
"]'\n",
|
|
"```"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "b5ad0e01-d893-484c-8988-d25c2f322b80",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Create MLX training script\n",
|
|
"\n",
|
|
"We need to wrap our training script into a function to create Kubeflow TrainJob.\n",
|
|
"\n",
|
|
"This is the simple MLP model to recognize digits from the MNIST dataset."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "57952ce1-5752-4976-8a35-c71d25935b74",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def mlx_train_mnist(args):\n",
|
|
" import time\n",
|
|
" from functools import partial\n",
|
|
" import mlx.core as mx\n",
|
|
" import mlx.nn as nn\n",
|
|
" import mlx.optimizers as optim\n",
|
|
" from mlx.data.datasets import load_mnist\n",
|
|
"\n",
|
|
" # Define a simple MLP model with MLX.\n",
|
|
" class MLP(nn.Module):\n",
|
|
" def __init__(\n",
|
|
" self, in_dims: int, hidden_dims: int, num_layers: int, out_dims: int\n",
|
|
" ):\n",
|
|
" super().__init__()\n",
|
|
" layer_sizes = [in_dims] + [hidden_dims] * num_layers + [out_dims]\n",
|
|
" self.layers = [\n",
|
|
" nn.Linear(idim, odim)\n",
|
|
" for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])\n",
|
|
" ]\n",
|
|
"\n",
|
|
" def __call__(self, x):\n",
|
|
" for layer in self.layers[:-1]:\n",
|
|
" x = nn.relu(layer(x))\n",
|
|
" return self.layers[-1](x)\n",
|
|
"\n",
|
|
" # Try to initialize MLX distributed, otherwise run code in non-distributed.\n",
|
|
" try:\n",
|
|
" dist = mx.distributed.init(strict=True, backend=\"mpi\")\n",
|
|
" world_size = dist.size()\n",
|
|
" rank = dist.rank()\n",
|
|
" print(f\"Start Distributed Training, WORLD_SIZE: {world_size}, RANK: {rank}\")\n",
|
|
" except Exception:\n",
|
|
" world_size = 1\n",
|
|
" rank = 0\n",
|
|
" print(\"Start non-Distributed Training\")\n",
|
|
"\n",
|
|
" # Load MNIST dataset and partition it.\n",
|
|
" BATCH_SIZE = 128\n",
|
|
" train_dataset = load_mnist()\n",
|
|
"\n",
|
|
" distributed_ds = (\n",
|
|
" train_dataset.shuffle()\n",
|
|
" .partition_if(world_size > 1, world_size, rank)\n",
|
|
" .key_transform(\"image\", lambda x: (x.astype(\"float32\") / 255.0).ravel())\n",
|
|
" )\n",
|
|
"\n",
|
|
"\n",
|
|
" # Create the MLP model and SGD optimizer\n",
|
|
" model = MLP(\n",
|
|
" in_dims=distributed_ds[0][\"image\"].shape[-1],\n",
|
|
" hidden_dims=32,\n",
|
|
" num_layers=2,\n",
|
|
" out_dims=10,\n",
|
|
" )\n",
|
|
" optimizer = optim.SGD(learning_rate=0.01)\n",
|
|
"\n",
|
|
" # Define function to calculate loss and accuracy.\n",
|
|
" def loss_fn(model, x, y):\n",
|
|
" output = model(x)\n",
|
|
" loss = mx.mean(nn.losses.cross_entropy(output, y))\n",
|
|
" acc = mx.mean(mx.argmax(output, axis=1) == y)\n",
|
|
" return loss, acc\n",
|
|
"\n",
|
|
" # Define single training step.\n",
|
|
" @partial(mx.compile, inputs=model.state, outputs=model.state)\n",
|
|
" def step(x, y):\n",
|
|
" loss_and_grad_fn = nn.value_and_grad(model, loss_fn)\n",
|
|
" (loss, acc), grads = loss_and_grad_fn(model, x, y)\n",
|
|
" # Average grads to aggregate them across distributed nodes.\n",
|
|
" grads = nn.utils.average_gradients(grads)\n",
|
|
" optimizer.update(model, grads)\n",
|
|
" return loss, acc\n",
|
|
"\n",
|
|
" # Average statistic across distributed nodes.\n",
|
|
" def average_stats(stats, count):\n",
|
|
" with mx.stream(mx.cpu):\n",
|
|
" stats = mx.distributed.all_sum(mx.array(stats))\n",
|
|
" count = mx.distributed.all_sum(count)\n",
|
|
" return (stats / count).tolist()\n",
|
|
"\n",
|
|
" # Start distributed training.\n",
|
|
" for epoch in range(10):\n",
|
|
" epoch_start = time.perf_counter()\n",
|
|
" losses = accuracies = count = 0\n",
|
|
"\n",
|
|
" for batch_idx, batch_sample in enumerate(distributed_ds.batch(BATCH_SIZE)):\n",
|
|
" x = mx.array(batch_sample[\"image\"])\n",
|
|
" y = mx.array(batch_sample[\"label\"])\n",
|
|
" loss, acc = step(x, y)\n",
|
|
" mx.eval(loss, acc, model.state)\n",
|
|
"\n",
|
|
" losses += loss.item()\n",
|
|
" accuracies += acc.item()\n",
|
|
" count += 1\n",
|
|
"\n",
|
|
" # Print the results.\n",
|
|
" if batch_idx % 10 == 0:\n",
|
|
" loss, acc = average_stats([losses, accuracies],count)\n",
|
|
" if rank == 0:\n",
|
|
" print(\n",
|
|
" \"Epoch: {} [{}/{} ({:.0f}%)] \\tTrain loss: {:.3f}, acc: {:.3f}\".format(\n",
|
|
" epoch,\n",
|
|
" batch_idx * len(x),\n",
|
|
" len(train_dataset),\n",
|
|
" 100.0 * batch_idx * len(x) / len(train_dataset),\n",
|
|
" loss,\n",
|
|
" acc,\n",
|
|
" )\n",
|
|
" )\n",
|
|
" if rank == 0:\n",
|
|
" print(\n",
|
|
" \"Epoch: {}, time: {:.2f} seconds\\n\\n\".format(\n",
|
|
" epoch, time.perf_counter() - epoch_start\n",
|
|
" )\n",
|
|
" )\n",
|
|
" if rank == 0: \n",
|
|
" # Finally, save the trained model to the disk. \n",
|
|
" model.save_weights(args[\"MODEL_PATH\"])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "875abc44-69dd-41d8-bb2a-436f730a2dd0",
|
|
"metadata": {},
|
|
"source": [
|
|
"## List Available Kubeflow Trainer Runtimes\n",
|
|
"\n",
|
|
"\n",
|
|
"Get available Kubeflow Trainer Runtimes with the `list_runtimes()` API.\n",
|
|
"\n",
|
|
"You can inspect Runtime details, including the name, framework, entry point, and number of accelerators.\n",
|
|
"\n",
|
|
"- Runtimes with **CustomTrainer**: You must write the training script within the function.\n",
|
|
"\n",
|
|
"- Runtimes with **BuiltinTrainer**: You can configure settings (e.g., LoRA Config) for LLM fine-tuning Job.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "51d8bc1d-8d9b-48f7-866f-c6ad4ad4241b",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Name: mlx-distributed, Framework: mlx, Trainer Type: CustomTrainer\n",
|
|
"\n",
|
|
"Entrypoint: ['mpirun', '--hostfile', '/etc/mpi/hostfile']\n",
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from kubeflow.trainer import TrainerClient, CustomTrainer\n",
|
|
"\n",
|
|
"for r in TrainerClient().list_runtimes():\n",
|
|
" print(f\"Name: {r.name}, Framework: {r.trainer.framework.value}, Trainer Type: {r.trainer.trainer_type.value}\\n\")\n",
|
|
" print(f\"Entrypoint: {r.trainer.entrypoint[:3]}\\n\")\n",
|
|
"\n",
|
|
" if r.name == \"mlx-distributed\":\n",
|
|
" mlx_runtime = r"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "206c389d-9dcb-474f-b0f5-cc098de2e183",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Create TrainJob for Distributed Training\n",
|
|
"\n",
|
|
"Use the `train()` API to create distributed TrainJob on 2 MPI Nodes."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "76dab189-f184-4e48-be74-f32c0dea675b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# The `mlx-model` folder must be created.\n",
|
|
"MODEL_PATH = \"mlx-model/model.npz\"\n",
|
|
"args = {\n",
|
|
" \"MODEL_PATH\": MODEL_PATH\n",
|
|
"}\n",
|
|
"\n",
|
|
"job_id = TrainerClient().train(\n",
|
|
" trainer=CustomTrainer(\n",
|
|
" func=mlx_train_mnist,\n",
|
|
" func_args=args,\n",
|
|
" num_nodes=3,\n",
|
|
" ),\n",
|
|
" runtime=mlx_runtime,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "6e1811dd-4bf2-40cf-ad35-35a87271eb21",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"'y5e1863ee6de'"
|
|
]
|
|
},
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# Train API generates a random TrainJob id.\n",
|
|
"job_id"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "f0864121-895f-4e5a-87b8-7ea2d92e6630",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Check the TrainJob Info\n",
|
|
"\n",
|
|
"Use the `list_jobs()` and `get_job()` APIs to get information about created TrainJob and its steps."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "8a945930-6cfe-4388-8ab8-462474f3f21f",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"TrainJob: b85d5f28b0ea, Status: Succeeded, Created at: 2025-03-26 17:10:13+00:00\n",
|
|
"TrainJob: g02078f3e66f, Status: Succeeded, Created at: 2025-03-26 17:14:49+00:00\n",
|
|
"TrainJob: y5e1863ee6de, Status: Created, Created at: 2025-03-26 17:18:18+00:00\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"for job in TrainerClient().list_jobs():\n",
|
|
" print(f\"TrainJob: {job.name}, Status: {job.status}, Created at: {job.creation_timestamp}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "75e8c741-6d31-49a2-8667-d38e40d62430",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Step: node-0, Status: Running, Devices: cpu x 2\n",
|
|
"\n",
|
|
"Step: node-1, Status: Running, Devices: cpu x 2\n",
|
|
"\n",
|
|
"Step: node-2, Status: Running, Devices: cpu x 2\n",
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# We execute mpirun command on node-0, which functions as the MPI Launcher node.\n",
|
|
"for c in TrainerClient().get_job(name=job_id).steps:\n",
|
|
" print(f\"Step: {c.name}, Status: {c.status}, Devices: {c.device} x {c.device_count}\\n\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "d4fa3ea2-bdf5-40ca-a9ec-cc948949ea59",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Get the TrainJob Logs\n",
|
|
"\n",
|
|
"Use the `get_job_logs()` API to retrieve the TrainJob logs.\n",
|
|
"\n",
|
|
"Since we distribute the dataset across 3 nodes, each rank processes `round(60,000 / 3) = 20,000` samples."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "6e630fd3-f061-4fea-8024-7bffcefb257c",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[node-0]: Warning: Permanently added '[y5e1863ee6de-node-0-1.y5e1863ee6de]:2222' (ECDSA) to the list of known hosts.\n",
|
|
"[node-0]: Warning: Permanently added '[y5e1863ee6de-node-0-0.y5e1863ee6de]:2222' (ECDSA) to the list of known hosts.\n",
|
|
"[node-0]: Start Distributed Training, WORLD_SIZE: 3, RANK: 1\n",
|
|
"[node-0]: Start Distributed Training, WORLD_SIZE: 3, RANK: 0\n",
|
|
"[node-0]: Start Distributed Training, WORLD_SIZE: 3, RANK: 2\n",
|
|
"Downloading https://raw.githubusercontent.com/fgnt/mnist/master/train-images-idx3-ubyte.gz 9.5MiB (12.6MiB/s) \n",
|
|
"Downloading https://raw.githubusercontent.com/fgnt/mnist/master/t10k-images-idx3-ubyte.gz 1.6MiB (10.1MiB/s) \n",
|
|
"Downloading https://raw.githubusercontent.com/fgnt/mnist/master/train-images-idx3-ubyte.gz 9.5MiB (7.5MiB/s) \n",
|
|
"Downloading https://raw.githubusercontent.com/fgnt/mnist/master/train-labels-idx1-ubyte.gz 32.0KiB (30.9MiB/s) \n",
|
|
"Downloading https://raw.githubusercontent.com/fgnt/mnist/master/t10k-images-idx3-ubyte.gz 1.6MiB (13.7MiB/s) \n",
|
|
"Downloading https://raw.githubusercontent.com/fgnt/mnist/master/train-labels-idx1-ubyte.gz 32.0KiB (18.5MiB/s) \n",
|
|
"Downloading https://raw.githubusercontent.com/fgnt/mnist/master/t10k-labels-idx1-ubyte.gz 8.0KiB (47.4MiB/s) iB/s) \n",
|
|
"Downloading https://raw.githubusercontent.com/fgnt/mnist/master/t10k-labels-idx1-ubyte.gz 8.0KiB (62.9MiB/s) iB/s) \n",
|
|
"Downloading https://raw.githubusercontent.com/fgnt/mnist/master/train-images-idx3-ubyte.gz 9.5MiB (5.9MiB/s) \n",
|
|
"Downloading https://raw.githubusercontent.com/fgnt/mnist/master/t10k-images-idx3-ubyte.gz 1.6MiB (18.6MiB/s) \n",
|
|
"Downloading https://raw.githubusercontent.com/fgnt/mnist/master/train-labels-idx1-ubyte.gz 32.0KiB (4.5MiB/s) \n",
|
|
"Downloading https://raw.githubusercontent.com/fgnt/mnist/master/t10k-labels-idx1-ubyte.gz 8.0KiB (1.8MiB/s) \n",
|
|
"[node-0]: Epoch: 0 [0/60000 (0%)] \tTrain loss: 2.301, acc: 0.115\n",
|
|
"[node-0]: Epoch: 0 [1280/60000 (2%)] \tTrain loss: 2.306, acc: 0.100\n",
|
|
"[node-0]: Epoch: 0 [2560/60000 (4%)] \tTrain loss: 2.307, acc: 0.096\n",
|
|
"[node-0]: Epoch: 0 [3840/60000 (6%)] \tTrain loss: 2.306, acc: 0.098\n",
|
|
"[node-0]: Epoch: 0 [5120/60000 (9%)] \tTrain loss: 2.306, acc: 0.097\n",
|
|
"[node-0]: Epoch: 0 [6400/60000 (11%)] \tTrain loss: 2.305, acc: 0.098\n",
|
|
"[node-0]: Epoch: 0 [7680/60000 (13%)] \tTrain loss: 2.305, acc: 0.097\n",
|
|
"[node-0]: Epoch: 0 [8960/60000 (15%)] \tTrain loss: 2.305, acc: 0.098\n",
|
|
"[node-0]: Epoch: 0 [10240/60000 (17%)] \tTrain loss: 2.304, acc: 0.099\n",
|
|
"[node-0]: Epoch: 0 [11520/60000 (19%)] \tTrain loss: 2.303, acc: 0.099\n",
|
|
"[node-0]: Epoch: 0 [12800/60000 (21%)] \tTrain loss: 2.303, acc: 0.100\n",
|
|
"[node-0]: Epoch: 0 [14080/60000 (23%)] \tTrain loss: 2.302, acc: 0.100\n",
|
|
"[node-0]: Epoch: 0 [15360/60000 (26%)] \tTrain loss: 2.302, acc: 0.100\n",
|
|
"[node-0]: Epoch: 0 [16640/60000 (28%)] \tTrain loss: 2.301, acc: 0.100\n",
|
|
"[node-0]: Epoch: 0 [17920/60000 (30%)] \tTrain loss: 2.300, acc: 0.101\n",
|
|
"[node-0]: Epoch: 0 [19200/60000 (32%)] \tTrain loss: 2.300, acc: 0.102\n",
|
|
"[node-0]: Epoch: 0, time: 16.08 seconds\n",
|
|
"[node-0]: Epoch: 1 [0/60000 (0%)] \tTrain loss: 2.285, acc: 0.130\n",
|
|
"[node-0]: Epoch: 1 [1280/60000 (2%)] \tTrain loss: 2.289, acc: 0.121\n",
|
|
"[node-0]: Epoch: 1 [2560/60000 (4%)] \tTrain loss: 2.290, acc: 0.116\n",
|
|
"[node-0]: Epoch: 1 [3840/60000 (6%)] \tTrain loss: 2.288, acc: 0.120\n",
|
|
"[node-0]: Epoch: 1 [5120/60000 (9%)] \tTrain loss: 2.288, acc: 0.120\n",
|
|
"[node-0]: Epoch: 1 [6400/60000 (11%)] \tTrain loss: 2.287, acc: 0.121\n",
|
|
"[node-0]: Epoch: 1 [7680/60000 (13%)] \tTrain loss: 2.287, acc: 0.123\n",
|
|
"[node-0]: Epoch: 1 [8960/60000 (15%)] \tTrain loss: 2.286, acc: 0.125\n",
|
|
"[node-0]: Epoch: 1 [10240/60000 (17%)] \tTrain loss: 2.285, acc: 0.126\n",
|
|
"[node-0]: Epoch: 1 [11520/60000 (19%)] \tTrain loss: 2.285, acc: 0.129\n",
|
|
"[node-0]: Epoch: 1 [12800/60000 (21%)] \tTrain loss: 2.284, acc: 0.131\n",
|
|
"[node-0]: Epoch: 1 [14080/60000 (23%)] \tTrain loss: 2.283, acc: 0.134\n",
|
|
"[node-0]: Epoch: 1 [15360/60000 (26%)] \tTrain loss: 2.282, acc: 0.136\n",
|
|
"[node-0]: Epoch: 1 [16640/60000 (28%)] \tTrain loss: 2.281, acc: 0.139\n",
|
|
"[node-0]: Epoch: 1 [17920/60000 (30%)] \tTrain loss: 2.280, acc: 0.141\n",
|
|
"[node-0]: Epoch: 1 [19200/60000 (32%)] \tTrain loss: 2.279, acc: 0.144\n",
|
|
"[node-0]: Epoch: 1, time: 6.10 seconds\n",
|
|
"[node-0]: Epoch: 2 [0/60000 (0%)] \tTrain loss: 2.261, acc: 0.195\n",
|
|
"[node-0]: Epoch: 2 [1280/60000 (2%)] \tTrain loss: 2.263, acc: 0.189\n",
|
|
"[node-0]: Epoch: 2 [2560/60000 (4%)] \tTrain loss: 2.264, acc: 0.185\n",
|
|
"[node-0]: Epoch: 2 [3840/60000 (6%)] \tTrain loss: 2.262, acc: 0.190\n",
|
|
"[node-0]: Epoch: 2 [5120/60000 (9%)] \tTrain loss: 2.261, acc: 0.190\n",
|
|
"[node-0]: Epoch: 2 [6400/60000 (11%)] \tTrain loss: 2.260, acc: 0.190\n",
|
|
"[node-0]: Epoch: 2 [7680/60000 (13%)] \tTrain loss: 2.259, acc: 0.192\n",
|
|
"[node-0]: Epoch: 2 [8960/60000 (15%)] \tTrain loss: 2.258, acc: 0.193\n",
|
|
"[node-0]: Epoch: 2 [10240/60000 (17%)] \tTrain loss: 2.257, acc: 0.194\n",
|
|
"[node-0]: Epoch: 2 [11520/60000 (19%)] \tTrain loss: 2.256, acc: 0.196\n",
|
|
"[node-0]: Epoch: 2 [12800/60000 (21%)] \tTrain loss: 2.255, acc: 0.197\n",
|
|
"[node-0]: Epoch: 2 [14080/60000 (23%)] \tTrain loss: 2.253, acc: 0.199\n",
|
|
"[node-0]: Epoch: 2 [15360/60000 (26%)] \tTrain loss: 2.252, acc: 0.200\n",
|
|
"[node-0]: Epoch: 2 [16640/60000 (28%)] \tTrain loss: 2.251, acc: 0.202\n",
|
|
"[node-0]: Epoch: 2 [17920/60000 (30%)] \tTrain loss: 2.249, acc: 0.204\n",
|
|
"[node-0]: Epoch: 2 [19200/60000 (32%)] \tTrain loss: 2.248, acc: 0.206\n",
|
|
"[node-0]: Epoch: 2, time: 6.07 seconds\n",
|
|
"[node-0]: Epoch: 3 [0/60000 (0%)] \tTrain loss: 2.224, acc: 0.237\n",
|
|
"[node-0]: Epoch: 3 [1280/60000 (2%)] \tTrain loss: 2.225, acc: 0.245\n",
|
|
"[node-0]: Epoch: 3 [2560/60000 (4%)] \tTrain loss: 2.225, acc: 0.243\n",
|
|
"[node-0]: Epoch: 3 [3840/60000 (6%)] \tTrain loss: 2.221, acc: 0.252\n",
|
|
"[node-0]: Epoch: 3 [5120/60000 (9%)] \tTrain loss: 2.220, acc: 0.256\n",
|
|
"[node-0]: Epoch: 3 [6400/60000 (11%)] \tTrain loss: 2.218, acc: 0.258\n",
|
|
"[node-0]: Epoch: 3 [7680/60000 (13%)] \tTrain loss: 2.217, acc: 0.262\n",
|
|
"[node-0]: Epoch: 3 [8960/60000 (15%)] \tTrain loss: 2.215, acc: 0.263\n",
|
|
"[node-0]: Epoch: 3 [10240/60000 (17%)] \tTrain loss: 2.213, acc: 0.264\n",
|
|
"[node-0]: Epoch: 3 [11520/60000 (19%)] \tTrain loss: 2.212, acc: 0.267\n",
|
|
"[node-0]: Epoch: 3 [12800/60000 (21%)] \tTrain loss: 2.210, acc: 0.269\n",
|
|
"[node-0]: Epoch: 3 [14080/60000 (23%)] \tTrain loss: 2.207, acc: 0.272\n",
|
|
"[node-0]: Epoch: 3 [15360/60000 (26%)] \tTrain loss: 2.206, acc: 0.274\n",
|
|
"[node-0]: Epoch: 3 [16640/60000 (28%)] \tTrain loss: 2.203, acc: 0.277\n",
|
|
"[node-0]: Epoch: 3 [17920/60000 (30%)] \tTrain loss: 2.201, acc: 0.280\n",
|
|
"[node-0]: Epoch: 3 [19200/60000 (32%)] \tTrain loss: 2.199, acc: 0.283\n",
|
|
"[node-0]: Epoch: 3, time: 6.04 seconds\n",
|
|
"[node-0]: Epoch: 4 [0/60000 (0%)] \tTrain loss: 2.163, acc: 0.326\n",
|
|
"[node-0]: Epoch: 4 [1280/60000 (2%)] \tTrain loss: 2.161, acc: 0.329\n",
|
|
"[node-0]: Epoch: 4 [2560/60000 (4%)] \tTrain loss: 2.160, acc: 0.333\n",
|
|
"[node-0]: Epoch: 4 [3840/60000 (6%)] \tTrain loss: 2.155, acc: 0.340\n",
|
|
"[node-0]: Epoch: 4 [5120/60000 (9%)] \tTrain loss: 2.152, acc: 0.345\n",
|
|
"[node-0]: Epoch: 4 [6400/60000 (11%)] \tTrain loss: 2.149, acc: 0.344\n",
|
|
"[node-0]: Epoch: 4 [7680/60000 (13%)] \tTrain loss: 2.146, acc: 0.345\n",
|
|
"[node-0]: Epoch: 4 [8960/60000 (15%)] \tTrain loss: 2.143, acc: 0.345\n",
|
|
"[node-0]: Epoch: 4 [10240/60000 (17%)] \tTrain loss: 2.140, acc: 0.344\n",
|
|
"[node-0]: Epoch: 4 [11520/60000 (19%)] \tTrain loss: 2.137, acc: 0.345\n",
|
|
"[node-0]: Epoch: 4 [12800/60000 (21%)] \tTrain loss: 2.134, acc: 0.345\n",
|
|
"[node-0]: Epoch: 4 [14080/60000 (23%)] \tTrain loss: 2.130, acc: 0.346\n",
|
|
"[node-0]: Epoch: 4 [15360/60000 (26%)] \tTrain loss: 2.127, acc: 0.346\n",
|
|
"[node-0]: Epoch: 4 [16640/60000 (28%)] \tTrain loss: 2.124, acc: 0.346\n",
|
|
"[node-0]: Epoch: 4 [17920/60000 (30%)] \tTrain loss: 2.120, acc: 0.347\n",
|
|
"[node-0]: Epoch: 4 [19200/60000 (32%)] \tTrain loss: 2.116, acc: 0.348\n",
|
|
"[node-0]: Epoch: 4, time: 5.87 seconds\n",
|
|
"[node-0]: Epoch: 5 [0/60000 (0%)] \tTrain loss: 2.061, acc: 0.346\n",
|
|
"[node-0]: Epoch: 5 [1280/60000 (2%)] \tTrain loss: 2.057, acc: 0.364\n",
|
|
"[node-0]: Epoch: 5 [2560/60000 (4%)] \tTrain loss: 2.053, acc: 0.367\n",
|
|
"[node-0]: Epoch: 5 [3840/60000 (6%)] \tTrain loss: 2.044, acc: 0.374\n",
|
|
"[node-0]: Epoch: 5 [5120/60000 (9%)] \tTrain loss: 2.039, acc: 0.377\n",
|
|
"[node-0]: Epoch: 5 [6400/60000 (11%)] \tTrain loss: 2.035, acc: 0.378\n",
|
|
"[node-0]: Epoch: 5 [7680/60000 (13%)] \tTrain loss: 2.030, acc: 0.377\n",
|
|
"[node-0]: Epoch: 5 [8960/60000 (15%)] \tTrain loss: 2.026, acc: 0.379\n",
|
|
"[node-0]: Epoch: 5 [10240/60000 (17%)] \tTrain loss: 2.022, acc: 0.379\n",
|
|
"[node-0]: Epoch: 5 [11520/60000 (19%)] \tTrain loss: 2.018, acc: 0.381\n",
|
|
"[node-0]: Epoch: 5 [12800/60000 (21%)] \tTrain loss: 2.013, acc: 0.381\n",
|
|
"[node-0]: Epoch: 5 [14080/60000 (23%)] \tTrain loss: 2.007, acc: 0.384\n",
|
|
"[node-0]: Epoch: 5 [15360/60000 (26%)] \tTrain loss: 2.003, acc: 0.385\n",
|
|
"[node-0]: Epoch: 5 [16640/60000 (28%)] \tTrain loss: 1.997, acc: 0.387\n",
|
|
"[node-0]: Epoch: 5 [17920/60000 (30%)] \tTrain loss: 1.991, acc: 0.389\n",
|
|
"[node-0]: Epoch: 5 [19200/60000 (32%)] \tTrain loss: 1.986, acc: 0.392\n",
|
|
"[node-0]: Epoch: 5, time: 6.06 seconds\n",
|
|
"[node-0]: Epoch: 6 [0/60000 (0%)] \tTrain loss: 1.895, acc: 0.404\n",
|
|
"[node-0]: Epoch: 6 [1280/60000 (2%)] \tTrain loss: 1.895, acc: 0.429\n",
|
|
"[node-0]: Epoch: 6 [2560/60000 (4%)] \tTrain loss: 1.888, acc: 0.431\n",
|
|
"[node-0]: Epoch: 6 [3840/60000 (6%)] \tTrain loss: 1.876, acc: 0.441\n",
|
|
"[node-0]: Epoch: 6 [5120/60000 (9%)] \tTrain loss: 1.869, acc: 0.443\n",
|
|
"[node-0]: Epoch: 6 [6400/60000 (11%)] \tTrain loss: 1.862, acc: 0.444\n",
|
|
"[node-0]: Epoch: 6 [7680/60000 (13%)] \tTrain loss: 1.855, acc: 0.444\n",
|
|
"[node-0]: Epoch: 6 [8960/60000 (15%)] \tTrain loss: 1.850, acc: 0.445\n",
|
|
"[node-0]: Epoch: 6 [10240/60000 (17%)] \tTrain loss: 1.843, acc: 0.446\n",
|
|
"[node-0]: Epoch: 6 [11520/60000 (19%)] \tTrain loss: 1.838, acc: 0.448\n",
|
|
"[node-0]: Epoch: 6 [12800/60000 (21%)] \tTrain loss: 1.831, acc: 0.450\n",
|
|
"[node-0]: Epoch: 6 [14080/60000 (23%)] \tTrain loss: 1.823, acc: 0.453\n",
|
|
"[node-0]: Epoch: 6 [15360/60000 (26%)] \tTrain loss: 1.816, acc: 0.455\n",
|
|
"[node-0]: Epoch: 6 [16640/60000 (28%)] \tTrain loss: 1.809, acc: 0.458\n",
|
|
"[node-0]: Epoch: 6 [17920/60000 (30%)] \tTrain loss: 1.801, acc: 0.461\n",
|
|
"[node-0]: Epoch: 6 [19200/60000 (32%)] \tTrain loss: 1.793, acc: 0.463\n",
|
|
"[node-0]: Epoch: 6, time: 6.54 seconds\n",
|
|
"[node-0]: Epoch: 7 [0/60000 (0%)] \tTrain loss: 1.653, acc: 0.508\n",
|
|
"[node-0]: Epoch: 7 [1280/60000 (2%)] \tTrain loss: 1.664, acc: 0.501\n",
|
|
"[node-0]: Epoch: 7 [2560/60000 (4%)] \tTrain loss: 1.655, acc: 0.503\n",
|
|
"[node-0]: Epoch: 7 [3840/60000 (6%)] \tTrain loss: 1.641, acc: 0.512\n",
|
|
"[node-0]: Epoch: 7 [5120/60000 (9%)] \tTrain loss: 1.633, acc: 0.514\n",
|
|
"[node-0]: Epoch: 7 [6400/60000 (11%)] \tTrain loss: 1.624, acc: 0.516\n",
|
|
"[node-0]: Epoch: 7 [7680/60000 (13%)] \tTrain loss: 1.616, acc: 0.516\n",
|
|
"[node-0]: Epoch: 7 [8960/60000 (15%)] \tTrain loss: 1.610, acc: 0.517\n",
|
|
"[node-0]: Epoch: 7 [10240/60000 (17%)] \tTrain loss: 1.602, acc: 0.518\n",
|
|
"[node-0]: Epoch: 7 [11520/60000 (19%)] \tTrain loss: 1.596, acc: 0.519\n",
|
|
"[node-0]: Epoch: 7 [12800/60000 (21%)] \tTrain loss: 1.588, acc: 0.521\n",
|
|
"[node-0]: Epoch: 7 [14080/60000 (23%)] \tTrain loss: 1.579, acc: 0.524\n",
|
|
"[node-0]: Epoch: 7 [15360/60000 (26%)] \tTrain loss: 1.572, acc: 0.525\n",
|
|
"[node-0]: Epoch: 7 [16640/60000 (28%)] \tTrain loss: 1.564, acc: 0.527\n",
|
|
"[node-0]: Epoch: 7 [17920/60000 (30%)] \tTrain loss: 1.555, acc: 0.529\n",
|
|
"[node-0]: Epoch: 7 [19200/60000 (32%)] \tTrain loss: 1.547, acc: 0.531\n",
|
|
"[node-0]: Epoch: 7, time: 5.89 seconds\n",
|
|
"[node-0]: Epoch: 8 [0/60000 (0%)] \tTrain loss: 1.388, acc: 0.578\n",
|
|
"[node-0]: Epoch: 8 [1280/60000 (2%)] \tTrain loss: 1.412, acc: 0.552\n",
|
|
"[node-0]: Epoch: 8 [2560/60000 (4%)] \tTrain loss: 1.404, acc: 0.556\n",
|
|
"[node-0]: Epoch: 8 [3840/60000 (6%)] \tTrain loss: 1.390, acc: 0.564\n",
|
|
"[node-0]: Epoch: 8 [5120/60000 (9%)] \tTrain loss: 1.384, acc: 0.565\n",
|
|
"[node-0]: Epoch: 8 [6400/60000 (11%)] \tTrain loss: 1.375, acc: 0.567\n",
|
|
"[node-0]: Epoch: 8 [7680/60000 (13%)] \tTrain loss: 1.367, acc: 0.569\n",
|
|
"[node-0]: Epoch: 8 [8960/60000 (15%)] \tTrain loss: 1.362, acc: 0.571\n",
|
|
"[node-0]: Epoch: 8 [10240/60000 (17%)] \tTrain loss: 1.356, acc: 0.572\n",
|
|
"[node-0]: Epoch: 8 [11520/60000 (19%)] \tTrain loss: 1.351, acc: 0.573\n",
|
|
"[node-0]: Epoch: 8 [12800/60000 (21%)] \tTrain loss: 1.344, acc: 0.576\n",
|
|
"[node-0]: Epoch: 8 [14080/60000 (23%)] \tTrain loss: 1.336, acc: 0.579\n",
|
|
"[node-0]: Epoch: 8 [15360/60000 (26%)] \tTrain loss: 1.330, acc: 0.581\n",
|
|
"[node-0]: Epoch: 8 [16640/60000 (28%)] \tTrain loss: 1.323, acc: 0.583\n",
|
|
"[node-0]: Epoch: 8 [17920/60000 (30%)] \tTrain loss: 1.315, acc: 0.585\n",
|
|
"[node-0]: Epoch: 8 [19200/60000 (32%)] \tTrain loss: 1.308, acc: 0.587\n",
|
|
"[node-0]: Epoch: 8, time: 5.90 seconds\n",
|
|
"[node-0]: Epoch: 9 [0/60000 (0%)] \tTrain loss: 1.161, acc: 0.628\n",
|
|
"[node-0]: Epoch: 9 [1280/60000 (2%)] \tTrain loss: 1.192, acc: 0.616\n",
|
|
"[node-0]: Epoch: 9 [2560/60000 (4%)] \tTrain loss: 1.184, acc: 0.620\n",
|
|
"[node-0]: Epoch: 9 [3840/60000 (6%)] \tTrain loss: 1.171, acc: 0.626\n",
|
|
"[node-0]: Epoch: 9 [5120/60000 (9%)] \tTrain loss: 1.167, acc: 0.627\n",
|
|
"[node-0]: Epoch: 9 [6400/60000 (11%)] \tTrain loss: 1.159, acc: 0.631\n",
|
|
"[node-0]: Epoch: 9 [7680/60000 (13%)] \tTrain loss: 1.151, acc: 0.634\n",
|
|
"[node-0]: Epoch: 9 [8960/60000 (15%)] \tTrain loss: 1.147, acc: 0.637\n",
|
|
"[node-0]: Epoch: 9 [10240/60000 (17%)] \tTrain loss: 1.141, acc: 0.639\n",
|
|
"[node-0]: Epoch: 9 [11520/60000 (19%)] \tTrain loss: 1.138, acc: 0.641\n",
|
|
"[node-0]: Epoch: 9 [12800/60000 (21%)] \tTrain loss: 1.132, acc: 0.644\n",
|
|
"[node-0]: Epoch: 9 [14080/60000 (23%)] \tTrain loss: 1.125, acc: 0.648\n",
|
|
"[node-0]: Epoch: 9 [15360/60000 (26%)] \tTrain loss: 1.120, acc: 0.650\n",
|
|
"[node-0]: Epoch: 9 [16640/60000 (28%)] \tTrain loss: 1.114, acc: 0.653\n",
|
|
"[node-0]: Epoch: 9 [17920/60000 (30%)] \tTrain loss: 1.107, acc: 0.656\n",
|
|
"[node-0]: Epoch: 9 [19200/60000 (32%)] \tTrain loss: 1.101, acc: 0.659\n",
|
|
"[node-0]: Epoch: 9, time: 6.25 seconds\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"_ = TrainerClient().get_job_logs(name=job_id, follow=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "030ea83c-9b1c-477a-a3f3-72d659066678",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Evaluate the Trained Model\n",
|
|
"\n",
|
|
"Since the volume is shared between the Minikube cluster and the local directory, you can evaluate the trained model directly.\n",
|
|
"\n",
|
|
"We will use test images from the MNIST dataset for prediction.\n",
|
|
"\n",
|
|
"- <span style=\"color:green\">Green label</span> indicate correct predictions.\n",
|
|
"- <span style=\"color:red\">Red label</span> indicate incorrect predictions, with the correct value shown in parentheses."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"id": "b11d5623-3425-41ea-87e7-ec6992540994",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/png": "",
|
|
"text/plain": [
|
|
"<Figure size 1500x500 with 20 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"import mlx.core as mx\n",
|
|
"import mlx.nn as nn\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"from mlx.data.datasets import load_mnist\n",
|
|
"\n",
|
|
"\n",
|
|
"# Load test dataset and take the random batch from it.\n",
|
|
"test_batch = (\n",
|
|
" load_mnist(train=False)\n",
|
|
" .key_transform(\"image\", lambda x: (x.astype(\"float32\") / 255.0).ravel())\n",
|
|
" .batch(20)[mx.random.randint(10, 500)]\n",
|
|
")\n",
|
|
"\n",
|
|
"class MLP(nn.Module):\n",
|
|
" def __init__(self, in_dims: int, hidden_dims: int, num_layers: int, out_dims: int):\n",
|
|
" super().__init__()\n",
|
|
" layer_sizes = [in_dims] + [hidden_dims] * num_layers + [out_dims]\n",
|
|
" self.layers = [\n",
|
|
" nn.Linear(idim, odim)\n",
|
|
" for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])\n",
|
|
" ]\n",
|
|
"\n",
|
|
" def __call__(self, x):\n",
|
|
" for layer in self.layers[:-1]:\n",
|
|
" x = nn.relu(layer(x))\n",
|
|
" return self.layers[-1](x)\n",
|
|
"\n",
|
|
"# Load weights from the trained model.\n",
|
|
"model = MLP(\n",
|
|
" in_dims=test_batch[\"image\"][0].shape[-1],\n",
|
|
" hidden_dims=32,\n",
|
|
" num_layers=2,\n",
|
|
" out_dims=10,\n",
|
|
").load_weights(MODEL_PATH)\n",
|
|
"\n",
|
|
"# Send test batch to the pre-trained MLP model.\n",
|
|
"x = mx.array(test_batch[\"image\"])\n",
|
|
"output = model(x)\n",
|
|
"fig = plt.figure(figsize=(15, 5))\n",
|
|
"for i in range(20):\n",
|
|
" # Format the input image and the model output.\n",
|
|
" image = test_batch[\"image\"][i].reshape((28, 28))\n",
|
|
" pred_label = mx.argmax(output[i])\n",
|
|
"\n",
|
|
" # Add data to the plot.\n",
|
|
" ax = fig.add_subplot(4, 5, i + 1, xticks=[], yticks=[])\n",
|
|
" ax.imshow(image, cmap=\"gray\")\n",
|
|
" if test_batch[\"label\"][i] == pred_label:\n",
|
|
" ax.set_title(test_batch[\"label\"][i], color=\"green\")\n",
|
|
" else:\n",
|
|
" ax.set_title(\"{} ({})\".format(pred_label, test_batch[\"label\"][i]), color=\"red\")\n",
|
|
"\n",
|
|
" ax.title.set_fontsize(20)\n",
|
|
" fig.tight_layout()\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "673897de-adfd-4c29-84ba-4c0f81b8d25b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "cda451fb-b96a-42da-aebf-59a1d6c3ec35",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "a6f1983c-243a-4e6d-aa9c-e97fbed2bfcf",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.13.2"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|