@@ -408,21 +408,25 @@ def _choose_and_init_cluster_environment(self) -> ClusterEnvironment:
408408 return LightningEnvironment ()
409409
410410 def _choose_strategy (self ) -> Union [Strategy , str ]:
411- if self ._accelerator_flag == "hpu" :
412- if not _habana_available_and_importable ():
413- raise ImportError (
414- "You have asked for HPU but you miss install related integration."
415- " Please run `pip install lightning-habana` or see for further instructions"
416- " in https://github.com/Lightning-AI/lightning-Habana/."
417- )
418- if self ._parallel_devices and len (self ._parallel_devices ) > 1 :
419- from lightning_habana import HPUParallelStrategy
411+ if _habana_available_and_importable ():
412+ from lightning_habana import HPUAccelerator
420413
421- return HPUParallelStrategy .strategy_name
414+ if self ._accelerator_flag == "hpu" or isinstance (self ._accelerator_flag , HPUAccelerator ):
415+ if self ._parallel_devices and len (self ._parallel_devices ) > 1 :
416+ from lightning_habana import HPUParallelStrategy
422417
423- from lightning_habana import SingleHPUStrategy
418+ return HPUParallelStrategy .strategy_name
419+
420+ from lightning_habana import SingleHPUStrategy
421+
422+ return SingleHPUStrategy (device = torch .device ("hpu" ))
423+ if self ._accelerator_flag == "hpu" and not _habana_available_and_importable ():
424+ raise ImportError (
425+ "You asked to run with HPU but you are missing a required dependency."
426+ " Please run `pip install lightning-habana` or seek further instructions"
427+ " in https://github.com/Lightning-AI/lightning-Habana/."
428+ )
424429
425- return SingleHPUStrategy (device = torch .device ("hpu" ))
426430 if self ._accelerator_flag == "tpu" or isinstance (self ._accelerator_flag , XLAAccelerator ):
427431 if self ._parallel_devices and len (self ._parallel_devices ) > 1 :
428432 return XLAStrategy .strategy_name
0 commit comments