Skip to content

Commit 1c271cb

Browse files
Make numba optional
1 parent 9030803 commit 1c271cb

File tree

6 files changed

+94
-21
lines changed

6 files changed

+94
-21
lines changed

.readthedocs.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,4 @@ python:
2323
extra_requirements:
2424
- dxf
2525
- rhino
26+
- numba

docs/installation.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,18 @@ package index:
2626
2727
pip install sectionproperties
2828
29+
Installing ``Numba``
30+
--------------------
31+
32+
``Numba`` translates a subset of Python and NumPy code into fast machine code, allowing
33+
algorithms to approach the speeds of C. The speed of several ``sectionproperties``
34+
analysis functions have been enhanced with `numba <https://github.com/numba/numba>`_.
35+
To take advantage of this increase in performance you can install ``numba`` alongside
36+
``sectionproperties`` with:
37+
38+
.. code-block:: shell
39+
40+
pip install sectionproperties[numba]
2941
3042
Installing ``PARDISO`` Solver
3143
-----------------------------

noxfile.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,13 @@ def docs_build(session: Session) -> None:
212212
args.insert(0, "--color")
213213

214214
session.run_always(
215-
"poetry", "install", "--only", "main", "--extras", "dxf rhino", external=True
215+
"poetry",
216+
"install",
217+
"--only",
218+
"main",
219+
"--extras",
220+
"dxf rhino numba",
221+
external=True,
216222
)
217223
session.install(
218224
"furo",
@@ -243,7 +249,13 @@ def docs(session: Session) -> None:
243249
"""
244250
args = session.posargs or ["--open-browser", "docs", "docs/_build"]
245251
session.run_always(
246-
"poetry", "install", "--only", "main", "--extras", "dxf rhino", external=True
252+
"poetry",
253+
"install",
254+
"--only",
255+
"main",
256+
"--extras",
257+
"dxf rhino numba",
258+
external=True,
247259
)
248260
session.install(
249261
"furo",

poetry.lock

Lines changed: 4 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,15 @@ Changelog = "https://github.com/robbievanleeuwen/section-properties/releases"
4848

4949
[tool.poetry.dependencies]
5050
python = ">=3.9.0,<3.12"
51-
numpy = "^1.25.2" # numba requires numpy <1.26
51+
numpy = "^1.25.2"
5252
scipy = "^1.11.3"
5353
matplotlib = "^3.8.0"
5454
shapely = "^2.0.1"
5555
triangle = "^20230923"
5656
rich = "^13.6.0"
5757
click = "^8.1.7"
5858
more-itertools = "^10.1.0"
59-
numba = "^0.58.0"
59+
numba = { version = "^0.58.0", optional = true }
6060
cad-to-shapely = { version = "^0.3.1", optional = true }
6161
rhino-shapley-interop = { version = "^0.0.4", optional = true }
6262
rhino3dm = { version = "==8.0.0b3", optional = true }
@@ -97,6 +97,7 @@ sphinxext-opengraph = "^0.8.2"
9797
[tool.poetry.extras]
9898
dxf = ["cad-to-shapely"]
9999
rhino = ["rhino-shapley-interop", "rhino3dm"]
100+
numba = ["numba"]
100101
pardiso = ["pypardiso"]
101102

102103
[tool.poetry.scripts]

src/sectionproperties/analysis/fea.py

Lines changed: 60 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,63 @@
1010
import warnings
1111
from dataclasses import dataclass, field
1212
from functools import lru_cache
13-
from typing import TYPE_CHECKING
13+
from typing import TYPE_CHECKING, Any, Callable
1414

1515
import numpy as np
1616
import numpy.typing as npt
17-
from numba import njit
18-
from numba.core.errors import NumbaPerformanceWarning
1917

2018

2119
if TYPE_CHECKING:
2220
from sectionproperties.pre.pre import Material
2321

2422

25-
@njit(cache=True, nogil=True) # type: ignore
23+
# numba is an optional dependency
24+
try:
25+
from numba import njit
26+
from numba.core.errors import NumbaPerformanceWarning
27+
28+
USE_NUMBA = True
29+
except ImportError:
30+
31+
def njit() -> None:
32+
"""Assigns empty function to njit if numba isn't installed."""
33+
return None
34+
35+
USE_NUMBA = False
36+
37+
38+
def conditional_decorator(
39+
dec: Callable[[Any], Any],
40+
condition: bool,
41+
) -> Callable[[Any], Any]:
42+
"""A decorator that applies a decorator only if a condition is True.
43+
44+
Args:
45+
dec: Decorator to apply
46+
condition: Apply decorator if this is true
47+
48+
Returns:
49+
Decorator wrapper
50+
"""
51+
52+
def decorator(func: Callable[[Any], Any]) -> Callable[[Any], Any]:
53+
"""Decorator wrapper.
54+
55+
Args:
56+
func: Function decorator operates on.
57+
58+
Returns:
59+
Original or decorated function.
60+
"""
61+
if not condition:
62+
return func
63+
64+
return dec(func) # type: ignore
65+
66+
return decorator
67+
68+
69+
@conditional_decorator(njit, USE_NUMBA)
2670
def _assemble_torsion(
2771
k_el: npt.NDArray[np.float64],
2872
f_el: npt.NDArray[np.float64],
@@ -55,7 +99,7 @@ def _assemble_torsion(
5599
return k_el, f_el, c_el
56100

57101

58-
@njit(cache=True, nogil=True) # type: ignore
102+
@conditional_decorator(njit, USE_NUMBA)
59103
def _shear_parameter(
60104
nx: float, ny: float, ixx: float, iyy: float, ixy: float
61105
) -> tuple[float, float, float, float, float, float]:
@@ -81,7 +125,7 @@ def _shear_parameter(
81125
return r, q, d1, d2, h1, h2
82126

83127

84-
@njit(cache=True, nogil=True) # type: ignore
128+
@conditional_decorator(njit, USE_NUMBA)
85129
def _assemble_shear_load(
86130
f_psi: npt.NDArray[np.float64],
87131
f_phi: npt.NDArray[np.float64],
@@ -126,7 +170,7 @@ def _assemble_shear_load(
126170
return f_psi, f_phi
127171

128172

129-
@njit(cache=True, nogil=True) # type: ignore
173+
@conditional_decorator(njit, USE_NUMBA)
130174
def _assemble_shear_coefficients(
131175
kappa_x: float,
132176
kappa_y: float,
@@ -648,7 +692,9 @@ def element_stress(
648692

649693
# extrapolate results to nodes, ignore numba warnings about performance
650694
with warnings.catch_warnings():
651-
warnings.simplefilter("ignore", category=NumbaPerformanceWarning)
695+
if USE_NUMBA:
696+
warnings.simplefilter("ignore", category=NumbaPerformanceWarning)
697+
652698
sig_zz_mxx = extrapolate_to_nodes(w=sig_zz_mxx_gp)
653699
sig_zz_myy = extrapolate_to_nodes(w=sig_zz_myy_gp)
654700
sig_zz_m11 = extrapolate_to_nodes(w=sig_zz_m11_gp)
@@ -951,7 +997,7 @@ def gauss_points(*, n: int) -> npt.NDArray[np.float64]:
951997

952998

953999
@lru_cache(maxsize=None)
954-
@njit(cache=True, nogil=True) # type: ignore
1000+
@conditional_decorator(njit, USE_NUMBA)
9551001
def __shape_function_cached(
9561002
coords: tuple[float, ...],
9571003
gauss_point: tuple[float, float, float],
@@ -1039,7 +1085,7 @@ def shape_function(
10391085

10401086

10411087
@lru_cache(maxsize=None)
1042-
@njit(cache=True, nogil=True) # type: ignore
1088+
@conditional_decorator(njit, USE_NUMBA)
10431089
def shape_function_only(p: tuple[float, float, float]) -> npt.NDArray[np.float64]:
10441090
"""The values of the ``Tri6`` shape function at a point ``p``.
10451091
@@ -1117,7 +1163,7 @@ def shape_function_only(p: tuple[float, float, float]) -> npt.NDArray[np.float64
11171163
)
11181164

11191165

1120-
@njit(cache=True, nogil=True) # type: ignore
1166+
@conditional_decorator(njit, USE_NUMBA)
11211167
def extrapolate_to_nodes(w: npt.NDArray[np.float64]) -> npt.NDArray[np.float64]:
11221168
"""Extrapolates results at six Gauss points to the six nodes of a ``Tri6`` element.
11231169
@@ -1130,7 +1176,7 @@ def extrapolate_to_nodes(w: npt.NDArray[np.float64]) -> npt.NDArray[np.float64]:
11301176
return h_inv @ w
11311177

11321178

1133-
@njit(cache=True, nogil=True) # type: ignore
1179+
@conditional_decorator(njit, USE_NUMBA)
11341180
def principal_coordinate(
11351181
phi: float,
11361182
x: float,
@@ -1153,7 +1199,7 @@ def principal_coordinate(
11531199
return x * cos_phi + y * sin_phi, y * cos_phi - x * sin_phi
11541200

11551201

1156-
@njit(cache=True, nogil=True) # type: ignore
1202+
@conditional_decorator(njit, USE_NUMBA)
11571203
def global_coordinate(
11581204
phi: float,
11591205
x11: float,
@@ -1176,7 +1222,7 @@ def global_coordinate(
11761222
return x11 * cos_phi - y22 * sin_phi, x11 * sin_phi + y22 * cos_phi
11771223

11781224

1179-
@njit(cache=True, nogil=True) # type: ignore
1225+
@conditional_decorator(njit, USE_NUMBA)
11801226
def point_above_line(
11811227
u: npt.NDArray[np.float64],
11821228
px: float,

0 commit comments

Comments
 (0)