at master 1.5 kB view raw
1diff --git a/tinygrad/runtime/autogen/cuda.py b/tinygrad/runtime/autogen/cuda.py 2index a30c8f53..e2078ff6 100644 3--- a/tinygrad/runtime/autogen/cuda.py 4+++ b/tinygrad/runtime/autogen/cuda.py 5@@ -145,7 +145,19 @@ def char_pointer_cast(string, encoding='utf-8'): 6 7 8 _libraries = {} 9-_libraries['libcuda.so'] = ctypes.CDLL(ctypes.util.find_library('cuda')) 10+libcuda = None 11+try: 12+ libcuda = ctypes.CDLL('libcuda.so') 13+except OSError: 14+ pass 15+try: 16+ libcuda = ctypes.CDLL('@driverLink@/lib/libcuda.so') 17+except OSError: 18+ pass 19+if libcuda is None: 20+ raise RuntimeError(f"`libcuda.so` not found") 21+ 22+_libraries['libcuda.so'] = libcuda 23 24 25 cuuint32_t = ctypes.c_uint32 26diff --git a/tinygrad/runtime/autogen/nvrtc.py b/tinygrad/runtime/autogen/nvrtc.py 27index 6af74187..c5a6c6c4 100644 28--- a/tinygrad/runtime/autogen/nvrtc.py 29+++ b/tinygrad/runtime/autogen/nvrtc.py 30@@ -10,7 +10,18 @@ import ctypes, ctypes.util 31 32 33 _libraries = {} 34-_libraries['libnvrtc.so'] = ctypes.CDLL(ctypes.util.find_library('nvrtc')) 35+libnvrtc = None 36+try: 37+ libnvrtc = ctypes.CDLL('libnvrtc.so') 38+except OSError: 39+ pass 40+try: 41+ libnvrtc = ctypes.CDLL('@libnvrtc@') 42+except OSError: 43+ pass 44+if libnvrtc is None: 45+ raise RuntimeError(f"`libnvrtc.so` not found") 46+_libraries['libnvrtc.so'] = libnvrtc 47 def string_cast(char_pointer, encoding='utf-8', errors='strict'): 48 value = ctypes.cast(char_pointer, ctypes.c_char_p).value 49 if value is not None and encoding is not None: