Scanner C++ API
halide.h
1 #pragma once
2 
3 #include "scanner/api/kernel.h"
4 #include "scanner/util/common.h"
5 
6 #include "HalideRuntime.h"
7 
8 #ifdef HAVE_CUDA
9 #include "HalideRuntimeCuda.h"
10 #include "scanner/util/halide_context.h"
11 #endif
12 
13 namespace scanner {
14 
15 void setup_halide_frame_buf(buffer_t& halide_buf, FrameInfo& frame_info) {
16  // Halide has the input format x * stride[0] + y * stride[1] + c * stride[2]
17  halide_buf.stride[0] = 3;
18  halide_buf.stride[1] = frame_info.width() * 3;
19  halide_buf.stride[2] = 1;
20  halide_buf.extent[0] = frame_info.width();
21  halide_buf.extent[1] = frame_info.height();
22  halide_buf.extent[2] = 3;
23  halide_buf.elem_size = 1;
24 }
25 
26 void set_halide_buf_ptr(const DeviceHandle& device, buffer_t& halide_buf,
27  u8* buf, size_t size) {
28  if (device.type == DeviceType::GPU) {
29  CUDA_PROTECT({
30  halide_buf.dev = (uintptr_t) nullptr;
31 
32  // "You likely want to set the dev_dirty flag for correctness. (It will
33  // not matter if all the code runs on the GPU.)"
34  halide_buf.dev_dirty = true;
35 
36  i32 err =
37  halide_cuda_wrap_device_ptr(nullptr, &halide_buf, (uintptr_t)buf);
38  LOG_IF(FATAL, err != 0) << "Halide wrap device ptr failed";
39 
40  // "You'll need to set the host field of the buffer_t structs to
41  // something other than nullptr as that is used to indicate bounds query
42  // calls" - Zalman Stern
43  halide_buf.host = (u8*)0xdeadbeef;
44  });
45  } else {
46  halide_buf.host = buf;
47  }
48 }
49 
50 void unset_halide_buf_ptr(const DeviceHandle& device, buffer_t& halide_buf) {
51  if (device.type == DeviceType::GPU) {
52  CUDA_PROTECT({ halide_cuda_detach_device_ptr(nullptr, &halide_buf); });
53  }
54 }
55 }
Definition: database.cpp:36