1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5
6 # build-system
7 setuptools,
8
9 # dependencies
10 boltons,
11 numpy,
12 scipy,
13 torch,
14 trampoline,
15
16 # tests
17 pytest7CheckHook,
18}:
19
20buildPythonPackage rec {
21 pname = "torchsde";
22 version = "0.2.6";
23 format = "pyproject";
24
25 src = fetchFromGitHub {
26 owner = "google-research";
27 repo = "torchsde";
28 tag = "v${version}";
29 hash = "sha256-D0p2tL/VvkouXrXfRhMuCq8wMtzeoBTppWEG5vM1qCo=";
30 };
31
32 postPatch = ''
33 substituteInPlace setup.py \
34 --replace "numpy==1.19.*" "numpy" \
35 --replace "scipy==1.5.*" "scipy"
36 '';
37
38 nativeBuildInputs = [ setuptools ];
39
40 propagatedBuildInputs = [
41 boltons
42 numpy
43 scipy
44 torch
45 trampoline
46 ];
47
48 pythonImportsCheck = [ "torchsde" ];
49
50 nativeCheckInputs = [ pytest7CheckHook ];
51
52 disabledTests = [
53 # RuntimeError: a view of a leaf Variable that requires grad is being used in an in-place operation.
54 "test_adjoint"
55 ];
56
57 meta = with lib; {
58 changelog = "https://github.com/google-research/torchsde/releases/tag/v${version}";
59 description = "Differentiable SDE solvers with GPU support and efficient sensitivity analysis";
60 homepage = "https://github.com/google-research/torchsde";
61 license = licenses.asl20;
62 teams = [ teams.tts ];
63 };
64}