How to generate random numbers on the GPU?

Hi!

I’m wondering how to generate random numbers directly on the GPU. For a relatively quick iteration loop, it’s quite slow to generate many large tensors on the CPU and then move them to the GPU. Also, I don’t think the CPU backend supports generating random bfloat16 numbers?

I tried the following code and it segfaults:

from gpu.host import DeviceContext
from layout import Layout, LayoutTensor, IntTuple
from random import randn

alias x_layout = Layout(IntTuple(1000,), IntTuple(1,))

def main():
    var ctx = DeviceContext()
    var x_buffer = ctx.enqueue_create_buffer[DType.bfloat16](x_layout.size())
    var x = LayoutTensor[DType.bfloat16, x_layout](x_buffer)
    randn[DType.bfloat16](x.ptr, x_layout.size())

Any tips/ideas would be appreciated!

Please use the gpu.random.

2 Likes

Thanks for the quick answer! Follow up: do you have any examples of how to use this? It looks like I’d have to write my own kernel to use this? I’m a bit confused sorry.

I wrote a quick kernel using the gpu.random and I can get it to work with float16 (and float32) but not bfloat16. Here is the code:

from gpu import global_idx
from gpu.host import DeviceContext
from layout import Layout, LayoutTensor, IntTuple
from gpu.random import Random
from algorithm.functional import elementwise

from utils.index import IndexList, Index

def main():

    var ctx = DeviceContext()
    var x_buffer = ctx.enqueue_create_buffer[DType.bfloat16](8)
    alias x_layout = Layout(IntTuple(8,), IntTuple(1,))
    var x = LayoutTensor[DType.bfloat16, x_layout](x_buffer)

    @parameter
    @always_inline
    @__copy_capture(x)
    fn func[simd_width: Int, rank: Int](idx0: IndexList[rank]):
        var idx = rebind[IndexList[1]](idx0)
        var rng_state = Random(seed=idx0[0])
        var values : SIMD[DType.float32, 4] = rng_state.step_uniform()
        @parameter
        for i in range(simd_width):
            x[idx[0] + i] = values[i].cast[DType.bfloat16]()

    elementwise[func, 4, target="gpu"](Index(8), ctx)
    
    with x_buffer.map_to_host() as x_host:
        for i in range(8):
            print(x_host[i].cast[DType.float32]())

It seems like at some point one of the casts zeros out the numbers? Is there a different way I should be generating/printing bfloat16 numbers?