[MRG] Nearest Centroid Divide by Zero Fix #18370
Conversation
Thanks @Trevor-Waite for your pull request. Do you mind adding a test checking that the error is thrown when necessary? |
Thank you for the PR @Trevor-Waite !
if np.sum(variance) == 0: | ||
raise ValueError("All features have zero variance. " | ||
"Division by zero.") |
We would been to be careful here with using variance directly:
import numpy as np
X = np.empty((10, 2))
X[:, 0] = -0.13725701
X[:, 1] = -0.9853293
X_means = X.mean(axis=0)
var = (X - means)**2
var = var.sum(axis=0)
var
# array([7.70371978e-33, 1.23259516e-31])
We can use np.ptp
instead to look for constant features:
np.ptp(X, axis=0)
# array([0., 0.])
No problem! Good catch, changed logic to np.ptp(X).sum() == 0
and changed test case to match your example.
@@ -154,16 +154,16 @@ def fit(self, X, y): | |||
self.centroids_[cur_class] = X[center_mask].mean(axis=0) | |||
|
|||
if self.shrink_threshold: | |||
if np.ptp(X).sum() == 0: |
I suspect we still need the axis=0
to do ptp for each feature:
if np.ptp(X).sum() == 0: | |
if np.all(np.ptp(X, axis=0) == 0): |
Also avoiding the sum
avoids more numerical instability.
Apologies, pushed before I noticed your review...will change to np.all methodology.
Please add an entry to the change log at doc/whats_new/v0.24.rst
with tag |Fix|. Like the other entries there, please reference this pull request with :pr:
and credit yourself (and other contributors if applicable) with :user:
.
Co-authored-by: Thomas J. Fan <[email protected]>
2538489
into
scikit-learn:master
Co-authored-by: Thomas J. Fan <[email protected]>
Co-authored-by: Thomas J. Fan <[email protected]>
Reference Issues/PRs
#18324
What does this implement/fix? Explain your changes.
Checks to see if X features are all 0 variance. If so, raises value error and will avoid divide by 0.
Any other comments?
Not sure if this is desired behavior but seems like best way to handle, let me know if the error message is specific enough.
The text was updated successfully, but these errors were encountered: