The Wayback Machine - https://web.archive.org/web/20210915173119/https://github.com/scikit-learn/scikit-learn/pull/18370
Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Nearest Centroid Divide by Zero Fix #18370

Merged
merged 12 commits into from Oct 2, 2020

Conversation

@trewaite
Copy link
Contributor

@trewaite trewaite commented Sep 10, 2020

Reference Issues/PRs

#18324

What does this implement/fix? Explain your changes.

Checks to see if X features are all 0 variance. If so, raises value error and will avoid divide by 0.

Any other comments?

Not sure if this is desired behavior but seems like best way to handle, let me know if the error message is specific enough.

@cmarmo
Copy link
Member

@cmarmo cmarmo commented Sep 21, 2020

Thanks @Trevor-Waite for your pull request. Do you mind adding a test checking that the error is thrown when necessary?
When you are ready for review, please change [WIP] to [MRG] in the title of your PR. Thanks!

@trewaite trewaite changed the title [WIP] Nearest Centroid Divide by Zero Fix [MRG] Nearest Centroid Divide by Zero Fix Sep 24, 2020
Copy link
Member

@thomasjpfan thomasjpfan left a comment

Thank you for the PR @Trevor-Waite !

if np.sum(variance) == 0:
raise ValueError("All features have zero variance. "
"Division by zero.")
Copy link
Member

@thomasjpfan thomasjpfan Sep 24, 2020

We would been to be careful here with using variance directly:

import numpy as np

X = np.empty((10, 2))
X[:, 0] = -0.13725701
X[:, 1] = -0.9853293

X_means = X.mean(axis=0)
var = (X - means)**2
var = var.sum(axis=0)

var
# array([7.70371978e-33, 1.23259516e-31])

We can use np.ptp instead to look for constant features:

np.ptp(X, axis=0)
# array([0., 0.])

Copy link
Contributor Author

@trewaite trewaite Sep 24, 2020

No problem! Good catch, changed logic to np.ptp(X).sum() == 0 and changed test case to match your example.

@@ -154,16 +154,16 @@ def fit(self, X, y):
self.centroids_[cur_class] = X[center_mask].mean(axis=0)

if self.shrink_threshold:
if np.ptp(X).sum() == 0:
Copy link
Member

@thomasjpfan thomasjpfan Sep 24, 2020

I suspect we still need the axis=0 to do ptp for each feature:

Suggested change
if np.ptp(X).sum() == 0:
if np.all(np.ptp(X, axis=0) == 0):

Also avoiding the sum avoids more numerical instability.

Copy link
Contributor Author

@trewaite trewaite Sep 24, 2020

Apologies, pushed before I noticed your review...will change to np.all methodology.

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Please add an entry to the change log at doc/whats_new/v0.24.rst with tag |Fix|. Like the other entries there, please reference this pull request with :pr: and credit yourself (and other contributors if applicable) with :user:.

doc/whats_new/v0.24.rst Outdated Show resolved Hide resolved
rth
rth approved these changes Oct 2, 2020
Copy link
Member

@rth rth left a comment

Thank you @trewaite !

@rth rth merged commit 2538489 into scikit-learn:master Oct 2, 2020
19 checks passed
19 checks passed
@lgtm-com[bot]
LGTM analysis: C/C++ No code changes detected
Details
@lgtm-com[bot]
LGTM analysis: JavaScript No code changes detected
Details
@lgtm-com[bot]
LGTM analysis: Python No new or fixed alerts
Details
ci/circleci: deploy Your tests passed on CircleCI!
Details
ci/circleci: doc Your tests passed on CircleCI!
Details
@circleci-artifacts-redirector[bot]
ci/circleci: doc artifact Link to 0/doc/_changed.html
Details
ci/circleci: doc-min-dependencies Your tests passed on CircleCI!
Details
ci/circleci: lint Your tests passed on CircleCI!
Details
@azure-pipelines[bot]
scikit-learn.scikit-learn Build #20200930.19 succeeded
Details
@azure-pipelines[bot]
scikit-learn.scikit-learn (Linting) Linting succeeded
Details
@azure-pipelines[bot]
scikit-learn.scikit-learn (Linux py36_conda_openblas) Linux py36_conda_openblas succeeded
Details
@azure-pipelines[bot]
scikit-learn.scikit-learn (Linux py36_ubuntu_atlas) Linux py36_ubuntu_atlas succeeded
Details
@azure-pipelines[bot]
scikit-learn.scikit-learn (Linux pylatest_pip_openblas_pandas) Linux pylatest_pip_openblas_pandas succeeded
Details
@azure-pipelines[bot]
scikit-learn.scikit-learn (Linux32 py36_ubuntu_atlas_32bit) Linux32 py36_ubuntu_atlas_32bit succeeded
Details
@azure-pipelines[bot]
scikit-learn.scikit-learn (Linux_Runs pylatest_conda_mkl) Linux_Runs pylatest_conda_mkl succeeded
Details
@azure-pipelines[bot]
scikit-learn.scikit-learn (Windows py36_pip_openblas_32bit) Windows py36_pip_openblas_32bit succeeded
Details
@azure-pipelines[bot]
scikit-learn.scikit-learn (Windows py37_conda_mkl) Windows py37_conda_mkl succeeded
Details
@azure-pipelines[bot]
scikit-learn.scikit-learn (macOS pylatest_conda_mkl) macOS pylatest_conda_mkl succeeded
Details
@azure-pipelines[bot]
scikit-learn.scikit-learn (macOS pylatest_conda_mkl_no_openmp) macOS pylatest_conda_mkl_no_openmp succeeded
Details
amrcode added a commit to amrcode/scikit-learn that referenced this issue Oct 19, 2020
jayzed82 added a commit to jayzed82/scikit-learn that referenced this issue Oct 22, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Linked issues

Successfully merging this pull request may close these issues.

None yet

4 participants