Skip to content

Commit fb3af12

Browse files
Simplify njit decorator, clarify LM tolerance error
1 parent b90bb54 commit fb3af12

File tree

2 files changed

+54
-66
lines changed

2 files changed

+54
-66
lines changed

src/sectionproperties/analysis/fea.py

Lines changed: 46 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
from __future__ import annotations
99

10-
import warnings
1110
from dataclasses import dataclass, field
1211
from functools import lru_cache
1312
from typing import TYPE_CHECKING, Any, Callable
@@ -23,54 +22,43 @@
2322
# numba is an optional dependency
2423
try:
2524
from numba import njit
26-
from numba.core.errors import NumbaPerformanceWarning
27-
28-
USE_NUMBA = True
2925
except ImportError:
3026

31-
def njit() -> None:
32-
"""Assigns empty function to njit if numba isn't installed.
27+
def njit(**options: Any) -> Callable[[Any], Any]:
28+
"""Empty decorator if numba is not installed.
3329
34-
Returns:
35-
None
30+
Args:
31+
options: Optional keyword arguments for numba that are discarded.
3632
"""
37-
return None
3833

39-
USE_NUMBA = False
34+
def decorator(func: Callable[[Any], Any]) -> Callable[[Any], Any]:
35+
"""Decorator.
4036
37+
Args:
38+
func: Function to decorate.
4139
42-
def conditional_decorator(
43-
dec: Callable[[Any], Any],
44-
condition: bool,
45-
) -> Callable[[Any], Any]:
46-
"""A decorator that applies a decorator only if a condition is True.
40+
Returns:
41+
Decorated function.
42+
"""
4743

48-
Args:
49-
dec: Decorator to apply
50-
condition: Apply decorator if this is true
44+
def wrapper(*args: Any, **kwargs: Any) -> Callable[[Any], Any]:
45+
"""Wrapper.
5146
52-
Returns:
53-
Decorator wrapper
54-
"""
47+
Args:
48+
args: Arguments.
49+
kwargs: Keyword arguments.
5550
56-
def decorator(func: Callable[[Any], Any]) -> Callable[[Any], Any]:
57-
"""Decorator wrapper.
58-
59-
Args:
60-
func: Function decorator operates on.
61-
62-
Returns:
63-
Original or decorated function.
64-
"""
65-
if not condition:
66-
return func
51+
Returns:
52+
Wrapped function.
53+
"""
54+
return func(*args, **kwargs) # type: ignore
6755

68-
return dec(func) # type: ignore
56+
return wrapper
6957

70-
return decorator
58+
return decorator
7159

7260

73-
@conditional_decorator(njit, USE_NUMBA)
61+
@njit(cache=True, nogil=True) # type: ignore
7462
def _assemble_torsion(
7563
k_el: npt.NDArray[np.float64],
7664
f_el: npt.NDArray[np.float64],
@@ -103,7 +91,7 @@ def _assemble_torsion(
10391
return k_el, f_el, c_el
10492

10593

106-
@conditional_decorator(njit, USE_NUMBA)
94+
@njit(cache=True, nogil=True) # type: ignore
10795
def _shear_parameter(
10896
nx: float, ny: float, ixx: float, iyy: float, ixy: float
10997
) -> tuple[float, float, float, float, float, float]:
@@ -129,7 +117,7 @@ def _shear_parameter(
129117
return r, q, d1, d2, h1, h2
130118

131119

132-
@conditional_decorator(njit, USE_NUMBA)
120+
@njit(cache=True, nogil=True) # type: ignore
133121
def _assemble_shear_load(
134122
f_psi: npt.NDArray[np.float64],
135123
f_phi: npt.NDArray[np.float64],
@@ -174,7 +162,7 @@ def _assemble_shear_load(
174162
return f_psi, f_phi
175163

176164

177-
@conditional_decorator(njit, USE_NUMBA)
165+
@njit(cache=True, nogil=True) # type: ignore
178166
def _assemble_shear_coefficients(
179167
kappa_x: float,
180168
kappa_y: float,
@@ -635,9 +623,9 @@ def element_stress(
635623
sig_zz_myy_gp = np.zeros(n_points)
636624
sig_zz_m11_gp = np.zeros(n_points)
637625
sig_zz_m22_gp = np.zeros(n_points)
638-
sig_zxy_mzz_gp = np.zeros((n_points, 2))
639-
sig_zxy_vx_gp = np.zeros((n_points, 2))
640-
sig_zxy_vy_gp = np.zeros((n_points, 2))
626+
sig_zxy_mzz_gp = np.zeros((n_points, 2), order="F")
627+
sig_zxy_vx_gp = np.zeros((n_points, 2), order="F")
628+
sig_zxy_vy_gp = np.zeros((n_points, 2), order="F")
641629

642630
# Gauss points for 6 point Gaussian integration
643631
gps = gauss_points(n=n_points)
@@ -694,21 +682,17 @@ def element_stress(
694682
* (b.dot(phi_shear) - nu / 2 * np.array([h1, h2]))
695683
)
696684

697-
# extrapolate results to nodes, ignore numba warnings about performance
698-
with warnings.catch_warnings():
699-
if USE_NUMBA:
700-
warnings.simplefilter("ignore", category=NumbaPerformanceWarning)
701-
702-
sig_zz_mxx = extrapolate_to_nodes(w=sig_zz_mxx_gp)
703-
sig_zz_myy = extrapolate_to_nodes(w=sig_zz_myy_gp)
704-
sig_zz_m11 = extrapolate_to_nodes(w=sig_zz_m11_gp)
705-
sig_zz_m22 = extrapolate_to_nodes(w=sig_zz_m22_gp)
706-
sig_zx_mzz = extrapolate_to_nodes(w=sig_zxy_mzz_gp[:, 0])
707-
sig_zy_mzz = extrapolate_to_nodes(w=sig_zxy_mzz_gp[:, 1])
708-
sig_zx_vx = extrapolate_to_nodes(w=sig_zxy_vx_gp[:, 0])
709-
sig_zy_vx = extrapolate_to_nodes(w=sig_zxy_vx_gp[:, 1])
710-
sig_zx_vy = extrapolate_to_nodes(w=sig_zxy_vy_gp[:, 0])
711-
sig_zy_vy = extrapolate_to_nodes(w=sig_zxy_vy_gp[:, 1])
685+
# extrapolate results to nodes
686+
sig_zz_mxx = extrapolate_to_nodes(w=sig_zz_mxx_gp)
687+
sig_zz_myy = extrapolate_to_nodes(w=sig_zz_myy_gp)
688+
sig_zz_m11 = extrapolate_to_nodes(w=sig_zz_m11_gp)
689+
sig_zz_m22 = extrapolate_to_nodes(w=sig_zz_m22_gp)
690+
sig_zx_mzz = extrapolate_to_nodes(w=sig_zxy_mzz_gp[:, 0])
691+
sig_zy_mzz = extrapolate_to_nodes(w=sig_zxy_mzz_gp[:, 1])
692+
sig_zx_vx = extrapolate_to_nodes(w=sig_zxy_vx_gp[:, 0])
693+
sig_zy_vx = extrapolate_to_nodes(w=sig_zxy_vx_gp[:, 1])
694+
sig_zx_vy = extrapolate_to_nodes(w=sig_zxy_vy_gp[:, 0])
695+
sig_zy_vy = extrapolate_to_nodes(w=sig_zxy_vy_gp[:, 1])
712696

713697
return (
714698
sig_zz_n,
@@ -1001,7 +985,7 @@ def gauss_points(*, n: int) -> npt.NDArray[np.float64]:
1001985

1002986

1003987
@lru_cache(maxsize=None)
1004-
@conditional_decorator(njit, USE_NUMBA)
988+
@njit(cache=True, nogil=True) # type: ignore
1005989
def __shape_function_cached(
1006990
coords: tuple[float, ...],
1007991
gauss_point: tuple[float, float, float],
@@ -1089,7 +1073,7 @@ def shape_function(
10891073

10901074

10911075
@lru_cache(maxsize=None)
1092-
@conditional_decorator(njit, USE_NUMBA)
1076+
@njit(cache=True, nogil=True) # type: ignore
10931077
def shape_function_only(p: tuple[float, float, float]) -> npt.NDArray[np.float64]:
10941078
"""The values of the ``Tri6`` shape function at a point ``p``.
10951079
@@ -1167,7 +1151,7 @@ def shape_function_only(p: tuple[float, float, float]) -> npt.NDArray[np.float64
11671151
)
11681152

11691153

1170-
@conditional_decorator(njit, USE_NUMBA)
1154+
@njit(cache=True, nogil=True) # type: ignore
11711155
def extrapolate_to_nodes(w: npt.NDArray[np.float64]) -> npt.NDArray[np.float64]:
11721156
"""Extrapolates results at six Gauss points to the six nodes of a ``Tri6`` element.
11731157
@@ -1180,7 +1164,7 @@ def extrapolate_to_nodes(w: npt.NDArray[np.float64]) -> npt.NDArray[np.float64]:
11801164
return h_inv @ w
11811165

11821166

1183-
@conditional_decorator(njit, USE_NUMBA)
1167+
@njit(cache=True, nogil=True) # type: ignore
11841168
def principal_coordinate(
11851169
phi: float,
11861170
x: float,
@@ -1203,7 +1187,7 @@ def principal_coordinate(
12031187
return x * cos_phi + y * sin_phi, y * cos_phi - x * sin_phi
12041188

12051189

1206-
@conditional_decorator(njit, USE_NUMBA)
1190+
@njit(cache=True, nogil=True) # type: ignore
12071191
def global_coordinate(
12081192
phi: float,
12091193
x11: float,
@@ -1226,7 +1210,7 @@ def global_coordinate(
12261210
return x11 * cos_phi - y22 * sin_phi, x11 * sin_phi + y22 * cos_phi
12271211

12281212

1229-
@conditional_decorator(njit, USE_NUMBA)
1213+
@njit(cache=True, nogil=True) # type: ignore
12301214
def point_above_line(
12311215
u: npt.NDArray[np.float64],
12321216
px: float,

src/sectionproperties/analysis/solver.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,9 @@ def solve_direct_lagrange(
121121
The solution vector to the linear system of equations
122122
123123
Raises:
124-
RuntimeError: If the Lagrangian multiplier method exceeds a tolerance of
125-
``1e-5``
124+
RuntimeError: If the Lagrangian multiplier method exceeds a relative tolerance
125+
of ``1e-7`` or absolute tolerance related to your machine's floating point
126+
precision.
126127
"""
127128
u = sp_solve(A=k_lg, b=np.append(f, 0))
128129

@@ -131,8 +132,11 @@ def solve_direct_lagrange(
131132
rel_error = multiplier / max(np.absolute(u))
132133

133134
if rel_error > 1e-7 and multiplier > 10.0 * np.finfo(float).eps:
134-
msg = "Lagrangian multiplier method error exceeds tolerance of 1e-5."
135-
raise RuntimeError(msg)
135+
raise RuntimeError(
136+
"Lagrangian multiplier method error exceeds the prescribed tolerance, "
137+
"consider refining your mesh. If this error is unexpected raise an issue "
138+
"at https://github.com/robbievanleeuwen/section-properties/issues."
139+
)
136140

137141
return u[:-1] # type: ignore
138142

0 commit comments

Comments
 (0)