Skip to content

Commit f19b02c

Browse files
NilsChudallaLeguark
authored andcommitted
More general check, not crashing in "no pytorch" environemtns
1 parent 94b200c commit f19b02c

File tree

2 files changed

+9
-17
lines changed

2 files changed

+9
-17
lines changed

gempy/modules/mesh_extranction/marching_cubes.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -73,29 +73,23 @@ def extract_mesh_for_element(structural_element: StructuralElement,
7373
mask : np.ndarray, optional
7474
Optional mask to restrict the mesh extraction to specific regions.
7575
"""
76-
if os.environ["DEFAULT_BACKEND"] == "PYTORCH":
76+
if type(scalar_field).__module__ == 'torch':
7777
import torch
7878
scalar_field = torch.to_numpy(scalar_field)
79-
if mask.dtype == torch.bool:
80-
mask = torch.to_numpy(mask)
81-
verts, faces, _, _ = measure.marching_cubes(
79+
if type(mask).__module__ == "torch":
80+
import torch
81+
mask = torch.to_numpy(mask)
82+
83+
84+
# Extract mesh using marching cubes
85+
verts, faces, _, _ = measure.marching_cubes(
8286
volume=scalar_field.reshape(regular_grid.resolution),
8387
level=structural_element.scalar_field_at_interface,
8488
spacing=(regular_grid.dx, regular_grid.dy, regular_grid.dz),
8589
mask=mask.reshape(regular_grid.resolution) if mask is not None else None,
8690
allow_degenerate=False,
8791
method="lewiner"
88-
)
89-
else:
90-
# Extract mesh using marching cubes
91-
verts, faces, _, _ = measure.marching_cubes(
92-
volume=scalar_field.reshape(regular_grid.resolution),
93-
level=structural_element.scalar_field_at_interface,
94-
spacing=(regular_grid.dx, regular_grid.dy, regular_grid.dz),
95-
mask=mask.reshape(regular_grid.resolution) if mask is not None else None,
96-
allow_degenerate=False,
97-
method="lewiner"
98-
)
92+
)
9993

10094
# Adjust vertices to correct coordinates in the model's extent
10195
verts = (verts + [regular_grid.extent[0],

test/test_modules/test_marching_cubes_pytorch.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ def test_marching_cubes_implementation():
3434
grid_type=[model.grid.GridTypes.DENSE],
3535
reset=True
3636
)
37-
print("here")
3837
model.interpolation_options = gp.data.InterpolationOptions.init_dense_grid_options()
3938
gp.compute_model(model)
4039

@@ -57,7 +56,6 @@ def test_marching_cubes_implementation():
5756

5857
if PLOT:
5958
gpv = require_gempy_viewer()
60-
gpv.plot_2d(model=model)
6159
gtv: gpv.GemPyToVista = gpv.plot_3d(
6260
model=model,
6361
show_data=True,

0 commit comments

Comments
 (0)