I’m wondering how to generate random numbers directly on the GPU. For a relatively quick iteration loop, it’s quite slow to generate many large tensors on the CPU and then move them to the GPU. Also, I don’t think the CPU backend supports generating random bfloat16 numbers?
I tried the following code and it segfaults:
from gpu.host import DeviceContext
from layout import Layout, LayoutTensor, IntTuple
from random import randn
alias x_layout = Layout(IntTuple(1000,), IntTuple(1,))
def main():
var ctx = DeviceContext()
var x_buffer = ctx.enqueue_create_buffer[DType.bfloat16](x_layout.size())
var x = LayoutTensor[DType.bfloat16, x_layout](x_buffer)
randn[DType.bfloat16](x.ptr, x_layout.size())
Thanks for the quick answer! Follow up: do you have any examples of how to use this? It looks like I’d have to write my own kernel to use this? I’m a bit confused sorry.
I wrote a quick kernel using the gpu.random and I can get it to work with float16 (and float32) but not bfloat16. Here is the code:
from gpu import global_idx
from gpu.host import DeviceContext
from layout import Layout, LayoutTensor, IntTuple
from gpu.random import Random
from algorithm.functional import elementwise
from utils.index import IndexList, Index
def main():
var ctx = DeviceContext()
var x_buffer = ctx.enqueue_create_buffer[DType.bfloat16](8)
alias x_layout = Layout(IntTuple(8,), IntTuple(1,))
var x = LayoutTensor[DType.bfloat16, x_layout](x_buffer)
@parameter
@always_inline
@__copy_capture(x)
fn func[simd_width: Int, rank: Int](idx0: IndexList[rank]):
var idx = rebind[IndexList[1]](idx0)
var rng_state = Random(seed=idx0[0])
var values : SIMD[DType.float32, 4] = rng_state.step_uniform()
@parameter
for i in range(simd_width):
x[idx[0] + i] = values[i].cast[DType.bfloat16]()
elementwise[func, 4, target="gpu"](Index(8), ctx)
with x_buffer.map_to_host() as x_host:
for i in range(8):
print(x_host[i].cast[DType.float32]())
It seems like at some point one of the casts zeros out the numbers? Is there a different way I should be generating/printing bfloat16 numbers?