1{ 2 lib, 3 buildPythonPackage, 4 fetchFromGitHub, 5 6 # build-system 7 setuptools, 8 9 # dependencies 10 cloudpickle, 11 gymnasium, 12 matplotlib, 13 numpy, 14 pandas, 15 torch, 16 17 # tests 18 ale-py, 19 pytestCheckHook, 20 rich, 21 tqdm, 22}: 23buildPythonPackage rec { 24 pname = "stable-baselines3"; 25 version = "2.7.0"; 26 pyproject = true; 27 28 src = fetchFromGitHub { 29 owner = "DLR-RM"; 30 repo = "stable-baselines3"; 31 tag = "v${version}"; 32 hash = "sha256-Ms2qoq1fokhUQ1/Wus786oYPT6C2lnHOZ+D7E7qUbjI="; 33 }; 34 35 postPatch = 36 # Environment version v0 for `CliffWalking` is deprecated 37 '' 38 substituteInPlace "tests/test_vec_normalize.py" \ 39 --replace-fail "CliffWalking-v0" "CliffWalking-v1" 40 ''; 41 42 build-system = [ setuptools ]; 43 44 pythonRelaxDeps = [ 45 "gymnasium" 46 ]; 47 48 dependencies = [ 49 cloudpickle 50 gymnasium 51 matplotlib 52 numpy 53 pandas 54 torch 55 ]; 56 57 nativeCheckInputs = [ 58 ale-py 59 pytestCheckHook 60 rich 61 tqdm 62 ]; 63 64 pythonImportsCheck = [ "stable_baselines3" ]; 65 66 disabledTestPaths = [ 67 # Tests starts training a model, which takes too long 68 "tests/test_cnn.py" 69 "tests/test_dict_env.py" 70 "tests/test_her.py" 71 "tests/test_save_load.py" 72 ]; 73 74 disabledTests = [ 75 # Flaky: Can fail if it takes too long, which happens when the system is under heavy load 76 "test_fps_logger" 77 78 # Tests that attempt to access the filesystem 79 "test_make_atari_env" 80 "test_vec_env_monitor_kwargs" 81 ]; 82 83 meta = { 84 description = "PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms"; 85 homepage = "https://github.com/DLR-RM/stable-baselines3"; 86 changelog = "https://github.com/DLR-RM/stable-baselines3/releases/tag/v${version}"; 87 license = lib.licenses.mit; 88 maintainers = with lib.maintainers; [ derdennisop ]; 89 }; 90}