Custom Structs

As functions are compiled as/when used in Julia for the given argument types (for C++ people: kind of like everything being a template argument by default), we can use custom structs and functions defined outside AcceleratedKernels.jl, which will be inlined and optimised as if they were hardcoded within the library. Normal Julia functions and code can be used, without special annotations like __device__, KOKKOS_LAMBDA or wrapping them in classes with overloaded operator().

As an example, let's compute the coordinate-wise minima of some points:

import AcceleratedKernels as AK
using Metal

struct Point
    x::Float32
    y::Float32
end

function compute_minima(points)
    AK.mapreduce(
        point -> (point.x, point.y),                    # Extract fields into tuple
        (a, b) -> (min(a[1], b[1]), min(a[2], b[2])),   # Keep each coordinate's minimum
        points,
        init=(typemax(Float32), typemax(Float32)),
    )
end

# Example output for Random.seed!(0):
#   minima = compute_minima(points) = (1.7966056f-5, 1.7797855f-6)
points = MtlArray([Point(rand(), rand()) for _ in 1:100_000])
@show minima = compute_minima(points)

Note that we did not have to explicitly type the function arguments in compute_minima - the types would be figured out when calling the function and compiled for the right backend automatically, e.g. CPU, oneAPI, ROCm, CUDA, Metal. Also, we used the standard Julia function min; it was not special-cased anywhere, it's just KernelAbstractions.jl inlining and compiling normal code, even from within the Julia.Base standard library.

You can also use unmaterialised index ranges in GPU kernels - unmaterialised meaning you do not need to waste memory creating a vector of indices, e.g.:

import AcceleratedKernels as AK
using CUDA

function complex_any(x, y)
    # Calling `any` on a normal Julia range, but running on x's backend
    AK.any(1:length(x), AK.get_backend(x)) do i
        x[i] < 0 && y[i] > 0
    end
end

complex_any(CuArray(rand(Float32, 100)), CuArray(rand(Float32, 100)))

Note that you have to specify the backend explicitly in this case, as a range does not have a backend per se - for example, when used in a GPU kernel, it only passes two numbers, the Base.UnitRange start and stop, as saved in a basic struct, rather than a whole vector.