[DEV] VESSL Docs
  • Welcome to VESSL Docs!
  • GETTING STARTED
    • Overview
    • Quickstart
    • End-to-end Guides
      • CLI-driven Workflow
      • SDK-driven Workflow
  • USER GUIDE
    • Organization
      • Creating an Organization
      • Organization Settings
        • Add Members
        • Set Notifications
        • Configure Clusters
        • Add Integrations
        • Billing Information
    • Project
      • Creating a Project
      • Project Overview
      • Project Repository & Project Dataset
    • Clusters
      • Cluster Integrations
        • Fully Managed Cloud
        • Personal Laptops
        • On-premise Clusters
        • Private Cloud (AWS)
      • Cluster Monitoring
      • Cluster Administration
        • Resource Specs
        • Access Control
        • Quotas and Limits
        • Remove Cluster
    • Dataset
      • Adding New Datasets
      • Managing Datasets
      • Tips & Limitations
    • Experiment
      • Creating an Experiment
      • Managing Experiments
      • Experiment Results
      • Distributed Experiments
      • Local Experiments
    • Model Registry
      • Creating a Model
      • Managing Models
    • Sweep
      • Creating a Sweep
      • Sweep Results
    • Workspace
      • Creating a Workspace
      • Exploring Workspaces
      • SSH Connection
      • Downloading / Attaching Datasets
      • Running a Server Application
      • Tips & Limitations
      • Building Custom Images
    • Serve
      • Quickstart
      • Serve Web Workflow
        • Monitoring Dashboard
        • Service Logs
        • Service Revisions
        • Service Rollouts
      • Serve YAML Workflow
        • YAML Schema Reference
    • Commons
      • Running Spot Instances
      • Volume Mount
  • API REFERENCE
    • What is the VESSL CLI/SDK?
    • CLI
      • Getting Started
      • vessl run
      • vessl cluster
      • vessl dataset
      • vessl experiment
      • vessl image
      • vessl model
      • vessl organization
      • vessl project
      • vessl serve
      • vessl ssh-key
      • vessl sweep
      • vessl volume
      • vessl workspace
    • Python SDK
      • Integrations
        • Keras
        • TensorBoard
      • Utilities API
        • configure
        • vessl.init
        • vessl.log
          • vessl.Image
          • vessl.Audio
        • vessl.hp.update
        • vessl.progress
        • vessl.upload
        • vessl.finish
      • Dataset API
      • Experiment API
      • Cluster API
      • Image API
      • Model API
        • Model Serving API
      • Organization API
      • Project API
      • Serving API
      • SSH Key API
      • Sweep API
      • Volume API
      • Workspace API
    • Rate Limits
  • TROUBLESHOOTING
    • GitHub Issues
    • VESSL Flare
Powered by GitBook
On this page
  • 1. Save Checkpoints
  • 2. Load Checkpoints
  • 3. Use the spot instance option
  1. USER GUIDE
  2. Commons

Running Spot Instances

VESSL supports Amazon EC2 Spot Instances on Amazon Elastic Kubernetes Service. Spot instances are attractive in terms of price and performance compared to on-demand instances, especially on stateless and fault-tolerant container runs.

Be aware that spot instances are subject to interruptions. The claimed spot instances are suspended with 2 minutes of notice if the resource is needed elsewhere. Thus, saving and loading models for each epoch is highly recommended. Fortunately, most ML toolkits such as Fairseq and Detectron2, provide checkpointing which keeps the best-performing model. Refer to following documents to find more information about checkpointing:

  • PyTorch: Saving and Loading Models

  • TensorFlow: Save and Load Models

Refer to example codes at VESSL GitHub repository.

1. Save Checkpoints

While training a model, you need to save the model periodically. The following PyTorch and Keras code compares validation accuracy and save the best performing model for each epoch. Note that the code keeps track of checkpoints so you can load the value as a starch_epoch value.

import torch

def save_checkpoint(state, is_best, filename):
    if is_best:
        print("=> Saving a new best")
        torch.save(state, filename)
    else:
        print("=> Validation Accuracy did not improve")
        

for epoch in range(epochs):
    train(...)
    test_accuracy = 
    
    
    test_accuracy = torch.FloatTensor([test_accuracy]) 
    is_best = bool(test_accuracy.numpy() > best_accuracy.numpy())
    best_accuracy = torch.FloatTensor(
                max(test_accuracy.numpy(), best_accuracy.numpy()))
    save_checkpoint({
        'epoch': start_epoch + epoch + 1,
        'state_dict': model.state_dict(),
        'best_accuracy': best_accuracy,
    }, is_best, checkpoint_file_path)
from savvihub.keras import SavviHubCallback
from keras.callbacks import ModelCheckpoint
import os

checkpoint_path = os.path.join(args.checkpoint_path, 'checkpoints-{epoch:04d}.ckpt')
checkpoint_dir = os.path.dirname(checkpoint_path)

checkpoint_callback = ModelCheckpoint(
    checkpoint_path,
    monitor='val_accuracy',
    verbose=1,
    save_weights_only=True,
    mode='max',
    save_freq=args.save_model_freq,
)

# Compile model
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(optimizer='adam',
              loss=loss_fn,
              metrics=['accuracy'])

model.save_weights(checkpoint_path.format(epoch=0))

model.fit(x_train, y_train,
          batch_size=args.batch_size,
          validation_data=(x_val, y_val),
          epochs=args.epochs,
          callbacks=[
              SavviHubCallback(
                  data_type='image',
                  validation_data=(x_val, y_val),
                  num_images=5,
                  start_epoch=start_epoch,
                  save_image=args.save_image,
              ),
              checkpoint_callback,
          ])

2. Load Checkpoints

When spot instances are interrupted, the code is executed again from the beginning. To prevent this, you need to write a code that loads the saved checkpoint.

import torch
import os

def load_checkpoint(checkpoint_file_path):
    print(f"=> Loading checkpoint '{checkpoint_file_path}' ...")
    if device == 'cuda':
        checkpoint = torch.load(checkpoint_file_path)
    else:
        checkpoint = torch.load(checkpoint_file_path, 
                        map_location=lambda storage, loc: storage)
    model.load_state_dict(checkpoint.get('state_dict'))
    print(f"=> Loaded checkpoint (trained for {checkpoint.get('epoch')} epochs)")
    return checkpoint.get('epoch'), checkpoint.get('best_accuracy')


if os.path.exists(args.checkpoint_path) and os.path.isfile(checkpoint_file_path):
    start_epoch, best_accuracy = load_checkpoint(checkpoint_file_path)
else:
    print("=> No checkpoint has found! train from scratch")
    start_epoch, best_accuracy = 0, torch.FloatTensor([0])
    if not os.path.exists(args.checkpoint_path):
        print(f" [*] Make directories : {args.checkpoint_path}")
        os.makedirs(args.checkpoint_path)
import os
import tensorflow as tf

def parse_epoch(file_path):
    return int(os.path.splitext(os.path.basename(file_path))[0].split('-')[1])


checkpoint_path = os.path.join(args.checkpoint_path, 'checkpoints-{epoch:04d}.ckpt')
checkpoint_dir = os.path.dirname(checkpoint_path)
if os.path.exists(checkpoint_dir) and len(os.listdir(checkpoint_dir)) > 0:
    latest = tf.train.latest_checkpoint(checkpoint_dir)
    print(f"=> Loading checkpoint '{latest}' ...")
    model.load_weights(latest)
    start_epoch = parse_epoch(latest)
    print(f'start_epoch:{start_epoch}')
else:
    start_epoch = 0
    if not os.path.exists(args.checkpoint_path):
        print(f" [*] Make directories : {args.checkpoint_path}")
        os.makedirs(args.checkpoint_path)
import savvihub

def train(...):
    ...
    savvihub.log(
        step=epoch+start_epoch+1, 
        row={'loss': loss.item()}
    )
from savvihub.keras import SavviHubCallback

model.fit(...,
    callbacks=[SavviHubCallback(
        ...,
        start_epoch=start_epoch,
        ...,
    )]
)

3. Use the spot instance option

To use a spot instance on VESSL, click the Use Spot Instance checkbox. We also put the postfix *.spot for every spot instance resource type. More resource types will be added in the future.\

PreviousCommonsNextVolume Mount

Last updated 3 years ago

The start_epoch value is a useful workaround to to the __ VESSL server. Otherwise, the metrics graph might crash due to the spot instance interruption.

logging metrics