Enhanced torch.chunk and torch.split #60531
Comments
I would like to work on this. Thanks! |
Would it be better to include a “mode” parameter instead of the drop_remainder” and “redistribute”? that might avoid confusion when both are True? if you really want to get fancy (although admittedly I don’t see much need for this) you could even have: drop_ends, drop_first, drop_last, spread_ends, spread_first, spread_last |
Agree that a |
@nmichlo, what will be the desired output for the following case? torch.arange(12).split(5, dim=-1, redistribute=True)
# CURRENT:
>>> (tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9]), tensor([10, 11])) |
@gautamborad you have got me there, my proposal is flawed. I didn't properly consider all the cases for |
Thinking about this more, the behaviour might only make sense for chunk?
I have not looked at the source code, but it seems as though chunk internally wraps torch.split, which explains why it gives these odd results where the number of requested chunks is not respected: > for i in range(1, 13):
> chunks = torch.arange(i).chunk(5) # same as: torch.arange(i).split(math.ceil(i/5))
> print(f'{i}: expected_chunks={min(i, 5)}, num_chunks={len(chunks)}, chunks={chunks}')
>>> 1: expected_chunks=1, num_chunks=1, chunks=(tensor([0]),)
>>> 2: expected_chunks=2, num_chunks=2, chunks=(tensor([0]), tensor([1]))
>>> 3: expected_chunks=3, num_chunks=3, chunks=(tensor([0]), tensor([1]), tensor([2]))
>>> 4: expected_chunks=4, num_chunks=4, chunks=(tensor([0]), tensor([1]), tensor([2]), tensor([3]))
>>> 5: expected_chunks=5, num_chunks=5, chunks=(tensor([0]), tensor([1]), tensor([2]), tensor([3]), tensor([4]))
>>> 6: expected_chunks=5, num_chunks=3, chunks=(tensor([0, 1]), tensor([2, 3]), tensor([4, 5]))
>>> 7: expected_chunks=5, num_chunks=4, chunks=(tensor([0, 1]), tensor([2, 3]), tensor([4, 5]), tensor([6]))
>>> 8: expected_chunks=5, num_chunks=4, chunks=(tensor([0, 1]), tensor([2, 3]), tensor([4, 5]), tensor([6, 7]))
>>> 9: expected_chunks=5, num_chunks=5, chunks=(tensor([0, 1]), tensor([2, 3]), tensor([4, 5]), tensor([6, 7]), tensor([8]))
>>> 10: expected_chunks=5, num_chunks=5, chunks=(tensor([0, 1]), tensor([2, 3]), tensor([4, 5]), tensor([6, 7]), tensor([8, 9]))
>>> 11: expected_chunks=5, num_chunks=4, chunks=(tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([ 9, 10]))
>>> 12: expected_chunks=5, num_chunks=4, chunks=(tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([ 9, 10, 11])) EDIT: this is what I imagine the updated chunk function to look like, although not exactly tested for bugs. def custom_chunk(tensor: torch.Tensor, chunks: int, dim=0, redistributed=False, drop_remainder=False):
# remove the trailing items
if drop_remainder:
r = tensor.size(dim) % chunks
if r != 0:
tensor = torch.moveaxis(torch.moveaxis(tensor, dim, 0)[:-r], 0, dim)
# split into chunks
if redistributed:
# new chunking logic
chunk_sizes = (tensor.size(dim) // chunks) + (np.arange(chunks) < (tensor.size(dim) % chunks))
return tensor.split(chunk_sizes.tolist(), dim=dim)
else:
# original chunking logic
# - should be the same as tensor.chunk(chunks)
split_size = math.ceil(tensor.size(dim) / chunks)
return tensor.split(split_size, dim=dim) with example cases below, for 5 chunks: for i in range(4, 13, 2):
chunk_orig_alt = custom_chunk(torch.arange(i), 5, redistributed=False, drop_remainder=False) # same as vals.chunk(5)
chunk_orig_alt_d = custom_chunk(torch.arange(i), 5, redistributed=False, drop_remainder=True)
chunk_redist = custom_chunk(torch.arange(i), 5, redistributed=True, drop_remainder=False)
chunk_redist_d = custom_chunk(torch.arange(i), 5, redistributed=True, drop_remainder=True)
print(f'{i}: r=F, d=F : num_chunks={len(chunk_orig_alt)}, chunks={chunk_orig_alt}')
print(f'{i}: r=F, d=T : num_chunks={len(chunk_orig_alt_d)}, chunks={chunk_orig_alt_d}')
print(f'{i}: r=T, d=F : num_chunks={len(chunk_redist)}, chunks={chunk_redist}')
print(f'{i}: r=T, d=T : num_chunks={len(chunk_redist_d)}, chunks={chunk_redist_d}')
print() giving: 4: r=F, d=F : num_chunks=4, chunks=(tensor([0]), tensor([1]), tensor([2]), tensor([3]))
4: r=F, d=T : num_chunks=1, chunks=(tensor([], dtype=torch.int64),)
4: r=T, d=F : num_chunks=5, chunks=(tensor([0]), tensor([1]), tensor([2]), tensor([3]), tensor([], dtype=torch.int64))
4: r=T, d=T : num_chunks=5, chunks=(tensor([], dtype=torch.int64), tensor([], dtype=torch.int64), tensor([], dtype=torch.int64), tensor([], dtype=torch.int64), tensor([], dtype=torch.int64))
6: r=F, d=F : num_chunks=3, chunks=(tensor([0, 1]), tensor([2, 3]), tensor([4, 5]))
6: r=F, d=T : num_chunks=5, chunks=(tensor([0]), tensor([1]), tensor([2]), tensor([3]), tensor([4]))
6: r=T, d=F : num_chunks=5, chunks=(tensor([0, 1]), tensor([2]), tensor([3]), tensor([4]), tensor([5]))
6: r=T, d=T : num_chunks=5, chunks=(tensor([0]), tensor([1]), tensor([2]), tensor([3]), tensor([4]))
8: r=F, d=F : num_chunks=4, chunks=(tensor([0, 1]), tensor([2, 3]), tensor([4, 5]), tensor([6, 7]))
8: r=F, d=T : num_chunks=5, chunks=(tensor([0]), tensor([1]), tensor([2]), tensor([3]), tensor([4]))
8: r=T, d=F : num_chunks=5, chunks=(tensor([0, 1]), tensor([2, 3]), tensor([4, 5]), tensor([6]), tensor([7]))
8: r=T, d=T : num_chunks=5, chunks=(tensor([0]), tensor([1]), tensor([2]), tensor([3]), tensor([4]))
10: r=F, d=F : num_chunks=5, chunks=(tensor([0, 1]), tensor([2, 3]), tensor([4, 5]), tensor([6, 7]), tensor([8, 9]))
10: r=F, d=T : num_chunks=5, chunks=(tensor([0, 1]), tensor([2, 3]), tensor([4, 5]), tensor([6, 7]), tensor([8, 9]))
10: r=T, d=F : num_chunks=5, chunks=(tensor([0, 1]), tensor([2, 3]), tensor([4, 5]), tensor([6, 7]), tensor([8, 9]))
10: r=T, d=T : num_chunks=5, chunks=(tensor([0, 1]), tensor([2, 3]), tensor([4, 5]), tensor([6, 7]), tensor([8, 9]))
12: r=F, d=F : num_chunks=4, chunks=(tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7, 8]), tensor([ 9, 10, 11]))
12: r=F, d=T : num_chunks=5, chunks=(tensor([0, 1]), tensor([2, 3]), tensor([4, 5]), tensor([6, 7]), tensor([8, 9]))
12: r=T, d=F : num_chunks=5, chunks=(tensor([0, 1, 2]), tensor([3, 4, 5]), tensor([6, 7]), tensor([8, 9]), tensor([10, 11]))
12: r=T, d=T : num_chunks=5, chunks=(tensor([0, 1]), tensor([2, 3]), tensor([4, 5]), tensor([6, 7]), tensor([8, 9])) EDIT 2: or if modes are used instead, although there is no longer handling of the case def custom_chunk(tensor: torch.Tensor, chunks: int, dim=0, mode='split'):
assert mode in {'error', 'drop_remainder', 'redistribute', 'split'}
# check that it is perfectly divisible
if mode == 'error':
r = tensor.size(dim) % chunks
if r != 0:
raise ValueError('dimension must be divisible by the number of chunks')
# remove the trailing items
if mode == 'drop_remainder':
r = tensor.size(dim) % chunks
if r != 0:
tensor = torch.moveaxis(torch.moveaxis(tensor, dim, 0)[:-r], 0, dim)
# redistribute remainder
elif mode == 'redistribute':
# new chunking logic
chunk_sizes = (tensor.size(dim) // chunks) + (np.arange(chunks) < (tensor.size(dim) % chunks))
return tensor.split(chunk_sizes.tolist(), dim=dim)
# fall back to old chunking logic
split_size = math.ceil(tensor.size(dim) / chunks)
return tensor.split(split_size, dim=dim) |
Add
drop_remainder: bool = True
support totorch.chunk
andtorch.split
drop_last
intorch.utils.data.DataLoader
Add
redistribute: bool = True
support totorch.chunk
andtorch.split
Motivation
Often it is desirable to have evenly sizes arrays, or approximately even sized arrays.
A combination of
drop_remainder
andredistribute
would enable new intuitive ways to usetorch.split
andtorch.chunk
.Pitch
redistribute
Add the parameter
redistribute
totorch.chunk
andtorch.split
which spreads the remainder evenly across the returned tensors. For example:drop_remainder
Add the parameter
drop_remainder
totorch.chunk
andtorch.split
which excludes the remainder elements from being returned, and so all the returned tensors will be the same size. For example:drop_remainder & redistribute interaction
drop_remainder
should be applied beforeredistribute
:drop_remainder=True
, then there is never a need to redistribute elements so in this caseredistribute
can be ignored.Alternatives
New functions that provide similar behavior or act over the outputs from
chunk
orsplit
, however this might be confusing in the first case or inefficient in the second case.Possible synonyms for
drop_remainder
include:drop_last
matchingtorch.utils.data.DataLoader
, but it may be argued that does not convey the correct behaviour in all cases?Possible synonyms for
redistribute
includeredistribute_remainder
distribute_remainder
spread_remainder
distribute
spread
The
_remainder
does not describe the movement of the remainder/last elements well, as internal elements are shifted too.Additional context
N/A
The text was updated successfully, but these errors were encountered: