1diff --git a/objax/util/util.py b/objax/util/util.py
2index c31a356..344cf9a 100644
3--- a/objax/util/util.py
4+++ b/objax/util/util.py
5@@ -117,7 +117,8 @@ def get_local_devices():
6 if _local_devices is None:
7 x = jn.zeros((jax.local_device_count(), 1), dtype=jn.float32)
8 sharded_x = map_to_device(x)
9- _local_devices = [b.device() for b in sharded_x.device_buffers]
10+ device_buffers = [buf.data for buf in sharded_x.addressable_shards]
11+ _local_devices = [list(b.devices())[0] for b in device_buffers]
12 return _local_devices
13
14