at master 1.2 kB view raw
1diff --git a/third_party/nvidia/backend/driver.c b/third_party/nvidia/backend/driver.c 2index ab24f7657..46dbaceb0 100644 3--- a/third_party/nvidia/backend/driver.c 4+++ b/third_party/nvidia/backend/driver.c 5@@ -1,4 +1,4 @@ 6-#include "cuda.h" 7+#include <cuda.h> 8 #include <dlfcn.h> 9 #include <stdbool.h> 10 #define PY_SSIZE_T_CLEAN 11diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py 12index 47544bd8e..d57c6a70f 100644 13--- a/third_party/nvidia/backend/driver.py 14+++ b/third_party/nvidia/backend/driver.py 15@@ -12,7 +12,8 @@ from triton.backends.compiler import GPUTarget 16 from triton.backends.driver import GPUDriver 17 18 dirname = os.path.dirname(os.path.realpath(__file__)) 19-include_dirs = [os.path.join(dirname, "include")] 20+import shlex 21+include_dirs = [*shlex.split("@cudaToolkitIncludeDirs@"), os.path.join(dirname, "include")] 22 libdevice_dir = os.path.join(dirname, "lib") 23 libraries = ['cuda'] 24 25@@ -256,7 +257,7 @@ def make_launcher(constants, signature, tensordesc_meta): 26 params = [f"&arg{i}" for i, ty in signature.items() if ty != "constexpr"] 27 params.append("&global_scratch") 28 src = f""" 29-#include \"cuda.h\" 30+#include <cuda.h> 31 #include <stdbool.h> 32 #include <Python.h> 33 #include <dlfcn.h>