Cloud TPU
Contents
Cloud TPU¶
SkyPilot supports running jobs on Google’s Cloud TPU, a specialized hardware accelerator for ML workloads.
Free TPUs via TPU Research Cloud (TRC)¶
ML researchers and students are encouraged to apply for free TPU access through TPU Research Cloud (TRC) program!
Getting TPUs in one command¶
Like GPUs, SkyPilot provides a simple command to quickly get TPUs for development:
sky tpunode # By default TPU v2-8 is used
sky tpunode --use-spot # Preemptible TPUs
sky tpunode --tpus tpu-v3-8 # Change TPU type to tpu-v3-8
sky tpunode --instance-type n1-highmem-16 # Change the host VM type to n1-highmem-16
sky tpunode --tpu-vm # Use TPU VM (instead of TPU Node)
After the command finishes, you will be dropped into a TPU host VM and can start developing code right away.
Below, we show examples of using SkyPilot to run MNIST training on (1) TPU VMs and (2) TPU Nodes.
TPU Architectures¶
Two different TPU architectures are available on GCP:
Both are supported by SkyPilot. We recommend TPU VMs which is a newer architecture encouraged by GCP.
The two architectures differ as follows. For TPU VMs, you can directly SSH into the “TPU host” VM that is physically connected to the TPU device. For TPU Nodes, a user VM (an n1 instance) must be separately provisioned to communicate with an inaccessible TPU host over gRPC. More details can be found on GCP documentation.
TPU VMs¶
To use TPU VMs, set the following in a task YAML’s resources
field:
resources:
accelerators: tpu-v2-8
accelerator_args:
tpu_vm: True
runtime_version: tpu-vm-base # optional
The accelerators
field specifies the TPU type, and the accelerator_args
dict includes the tpu_vm
bool (defaults to false, which means TPU Node is used), and an optional TPU runtime_version
field.
To show what TPU types are supported, run sky show-gpus
.
Here is a complete task YAML that runs MNIST training on a TPU VM using JAX.
name: mnist-tpu-vm
resources:
accelerators: tpu-v2-8
accelerator_args:
tpu_vm: True
runtime_version: tpu-vm-base
setup: |
git clone https://github.com/google/flax.git
conda activate flax
if [ $? -eq 0 ]; then
echo 'conda env exists'
else
conda create -n flax python=3.8 -y
conda activate flax
# Make sure to install TPU related packages in a conda env to avoid package conflicts.
pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install --upgrade clu
pip install -e flax
pip install tensorflow tensorflow-datasets
fi
run: |
conda activate flax
cd flax/examples/mnist
python3 main.py --workdir=/tmp/mnist \
--config=configs/default.py \
--config.learning_rate=0.05 \
--config.num_epochs=10
This YAML lives under the SkyPilot repo (examples/tpu/tpuvm_mnist.yaml
), or you can paste it into a local file.
Launch it with:
$ sky launch examples/tpu/tpuvm_mnist.yaml -c mycluster
You should see the following outputs when the job finishes.
$ sky launch examples/tpu/tpuvm_mnist.yaml -c mycluster
...
(mnist-tpu-vm pid=10155) I0823 07:49:25.468526 139641357117440 train.py:146] epoch: 9, train_loss: 0.0120, train_accuracy: 99.64, test_loss: 0.0278, test_accuracy: 99.02
(mnist-tpu-vm pid=10155) I0823 07:49:26.966874 139641357117440 train.py:146] epoch: 10, train_loss: 0.0095, train_accuracy: 99.73, test_loss: 0.0264, test_accuracy: 99.19
TPU Nodes¶
In a TPU Node, a normal CPU VM (an n1 instance) needs to be provisioned to communicate with the TPU host/device.
To use a TPU Node, set the following in a task YAML’s resources
field:
resources:
instance_type: n1-highmem-8
accelerators: tpu-v2-8
accelerator_args:
runtime_version: 2.5.0 # optional, TPU runtime version.
The above YAML considers n1-highmem-8
as the host machine and tpu-v2-8
as the TPU node resource.
You can modify the host instance type or the TPU type.
Here is a complete task YAML that runs MNIST training on a TPU Node using TensorFlow.
name: mnist-tpu-node
resources:
accelerators: tpu-v2-8
accelerator_args:
runtime_version: 2.5.0 # optional, TPU runtime version.
# TPU node requires loading data from a GCS bucket.
# We use SkyPilot Storage to mount a GCS bucket to /dataset.
file_mounts:
/dataset:
name: mnist-tpu-node
store: gcs
mode: MOUNT
setup: |
git clone https://github.com/tensorflow/models.git
conda activate mnist
if [ $? -eq 0 ]; then
echo 'conda env exists'
else
conda create -n mnist python=3.8 -y
conda activate mnist
pip install tensorflow==2.5.0 tensorflow-datasets tensorflow-model-optimization cloud-tpu-client
fi
run: |
conda activate mnist
cd models/official/legacy/image_classification/
export STORAGE_BUCKET=gs://mnist-tpu-node
export MODEL_DIR=${STORAGE_BUCKET}/mnist
export DATA_DIR=${STORAGE_BUCKET}/data
export PYTHONPATH=/home/gcpuser/sky_workdir/models
python3 mnist_main.py \
--tpu=${TPU_NAME} \
--model_dir=${MODEL_DIR} \
--data_dir=${DATA_DIR} \
--train_epochs=10 \
--distribution_strategy=tpu \
--download
Note
TPU node requires loading data from a GCS bucket. The file_mounts
spec above simplifies this by using SkyPilot Storage to create a new bucket/mount an existing bucket.
If you encounter a bucket Permission denied
error,
make sure the bucket is created in the same region as the Host VM/TPU Nodes and IAM permission for Cloud TPU is
correctly setup (follow instructions here).
Note
The special environment variable $TPU_NAME
is automatically set by SkyPilot at run time, so it can be used in the run
commands.
This YAML lives under the SkyPilot repo (examples/tpu/tpu_node_mnist.yaml
). Launch it with:
$ sky launch examples/tpu/tpu_node_mnist.yaml -c mycluster
...
(mnist-tpu-node pid=28961) Epoch 9/10
(mnist-tpu-node pid=28961) 58/58 [==============================] - 1s 19ms/step - loss: 0.1181 - sparse_categorical_accuracy: 0.9646 - val_loss: 0.0921 - val_sparse_categorical_accuracy: 0.9719
(mnist-tpu-node pid=28961) Epoch 10/10
(mnist-tpu-node pid=28961) 58/58 [==============================] - 1s 20ms/step - loss: 0.1139 - sparse_categorical_accuracy: 0.9655 - val_loss: 0.0831 - val_sparse_categorical_accuracy: 0.9742
...
(mnist-tpu-node pid=28961) {'accuracy_top_1': 0.9741753339767456, 'eval_loss': 0.0831054300069809, 'loss': 0.11388632655143738, 'training_accuracy_top_1': 0.9654667377471924}
Using TPU Pods¶
A TPU Pod is a collection of TPU devices connected by dedicated high-speed network interfaces for high-performance training.
To use a TPU Pod, simply change the accelerators
field in the task YAML (e.g., v2-8
-> v2-32
).
resources:
accelerators: tpu-v2-32 # Pods have > 8 cores (the last number)
accelerator_args:
runtime_version: tpu-vm-base
tpu_vm: True
Note
Both TPU architectures, TPU VMs and TPU Nodes, can be used with TPU Pods. The example below is based on TPU VMs.
To show all available TPU Pod types, run sky show-gpus
(more than 8 cores means Pods):
GOOGLE_TPU AVAILABLE_QUANTITIES
tpu-v2-8 1
tpu-v2-32 1
tpu-v2-128 1
tpu-v2-256 1
tpu-v2-512 1
tpu-v3-8 1
tpu-v3-32 1
tpu-v3-64 1
tpu-v3-128 1
tpu-v3-256 1
tpu-v3-512 1
tpu-v3-1024 1
tpu-v3-2048 1
After creating a TPU Pod, multiple host VMs (e.g., v2-32
comes with 4 host VMs) are launched.
Normally, the user needs to SSH into all hosts (depending on the architecture used, either the n1
User VMs or the TPU Host VMs) to prepare files and setup environments, and
then launch the job on each host, which is a tedious and error-prone process.
SkyPilot automates away this complexity. From your laptop, a single sky launch
command will perform:
workdir/file_mounts syncing; and
execute the setup/run commands on every host of the pod.
Here is a task YAML for a cifar10 training job on a v2-32
TPU Pod with JAX (code repo):
name: cifar-tpu-pod
resources:
accelerators: tpu-v2-32
accelerator_args:
runtime_version: tpu-vm-base
tpu_vm: True
setup: |
git clone https://github.com/infwinston/tpu-example.git
cd tpu-example
pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install -r requirements.txt
run: |
python -u tpu-example/train.py
Launch it with:
$ sky launch examples/tpu/cifar_pod.yaml -c mycluster
You should see the following output.
(node-0 pid=57977, ip=10.164.0.24) JAX process: 1 / 4
(node-3 pid=57963, ip=10.164.0.26) JAX process: 3 / 4
(node-2 pid=57922, ip=10.164.0.25) JAX process: 2 / 4
(node-1 pid=63223) JAX process: 0 / 4
...
(node-0 pid=57977, ip=10.164.0.24) [ 1000/100000] time 0.034 ( 0.063) data 0.008 ( 0.008) loss 1.215 ( 1.489) acc 68.750 (46.163)
Note
By default, outputs from all hosts are shown with the node-<i>
prefix. Use jax.process_index()
to control which host to print messages.
To submit more jobs to the same TPU Pod, use sky exec
:
$ sky exec mycluster examples/tpu/cifar_pod.yaml