1{
2 lib,
3 buildPythonPackage,
4 fetchPypi,
5
6 # build-system
7 poetry-core,
8
9 # dependencies
10 jax,
11 jaxlib,
12 tensorflow-probability,
13
14 # tests
15 inference-gym,
16 pytestCheckHook,
17}:
18
19buildPythonPackage rec {
20 pname = "oryx";
21 version = "0.2.9";
22 pyproject = true;
23
24 # No more tags on GitHub. See https://github.com/jax-ml/oryx/issues/95
25 src = fetchPypi {
26 inherit pname version;
27 hash = "sha256-HlKUnguTNfs7gSqIJ0n2EjjLXPUgtI2JsQM70wKMeXs=";
28 };
29
30 build-system = [ poetry-core ];
31
32 dependencies = [
33 jax
34 jaxlib
35 tensorflow-probability
36 ];
37
38 pythonImportsCheck = [ "oryx" ];
39
40 nativeCheckInputs = [
41 inference-gym
42 pytestCheckHook
43 ];
44
45 disabledTests = [
46 # ValueError: Number of devices 1 must equal the product of mesh_shape (1, 2)
47 "test_plant"
48 "test_plant_before_shmap"
49 "test_plant_inside_shmap_fails"
50 "test_reap"
51 "test_reap_before_shmap"
52 "test_reap_inside_shmap_fails"
53
54 # ValueError: Variable has already been reaped
55 "test_call_list"
56 "test_call_tuple"
57 "test_dense_combinator"
58 "test_dense_function"
59 "test_dense_imperative"
60 "test_function_in_combinator_in_function"
61 "test_grad_of_function_with_literal"
62 "test_grad_of_shared_layer"
63 "test_grad_of_stateful_function"
64 "test_kwargs_rng"
65 "test_kwargs_training"
66 "test_kwargs_training_rng"
67 "test_reshape_call"
68 "test_scale_by_adam_should_scale_by_adam"
69 "test_scale_by_schedule_should_update_scale"
70 "test_scale_by_stddev_should_scale_by_stddev"
71 "test_trace_should_keep_track_of_momentum_with_nesterov"
72
73 # NotImplementedError: No registered inverse for `split`
74 "test_inverse_of_split"
75
76 # jax.errors.UnexpectedTracerError: Encountered an unexpected tracer
77 "test_can_plant_into_jvp_of_custom_jvp_function_unimplemented"
78 "test_forward_Scale"
79
80 # ValueError: No variable declared for assign: update_1
81 "test_optimizer_adam"
82 "test_optimizer_noisy_sgd"
83 "test_optimizer_rmsprop"
84 "test_optimizer_sgd"
85 "test_optimizer_sgd_with_momentum"
86 "test_optimizer_sgd_with_nesterov_momentum"
87
88 # AssertionError
89 # ACTUAL: array(-2.337877, dtype=float32)
90 # DESIRED: array(0., dtype=float32)
91 "test_can_map_over_batches_with_vmap_and_reduce_to_scalar_log_prob"
92 "test_vmapping_distribution_reduces_to_scalar_log_prob"
93
94 # TypeError: _dot_general_shape_rule() missing 1 required keyword-only argument: 'out_sharding'
95 "test_can_rewrite_dot_to_einsu"
96
97 # AttributeError: 'float' object has no attribute 'shape'
98 "test_add_noise_should_add_noise"
99 "test_apply_every_should_delay_updates"
100
101 # TypeError: Error interpreting argument to functools.partial(...) as an abstract array
102 "test_can_rewrite_nested_expression_into_single_einsum"
103 ];
104
105 disabledTestPaths = [
106 # ValueError: Variable has already been reaped
107 "oryx/experimental/nn/normalization_test.py"
108 "oryx/experimental/nn/pooling_test.py"
109 ];
110
111 meta = {
112 description = "Library for probabilistic programming and deep learning built on top of Jax";
113 homepage = "https://github.com/jax-ml/oryx";
114 changelog = "https://github.com/jax-ml/oryx/releases/tag/v${version}";
115 license = lib.licenses.asl20;
116 maintainers = with lib.maintainers; [ GaetanLepage ];
117 # oryx seems to be incompatible with jax 0.5.1
118 # 237 additional test failures are resulting from the jax bump.
119 broken = true;
120 };
121}