The Wayback Machine - https://web.archive.org/web/20230516140835/https://github.com/matplotlib/matplotlib/issues/25882
Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: plt.hist takes significantly more time with torch and jax arrays #25882

Open
patel-zeel opened this issue May 13, 2023 · 7 comments · May be fixed by #25887
Open

[Bug]: plt.hist takes significantly more time with torch and jax arrays #25882

patel-zeel opened this issue May 13, 2023 · 7 comments · May be fixed by #25887

Comments

@patel-zeel
Copy link

patel-zeel commented May 13, 2023

Bug summary

Hi,

Time taken to plot plt.hist directly on jax or torch arrays is significantly more than combined time taken to first convert them to numpy and then using plt.hist. Shouldn't matplotlib internally convert them to numpy arrays before plotting?

To reproduce the bug, directly run the following snippet on Google Colab.

Code for reproduction

from time import time
import numpy as np

import torch

import jax
import jax.random as jr
import jax.numpy as jnp

import matplotlib.pyplot as plt

jax_array = jr.normal(jr.PRNGKey(0), (1000, 150))
torch_array = torch.randn(1000, 150)

def plot_hist(array):
    init = time()
    plt.figure()
    plt.hist(array)
    print(f"Time to plot: {time() - init:.2f} s")
    plt.show()
    
plot_hist(jax_array.ravel())
plot_hist(torch_array.ravel())
plot_hist(np.array(jax_array.ravel()))
plot_hist(np.array(torch_array.ravel()))

Actual outcome

Time to plot: 4.19 s
image

Time to plot: 2.61 s
image

Time to plot: 0.03 s
image

Time to plot: 0.04 s
image

Expected outcome

Time to plot: 0.03 s

Time to plot: 0.04 s

Time to plot: 0.03 s

Time to plot: 0.04 s

Additional information

What are the conditions under which this bug happens? input parameters, edge cases, etc?

It is happening with all kinds of shapes.

Has this worked in earlier versions?

Tested with default colab matplotlib version 3.7.1 and also with 3.6.3.

Do you know why this bug is happening?

Not exactly sure.

Do you maybe even know a fix?

Maybe convert any python object to a numpy array before plotting?

Operating system

Ubuntu 20.04.5 LTS

Matplotlib Version

3.7.1

Matplotlib Backend

module://matplotlib_inline.backend_inline

Python version

3.10.11

Jupyter version

6.4.8

Installation

None

@oscargus
Copy link
Contributor

The unpacking happens here:

def _unpack_to_numpy(x):
"""Internal helper to extract data from e.g. pandas and xarray objects."""
if isinstance(x, np.ndarray):
# If numpy, return directly
return x
if hasattr(x, 'to_numpy'):
# Assume that any to_numpy() method actually returns a numpy array
return x.to_numpy()
if hasattr(x, 'values'):
xtmp = x.values
# For example a dict has a 'values' attribute, but it is not a property
# so in this case we do not want to return a function
if isinstance(xtmp, np.ndarray):
return xtmp
return x

The pytorch tensor does not support any of the conversion methods, so Matplotlib doesn't really know what to do with it. There is a discussion in #22645 about this, but if I remember correctly we expect the libraries to support the to_numpy method (but still support the values attribute).

(I could not install jax, but I suppose something similar goes on there.)

@oscargus
Copy link
Contributor

And when the conversion doesn't work, it ends up in this loop:

result = []
is_1d = True
for xi in X:
# check if this is iterable, except for strings which we
# treat as singletons.
if not isinstance(xi, str):
try:
iter(xi)
except TypeError:
pass
else:
is_1d = False
xi = np.asanyarray(xi)
nd = np.ndim(xi)
if nd > 1:
raise ValueError(f'{name} must have 2 or fewer dimensions')
result.append(xi.reshape(-1))

which is where most of the time is spent.

@patel-zeel
Copy link
Author

Thanks for the quick response, @oscargus! Given that both these libraries support .__array__() method for conversion to numpy array, wouldn't it be easier to add one more if condition in _unpack_to_numpy to include them?

type(jax_array.__array__()), type(torch_array.__array__())
# Output: (numpy.ndarray, numpy.ndarray)

@oscargus
Copy link
Contributor

Yes, I also noted that. It probably can make sense.

(I think the reason why we do this somewhat carefully is for unit information to not get lost.)

Would you be interested in submitting a patch? I think that if this goes last in the conversion chain, it shouldn't break too many things... (A problem here is that we do not, yet, test for "all" types that possibly can be used and "works". There's been a discussion of having a special test suite for that, but it has not yet been implemented.)

@patel-zeel
Copy link
Author

patel-zeel commented May 13, 2023

Even tensorflow supports __array__() method. I guess these 3 libraries account for almost 99% of the machine learning codebase available online :) It'd be great if this conversion passes without breaking many things!

Sure, I'd submit a patch. I guess I need to only change the _unpack_to_numpy to the following, right?

def _unpack_to_numpy(x): 
     """Internal helper to extract data from e.g. pandas and xarray objects.""" 
     if isinstance(x, np.ndarray): 
         # If numpy, return directly 
         return x 
     if hasattr(x, 'to_numpy'): 
         # Assume that any to_numpy() method actually returns a numpy array 
         return x.to_numpy() 
     if hasattr(x, 'values'): 
         xtmp = x.values 
         # For example a dict has a 'values' attribute, but it is not a property 
         # so in this case we do not want to return a function 
         if isinstance(xtmp, np.ndarray): 
             return xtmp 
     if hasattr(x, '__array__'):
         # Assume that any to __array__() method returns a numpy array (e.g. TensorFlow, JAX or PyTorch arrays)
         return x.__array__()
     return x 

@timhoffm
Copy link
Member

Yes, but please verify that __array__ actually returns a numpy array, like we do with values above.

@patel-zeel
Copy link
Author

patel-zeel commented May 13, 2023

Thank you for the important suggestion, @timhoffm. Now, __array__ method check works in theory for the cases I imagined but np.float32 type objects get stuck into that check. When __array__ method is called on np.float32 object, it gets converted to ndarray type and eventually this leads to an infinite recursion.

A temporary fix I could figure out is to add two more if conditions to check if object is of type np.floating (includes all float types) and type np.integer (includes all integer types including uint). I can also include a boolean check. Will it be all or this already looks unpythonic?

More directions to solve this issue could be the following:

  1. Raise an issue to add to_numpy() methods in JAX and PyTorch repos.
  2. Raise an issue to have a universal numpy object checker type in NumPy library so that we can replace ndarray check with that. After this, any numpy object will be captured in the first check.
  3. Add hard-coded checks for JAX and PyTorch like the following:
if str(type(x)) == "<class 'torch.Tensor'>":
    return x.__array__()
if str(type(x)) == "<class 'jaxlib.xla_extension.ArrayImpl'>":
    return x.__array__()

I am open to your suggestions.

Edit1: np.generic works for most (all?) scalars, so we can add if isinstance(x, np.generic) as the second check just after ndarray check like the following:

def _unpack_to_numpy(x):
    """Internal helper to extract data from e.g. pandas and xarray objects."""
    if isinstance(x, np.ndarray):
        # If numpy array, return directly
        return x
    if isinstance(x, np.generic):
       # If numpy scalar, return directly
        return x
    if hasattr(x, 'to_numpy'):
        # Assume that any to_numpy() method actually returns a numpy array
        return x.to_numpy()
    if hasattr(x, 'values'):
        xtmp = x.values
        # For example a dict has a 'values' attribute, but it is not a property
        # so in this case we do not want to return a function
        if isinstance(xtmp, np.ndarray):
            return xtmp
    if hasattr(x, '__array__'):
        # Assume that any to __array__() method returns a numpy array (e.g. TensorFlow, JAX or PyTorch arrays)
        x = x.__array__()
        # Anything that doesn't return ndarray via __array__() method will be filtered by the following check
        if isinstance(x, np.ndarray):
            return x
    return x

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants