The Wayback Machine - https://web.archive.org/web/20210705010209/https://github.com/pytorch/pytorch/issues/60531
Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhanced torch.chunk and torch.split #60531

Open
nmichlo opened this issue Jun 23, 2021 · 6 comments
Open

Enhanced torch.chunk and torch.split #60531

nmichlo opened this issue Jun 23, 2021 · 6 comments

Comments

@nmichlo
Copy link

@nmichlo nmichlo commented Jun 23, 2021

🚀 Feature

Add drop_remainder: bool = True support to torch.chunk and torch.split

  • similar in function to drop_last in torch.utils.data.DataLoader
  • If the length of the dimension is not perfectly divisible, drop the remaining elements from the returned arrays.

Add redistribute: bool = True support to torch.chunk and torch.split

  • Spread the elements evenly across the chunks/splits if the length of the dimension is not perfectly divisible by the number of chunks or split size.
  • The length of returned tensors differs by at most one element.

Motivation

Often it is desirable to have evenly sizes arrays, or approximately even sized arrays.

A combination of drop_remainder and redistribute would enable new intuitive ways to use torch.split and torch.chunk.

Pitch

redistribute

Add the parameter redistribute to torch.chunk and torch.split which spreads the remainder evenly across the returned tensors. For example:

# CURRENT:
torch.arange(7).chunk(3, dim=-1, redistribute=False)
>>> (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6]))

# NEW:
torch.arange(7).chunk(3, dim=-1, redistribute=True)
>>> (tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6]))
# CURRENT:
torch.arange(7).split(3, dim=-1, redistribute=False)
>>> (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6]))

# NEW:
torch.arange(7).split(3, dim=-1, redistribute=True)
>>> (tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6]))

drop_remainder

Add the parameter drop_remainder to torch.chunk and torch.split which excludes the remainder elements from being returned, and so all the returned tensors will be the same size. For example:

# CURRENT:
torch.arange(7).chunk(3, dim=-1, drop_remainder=False)
>>> (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6]))

# NEW:
torch.arange(7).chunk(3, dim=-1, drop_remainder=True)
>>> (tensor([0, 1]), tensor([2, 3]), tensor([4, 5]))
# CURRENT:
torch.arange(7).split(3, dim=-1, drop_remainder=False)
>>> (tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6]))

# NEW:
torch.arange(7).split(3, dim=-1, drop_remainder=True)
>>> (tensor([0, 1, 2]), tensor([3, 4, 5]))

drop_remainder & redistribute interaction

drop_remainder should be applied before redistribute:

  • thus if both drop_remainder=True, then there is never a need to redistribute elements so in this case redistribute can be ignored.

Alternatives

New functions that provide similar behavior or act over the outputs from chunk or split, however this might be confusing in the first case or inefficient in the second case.

Possible synonyms for drop_remainder include:

  • drop_last matching torch.utils.data.DataLoader, but it may be argued that does not convey the correct behaviour in all cases?

Possible synonyms for redistribute include

  • redistribute_remainder
  • distribute_remainder
  • spread_remainder
  • distribute
  • spread

The _remainder does not describe the movement of the remainder/last elements well, as internal elements are shifted too.

Additional context

N/A

@gautamborad
Copy link

@gautamborad gautamborad commented Jun 24, 2021

I would like to work on this. Thanks!

@nmichlo
Copy link
Author

@nmichlo nmichlo commented Jun 24, 2021

Would it be better to include a “mode” parameter instead of the drop_remainder” and “redistribute”?

that might avoid confusion when both are True?
mode=“default”, mode=“spread”, mode=“drop”

if you really want to get fancy (although admittedly I don’t see much need for this) you could even have: drop_ends, drop_first, drop_last, spread_ends, spread_first, spread_last

@gautamborad
Copy link

@gautamborad gautamborad commented Jun 28, 2021

Agree that a mode parameter would be less confusing compared to individual parameters like redistribute and drop_remainder.
We can start with a couple of implementations first and add more later. Let me know what you think.

@gautamborad
Copy link

@gautamborad gautamborad commented Jun 29, 2021

@nmichlo, what will be the desired output for the following case?

torch.arange(12).split(5, dim=-1, redistribute=True)
# CURRENT:
>>> (tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9]), tensor([10, 11]))
@nmichlo
Copy link
Author

@nmichlo nmichlo commented Jun 29, 2021

@gautamborad you have got me there, my proposal is flawed. I didn't properly consider all the cases for split, my first concern was chunk.

@nmichlo
Copy link
Author

@nmichlo nmichlo commented Jun 29, 2021

Thinking about this more, the behaviour might only make sense for chunk?

  • With redistribute=True, unless the size of the dimension is less than the number of chunks, it should always return the requested number of chunks, especially with the values you just gave:

I have not looked at the source code, but it seems as though chunk internally wraps torch.split, which explains why it gives these odd results where the number of requested chunks is not respected:

> for i in range(1, 13):
>     chunks = torch.arange(i).chunk(5)  # same as: torch.arange(i).split(math.ceil(i/5))
>     print(f'{i}: expected_chunks={min(i, 5)}, num_chunks={len(chunks)}, chunks={chunks}')

>>> 1: expected_chunks=1, num_chunks=1, chunks=(tensor([0]),)
>>> 2: expected_chunks=2, num_chunks=2, chunks=(tensor([0]), tensor([1]))
>>> 3: expected_chunks=3, num_chunks=3, chunks=(tensor([0]), tensor([1]), tensor([2]))
>>> 4: expected_chunks=4, num_chunks=4, chunks=(tensor([0]), tensor([1]), tensor([2]), tensor([3]))
>>> 5: expected_chunks=5, num_chunks=5, chunks=(tensor([0]), tensor([1]), tensor([2]), tensor([3]), tensor([4]))
>>> 6: expected_chunks=5, num_chunks=3, chunks=(tensor([0, 1]), tensor([2, 3]), tensor([4, 5]))
>>> 7: expected_chunks=5, num_chunks=4, chunks=(tensor([0, 1]), tensor([2, 3]), tensor([4, 5]), tensor([6]))
>>> 8: expected_chunks=5, num_chunks=4, chunks=(tensor([0, 1]), tensor([2, 3]), tensor([4, 5]), tensor([6, 7]))
>>> 9: expected_chunks=5, num_chunks=5, chunks=(tensor([0, 1]), tensor([2, 3]), tensor([4, 5]), tensor([6, 7]), tensor([8]))
>>> 10: expected_chunks=5, num_chunks=5, chunks=(tensor([0, 1]), tensor([2, 3]), tensor([4, 5]), tensor([6, 7]), tensor([8, 9]))
>>> 11: expected_chunks=5, num_chunks=4, chunks=(tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([ 9, 10]))
>>> 12: expected_chunks=5, num_chunks=4, chunks=(tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([ 9, 10, 11]))

EDIT: this is what I imagine the updated chunk function to look like, although not exactly tested for bugs.

def custom_chunk(tensor: torch.Tensor, chunks: int, dim=0, redistributed=False, drop_remainder=False):
    # remove the trailing items
    if drop_remainder:
        r = tensor.size(dim) % chunks
        if r != 0:
            tensor = torch.moveaxis(torch.moveaxis(tensor, dim, 0)[:-r], 0, dim)
    # split into chunks
    if redistributed:
        # new chunking logic
        chunk_sizes = (tensor.size(dim) // chunks) + (np.arange(chunks) < (tensor.size(dim) % chunks))
        return tensor.split(chunk_sizes.tolist(), dim=dim)
    else:
        # original chunking logic
        # - should be the same as tensor.chunk(chunks)
        split_size = math.ceil(tensor.size(dim) / chunks)
        return tensor.split(split_size, dim=dim)

with example cases below, for 5 chunks:

for i in range(4, 13, 2):
    chunk_orig_alt   = custom_chunk(torch.arange(i), 5, redistributed=False, drop_remainder=False)  # same as vals.chunk(5)
    chunk_orig_alt_d = custom_chunk(torch.arange(i), 5, redistributed=False, drop_remainder=True)
    chunk_redist     = custom_chunk(torch.arange(i), 5, redistributed=True, drop_remainder=False)
    chunk_redist_d   = custom_chunk(torch.arange(i), 5, redistributed=True, drop_remainder=True)
    print(f'{i}: r=F, d=F : num_chunks={len(chunk_orig_alt)}, chunks={chunk_orig_alt}')
    print(f'{i}: r=F, d=T : num_chunks={len(chunk_orig_alt_d)}, chunks={chunk_orig_alt_d}')
    print(f'{i}: r=T, d=F : num_chunks={len(chunk_redist)}, chunks={chunk_redist}')
    print(f'{i}: r=T, d=T : num_chunks={len(chunk_redist_d)}, chunks={chunk_redist_d}')
    print()

giving:

4: r=F, d=F : num_chunks=4, chunks=(tensor([0]), tensor([1]), tensor([2]), tensor([3]))
4: r=F, d=T : num_chunks=1, chunks=(tensor([], dtype=torch.int64),)
4: r=T, d=F : num_chunks=5, chunks=(tensor([0]), tensor([1]), tensor([2]), tensor([3]), tensor([], dtype=torch.int64))
4: r=T, d=T : num_chunks=5, chunks=(tensor([], dtype=torch.int64), tensor([], dtype=torch.int64), tensor([], dtype=torch.int64), tensor([], dtype=torch.int64), tensor([], dtype=torch.int64))

6: r=F, d=F : num_chunks=3, chunks=(tensor([0, 1]), tensor([2, 3]), tensor([4, 5]))
6: r=F, d=T : num_chunks=5, chunks=(tensor([0]), tensor([1]), tensor([2]), tensor([3]), tensor([4]))
6: r=T, d=F : num_chunks=5, chunks=(tensor([0, 1]), tensor([2]), tensor([3]), tensor([4]), tensor([5]))
6: r=T, d=T : num_chunks=5, chunks=(tensor([0]), tensor([1]), tensor([2]), tensor([3]), tensor([4]))

8: r=F, d=F : num_chunks=4, chunks=(tensor([0, 1]), tensor([2, 3]), tensor([4, 5]), tensor([6, 7]))
8: r=F, d=T : num_chunks=5, chunks=(tensor([0]), tensor([1]), tensor([2]), tensor([3]), tensor([4]))
8: r=T, d=F : num_chunks=5, chunks=(tensor([0, 1]), tensor([2, 3]), tensor([4, 5]), tensor([6]), tensor([7]))
8: r=T, d=T : num_chunks=5, chunks=(tensor([0]), tensor([1]), tensor([2]), tensor([3]), tensor([4]))

10: r=F, d=F : num_chunks=5, chunks=(tensor([0, 1]), tensor([2, 3]), tensor([4, 5]), tensor([6, 7]), tensor([8, 9]))
10: r=F, d=T : num_chunks=5, chunks=(tensor([0, 1]), tensor([2, 3]), tensor([4, 5]), tensor([6, 7]), tensor([8, 9]))
10: r=T, d=F : num_chunks=5, chunks=(tensor([0, 1]), tensor([2, 3]), tensor([4, 5]), tensor([6, 7]), tensor([8, 9]))
10: r=T, d=T : num_chunks=5, chunks=(tensor([0, 1]), tensor([2, 3]), tensor([4, 5]), tensor([6, 7]), tensor([8, 9]))

12: r=F, d=F : num_chunks=4, chunks=(tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([ 9, 10, 11]))
12: r=F, d=T : num_chunks=5, chunks=(tensor([0, 1]), tensor([2, 3]), tensor([4, 5]), tensor([6, 7]), tensor([8, 9]))
12: r=T, d=F : num_chunks=5, chunks=(tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7]), tensor([8, 9]), tensor([10, 11]))
12: r=T, d=T : num_chunks=5, chunks=(tensor([0, 1]), tensor([2, 3]), tensor([4, 5]), tensor([6, 7]), tensor([8, 9]))

EDIT 2: or if modes are used instead, although there is no longer handling of the case redistribute=True and drop_remainder=True, with the difference noticeable if tensor.size(dim) < chunks.

def custom_chunk(tensor: torch.Tensor, chunks: int, dim=0, mode='split'):
    assert mode in {'error', 'drop_remainder', 'redistribute', 'split'}
    # check that it is perfectly divisible
    if mode == 'error':
        r = tensor.size(dim) % chunks
        if r != 0:
            raise ValueError('dimension must be divisible by the number of chunks')
    # remove the trailing items
    if mode == 'drop_remainder':
        r = tensor.size(dim) % chunks
        if r != 0:
            tensor = torch.moveaxis(torch.moveaxis(tensor, dim, 0)[:-r], 0, dim)
    # redistribute remainder
    elif mode == 'redistribute':
        # new chunking logic
        chunk_sizes = (tensor.size(dim) // chunks) + (np.arange(chunks) < (tensor.size(dim) % chunks))
        return tensor.split(chunk_sizes.tolist(), dim=dim)
    # fall back to old chunking logic
    split_size = math.ceil(tensor.size(dim) / chunks)
    return tensor.split(split_size, dim=dim)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Linked pull requests

Successfully merging a pull request may close this issue.

None yet
3 participants