my nix configs for my servers and desktop
at main 2.0 kB view raw
1{ 2 description = "A Nix-flake-based PyTorch development environment"; 3 4 # CUDA binaries are cached by the community. 5 nixConfig = { 6 extra-substituters = [ 7 "https://nix-community.cachix.org" 8 ]; 9 extra-trusted-public-keys = [ 10 "nix-community.cachix.org-1:mB9FSh9qf2dCimDSUo8Zy7bkq5CX+/rkCWyvRCYg3Fs=" 11 ]; 12 }; 13 14 inputs.nixpkgs.url = "https://flakehub.com/f/NixOS/nixpkgs/0.1.*.tar.gz"; 15 16 outputs = { 17 self, 18 nixpkgs, 19 }: let 20 supportedSystems = ["x86_64-linux" "aarch64-linux" "x86_64-darwin" "aarch64-darwin"]; 21 forEachSupportedSystem = f: 22 nixpkgs.lib.genAttrs supportedSystems (system: 23 f { 24 pkgs = import nixpkgs { 25 inherit system; 26 config.allowUnfree = true; 27 }; 28 }); 29 in { 30 devShells = forEachSupportedSystem ({pkgs}: let 31 libs = [ 32 # PyTorch and Numpy depends on the following libraries. 33 pkgs.cudaPackages.cudatoolkit 34 pkgs.cudaPackages.cudnn 35 pkgs.stdenv.cc.cc.lib 36 pkgs.zlib 37 38 # PyTorch also needs to know where your local "lib/libcuda.so" lives. 39 # If you're not on NixOS, you should provide the right path (likely 40 # another one). 41 "/run/opengl-driver" 42 ]; 43 in { 44 default = pkgs.mkShell { 45 packages = [ 46 pkgs.python312 47 pkgs.python312Packages.venvShellHook 48 ]; 49 50 env = { 51 CC = "${pkgs.gcc}/bin/gcc"; # For `torch.compile`. 52 LD_LIBRARY_PATH = pkgs.lib.makeLibraryPath libs; 53 }; 54 55 venvDir = ".venv"; 56 postVenvCreation = '' 57 # This is run only when creating the virtual environment. 58 pip install torch==2.5.1 numpy==2.2.2 59 ''; 60 postShellHook = '' 61 # This is run every time you enter the devShell. 62 python3 -c "import torch; print('CUDA available' if torch.cuda.is_available() else 'CPU only')" 63 ''; 64 }; 65 }); 66 }; 67}