1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5
6 # build-system
7 setuptools,
8
9 # dependencies
10 cloudpickle,
11 gymnasium,
12 matplotlib,
13 numpy,
14 pandas,
15 torch,
16
17 # tests
18 ale-py,
19 pytestCheckHook,
20 rich,
21 tqdm,
22}:
23buildPythonPackage rec {
24 pname = "stable-baselines3";
25 version = "2.7.0";
26 pyproject = true;
27
28 src = fetchFromGitHub {
29 owner = "DLR-RM";
30 repo = "stable-baselines3";
31 tag = "v${version}";
32 hash = "sha256-Ms2qoq1fokhUQ1/Wus786oYPT6C2lnHOZ+D7E7qUbjI=";
33 };
34
35 postPatch =
36 # Environment version v0 for `CliffWalking` is deprecated
37 ''
38 substituteInPlace "tests/test_vec_normalize.py" \
39 --replace-fail "CliffWalking-v0" "CliffWalking-v1"
40 '';
41
42 build-system = [ setuptools ];
43
44 pythonRelaxDeps = [
45 "gymnasium"
46 ];
47
48 dependencies = [
49 cloudpickle
50 gymnasium
51 matplotlib
52 numpy
53 pandas
54 torch
55 ];
56
57 nativeCheckInputs = [
58 ale-py
59 pytestCheckHook
60 rich
61 tqdm
62 ];
63
64 pythonImportsCheck = [ "stable_baselines3" ];
65
66 disabledTestPaths = [
67 # Tests starts training a model, which takes too long
68 "tests/test_cnn.py"
69 "tests/test_dict_env.py"
70 "tests/test_her.py"
71 "tests/test_save_load.py"
72 ];
73
74 disabledTests = [
75 # Flaky: Can fail if it takes too long, which happens when the system is under heavy load
76 "test_fps_logger"
77
78 # Tests that attempt to access the filesystem
79 "test_make_atari_env"
80 "test_vec_env_monitor_kwargs"
81 ];
82
83 meta = {
84 description = "PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms";
85 homepage = "https://github.com/DLR-RM/stable-baselines3";
86 changelog = "https://github.com/DLR-RM/stable-baselines3/releases/tag/v${version}";
87 license = lib.licenses.mit;
88 maintainers = with lib.maintainers; [ derdennisop ];
89 };
90}