at master 3.4 kB view raw
1{ 2 lib, 3 buildPythonPackage, 4 fetchPypi, 5 6 # build-system 7 poetry-core, 8 9 # dependencies 10 jax, 11 jaxlib, 12 tensorflow-probability, 13 14 # tests 15 inference-gym, 16 pytestCheckHook, 17}: 18 19buildPythonPackage rec { 20 pname = "oryx"; 21 version = "0.2.9"; 22 pyproject = true; 23 24 # No more tags on GitHub. See https://github.com/jax-ml/oryx/issues/95 25 src = fetchPypi { 26 inherit pname version; 27 hash = "sha256-HlKUnguTNfs7gSqIJ0n2EjjLXPUgtI2JsQM70wKMeXs="; 28 }; 29 30 build-system = [ poetry-core ]; 31 32 dependencies = [ 33 jax 34 jaxlib 35 tensorflow-probability 36 ]; 37 38 pythonImportsCheck = [ "oryx" ]; 39 40 nativeCheckInputs = [ 41 inference-gym 42 pytestCheckHook 43 ]; 44 45 disabledTests = [ 46 # ValueError: Number of devices 1 must equal the product of mesh_shape (1, 2) 47 "test_plant" 48 "test_plant_before_shmap" 49 "test_plant_inside_shmap_fails" 50 "test_reap" 51 "test_reap_before_shmap" 52 "test_reap_inside_shmap_fails" 53 54 # ValueError: Variable has already been reaped 55 "test_call_list" 56 "test_call_tuple" 57 "test_dense_combinator" 58 "test_dense_function" 59 "test_dense_imperative" 60 "test_function_in_combinator_in_function" 61 "test_grad_of_function_with_literal" 62 "test_grad_of_shared_layer" 63 "test_grad_of_stateful_function" 64 "test_kwargs_rng" 65 "test_kwargs_training" 66 "test_kwargs_training_rng" 67 "test_reshape_call" 68 "test_scale_by_adam_should_scale_by_adam" 69 "test_scale_by_schedule_should_update_scale" 70 "test_scale_by_stddev_should_scale_by_stddev" 71 "test_trace_should_keep_track_of_momentum_with_nesterov" 72 73 # NotImplementedError: No registered inverse for `split` 74 "test_inverse_of_split" 75 76 # jax.errors.UnexpectedTracerError: Encountered an unexpected tracer 77 "test_can_plant_into_jvp_of_custom_jvp_function_unimplemented" 78 "test_forward_Scale" 79 80 # ValueError: No variable declared for assign: update_1 81 "test_optimizer_adam" 82 "test_optimizer_noisy_sgd" 83 "test_optimizer_rmsprop" 84 "test_optimizer_sgd" 85 "test_optimizer_sgd_with_momentum" 86 "test_optimizer_sgd_with_nesterov_momentum" 87 88 # AssertionError 89 # ACTUAL: array(-2.337877, dtype=float32) 90 # DESIRED: array(0., dtype=float32) 91 "test_can_map_over_batches_with_vmap_and_reduce_to_scalar_log_prob" 92 "test_vmapping_distribution_reduces_to_scalar_log_prob" 93 94 # TypeError: _dot_general_shape_rule() missing 1 required keyword-only argument: 'out_sharding' 95 "test_can_rewrite_dot_to_einsu" 96 97 # AttributeError: 'float' object has no attribute 'shape' 98 "test_add_noise_should_add_noise" 99 "test_apply_every_should_delay_updates" 100 101 # TypeError: Error interpreting argument to functools.partial(...) as an abstract array 102 "test_can_rewrite_nested_expression_into_single_einsum" 103 ]; 104 105 disabledTestPaths = [ 106 # ValueError: Variable has already been reaped 107 "oryx/experimental/nn/normalization_test.py" 108 "oryx/experimental/nn/pooling_test.py" 109 ]; 110 111 meta = { 112 description = "Library for probabilistic programming and deep learning built on top of Jax"; 113 homepage = "https://github.com/jax-ml/oryx"; 114 changelog = "https://github.com/jax-ml/oryx/releases/tag/v${version}"; 115 license = lib.licenses.asl20; 116 maintainers = with lib.maintainers; [ GaetanLepage ]; 117 # oryx seems to be incompatible with jax 0.5.1 118 # 237 additional test failures are resulting from the jax bump. 119 broken = true; 120 }; 121}