The Wayback Machine - https://web.archive.org/web/20210819030409/https://github.com/scikit-learn/scikit-learn/pull/18405
Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX use unique values of y_true and y_pred in plot_confusion_matrix instead of estimator.classes_ #18405

Merged
merged 19 commits into from Oct 21, 2020

Conversation

@kyouma
Copy link
Contributor

@kyouma kyouma commented Sep 15, 2020

Although the docstring and the API guide of sklearn.metrics.plot_confusion_matrix() say about the labels argument the following: "If 'None' is given, those that appear at least once in 'y_true' or 'y_pred' are used in sorted order", the estimator.classes_ field was used.

Reference Issues/PRs

What does this implement/fix? Explain your changes.

This change fixes errors when y_true and y_pred doesn't have some values from estimator.classes_.

Any other comments?

Although the docstring and the API guide say "If 'None' is given, those that appear at least once in `y_true` or `y_pred` are used in sorted order", the "estimator.classes_" field was used.
Although the docstring and the API guide say "If 'None' is given, those that appear at least once in 'y_true' or 'y_pred' are used in sorted order", the 'estimator.classes_' field was used.
Copy link
Member

@thomasjpfan thomasjpfan left a comment

Thank you for the PR!

Please add a non-regression test that would fail at master but pass in this PR.

kyouma added 4 commits Sep 17, 2020
A test for plot_confusion_matrix() behaviour when 'labels=None' and the dataset with true labels contains labels previously unseen by the classifier (and therefore not present in its 'classes_') attribute. According to the function description, it must create a union of the predicted labels and the true labels.
An update to the 'test_error_on_a_dataset_with_unseen_labels()' function to fix 'E501 line too long' errors.
@kyouma
Copy link
Contributor Author

@kyouma kyouma commented Sep 17, 2020

Thank you for the PR!

Please add a non-regression test that would fail at master but pass in this PR.

Thank you for the review, @thomasjpfan! This is my first pull request, I will try to do my best to implement and prepare everything correctly.

I have added the test test_error_on_a_dataset_with_unseen_labels() that checks tick labels of the confusion matrix plot.
In iPython console matplotlib doesn't throw exceptions on this test, so I had to add this check. And in Jupyter Notebook the very call of plot_confusion_matrix() would raise the exception ValueError("The number of FixedLocator locations (...), usually from a call to set_ticks, does not match the number of ticklabels (...).").
The updated plot_confusion_matrix() function is intended to pass this test.

@kyouma kyouma requested a review from thomasjpfan Sep 17, 2020
raise TypeError(
f"Labels in y_true and y_pred should be of the same type. "
f"Got y_true={np.unique(y_true)} and "
f"y_pred={np.unique(y_pred)}. Make sure that the "
f"predictions provided by the classifier coincides with "
f"the true labels."
Comment on lines 272 to 277

This comment has been minimized.

@thomasjpfan

thomasjpfan Sep 24, 2020
Member

Do we have a test to make sure this error is raised?

This comment has been minimized.

@kyouma

kyouma Sep 24, 2020
Author Contributor

I have removed the Try-Except wrapping as the function confusion_matrix(), which is used above to get the matrix itself, contains the same unique_labels() call that was wrapped by the Try-Except block, and the function unique_labels() raises an exception with a description when the true and predicted labels have different types. So if the execution arrives at the line, it will not make any problems.

labels=None, display_labels=None)

disp_labels = set([tick.get_text() for tick in disp.ax_.get_xticklabels()])
expected_labels = unique_labels(y, fitted_clf.predict(X))

This comment has been minimized.

@thomasjpfan

thomasjpfan Sep 24, 2020
Member

In this case, we can list the labels:

    display_labels = [tick.get_text() for tick in disp.ax_.get_xticklabels()]
    expected_labels = [f'{i}' for range(6)]
    assert_array_equal(expected_labels, display_labels)

This comment has been minimized.

@kyouma

kyouma Sep 24, 2020
Author Contributor

Thank you, I have replaced these lines and the assertion check with your code.

kyouma added 3 commits Sep 24, 2020
This Try-Catch is not necessary, as the same unique_labels() function is called inside confusion_matrix() above and raises an exception with a description if the types of true and predicted labels differ.
Copy link
Member

@thomasjpfan thomasjpfan left a comment

Please add an entry to the change log at doc/whats_new/v0.24.rst with tag |Fix|. Like the other entries there, please reference this pull request with :pr: and credit yourself (and other contributors if applicable) with :user:.

@@ -314,3 +315,16 @@ def test_default_labels(pyplot, display_labels, expected_labels):

assert_array_equal(x_ticks, expected_labels)
assert_array_equal(y_ticks, expected_labels)


def test_error_on_a_dataset_with_unseen_labels(pyplot, fitted_clf, data):

This comment has been minimized.

@thomasjpfan

thomasjpfan Sep 24, 2020
Member

We may need to wrap this to be <= 79:

Suggested change
def test_error_on_a_dataset_with_unseen_labels(pyplot, fitted_clf, data):
def test_error_on_a_dataset_with_unseen_labels(pyplot, fitted_clf, data, n_classes):

This comment has been minimized.

@kyouma

kyouma Sep 24, 2020
Author Contributor

Thank you.

kyouma and others added 4 commits Sep 24, 2020
…seen_labels()

- Replaced the assertion check
- Removed the unused import
The `labels` and `display_labels` parameters have been set to thier default values.

Co-authored-by: Thomas J. Fan <[email protected]>
@kyouma
Copy link
Contributor Author

@kyouma kyouma commented Sep 24, 2020

Thank you very much, I have implemented your suggestions and corrections. I have also added the |Fix| entry to doc/whats_new/v0.24.rst.

@kyouma kyouma requested a review from thomasjpfan Oct 1, 2020
Copy link
Member

@thomasjpfan thomasjpfan left a comment

Minor comments, otherwise LGTM

doc/whats_new/v0.24.rst Outdated Show resolved Hide resolved
doc/whats_new/v0.24.rst Outdated Show resolved Hide resolved
@kyouma
Copy link
Contributor Author

@kyouma kyouma commented Oct 1, 2020

Minor comments, otherwise LGTM

I have applied the suggested changes. Thank you for your guidance!

@glemaitre glemaitre self-requested a review Oct 21, 2020
Copy link
Contributor

@glemaitre glemaitre left a comment

LGTM. I will merge when the CIs will turn green.

@glemaitre glemaitre changed the title FIX for the behaviour of plot_confusion_matrix() with the argument 'labels' equal to 'None' FIX use unique values of y_true and y_pred in plot_confusion_matrix() instead of estimator.classes_ Oct 21, 2020
@glemaitre glemaitre changed the title FIX use unique values of y_true and y_pred in plot_confusion_matrix() instead of estimator.classes_ FIX use unique values of y_true and y_pred in plot_confusion_matrix instead of estimator.classes_ Oct 21, 2020
@glemaitre glemaitre merged commit 90b9b5d into scikit-learn:master Oct 21, 2020
18 of 20 checks passed
18 of 20 checks passed
@github-actions
triage
Details
@azure-pipelines
scikit-learn.scikit-learn Build #20201021.16 had test failures
Details
@azure-pipelines
scikit-learn.scikit-learn (macOS pylatest_conda_mkl) macOS pylatest_conda_mkl failed
Details
@lgtm-com
LGTM analysis: C/C++ No code changes detected
Details
@lgtm-com
LGTM analysis: JavaScript No code changes detected
Details
@lgtm-com
LGTM analysis: Python No new or fixed alerts
Details
ci/circleci: deploy Your tests passed on CircleCI!
Details
ci/circleci: doc Your tests passed on CircleCI!
Details
@circleci-artifacts-redirector
ci/circleci: doc artifact Link to 0/doc/_changed.html
Details
ci/circleci: doc-min-dependencies Your tests passed on CircleCI!
Details
ci/circleci: lint Your tests passed on CircleCI!
Details
@azure-pipelines
scikit-learn.scikit-learn (Linting) Linting succeeded
Details
@azure-pipelines
scikit-learn.scikit-learn (Linux py36_conda_openblas) Linux py36_conda_openblas succeeded
Details
@azure-pipelines
scikit-learn.scikit-learn (Linux py36_ubuntu_atlas) Linux py36_ubuntu_atlas succeeded
Details
@azure-pipelines
scikit-learn.scikit-learn (Linux pylatest_pip_openblas_pandas) Linux pylatest_pip_openblas_pandas succeeded
Details
@azure-pipelines
scikit-learn.scikit-learn (Linux32 py36_ubuntu_atlas_32bit) Linux32 py36_ubuntu_atlas_32bit succeeded
Details
@azure-pipelines
scikit-learn.scikit-learn (Linux_Runs pylatest_conda_mkl) Linux_Runs pylatest_conda_mkl succeeded
Details
@azure-pipelines
scikit-learn.scikit-learn (Windows py36_pip_openblas_32bit) Windows py36_pip_openblas_32bit succeeded
Details
@azure-pipelines
scikit-learn.scikit-learn (Windows py37_conda_mkl) Windows py37_conda_mkl succeeded
Details
@azure-pipelines
scikit-learn.scikit-learn (macOS pylatest_conda_mkl_no_openmp) macOS pylatest_conda_mkl_no_openmp succeeded
Details
thomasjpfan added a commit to thomasjpfan/scikit-learn that referenced this pull request Oct 28, 2020
…nstead of estimator.classes_ (scikit-learn#18405)

Co-authored-by: Thomas J. Fan <[email protected]>
Co-authored-by: Guillaume Lemaitre <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Linked issues

Successfully merging this pull request may close these issues.

None yet

4 participants