at master 573 B view raw
1diff --git a/objax/util/util.py b/objax/util/util.py 2index c31a356..344cf9a 100644 3--- a/objax/util/util.py 4+++ b/objax/util/util.py 5@@ -117,7 +117,8 @@ def get_local_devices(): 6 if _local_devices is None: 7 x = jn.zeros((jax.local_device_count(), 1), dtype=jn.float32) 8 sharded_x = map_to_device(x) 9- _local_devices = [b.device() for b in sharded_x.device_buffers] 10+ device_buffers = [buf.data for buf in sharded_x.addressable_shards] 11+ _local_devices = [list(b.devices())[0] for b in device_buffers] 12 return _local_devices 13 14