The Wayback Machine - https://web.archive.org/web/20220525121527/https://github.com/scikit-learn/scikit-learn/pull/22595
Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX Support extra class_weights in compute_class_weight #22595

Merged
merged 12 commits into from Mar 7, 2022

Conversation

thomasjpfan
Copy link
Member

@thomasjpfan thomasjpfan commented Feb 24, 2022

Reference Issues/PRs

Fixes #22413

What does this implement/fix? Explain your changes.

This PR:

  1. Changes the behavior of compute_class_weight to allow class_weight to contain other classes that is not in the classes.
  2. If none of the classes in class_weight appear in classes, then an error is raised.

Any other comments?

I think the intention of using class_weights as the source of truth is to prevent this type of error:

from sklearn.tree import DecisionTreeClassifier

# Misspelled "dog" in `class_weight`
tree = DecisionTreeClassifier(class_weight={"dogs": 2, "cat": 1})
tree.fit([[1, 2, 3], [1, 3, 2]], ["dog", "cat"])

with this PR, "dog" would not be weighted.

@glemaitre
Copy link
Contributor

@glemaitre glemaitre commented Mar 2, 2022

I find the behaviour a bit dangerous because you could potentially accept a key with a typo that will not be used.
I would expect to at least raise a warning when the class_weight is not used if you think that raising an error is too much constraining.

@thomasjpfan
Copy link
Member Author

@thomasjpfan thomasjpfan commented Mar 3, 2022

I updated the PR to be more strict when it comes to unweighted classes:

  1. If all classes have a weight from class_weight, then there is no error.
  2. If there is a class in y that is not in class_weight, then an error is raised.

This means that an error is raised:

from sklearn.tree import DecisionTreeClassifier

tree = DecisionTreeClassifier(class_weight={"dogs": 2, "cat": 1})

# Raises error
tree.fit([[1, 2, 3], [1, 3, 2]], ["dog", "cat"])
# ValueError: The classes, ['dog'], are not in class_weight

doc/whats_new/v1.1.rst Outdated Show resolved Hide resolved
Copy link
Contributor

@glemaitre glemaitre left a comment

Otherwise, LGTM.

@thomasjpfan thomasjpfan added the Quick Review label Mar 4, 2022
sklearn/utils/tests/test_class_weight.py Outdated Show resolved Hide resolved
@jjerphan jjerphan merged commit 3605c14 into scikit-learn:main Mar 7, 2022
27 checks passed
Diwakar-Gupta pushed a commit to Diwakar-Gupta/scikit-learn that referenced this issue Mar 7, 2022
thomasjpfan added a commit to thomasjpfan/scikit-learn that referenced this issue Mar 8, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module:utils Quick Review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants