diff --git a/third_party/nvidia/backend/driver.c b/third_party/nvidia/backend/driver.c index ab24f7657..46dbaceb0 100644 --- a/third_party/nvidia/backend/driver.c +++ b/third_party/nvidia/backend/driver.c @@ -1,4 +1,4 @@ -#include "cuda.h" +#include #include #include #define PY_SSIZE_T_CLEAN diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py index 47544bd8e..d57c6a70f 100644 --- a/third_party/nvidia/backend/driver.py +++ b/third_party/nvidia/backend/driver.py @@ -12,7 +12,8 @@ from triton.backends.compiler import GPUTarget from triton.backends.driver import GPUDriver dirname = os.path.dirname(os.path.realpath(__file__)) -include_dirs = [os.path.join(dirname, "include")] +import shlex +include_dirs = [*shlex.split("@cudaToolkitIncludeDirs@"), os.path.join(dirname, "include")] libdevice_dir = os.path.join(dirname, "lib") libraries = ['cuda'] @@ -256,7 +257,7 @@ def make_launcher(constants, signature, tensordesc_meta): params = [f"&arg{i}" for i, ty in signature.items() if ty != "constexpr"] params.append("&global_scratch") src = f""" -#include \"cuda.h\" +#include #include #include #include