Skip to content

Commit d2720b7

Browse files
Merge pull request #349 from robbievanleeuwen/optional-numba
Make `numba` an optional dependency
2 parents 9030803 + 8cff33d commit d2720b7

File tree

6 files changed

+98
-21
lines changed

6 files changed

+98
-21
lines changed

.readthedocs.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,4 @@ python:
2323
extra_requirements:
2424
- dxf
2525
- rhino
26+
- numba

docs/installation.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,18 @@ package index:
2626
2727
pip install sectionproperties
2828
29+
Installing ``Numba``
30+
--------------------
31+
32+
``Numba`` translates a subset of Python and NumPy code into fast machine code, allowing
33+
algorithms to approach the speeds of C. The speed of several ``sectionproperties``
34+
analysis functions have been enhanced with `numba <https://github.com/numba/numba>`_.
35+
To take advantage of this increase in performance you can install ``numba`` alongside
36+
``sectionproperties`` with:
37+
38+
.. code-block:: shell
39+
40+
pip install sectionproperties[numba]
2941
3042
Installing ``PARDISO`` Solver
3143
-----------------------------

noxfile.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,13 @@ def docs_build(session: Session) -> None:
212212
args.insert(0, "--color")
213213

214214
session.run_always(
215-
"poetry", "install", "--only", "main", "--extras", "dxf rhino", external=True
215+
"poetry",
216+
"install",
217+
"--only",
218+
"main",
219+
"--extras",
220+
"dxf rhino numba",
221+
external=True,
216222
)
217223
session.install(
218224
"furo",
@@ -243,7 +249,13 @@ def docs(session: Session) -> None:
243249
"""
244250
args = session.posargs or ["--open-browser", "docs", "docs/_build"]
245251
session.run_always(
246-
"poetry", "install", "--only", "main", "--extras", "dxf rhino", external=True
252+
"poetry",
253+
"install",
254+
"--only",
255+
"main",
256+
"--extras",
257+
"dxf rhino numba",
258+
external=True,
247259
)
248260
session.install(
249261
"furo",

poetry.lock

Lines changed: 4 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,15 @@ Changelog = "https://github.com/robbievanleeuwen/section-properties/releases"
4848

4949
[tool.poetry.dependencies]
5050
python = ">=3.9.0,<3.12"
51-
numpy = "^1.25.2" # numba requires numpy <1.26
51+
numpy = "^1.25.2"
5252
scipy = "^1.11.3"
5353
matplotlib = "^3.8.0"
5454
shapely = "^2.0.1"
5555
triangle = "^20230923"
5656
rich = "^13.6.0"
5757
click = "^8.1.7"
5858
more-itertools = "^10.1.0"
59-
numba = "^0.58.0"
59+
numba = { version = "^0.58.0", optional = true }
6060
cad-to-shapely = { version = "^0.3.1", optional = true }
6161
rhino-shapley-interop = { version = "^0.0.4", optional = true }
6262
rhino3dm = { version = "==8.0.0b3", optional = true }
@@ -97,6 +97,7 @@ sphinxext-opengraph = "^0.8.2"
9797
[tool.poetry.extras]
9898
dxf = ["cad-to-shapely"]
9999
rhino = ["rhino-shapley-interop", "rhino3dm"]
100+
numba = ["numba"]
100101
pardiso = ["pypardiso"]
101102

102103
[tool.poetry.scripts]

src/sectionproperties/analysis/fea.py

Lines changed: 64 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,67 @@
1010
import warnings
1111
from dataclasses import dataclass, field
1212
from functools import lru_cache
13-
from typing import TYPE_CHECKING
13+
from typing import TYPE_CHECKING, Any, Callable
1414

1515
import numpy as np
1616
import numpy.typing as npt
17-
from numba import njit
18-
from numba.core.errors import NumbaPerformanceWarning
1917

2018

2119
if TYPE_CHECKING:
2220
from sectionproperties.pre.pre import Material
2321

2422

25-
@njit(cache=True, nogil=True) # type: ignore
23+
# numba is an optional dependency
24+
try:
25+
from numba import njit
26+
from numba.core.errors import NumbaPerformanceWarning
27+
28+
USE_NUMBA = True
29+
except ImportError:
30+
31+
def njit() -> None:
32+
"""Assigns empty function to njit if numba isn't installed.
33+
34+
Returns:
35+
None
36+
"""
37+
return None
38+
39+
USE_NUMBA = False
40+
41+
42+
def conditional_decorator(
43+
dec: Callable[[Any], Any],
44+
condition: bool,
45+
) -> Callable[[Any], Any]:
46+
"""A decorator that applies a decorator only if a condition is True.
47+
48+
Args:
49+
dec: Decorator to apply
50+
condition: Apply decorator if this is true
51+
52+
Returns:
53+
Decorator wrapper
54+
"""
55+
56+
def decorator(func: Callable[[Any], Any]) -> Callable[[Any], Any]:
57+
"""Decorator wrapper.
58+
59+
Args:
60+
func: Function decorator operates on.
61+
62+
Returns:
63+
Original or decorated function.
64+
"""
65+
if not condition:
66+
return func
67+
68+
return dec(func) # type: ignore
69+
70+
return decorator
71+
72+
73+
@conditional_decorator(njit, USE_NUMBA)
2674
def _assemble_torsion(
2775
k_el: npt.NDArray[np.float64],
2876
f_el: npt.NDArray[np.float64],
@@ -55,7 +103,7 @@ def _assemble_torsion(
55103
return k_el, f_el, c_el
56104

57105

58-
@njit(cache=True, nogil=True) # type: ignore
106+
@conditional_decorator(njit, USE_NUMBA)
59107
def _shear_parameter(
60108
nx: float, ny: float, ixx: float, iyy: float, ixy: float
61109
) -> tuple[float, float, float, float, float, float]:
@@ -81,7 +129,7 @@ def _shear_parameter(
81129
return r, q, d1, d2, h1, h2
82130

83131

84-
@njit(cache=True, nogil=True) # type: ignore
132+
@conditional_decorator(njit, USE_NUMBA)
85133
def _assemble_shear_load(
86134
f_psi: npt.NDArray[np.float64],
87135
f_phi: npt.NDArray[np.float64],
@@ -126,7 +174,7 @@ def _assemble_shear_load(
126174
return f_psi, f_phi
127175

128176

129-
@njit(cache=True, nogil=True) # type: ignore
177+
@conditional_decorator(njit, USE_NUMBA)
130178
def _assemble_shear_coefficients(
131179
kappa_x: float,
132180
kappa_y: float,
@@ -648,7 +696,9 @@ def element_stress(
648696

649697
# extrapolate results to nodes, ignore numba warnings about performance
650698
with warnings.catch_warnings():
651-
warnings.simplefilter("ignore", category=NumbaPerformanceWarning)
699+
if USE_NUMBA:
700+
warnings.simplefilter("ignore", category=NumbaPerformanceWarning)
701+
652702
sig_zz_mxx = extrapolate_to_nodes(w=sig_zz_mxx_gp)
653703
sig_zz_myy = extrapolate_to_nodes(w=sig_zz_myy_gp)
654704
sig_zz_m11 = extrapolate_to_nodes(w=sig_zz_m11_gp)
@@ -951,7 +1001,7 @@ def gauss_points(*, n: int) -> npt.NDArray[np.float64]:
9511001

9521002

9531003
@lru_cache(maxsize=None)
954-
@njit(cache=True, nogil=True) # type: ignore
1004+
@conditional_decorator(njit, USE_NUMBA)
9551005
def __shape_function_cached(
9561006
coords: tuple[float, ...],
9571007
gauss_point: tuple[float, float, float],
@@ -1039,7 +1089,7 @@ def shape_function(
10391089

10401090

10411091
@lru_cache(maxsize=None)
1042-
@njit(cache=True, nogil=True) # type: ignore
1092+
@conditional_decorator(njit, USE_NUMBA)
10431093
def shape_function_only(p: tuple[float, float, float]) -> npt.NDArray[np.float64]:
10441094
"""The values of the ``Tri6`` shape function at a point ``p``.
10451095
@@ -1117,7 +1167,7 @@ def shape_function_only(p: tuple[float, float, float]) -> npt.NDArray[np.float64
11171167
)
11181168

11191169

1120-
@njit(cache=True, nogil=True) # type: ignore
1170+
@conditional_decorator(njit, USE_NUMBA)
11211171
def extrapolate_to_nodes(w: npt.NDArray[np.float64]) -> npt.NDArray[np.float64]:
11221172
"""Extrapolates results at six Gauss points to the six nodes of a ``Tri6`` element.
11231173
@@ -1130,7 +1180,7 @@ def extrapolate_to_nodes(w: npt.NDArray[np.float64]) -> npt.NDArray[np.float64]:
11301180
return h_inv @ w
11311181

11321182

1133-
@njit(cache=True, nogil=True) # type: ignore
1183+
@conditional_decorator(njit, USE_NUMBA)
11341184
def principal_coordinate(
11351185
phi: float,
11361186
x: float,
@@ -1153,7 +1203,7 @@ def principal_coordinate(
11531203
return x * cos_phi + y * sin_phi, y * cos_phi - x * sin_phi
11541204

11551205

1156-
@njit(cache=True, nogil=True) # type: ignore
1206+
@conditional_decorator(njit, USE_NUMBA)
11571207
def global_coordinate(
11581208
phi: float,
11591209
x11: float,
@@ -1176,7 +1226,7 @@ def global_coordinate(
11761226
return x11 * cos_phi - y22 * sin_phi, x11 * sin_phi + y22 * cos_phi
11771227

11781228

1179-
@njit(cache=True, nogil=True) # type: ignore
1229+
@conditional_decorator(njit, USE_NUMBA)
11801230
def point_above_line(
11811231
u: npt.NDArray[np.float64],
11821232
px: float,

0 commit comments

Comments
 (0)