@@ -73,29 +73,23 @@ def extract_mesh_for_element(structural_element: StructuralElement,
7373 mask : np.ndarray, optional
7474 Optional mask to restrict the mesh extraction to specific regions.
7575 """
76- if os . environ [ "DEFAULT_BACKEND" ] == "PYTORCH" :
76+ if type ( scalar_field ). __module__ == 'torch' :
7777 import torch
7878 scalar_field = torch .to_numpy (scalar_field )
79- if mask .dtype == torch .bool :
80- mask = torch .to_numpy (mask )
81- verts , faces , _ , _ = measure .marching_cubes (
79+ if type (mask ).__module__ == "torch" :
80+ import torch
81+ mask = torch .to_numpy (mask )
82+
83+
84+ # Extract mesh using marching cubes
85+ verts , faces , _ , _ = measure .marching_cubes (
8286 volume = scalar_field .reshape (regular_grid .resolution ),
8387 level = structural_element .scalar_field_at_interface ,
8488 spacing = (regular_grid .dx , regular_grid .dy , regular_grid .dz ),
8589 mask = mask .reshape (regular_grid .resolution ) if mask is not None else None ,
8690 allow_degenerate = False ,
8791 method = "lewiner"
88- )
89- else :
90- # Extract mesh using marching cubes
91- verts , faces , _ , _ = measure .marching_cubes (
92- volume = scalar_field .reshape (regular_grid .resolution ),
93- level = structural_element .scalar_field_at_interface ,
94- spacing = (regular_grid .dx , regular_grid .dy , regular_grid .dz ),
95- mask = mask .reshape (regular_grid .resolution ) if mask is not None else None ,
96- allow_degenerate = False ,
97- method = "lewiner"
98- )
92+ )
9993
10094 # Adjust vertices to correct coordinates in the model's extent
10195 verts = (verts + [regular_grid .extent [0 ],
0 commit comments