注意:本教程仍在持续更新中


    在本教程中,我们将展示:

    • 如何用C ++和CUDA编写您的算子并对其进行即时编译
    • 运行您的自定义算子
      如果您想用几行代码来实现一个非常简单的算子,请使用code运算,请参阅help(jt.code).custom_op用于实现复杂的算子。 custom_op和内置运算的功能完全相同。
    1. import jittor as jt
    2. header ="""
    3. #pragma once
    4. #include "op.h"
    5. namespace jittor {
    6. struct CustomOp : Op {
    7. Var* output;
    8. CustomOp(NanoVector shape, NanoString dtype=ns_float);
    9. const char* name() const override { return "custom"; }
    10. DECLARE_jit_run;
    11. };
    12. } // jittor
    13. """
    14. src = """
    15. #include "var.h"
    16. #include "custom_op.h"
    17. namespace jittor {
    18. #ifndef JIT
    19. CustomOp::CustomOp(NanoVector shape, NanoString dtype) {
    20. flags.set(NodeFlags::_cuda, 1);
    21. flags.set(NodeFlags::_cpu, 1);
    22. output = create_output(shape, dtype);
    23. }
    24. void CustomOp::jit_prepare() {
    25. add_jit_define("T", output->dtype());
    26. }
    27. #else // JIT
    28. #ifdef JIT_cpu
    29. void CustomOp::jit_run() {
    30. index_t num = output->num;
    31. auto* __restrict__ x = output->ptr<T>();
    32. for (index_t i=0; i<num; i++)
    33. x[i] = (T)i;
    34. }
    35. #else
    36. // JIT_cuda
    37. __global__ void kernel(index_t n, T *x) {
    38. int index = blockIdx.x * blockDim.x + threadIdx.x;
    39. int stride = blockDim.x * gridDim.x;
    40. for (int i = index; i < n; i += stride)
    41. x[i] = (T)-i;
    42. }
    43. void CustomOp::jit_run() {
    44. index_t num = output->num;
    45. auto* __restrict__ x = output->ptr<T>();
    46. int blockSize = 256;
    47. int numBlocks = (num + blockSize - 1) / blockSize;
    48. kernel<<<numBlocks, blockSize>>>(num, x);
    49. }
    50. #endif // JIT_cpu
    51. #endif // JIT
    52. } // jittor
    53. """
    54. my_op = jt.compile_custom_op(header, src, "custom", warp=False)

    让我们查看一下这个运算的结果。

    1. # run cpu version
    2. jt.flags.use_cuda = 0
    3. a = my_op([3,4,5], 'float').fetch_sync()
    4. assert (a.flatten() == range(3*4*5)).all()
    5. if jt.compiler.has_cuda:
    6. # run cuda version
    7. jt.flags.use_cuda = 1
    8. a = my_op([3,4,5], 'float').fetch_sync()
    9. assert (-a.flatten() == range(3*4*5)).all()