66from scipy .sparse import issparse
77
88import numpy as np
9+ from lightning .impl .dataset_fast import get_dataset
10+
11+ from .kernels_fast import _fast_anova_kernel_batch , _fast_anova_grad
912
1013
1114def safe_power (X , degree = 2 ):
@@ -65,7 +68,7 @@ def homogeneous_kernel(X, P, degree=2):
6568 return polynomial_kernel (X , P , degree = degree , gamma = 1 , coef0 = 0 )
6669
6770
68- def anova_kernel (X , P , degree = 2 ):
71+ def anova_kernel (X , P , degree = 2 , method = 'auto' ):
6972 """ANOVA kernel between X and P::
7073
7174 K_A(x, p) = sum_i1>i2>...>id x_i1 p_i1 x_i2 p_i2 ... x_id p_id
@@ -81,11 +84,25 @@ def anova_kernel(X, P, degree=2):
8184
8285 degree : int, default 2
8386
87+ method : string, default: 'auto'
88+ - 'dp' : dynamic programming recursion
89+ - 'auto': vectorized formula for degree 2 or 3, revert to 'dp' for
90+ higher degrees.
91+
8492 Returns
8593 -------
8694 Gram matrix : array of shape (n_samples_1, n_samples_2)
8795 """
88- if degree == 2 :
96+
97+ if degree > 3 or method == 'dp' :
98+ n_samples = X .shape [0 ]
99+ n_components = P .shape [0 ]
100+ ds = get_dataset (X , 'c' )
101+
102+ K = np .empty ((n_samples , n_components ))
103+ _fast_anova_kernel_batch (ds , P , degree , K )
104+
105+ elif degree == 2 :
89106 K = homogeneous_kernel (X , P , degree = 2 )
90107 K -= _D (X , P , degree = 2 )
91108 K /= 2
@@ -95,11 +112,20 @@ def anova_kernel(X, P, degree=2):
95112 K += 2 * _D (X , P , degree = 3 )
96113 K /= 6
97114 else :
98- raise NotImplementedError ( "ANOVA kernel for degree >= 4 not yet "
99- "implemented efficiently." )
115+ raise ValueError ( "Unsupported parameters. Degree must be > 1." )
116+
100117 return K
101118
102119
120+ def anova_grad (X , i , P , degree = 2 ):
121+ """Computes the ANOVA gradient of the i-th row of X, wrt to P"""
122+
123+ ds = get_dataset (X , 'c' )
124+ grad = np .empty_like (P )
125+ _fast_anova_grad (ds , i , P , degree , grad )
126+ return grad
127+
128+
103129def _poly_predict (X , P , lams , kernel , degree = 2 ):
104130 if kernel == "anova" :
105131 K = anova_kernel (X , P , degree )
0 commit comments