1111from ..utils .extmath import fast_dot , svd_flip
1212
1313
14- def _weighted_average (x , y , x_weight , y_weight ):
15- num = x * x_weight + y * y_weight
16- denom = x_weight + y_weight
14+ def _mean_update (old_mean , new_mean , old_sample_count , new_sample_count ):
15+ """Minibatch mean update."""
16+ num = old_mean * old_sample_count + new_mean * new_sample_count
17+ denom = old_sample_count + new_sample_count
1718 return num / denom
1819
1920
21+ def _calc_sum_and_var (X ):
22+ """Calculate Youngs and Cramer components (T and S)."""
23+ stored_sum = np .sum (X , axis = 0 )
24+ unnormalized_variance = np .sum ((X - 1. / X .shape [0 ] *
25+ stored_sum ) ** 2 , axis = 0 )
26+ return stored_sum , unnormalized_variance
27+
28+
29+ def _variance_update (old_sum , new_sum , old_var , new_var , old_sample_count ,
30+ new_sample_count ):
31+ """Youngs and Cramer minibatch update."""
32+ batch_sum = old_sum + new_sum
33+ n = new_sample_count
34+ m = old_sample_count
35+ partial_var = float (m ) / (n * (m + n )) * (n / float (m ) * old_sum
36+ - new_sum ) ** 2
37+ batch_var = old_var + new_var + partial_var
38+ return batch_sum , batch_var
39+
40+
2041class IncrementalPCA (BaseEstimator , TransformerMixin ):
2142 """Incremental principal components analysis (IPCA).
2243
@@ -124,9 +145,6 @@ def fit(self, X, y=None):
124145 if hasattr (self , "components_" ):
125146 del self .components_
126147 del self .mean_
127- del self .explained_variance_
128- del self .explained_variance_ratio_
129- del self ._explained_variance_sum
130148 self .samples_seen_ = 0
131149 X = array2d (X )
132150 n_samples , n_features = X .shape
@@ -178,19 +196,19 @@ def partial_fit(self, X, y=None):
178196 U , S , V = linalg .svd (X , full_matrices = False )
179197 U , V = svd_flip (U , V , u_based_decision = False )
180198 components = V [:n_components ]
181-
182- explained_variance = (S ** 2 ) / n_samples
183- explained_variance_sum = np .sum (explained_variance )
184- explained_variance_ratio = (explained_variance /
185- explained_variance_sum )
199+ stored_sum , unnormalized_variance = _calc_sum_and_var (X )
200+ explained_variance = S ** 2 / n_samples
201+ variance_sum = np .sum (unnormalized_variance / n_samples )
202+ explained_variance_ratio = explained_variance / variance_sum
186203 else :
187204 old_components = self .components_
188205 old_mean = self .mean_
189206 old_sample_count = self .samples_seen_
190207 new_sample_count = n_samples
191208 new_mean = X .mean (axis = 0 )
192- mean = _weighted_average (old_mean , new_mean ,
193- old_sample_count , new_sample_count )
209+ mean = _mean_update (old_mean , new_mean , old_sample_count ,
210+ new_sample_count )
211+
194212 X -= new_mean
195213 append_vals = np .sqrt ((old_sample_count * new_sample_count ) /
196214 (old_sample_count + new_sample_count ))
@@ -204,22 +222,25 @@ def partial_fit(self, X, y=None):
204222 U , V = svd_flip (U , V , u_based_decision = False )
205223 components = V [:n_components ]
206224
207- explained_variance = (S ** 2 ) / (old_sample_count + new_sample_count )
208- old_sum = self ._explained_variance_sum
209- new_sum = np .sum (explained_variance )
210- var_diff = np .abs (new_sum - old_sum )
211- explained_variance_sum = old_sum + (float (new_sample_count ) /
212- old_sample_count ) * var_diff
213- explained_variance_ratio = (explained_variance /
214- explained_variance_sum )
225+ old_stored_sum = self ._stored_sum
226+ old_unnormalized_variance = self ._unnormalized_variance
227+ new_stored_sum , new_unnormalized_variance = _calc_sum_and_var (X )
228+ stored_sum , unnormalized_variance = _variance_update (
229+ old_stored_sum , new_stored_sum , old_unnormalized_variance ,
230+ new_unnormalized_variance , old_sample_count , new_sample_count )
231+ explained_variance = S ** 2 / (old_sample_count + new_sample_count )
232+ variance_sum = np .sum (unnormalized_variance / (old_sample_count +
233+ new_sample_count ))
234+ explained_variance_ratio = explained_variance / variance_sum
215235
216236 self .samples_seen_ += new_sample_count
217237 self .components_ = components [:n_components ]
218238 self .singular_vals_ = S [:n_components ]
239+ self .mean_ = mean
219240 self .explained_variance_ = explained_variance [:n_components ]
220- self ._explained_variance_sum = explained_variance_sum
221241 self .explained_variance_ratio_ = explained_variance_ratio [:n_components ]
222- self .mean_ = mean
242+ self ._stored_sum = stored_sum
243+ self ._unnormalized_variance = unnormalized_variance
223244 return self
224245
225246 def transform (self , X ):
0 commit comments