my nix configs for my servers and desktop
1{
2 description = "A Nix-flake-based PyTorch development environment";
3
4 # CUDA binaries are cached by the community.
5 nixConfig = {
6 extra-substituters = [
7 "https://nix-community.cachix.org"
8 ];
9 extra-trusted-public-keys = [
10 "nix-community.cachix.org-1:mB9FSh9qf2dCimDSUo8Zy7bkq5CX+/rkCWyvRCYg3Fs="
11 ];
12 };
13
14 inputs.nixpkgs.url = "https://flakehub.com/f/NixOS/nixpkgs/0.1.*.tar.gz";
15
16 outputs = {
17 self,
18 nixpkgs,
19 }: let
20 supportedSystems = ["x86_64-linux" "aarch64-linux" "x86_64-darwin" "aarch64-darwin"];
21 forEachSupportedSystem = f:
22 nixpkgs.lib.genAttrs supportedSystems (system:
23 f {
24 pkgs = import nixpkgs {
25 inherit system;
26 config.allowUnfree = true;
27 };
28 });
29 in {
30 devShells = forEachSupportedSystem ({pkgs}: let
31 libs = [
32 # PyTorch and Numpy depends on the following libraries.
33 pkgs.cudaPackages.cudatoolkit
34 pkgs.cudaPackages.cudnn
35 pkgs.stdenv.cc.cc.lib
36 pkgs.zlib
37
38 # PyTorch also needs to know where your local "lib/libcuda.so" lives.
39 # If you're not on NixOS, you should provide the right path (likely
40 # another one).
41 "/run/opengl-driver"
42 ];
43 in {
44 default = pkgs.mkShell {
45 packages = [
46 pkgs.python312
47 pkgs.python312Packages.venvShellHook
48 ];
49
50 env = {
51 CC = "${pkgs.gcc}/bin/gcc"; # For `torch.compile`.
52 LD_LIBRARY_PATH = pkgs.lib.makeLibraryPath libs;
53 };
54
55 venvDir = ".venv";
56 postVenvCreation = ''
57 # This is run only when creating the virtual environment.
58 pip install torch==2.5.1 numpy==2.2.2
59 '';
60 postShellHook = ''
61 # This is run every time you enter the devShell.
62 python3 -c "import torch; print('CUDA available' if torch.cuda.is_available() else 'CPU only')"
63 '';
64 };
65 });
66 };
67}