DEV Community

Chaim Rand
Chaim Rand

Posted on • Originally published at towardsdatascience.com

On the Programmability of AWS Trainium and Inferentia

Accelerating AI/ML Model Training with Custom Operators — Part 4

Photo by Agata Bres on Unsplash

In this post we continue our exploration of the opportunities for runtime optimization of machine learning (ML) workloads through custom operator development. This time, we focus on the tools provided by the AWS Neuron SDK for developing and running new kernels on AWS Trainium and AWS Inferentia. With the rapid development of the low-level model components (e.g., attention layers) driving the AI revolution, the programmability of the accelerators used for training and running ML models is crucial. Dedicated AI chips, in particular, must offer a worthy alternative to the widely used and highly impactful general-purpose GPU (GPGPU) development frameworks, such as CUDA and Triton.

In previous posts (e.g., here and here) we explored the opportunity for building and running ML models on AWS's custom-built AI chips using the the dedicated AWS Neuron SDK. In its most recent release of the SDK (version 2.20.0), AWS introduced the Neuron Kernel Interface (NKI) for developing custom kernels for NeuronCore-v2, the underlying accelerator powering both Trainium and Inferentia2. The NKI interface joins another API that enables NeuronCore-v2 programmability, Neuron Custom C++ Operators. In this post we will explore both opportunities and demonstrate them in action.

Disclaimers

Importantly, this post should not be viewed as a substitute for the official AWS Neuron SDK documentation. At the time of this writing the Neuron SDK APIs for custom kernel development is in Beta, and may change by the time you read this. The examples we share are intended for demonstrative purposes, only. We make no claims as to their optimality, robustness, durability, or accuracy. Please do not view our mention of any platforms, tools, APIs, etc., as an endorsement for their use. The best choices for any project depend on the specifics of the use-case at hand and warrant appropriate investigation and analysis.

Developing Custom Kernels for Neuron Cores

Although the list of ML models supported by the Neuron SDK is continuously growing, some operations remain either unsupported or implemented suboptimally. By exposing APIs for Neuron kernel customization, the SDK empowers developers to create and/or optimize the low-level operations that they need, greatly increasing the opportunity for running ML workloads on Trainium and Inferentia.

As discussed in our previous posts in this series, fully leveraging the power of these AI chips requires a detailed understanding their low-level architecture.

The Neuron Core Architecture

The NKI documentation includes a dedicated section on the architecture design of NeuronCore-v2 and its implications on custom operator development. Importantly, there are many differences between Neuron cores and their AI accelerator counterparts (e.g., GPUs and TPUs). Optimizing for Neuron cores requires a unique set of strategies and skills.

Similar to other dedicated AI chips, NeuronCore-v2 includes several internal acceleration engines, each of which specializes in performing certain types of computations. The engines can be run asynchronously and in parallel. The Neuron Compiler is responsible for transforming ML models into low-level operations and optimizing the choice of compute engine for each one.

The Tensor engine specializes in matrix multiplication. The Vector and Scalar engines both operate on tensors with the Vector engine specializing in reduction operations and the Scalar engine in non-linear functions. GpSimd is a general purpose engine capable of running arbitrary C/C++ programs. Note that while the NKI interface exposes access to all four compute engines, custom C++ operators are designed specifically for the GpSimd.

More details on the capabilities of each engine can be found in the architecture documentation. Furthermore, the NKI Instruction Set Architecture (ISA) documentation provides details on the engines on which different low-level operations are run.

Another important aspect of the Neuron chip is its memory architecture. A Neuron device includes three types of memory, HBM, SBUF, and PSUM. An intimate understanding of the capacities and capabilities of each one is crucial for optimal kernel development.

Given the architecture overview, you might conclude that Neuron kernel development requires high expertise. While this may be true for creating fully optimized kernels that leverage all the capabilities of the Neuron core, our aim is to demonstrate the accessibility, value, and potential of the Neuron custom kernel APIs - even for non-expert developers.

Custom NKI Kernels

The NKI interface is a Python-level API that exposes the use of the Neuron core compute engines and memory resources to ML developers. The NKI Getting Started guide details the setup instructions and provides a soft landing with a simple, "hello world", kernel. The NKI Programming Model guide details the three stages of a typical NKI kernel (loading inputs, running operations on the computation engines, and storing outputs) and introduces the NKI Tile and Tile-based operations. The NKI tutorials demonstrate a variety of NKI kernel sample applications, with each one introducing new core NKI APIs and capabilities. Given the presumed optimality of the sample kernels, one possible strategy for developing new kernels could be to 1) identify a sample that is similar to the operation you wish to implement and then 2) use it as a baseline and iteratively refine and adjust it to achieve the specific functionality you require.

The NKI API Reference Manual details the Python API for kernel development. With a syntax and semantics that are similar to Triton and NumPy, the NKI language definition aims to maximize accessibility and ease of use. However, it is important to note that NKI kernel development is limited to the operations defined in the NKI library, which (as of the time of this writing) are fewer and more constrained than in libraries such as Triton and NumPy.

Toy Example - A GIOU Kernel

As in our previous posts, we assess the use of NKI by building a custom implementation of the Generalized Intersection Over Union (GIOU) operation on a pair of batches of input boxes. Since GIOU involves pixel-wise operations, we used the *exp *kernel from the NKI Programming guide as a reference point and incorporated the use of NKI's advanced tensor indexing in our implementation. To facilitate debugging in a CPU environment, we also added options to run the code using the nki.simulate_kernel and nki.language.device_print.html APIs.

import torch
import neuronxcc.nki as nki
import neuronxcc.nki.language as nl
import numpy as np

simulate = False

try:
    # if torch libraries are installed assume that we are running on Neuron
    import torch_xla.core.xla_model as xm
    import torch_neuronx
    from torch_neuronx import nki_jit

    device = xm.xla_device()

    # empty implementation
    def debug_print(*args, **kwargs):
        pass
except:
    # if torch libraries are not installed assume that we are running on CPU
    # and program script to use nki simulation
    simulate = True
    nki_jit = nki.trace
    debug_print = nl.device_print
    device = 'cpu'

@nki_jit
def giou_kernel(preds_ptr,
                targets_ptr,
                output_ptr):
    epsilon = 1e-5
    TILE_M = nl.tile_size.pmax  # 128
    TILE_N = nl.tile_size.psum_fmax  # 512
    TILE_N_OUT = TILE_N // 4

    p_1, p_2 = preds_ptr.shape
    t_1, t_2 = targets_ptr.shape
    o_1, o_2 = output_ptr.shape

    #  verify input
    # batch size must be multiple of 128
    assert p_1 % TILE_M == 0
    assert p_1 == t_1
    assert p_1 == o_1
    # num boxes box *4 must be multiple of 512
    assert p_2 % TILE_N == 0
    assert p_2 == t_2
    assert p_2 // 4 == o_2

    num_tiles_m = p_1 // TILE_M
    num_tiles_n = p_2 // TILE_N

    # Generate tensors for advanced indexing
    i_p = nl.arange(TILE_M)[:, None]
    i_f = nl.arange(TILE_N // 4)[None, :]
    i_f_0 = (4 * i_f)
    i_f_1 = (4 * i_f + 1)
    i_f_2 = (4 * i_f + 2)
    i_f_3 = (4 * i_f + 3)

    # Use affine_range to loop over tiles
    for m in nl.affine_range(num_tiles_m):
        for n in nl.affine_range(num_tiles_n):
            # Load input data from HBM
            preds = nl.load(preds_ptr[m * TILE_M:(m + 1) * TILE_M,
                            n * TILE_N:(n + 1) * TILE_N])
            targets = nl.load(targets_ptr[m * TILE_M:(m + 1) * TILE_M,
                              n * TILE_N:(n + 1) * TILE_N])
            debug_print('preds', preds)
            preds_left = preds[i_p, i_f_0]
            preds_top = preds[i_p, i_f_1]
            preds_right = preds[i_p, i_f_2]
            preds_bottom = preds[i_p, i_f_3]

            gt_left = targets[i_p, i_f_0]
            gt_top = targets[i_p, i_f_1]
            gt_right = targets[i_p, i_f_2]
            gt_bottom = targets[i_p, i_f_3]

            # Compute the area of each box
            area1 = (preds_right - preds_left) * (preds_bottom - preds_top)
            area2 = (gt_right - gt_left) * (gt_bottom - gt_top)

            # Compute the intersection
            left = nl.maximum(preds_left, gt_left)
            top = nl.maximum(preds_top, gt_top)
            right = nl.minimum(preds_right, gt_right)
            bottom = nl.minimum(preds_bottom, gt_bottom)

            inter_w = nl.maximum(right - left, 0)
            inter_h = nl.maximum(bottom - top, 0)
            inter_area = inter_w * inter_h

            union_area = area1 + area2 - inter_area

            iou_val = inter_area / nl.maximum(union_area, epsilon)

            # Compute the smallest enclosing box
            enclose_left = nl.minimum(preds_left, gt_left)
            enclose_top = nl.minimum(preds_top, gt_top)
            enclose_right = nl.maximum(preds_right, gt_right)
            enclose_bottom = nl.maximum(preds_bottom, gt_bottom)

            enclose_w = nl.maximum(enclose_right - enclose_left, 0)
            enclose_h = nl.maximum(enclose_bottom - enclose_top, 0)
            enclose_area = enclose_w * enclose_h

            # Compute GIOU
            delta_area = (enclose_area - union_area)
            enclose_area = nl.maximum(enclose_area, epsilon)
            giou = iou_val - delta_area / enclose_area

            # Store results
            nl.store(output_ptr[m * TILE_M:(m + 1) * TILE_M,
                     n * TILE_N_OUT:(n + 1) * TILE_N_OUT],
                     giou)
Enter fullscreen mode Exit fullscreen mode

To run our GIOU kernel, we generate two batches of random boxes and feed them to our function:

# generate random data in np
np.random.seed(0)
batch_size = 1024
n_boxes = 256
img_size = 256
boxes = []

for i in range(2):
    # Randomly generate box sizes and positions
    box_sizes = np.random.randint(1, img_size, size=(batch_size,n_boxes,2))
    top_left = np.random.randint(0, img_size-1, size=(batch_size,n_boxes,2))
    bottom_right = np.clip(top_left + box_sizes, 0, img_size - 1)

    # Concatenate top-left and bottom-right coordinates
    rand_boxes = np.concatenate((top_left, bottom_right), axis=2)

    boxes.append(rand_boxes.astype(np.float32))

out = np.empty((batch_size, n_boxes), np.float32)

# convert tensors to PyTorch
t_boxes_0 = torch.tensor(boxes[0]).to(device)
t_boxes_1 = torch.tensor(boxes[1]).to(device)
t_out = torch.tensor(out).to(device)

if simulate:
    # the simulation API requires numpy input
    nki.simulate_kernel(giou_kernel,
                        boxes[0].reshape((batch_size, -1)),
                        boxes[1].reshape((batch_size, -1)),
                        out)
else:
    giou_kernel(t_boxes_0.view((batch_size, -1)),
                t_boxes_1.view((batch_size, -1)),
                t_out)
Enter fullscreen mode Exit fullscreen mode

To assess the performance of our NKI kernel, we will compare it with the following naive implementation of GIOU in PyTorch:

def torch_giou(boxes1, boxes2):
    # loosely based on torchvision generalized_box_iou_loss code
    epsilon = 1e-5

    # Compute areas of both sets of boxes
    area1 = (boxes1[...,2]-boxes1[...,0])*(boxes1[...,3]-boxes1[...,1])
    area2 = (boxes2[...,2]-boxes2[...,0])*(boxes2[...,3]-boxes2[...,1])

    # Corners of intersection
    lt = torch.max(boxes1[..., :2], boxes2[..., :2])
    rb = torch.min(boxes1[..., 2:], boxes2[..., 2:])

    # Width and height of intersection
    wh = (rb - lt).clamp(min=0)

    # Area of the intersection
    inter = wh[..., 0] * wh[..., 1]

    # Union of the two boxes
    union = area1 + area2 - inter
    iou = inter / union.clamp(epsilon)

    # Corners of enclosing box
    lti = torch.min(boxes1[..., :2], boxes2[..., :2])
    rbi = torch.max(boxes1[..., 2:], boxes2[..., 2:])

    # Width and height of the enclosing box
    whi = (rbi - lti).clamp(min=0)

    # Area of the enclosing box
    areai = (whi[..., 0] * whi[..., 1]).clamp(epsilon)

    return iou - (areai - union) / areai
Enter fullscreen mode Exit fullscreen mode

We use the following benchmarking utility to compare the runtime performance of our two functions:

import time
def benchmark(f, warmup_iters=20, ntrials: int = 100):
    def run(*args, **kwargs):
        # warmup
        for _ in range(warmup_iters):
            f(*args, **kwargs)
        start_time = time.time()
        for _ in range(ntrials):
            f(*args, **kwargs)
        end_time = time.time()
        # Calculate average time per iteration
        avg_time = (end_time - start_time) / ntrials
        return avg_time

    return run

avg_time = benchmark(torch_giou)(t_boxes_0, t_boxes_1)
print(f'torch_giou: {avg_time}')

avg_time = benchmark(giou_kernel)(t_boxes_0.view((batch_size, -1)),
                                  t_boxes_1.view((batch_size, -1)),
                                  t_out)
print(f'giou_kernel: {avg_time}')
Enter fullscreen mode Exit fullscreen mode

Runtime Environment

We ran our script on an Amazon EC2 inf2.xlarge instance (containing two Neuron cores and four vCPUs). We used the most recent version of the Deep Learning AMI for Neuron available at the time of this writing, "Deep Learning AMI Neuron (Ubuntu 22.04) 20241027", with AWS Neuron 2.20.1 and PyTorch 2.1.

Results

Our custom GIOU kernel demonstrated an average runtime of 0.211 milliseconds compared to 0.293, amounting to a 39% performance boost. Keep in mind that these results are unique to our toy example. Other operators, particularly ones that include matrix multiplications (and utilize the Tensor engine) are likely to exhibit different comparative results.

Optimizing NKI Kernel Performance

The next step in our kernel development - beyond the scope of this post - would to be to analyze the performance of the GIOU kernel using the dedicated Neuron Profiler in order to identify bottlenecks and optimize our implementation. Please see the NKI performance guide for more details.

Neuron Custom C++ Operators

The second method for creating a custom Neuron kernel is to build a C++ operator for the GpSimd engine. This method is described in the Neuron Custom C++ Operators Developer Guide and demonstrated in the Neuron Custom C++ Operators in MLP and Neuron Custom C++ Operators Performance Optimization tutorials.

Neuron Custom C++ Operators presents an opportunity for "kernel fusion" on the GpSimd engine by facilitating the combination of multiple low-level operations into a single kernel execution. This approach can significantly reduce the overhead associated with: 1) loading multiple individual kernels, and 2) transferring data between different memory regions.

Toy Example - A GIOU C++ Kernel

In the code block below we implement a C++ GIOU operator for Neuron and save it to a file named giou.cpp. Our kernel uses the TCM accessor for optimizing memory read and write performance and applies the *multicore *setting in order to use all eight of the GpSimd's internal processors.

#include <stdint.h>
#include <stdlib.h>
#include <torch/torch.h>
#include <neuron/neuron-utils.hpp>
#include <algorithm>

// input boxes of shape 1024x256x4
// output scores of shape 1024x256
torch::Tensor giou(const torch::Tensor& t_pred,
                   const torch::Tensor& t_target) {
  size_t num_samples = t_pred.sizes()[0];
  size_t num_boxes = t_pred.sizes()[1];
  torch::Tensor t_out = get_dst_tensor();

  // get the number of GpSimd processors (8 in NeuronCoreV2)
  uint32_t cpu_count = get_cpu_count();
  // get index of current processor
  uint32_t cpu_id = get_cpu_id();

  // divide the batch size into 8 partitions
  uint32_t partition = num_samples / cpu_count;

  // use tcm buffers to load and write data
  size_t tcm_in_size = num_boxes*4;
  size_t tcm_out_size = num_boxes;
  float *tcm_pred = (float*)torch::neuron::tcm_malloc(
                                             sizeof(float)*tcm_in_size);
  float *tcm_target = (float*)torch::neuron::tcm_malloc(
                                             sizeof(float)*tcm_in_size);
  float *tcm_output = (float*)torch::neuron::tcm_malloc(
                                             sizeof(float)*tcm_in_size);
  auto t_pred_tcm_acc = t_pred.tcm_accessor();
  auto t_target_tcm_acc = t_target.tcm_accessor();
  auto t_out_tcm_acc = t_out.tcm_accessor();

  // iterate over each of the entries in the partition
  for (size_t i = 0; i < partition; i++) {
    // load the pred and target boxes into local memory
    t_pred_tcm_acc.tensor_to_tcm<float>(tcm_pred,
                                        partition*cpu_id + i*tcm_in_size,
                                        tcm_in_size);
    t_target_tcm_acc.tensor_to_tcm<float>(tcm_target,
                                          partition*cpu_id + i*tcm_in_size,
                                          tcm_in_size);

    // iterate over each of the boxes in the entry
    for (size_t j = 0; j < num_boxes; j++) {
      const float epsilon = 1e-5;
      const float* box1 = &tcm_pred[j * 4];
      const float* box2 = &tcm_target[j * 4];
      // Compute area of each box
      float area1 = (box1[2] - box1[0]) * (box1[3] - box1[1]);
      float area2 = (box2[2] - box2[0]) * (box2[3] - box2[1]);

      // Compute the intersection
      float left = std::max(box1[0], box2[0]);
      float top = std::max(box1[1], box2[1]);
      float right = std::min(box1[2], box2[2]);
      float bottom = std::min(box1[3], box2[3]);

      float inter_w = std::max(right - left, 0.f);
      float inter_h = std::max(bottom - top, 0.f);
      float inter_area = inter_w * inter_h;

      // Compute the union area
      float union_area = area1 + area2 - inter_area;

      // IoU
      float iou_val = inter_area / std::max(union_area, epsilon);

      // Compute the smallest enclosing box
      float enclose_left = std::min(box1[0], box2[0]);
      float enclose_top = std::min(box1[1], box2[1]);
      float enclose_right = std::max(box1[2], box2[2]);
      float enclose_bottom = std::max(box1[3], box2[3]);

      float enclose_w = std::max(enclose_right - enclose_left, 0.f);
      float enclose_h = std::max(enclose_bottom - enclose_top, 0.f);
      float enclose_area = std::max(enclose_w * enclose_h, epsilon);

      float result = iou_val - (enclose_area-union_area)/enclose_area;
      tcm_output[j] = result;
    }

    // write the giou scores of all boxes in the current entry
    t_out_tcm_acc.tcm_to_tensor<float>(tcm_output,
                                       partition*cpu_id + i*tcm_out_size,
                                       tcm_out_size);
  }

  torch::neuron::tcm_free(tcm_pred);
  torch::neuron::tcm_free(tcm_target);
  return t_out;
}
Enter fullscreen mode Exit fullscreen mode

We require a separate shape.cpp file that defines the output shape of our GIOU function and registers our custom operator with the Neuron library:

#include <stdint.h>
#include <stdlib.h>
#include <torch/torch.h>
#include "torchneuron/register.h"

torch::Tensor giou_shape(torch::Tensor boxes1, torch::Tensor boxes2) {
    torch::Tensor t_out = torch::zeros({boxes1.sizes()[0],
                                        boxes1.sizes()[1]},
                                       torch::kFloat);
    return t_out;
}

NEURON_LIBRARY(my_ops, m) {
  m.def("giou", &giou_shape, "giou");
}
Enter fullscreen mode Exit fullscreen mode

The build.py script compiles the C++ operator and exposes it as a Python API:

import os
import torch_neuronx
from torch_neuronx.xla_impl import custom_op

custom_op.load(
    name='giou',
    compute_srcs=['giou.cpp'],
    shape_srcs=['shape.cpp'],
    build_directory=os.getcwd(),
    multicore=True,
    verbose=True
)
Enter fullscreen mode Exit fullscreen mode

The compilation script generates a *libgiou.so *library containing the implementation of our C++ GIOU operator. In the code block below we load the library and measure the performance of our custom kernel using the benchmarking utility defined above:

from torch_neuronx.xla_impl import custom_op
custom_op.load_library('libgiou.so')

avg_time = benchmark(torch.ops.my_ops.giou)(t_boxes_0, t_boxes_1)
print(f'C++ giou: {avg_time}')
Enter fullscreen mode Exit fullscreen mode

Runtime Environment

We used the same Neuron environment from our NKI experiments to compile and test our C++ kernel. Please note the installation steps that are required for custom C++ operator development.

Results

Our C++ GIOU kernel demonstrated an average runtime of 0.061 milliseconds - nearly five times faster than our baseline implementation. This is presumably a result of "kernel fusion", as discussed above.

Conclusion

The table below summarizes the runtime results of our experiments.
Avg time of different GIOU implementations (lower is better) - by Author

Please keep in mind that these results are specific to the toy example and runtime environment used in this study. The comparative results of other kernels might be very different - depending on the degree to which they can leverage the Neuron core's internal compute engines.

The table below summarizes some of the differences we observed between the two methods of AWS Neuron kernel customization.

Comparison between kernel customization tools (by Author)

Through its high-level Python interface, the NKI APIs expose the power of the Neuron acceleration engines to ML developers in an accessible and user-friendly manner. The low-level C++ Custom Operators library enables even greater programmability, but is limited to the GpSimd engine. By effectively combining both tools, developers can fully leverage the AWS Neuron architecture's capabilities.

Summary

With the AI revolution in full swing, many companies are developing advanced new AI chips to meet the growing demand for compute. While public announcements often highlight these chips' runtime performance, cost savings, and energy efficiency, several core capabilities are essential to make these chips and their software stacks truly viable for ML development. These capabilities include robust debugging tools, performance analysis and optimization utilities, programmability, and more.

In this post, we focused on the utilities available for programming AWS's homegrown AI accelerators, Trainium and Inferentia, and demonstrated their use in building custom ML operations. These tools empower developers to optimize the performance of their ML models on AWS's AI chips and open up new opportunities for innovation and creativity.

Top comments (0)