1{
2 lib,
3 buildPythonPackage,
4 fetchFromGitHub,
5 pytestCheckHook,
6 setuptools,
7 setuptools-scm,
8 numpy,
9 jaxlib,
10 jax,
11 torch,
12 dask,
13 sparse,
14 array-api-strict,
15 config,
16 cudaSupport ? config.cudaSupport,
17 cupy,
18}:
19
20buildPythonPackage rec {
21 pname = "array-api-compat";
22 version = "1.12";
23 pyproject = true;
24
25 src = fetchFromGitHub {
26 owner = "data-apis";
27 repo = "array-api-compat";
28 tag = version;
29 hash = "sha256-Hb0bFjVMl4CBI3gN3abTO2QUPAOvUaFE0GdPjdops5E=";
30 };
31
32 build-system = [
33 setuptools
34 setuptools-scm
35 ];
36
37 nativeCheckInputs = [
38 pytestCheckHook
39 numpy
40 jaxlib
41 jax
42 torch
43 dask
44 sparse
45 array-api-strict
46 ]
47 ++ lib.optionals cudaSupport [ cupy ];
48
49 pythonImportsCheck = [ "array_api_compat" ];
50
51 # CUDA (used via cupy) is not available in the testing sandbox
52 disabledTests = [
53 "cupy"
54 ];
55
56 meta = {
57 homepage = "https://data-apis.org/array-api-compat";
58 changelog = "https://github.com/data-apis/array-api-compat/releases/tag/${src.tag}";
59 description = "Compatibility layer for NumPy to support the Python array API";
60 license = lib.licenses.mit;
61 maintainers = with lib.maintainers; [ berquist ];
62 };
63}