Skip to content

Commit d662ac7

Browse files
Adding get_device_count function to the distribution_lib (#21791)
* adding get_device_count * adding tests * fixed docstring for get_device_count
1 parent 1519bcc commit d662ac7

File tree

5 files changed

+40
-2
lines changed

5 files changed

+40
-2
lines changed

keras/api/_tf_keras/keras/distribution/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
distribute_tensor as distribute_tensor,
1616
)
1717
from keras.src.distribution.distribution_lib import distribution as distribution
18+
from keras.src.distribution.distribution_lib import (
19+
get_device_count as get_device_count,
20+
)
1821
from keras.src.distribution.distribution_lib import initialize as initialize
1922
from keras.src.distribution.distribution_lib import list_devices as list_devices
2023
from keras.src.distribution.distribution_lib import (

keras/api/distribution/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
distribute_tensor as distribute_tensor,
1616
)
1717
from keras.src.distribution.distribution_lib import distribution as distribution
18+
from keras.src.distribution.distribution_lib import (
19+
get_device_count as get_device_count,
20+
)
1821
from keras.src.distribution.distribution_lib import initialize as initialize
1922
from keras.src.distribution.distribution_lib import list_devices as list_devices
2023
from keras.src.distribution.distribution_lib import (

keras/src/backend/jax/distribution_lib.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,20 @@ def list_devices(device_type=None):
2727
return [f"{device.platform}:{device.id}" for device in jax_devices]
2828

2929

30+
def get_device_count(device_type=None):
31+
"""Returns the number of available JAX devices.
32+
Args:
33+
device_type: Optional device type to count (e.g., "cpu", "gpu", "tpu").
34+
If `None`, it defaults to counting "gpu" or "tpu" devices if
35+
available, otherwise it counts "cpu" devices. It does not
36+
return the sum of all device types.
37+
Returns:
38+
int: The total number of JAX devices for the specified type.
39+
"""
40+
device_type = device_type.lower() if device_type else None
41+
return jax.device_count(device_type)
42+
43+
3044
def distribute_variable(value, layout):
3145
"""Create a distributed variable for JAX.
3246

keras/src/backend/jax/distribution_lib_test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@
2929

3030

3131
@pytest.mark.skipif(
32-
backend.backend() != "jax",
33-
reason="Backend specific test",
32+
backend.backend() != "jax" or len(jax.devices()) != 8,
33+
reason="Backend specific test and requires 8 devices",
3434
)
3535
class JaxDistributionLibTest(testing.TestCase):
3636
def _create_jax_layout(self, sharding):
@@ -42,6 +42,10 @@ def _create_jax_layout(self, sharding):
4242

4343
return sharding
4444

45+
def test_get_device_count(self):
46+
self.assertEqual(backend_dlib.get_device_count(), 8)
47+
self.assertEqual(backend_dlib.get_device_count("cpu"), 8)
48+
4549
def test_list_devices(self):
4650
self.assertEqual(len(distribution_lib.list_devices()), 8)
4751
self.assertEqual(len(distribution_lib.list_devices("cpu")), 8)

keras/src/distribution/distribution_lib.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,20 @@ def list_devices(device_type=None):
3939
return distribution_lib.list_devices(device_type)
4040

4141

42+
@keras_export("keras.distribution.get_device_count")
43+
def get_device_count(device_type=None):
44+
"""Returns the number of available JAX devices.
45+
Args:
46+
device_type: Optional device type to count (e.g., "cpu", "gpu", "tpu").
47+
If `None`, it defaults to counting "gpu" or "tpu" devices if
48+
available, otherwise it counts "cpu" devices. It does not
49+
return the sum of all device types.
50+
Returns:
51+
int: The total number of JAX devices for the specified type.
52+
"""
53+
return distribution_lib.get_device_count(device_type=device_type)
54+
55+
4256
@keras_export("keras.distribution.initialize")
4357
def initialize(job_addresses=None, num_processes=None, process_id=None):
4458
"""Initialize the distribution system for multi-host/process setting.

0 commit comments

Comments
 (0)