Skip to content

Commit 02d9ece

Browse files
Merge pull request #352 from robbievanleeuwen/optional-numba-fix
Fix optional njit decorator
2 parents 5a9c602 + af61b45 commit 02d9ece

File tree

2 files changed

+56
-65
lines changed

2 files changed

+56
-65
lines changed

src/sectionproperties/analysis/fea.py

Lines changed: 48 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
from __future__ import annotations
99

10-
import warnings
1110
from dataclasses import dataclass, field
1211
from functools import lru_cache
1312
from typing import TYPE_CHECKING, Any, Callable
@@ -23,54 +22,46 @@
2322
# numba is an optional dependency
2423
try:
2524
from numba import njit
26-
from numba.core.errors import NumbaPerformanceWarning
27-
28-
USE_NUMBA = True
2925
except ImportError:
3026

31-
def njit() -> None:
32-
"""Assigns empty function to njit if numba isn't installed.
27+
def njit(**options: Any) -> Callable[[Any], Any]:
28+
"""Empty decorator if numba is not installed.
29+
30+
Args:
31+
options: Optional keyword arguments for numba that are discarded.
3332
3433
Returns:
35-
None
34+
Empty njit decorator.
3635
"""
37-
return None
3836

39-
USE_NUMBA = False
37+
def decorator(func: Callable[[Any], Any]) -> Callable[[Any], Any]:
38+
"""Decorator.
4039
40+
Args:
41+
func: Function to decorate.
4142
42-
def conditional_decorator(
43-
dec: Callable[[Any], Any],
44-
condition: bool,
45-
) -> Callable[[Any], Any]:
46-
"""A decorator that applies a decorator only if a condition is True.
43+
Returns:
44+
Decorated function.
45+
"""
4746

48-
Args:
49-
dec: Decorator to apply
50-
condition: Apply decorator if this is true
47+
def wrapper(*args: Any, **kwargs: Any) -> Callable[[Any], Any]:
48+
"""Wrapper.
5149
52-
Returns:
53-
Decorator wrapper
54-
"""
50+
Args:
51+
args: Arguments.
52+
kwargs: Keyword arguments.
5553
56-
def decorator(func: Callable[[Any], Any]) -> Callable[[Any], Any]:
57-
"""Decorator wrapper.
58-
59-
Args:
60-
func: Function decorator operates on.
61-
62-
Returns:
63-
Original or decorated function.
64-
"""
65-
if not condition:
66-
return func
54+
Returns:
55+
Wrapped function.
56+
"""
57+
return func(*args, **kwargs) # type: ignore
6758

68-
return dec(func) # type: ignore
59+
return wrapper
6960

70-
return decorator
61+
return decorator
7162

7263

73-
@conditional_decorator(njit, USE_NUMBA)
64+
@njit(cache=True, nogil=True) # type: ignore
7465
def _assemble_torsion(
7566
k_el: npt.NDArray[np.float64],
7667
f_el: npt.NDArray[np.float64],
@@ -103,7 +94,7 @@ def _assemble_torsion(
10394
return k_el, f_el, c_el
10495

10596

106-
@conditional_decorator(njit, USE_NUMBA)
97+
@njit(cache=True, nogil=True) # type: ignore
10798
def _shear_parameter(
10899
nx: float, ny: float, ixx: float, iyy: float, ixy: float
109100
) -> tuple[float, float, float, float, float, float]:
@@ -129,7 +120,7 @@ def _shear_parameter(
129120
return r, q, d1, d2, h1, h2
130121

131122

132-
@conditional_decorator(njit, USE_NUMBA)
123+
@njit(cache=True, nogil=True) # type: ignore
133124
def _assemble_shear_load(
134125
f_psi: npt.NDArray[np.float64],
135126
f_phi: npt.NDArray[np.float64],
@@ -174,7 +165,7 @@ def _assemble_shear_load(
174165
return f_psi, f_phi
175166

176167

177-
@conditional_decorator(njit, USE_NUMBA)
168+
@njit(cache=True, nogil=True) # type: ignore
178169
def _assemble_shear_coefficients(
179170
kappa_x: float,
180171
kappa_y: float,
@@ -635,9 +626,9 @@ def element_stress(
635626
sig_zz_myy_gp = np.zeros(n_points)
636627
sig_zz_m11_gp = np.zeros(n_points)
637628
sig_zz_m22_gp = np.zeros(n_points)
638-
sig_zxy_mzz_gp = np.zeros((n_points, 2))
639-
sig_zxy_vx_gp = np.zeros((n_points, 2))
640-
sig_zxy_vy_gp = np.zeros((n_points, 2))
629+
sig_zxy_mzz_gp = np.zeros((n_points, 2), order="F")
630+
sig_zxy_vx_gp = np.zeros((n_points, 2), order="F")
631+
sig_zxy_vy_gp = np.zeros((n_points, 2), order="F")
641632

642633
# Gauss points for 6 point Gaussian integration
643634
gps = gauss_points(n=n_points)
@@ -694,21 +685,17 @@ def element_stress(
694685
* (b.dot(phi_shear) - nu / 2 * np.array([h1, h2]))
695686
)
696687

697-
# extrapolate results to nodes, ignore numba warnings about performance
698-
with warnings.catch_warnings():
699-
if USE_NUMBA:
700-
warnings.simplefilter("ignore", category=NumbaPerformanceWarning)
701-
702-
sig_zz_mxx = extrapolate_to_nodes(w=sig_zz_mxx_gp)
703-
sig_zz_myy = extrapolate_to_nodes(w=sig_zz_myy_gp)
704-
sig_zz_m11 = extrapolate_to_nodes(w=sig_zz_m11_gp)
705-
sig_zz_m22 = extrapolate_to_nodes(w=sig_zz_m22_gp)
706-
sig_zx_mzz = extrapolate_to_nodes(w=sig_zxy_mzz_gp[:, 0])
707-
sig_zy_mzz = extrapolate_to_nodes(w=sig_zxy_mzz_gp[:, 1])
708-
sig_zx_vx = extrapolate_to_nodes(w=sig_zxy_vx_gp[:, 0])
709-
sig_zy_vx = extrapolate_to_nodes(w=sig_zxy_vx_gp[:, 1])
710-
sig_zx_vy = extrapolate_to_nodes(w=sig_zxy_vy_gp[:, 0])
711-
sig_zy_vy = extrapolate_to_nodes(w=sig_zxy_vy_gp[:, 1])
688+
# extrapolate results to nodes
689+
sig_zz_mxx = extrapolate_to_nodes(w=sig_zz_mxx_gp)
690+
sig_zz_myy = extrapolate_to_nodes(w=sig_zz_myy_gp)
691+
sig_zz_m11 = extrapolate_to_nodes(w=sig_zz_m11_gp)
692+
sig_zz_m22 = extrapolate_to_nodes(w=sig_zz_m22_gp)
693+
sig_zx_mzz = extrapolate_to_nodes(w=sig_zxy_mzz_gp[:, 0])
694+
sig_zy_mzz = extrapolate_to_nodes(w=sig_zxy_mzz_gp[:, 1])
695+
sig_zx_vx = extrapolate_to_nodes(w=sig_zxy_vx_gp[:, 0])
696+
sig_zy_vx = extrapolate_to_nodes(w=sig_zxy_vx_gp[:, 1])
697+
sig_zx_vy = extrapolate_to_nodes(w=sig_zxy_vy_gp[:, 0])
698+
sig_zy_vy = extrapolate_to_nodes(w=sig_zxy_vy_gp[:, 1])
712699

713700
return (
714701
sig_zz_n,
@@ -1001,7 +988,7 @@ def gauss_points(*, n: int) -> npt.NDArray[np.float64]:
1001988

1002989

1003990
@lru_cache(maxsize=None)
1004-
@conditional_decorator(njit, USE_NUMBA)
991+
@njit(cache=True, nogil=True) # type: ignore
1005992
def __shape_function_cached(
1006993
coords: tuple[float, ...],
1007994
gauss_point: tuple[float, float, float],
@@ -1089,7 +1076,7 @@ def shape_function(
10891076

10901077

10911078
@lru_cache(maxsize=None)
1092-
@conditional_decorator(njit, USE_NUMBA)
1079+
@njit(cache=True, nogil=True) # type: ignore
10931080
def shape_function_only(p: tuple[float, float, float]) -> npt.NDArray[np.float64]:
10941081
"""The values of the ``Tri6`` shape function at a point ``p``.
10951082
@@ -1167,7 +1154,7 @@ def shape_function_only(p: tuple[float, float, float]) -> npt.NDArray[np.float64
11671154
)
11681155

11691156

1170-
@conditional_decorator(njit, USE_NUMBA)
1157+
@njit(cache=True, nogil=True) # type: ignore
11711158
def extrapolate_to_nodes(w: npt.NDArray[np.float64]) -> npt.NDArray[np.float64]:
11721159
"""Extrapolates results at six Gauss points to the six nodes of a ``Tri6`` element.
11731160
@@ -1180,7 +1167,7 @@ def extrapolate_to_nodes(w: npt.NDArray[np.float64]) -> npt.NDArray[np.float64]:
11801167
return h_inv @ w
11811168

11821169

1183-
@conditional_decorator(njit, USE_NUMBA)
1170+
@njit(cache=True, nogil=True) # type: ignore
11841171
def principal_coordinate(
11851172
phi: float,
11861173
x: float,
@@ -1203,7 +1190,7 @@ def principal_coordinate(
12031190
return x * cos_phi + y * sin_phi, y * cos_phi - x * sin_phi
12041191

12051192

1206-
@conditional_decorator(njit, USE_NUMBA)
1193+
@njit(cache=True, nogil=True) # type: ignore
12071194
def global_coordinate(
12081195
phi: float,
12091196
x11: float,
@@ -1226,7 +1213,7 @@ def global_coordinate(
12261213
return x11 * cos_phi - y22 * sin_phi, x11 * sin_phi + y22 * cos_phi
12271214

12281215

1229-
@conditional_decorator(njit, USE_NUMBA)
1216+
@njit(cache=True, nogil=True) # type: ignore
12301217
def point_above_line(
12311218
u: npt.NDArray[np.float64],
12321219
px: float,

src/sectionproperties/analysis/solver.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,9 @@ def solve_direct_lagrange(
121121
The solution vector to the linear system of equations
122122
123123
Raises:
124-
RuntimeError: If the Lagrangian multiplier method exceeds a tolerance of
125-
``1e-5``
124+
RuntimeError: If the Lagrangian multiplier method exceeds a relative tolerance
125+
of ``1e-7`` or absolute tolerance related to your machine's floating point
126+
precision.
126127
"""
127128
u = sp_solve(A=k_lg, b=np.append(f, 0))
128129

@@ -131,8 +132,11 @@ def solve_direct_lagrange(
131132
rel_error = multiplier / max(np.absolute(u))
132133

133134
if rel_error > 1e-7 and multiplier > 10.0 * np.finfo(float).eps:
134-
msg = "Lagrangian multiplier method error exceeds tolerance of 1e-5."
135-
raise RuntimeError(msg)
135+
raise RuntimeError(
136+
"Lagrangian multiplier method error exceeds the prescribed tolerance, "
137+
"consider refining your mesh. If this error is unexpected raise an issue "
138+
"at https://github.com/robbievanleeuwen/section-properties/issues."
139+
)
136140

137141
return u[:-1] # type: ignore
138142

0 commit comments

Comments
 (0)