77
88from __future__ import annotations
99
10- import warnings
1110from dataclasses import dataclass , field
1211from functools import lru_cache
1312from typing import TYPE_CHECKING , Any , Callable
2322# numba is an optional dependency
2423try :
2524 from numba import njit
26- from numba .core .errors import NumbaPerformanceWarning
27-
28- USE_NUMBA = True
2925except ImportError :
3026
31- def njit () -> None :
32- """Assigns empty function to njit if numba isn't installed.
27+ def njit (** options : Any ) -> Callable [[Any ], Any ]:
28+ """Empty decorator if numba is not installed.
29+
30+ Args:
31+ options: Optional keyword arguments for numba that are discarded.
3332
3433 Returns:
35- None
34+ Empty njit decorator.
3635 """
37- return None
3836
39- USE_NUMBA = False
37+ def decorator (func : Callable [[Any ], Any ]) -> Callable [[Any ], Any ]:
38+ """Decorator.
4039
40+ Args:
41+ func: Function to decorate.
4142
42- def conditional_decorator (
43- dec : Callable [[Any ], Any ],
44- condition : bool ,
45- ) -> Callable [[Any ], Any ]:
46- """A decorator that applies a decorator only if a condition is True.
43+ Returns:
44+ Decorated function.
45+ """
4746
48- Args:
49- dec: Decorator to apply
50- condition: Apply decorator if this is true
47+ def wrapper (* args : Any , ** kwargs : Any ) -> Callable [[Any ], Any ]:
48+ """Wrapper.
5149
52- Returns :
53- Decorator wrapper
54- """
50+ Args :
51+ args: Arguments.
52+ kwargs: Keyword arguments.
5553
56- def decorator (func : Callable [[Any ], Any ]) -> Callable [[Any ], Any ]:
57- """Decorator wrapper.
58-
59- Args:
60- func: Function decorator operates on.
61-
62- Returns:
63- Original or decorated function.
64- """
65- if not condition :
66- return func
54+ Returns:
55+ Wrapped function.
56+ """
57+ return func (* args , ** kwargs ) # type: ignore
6758
68- return dec ( func ) # type: ignore
59+ return wrapper
6960
70- return decorator
61+ return decorator
7162
7263
73- @conditional_decorator ( njit , USE_NUMBA )
64+ @njit ( cache = True , nogil = True ) # type: ignore
7465def _assemble_torsion (
7566 k_el : npt .NDArray [np .float64 ],
7667 f_el : npt .NDArray [np .float64 ],
@@ -103,7 +94,7 @@ def _assemble_torsion(
10394 return k_el , f_el , c_el
10495
10596
106- @conditional_decorator ( njit , USE_NUMBA )
97+ @njit ( cache = True , nogil = True ) # type: ignore
10798def _shear_parameter (
10899 nx : float , ny : float , ixx : float , iyy : float , ixy : float
109100) -> tuple [float , float , float , float , float , float ]:
@@ -129,7 +120,7 @@ def _shear_parameter(
129120 return r , q , d1 , d2 , h1 , h2
130121
131122
132- @conditional_decorator ( njit , USE_NUMBA )
123+ @njit ( cache = True , nogil = True ) # type: ignore
133124def _assemble_shear_load (
134125 f_psi : npt .NDArray [np .float64 ],
135126 f_phi : npt .NDArray [np .float64 ],
@@ -174,7 +165,7 @@ def _assemble_shear_load(
174165 return f_psi , f_phi
175166
176167
177- @conditional_decorator ( njit , USE_NUMBA )
168+ @njit ( cache = True , nogil = True ) # type: ignore
178169def _assemble_shear_coefficients (
179170 kappa_x : float ,
180171 kappa_y : float ,
@@ -635,9 +626,9 @@ def element_stress(
635626 sig_zz_myy_gp = np .zeros (n_points )
636627 sig_zz_m11_gp = np .zeros (n_points )
637628 sig_zz_m22_gp = np .zeros (n_points )
638- sig_zxy_mzz_gp = np .zeros ((n_points , 2 ))
639- sig_zxy_vx_gp = np .zeros ((n_points , 2 ))
640- sig_zxy_vy_gp = np .zeros ((n_points , 2 ))
629+ sig_zxy_mzz_gp = np .zeros ((n_points , 2 ), order = "F" )
630+ sig_zxy_vx_gp = np .zeros ((n_points , 2 ), order = "F" )
631+ sig_zxy_vy_gp = np .zeros ((n_points , 2 ), order = "F" )
641632
642633 # Gauss points for 6 point Gaussian integration
643634 gps = gauss_points (n = n_points )
@@ -694,21 +685,17 @@ def element_stress(
694685 * (b .dot (phi_shear ) - nu / 2 * np .array ([h1 , h2 ]))
695686 )
696687
697- # extrapolate results to nodes, ignore numba warnings about performance
698- with warnings .catch_warnings ():
699- if USE_NUMBA :
700- warnings .simplefilter ("ignore" , category = NumbaPerformanceWarning )
701-
702- sig_zz_mxx = extrapolate_to_nodes (w = sig_zz_mxx_gp )
703- sig_zz_myy = extrapolate_to_nodes (w = sig_zz_myy_gp )
704- sig_zz_m11 = extrapolate_to_nodes (w = sig_zz_m11_gp )
705- sig_zz_m22 = extrapolate_to_nodes (w = sig_zz_m22_gp )
706- sig_zx_mzz = extrapolate_to_nodes (w = sig_zxy_mzz_gp [:, 0 ])
707- sig_zy_mzz = extrapolate_to_nodes (w = sig_zxy_mzz_gp [:, 1 ])
708- sig_zx_vx = extrapolate_to_nodes (w = sig_zxy_vx_gp [:, 0 ])
709- sig_zy_vx = extrapolate_to_nodes (w = sig_zxy_vx_gp [:, 1 ])
710- sig_zx_vy = extrapolate_to_nodes (w = sig_zxy_vy_gp [:, 0 ])
711- sig_zy_vy = extrapolate_to_nodes (w = sig_zxy_vy_gp [:, 1 ])
688+ # extrapolate results to nodes
689+ sig_zz_mxx = extrapolate_to_nodes (w = sig_zz_mxx_gp )
690+ sig_zz_myy = extrapolate_to_nodes (w = sig_zz_myy_gp )
691+ sig_zz_m11 = extrapolate_to_nodes (w = sig_zz_m11_gp )
692+ sig_zz_m22 = extrapolate_to_nodes (w = sig_zz_m22_gp )
693+ sig_zx_mzz = extrapolate_to_nodes (w = sig_zxy_mzz_gp [:, 0 ])
694+ sig_zy_mzz = extrapolate_to_nodes (w = sig_zxy_mzz_gp [:, 1 ])
695+ sig_zx_vx = extrapolate_to_nodes (w = sig_zxy_vx_gp [:, 0 ])
696+ sig_zy_vx = extrapolate_to_nodes (w = sig_zxy_vx_gp [:, 1 ])
697+ sig_zx_vy = extrapolate_to_nodes (w = sig_zxy_vy_gp [:, 0 ])
698+ sig_zy_vy = extrapolate_to_nodes (w = sig_zxy_vy_gp [:, 1 ])
712699
713700 return (
714701 sig_zz_n ,
@@ -1001,7 +988,7 @@ def gauss_points(*, n: int) -> npt.NDArray[np.float64]:
1001988
1002989
1003990@lru_cache (maxsize = None )
1004- @conditional_decorator ( njit , USE_NUMBA )
991+ @njit ( cache = True , nogil = True ) # type: ignore
1005992def __shape_function_cached (
1006993 coords : tuple [float , ...],
1007994 gauss_point : tuple [float , float , float ],
@@ -1089,7 +1076,7 @@ def shape_function(
10891076
10901077
10911078@lru_cache (maxsize = None )
1092- @conditional_decorator ( njit , USE_NUMBA )
1079+ @njit ( cache = True , nogil = True ) # type: ignore
10931080def shape_function_only (p : tuple [float , float , float ]) -> npt .NDArray [np .float64 ]:
10941081 """The values of the ``Tri6`` shape function at a point ``p``.
10951082
@@ -1167,7 +1154,7 @@ def shape_function_only(p: tuple[float, float, float]) -> npt.NDArray[np.float64
11671154)
11681155
11691156
1170- @conditional_decorator ( njit , USE_NUMBA )
1157+ @njit ( cache = True , nogil = True ) # type: ignore
11711158def extrapolate_to_nodes (w : npt .NDArray [np .float64 ]) -> npt .NDArray [np .float64 ]:
11721159 """Extrapolates results at six Gauss points to the six nodes of a ``Tri6`` element.
11731160
@@ -1180,7 +1167,7 @@ def extrapolate_to_nodes(w: npt.NDArray[np.float64]) -> npt.NDArray[np.float64]:
11801167 return h_inv @ w
11811168
11821169
1183- @conditional_decorator ( njit , USE_NUMBA )
1170+ @njit ( cache = True , nogil = True ) # type: ignore
11841171def principal_coordinate (
11851172 phi : float ,
11861173 x : float ,
@@ -1203,7 +1190,7 @@ def principal_coordinate(
12031190 return x * cos_phi + y * sin_phi , y * cos_phi - x * sin_phi
12041191
12051192
1206- @conditional_decorator ( njit , USE_NUMBA )
1193+ @njit ( cache = True , nogil = True ) # type: ignore
12071194def global_coordinate (
12081195 phi : float ,
12091196 x11 : float ,
@@ -1226,7 +1213,7 @@ def global_coordinate(
12261213 return x11 * cos_phi - y22 * sin_phi , x11 * sin_phi + y22 * cos_phi
12271214
12281215
1229- @conditional_decorator ( njit , USE_NUMBA )
1216+ @njit ( cache = True , nogil = True ) # type: ignore
12301217def point_above_line (
12311218 u : npt .NDArray [np .float64 ],
12321219 px : float ,
0 commit comments