Description
ð Bug
TopK is slower than a divide and conquer implementation with k=100 and dimension of the input 2_000_000
To Reproduce
Steps to reproduce the behavior:
def gpu_time(f, reps=100):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for i in range(reps):
y = f()
end_event.record()
torch.cuda.synchronize() # Wait for the events to be recorded!
elapsed_time_ms = start_event.elapsed_time(end_event)
print("{} ms".format(elapsed_time_ms / reps))
return y
def divide_and_conquer_topk(x, divide, k):
num_outputs = x.shape[1]
batch_size = x.shape[0]
assert num_outputs % divide == 0
partial = torch.topk(x.view(-1, divide, num_outputs // divide), k=k, dim=-1)
indices_view = torch.arange(num_outputs, device='cuda').view(1, divide, num_outputs // divide)
indices_view_2 = indices_view.expand((batch_size, divide, num_outputs // divide)).gather(index=partial.indices, dim=2).view(batch_size, -1)
values_2d = partial.values.view(batch_size, -1)
topk_2 = torch.topk(values_2d, k=k, dim=-1)
return indices_view_2.gather(index=topk_2.indices, dim=-1), topk_2.values
b = 512
n = 2000000
k = 100
cuda = torch.device('cuda')
gpu_time(lambda: torch.topk(x, k, dim=-1));
Took 104.4182421875 ms,
while
gpu_time(lambda: divide_and_conquer_topk(x, 100, k));
took 86.761689453125 ms
On the CPU
x = torch.rand(b,n,device='cpu')
The output of this is
%%time
torch.topk(x, k, dim=-1)
CPU times: user 10.3 s, sys: 11.9 ms, total: 10.3 s
while the following
%%time
divide_and_conquer_topk(x, 100, k)
outputs CPU times: user 7.19 s, sys: 88.8 ms, total: 7.28 s.
Expected behavior
I did not expect such a simple implementation of topk outperforming pytorch native one.
I wonder if there is something suboptimal in topk. It might be specific for this case and so not very interesting but I thought it was surprising enough to report.
Environment
Collecting environment information...
PyTorch version: 1.4.0
Is debug build: No
CUDA used to build PyTorch: 10.1
OS: Amazon Linux AMI 2018.03
GCC version: (GCC) 4.8.5 20150623 (Red Hat 4.8.5-28)
CMake version: version 3.13.3
Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 10.0.130
GPU models and configuration: GPU 0: Tesla V100-SXM2-16GB
Nvidia driver version: 440.33.01
cuDNN version: Probably one of the following:
/usr/local/cuda-10.1/targets/x86_64-linux/lib/libcudnn.so.7.6.5
/usr/local/cuda-10.2/targets/x86_64-linux/lib/libcudnn.so.7.6.5
Versions of relevant libraries:
[pip3] numpy==1.15.4
[pip3] numpydoc==0.8.0
[pip3] torch==1.4.0
[pip3] torchvision==0.5.0
[conda] blas 1.0 mkl
[conda] cudatoolkit 10.1.243 h6bb024c_0
[conda] mkl 2018.0.3 1
[conda] mkl-service 1.1.2 py36h17a0993_4
[conda] mkl_fft 1.0.6 py36h7dd41cf_0
[conda] mkl_random 1.0.1 py36h629b387_0
[conda] numpy 1.15.4 py36h1d66e8a_0
[conda] numpy-base 1.15.4 py36h81de0dd_0
[conda] numpydoc 0.8.0 py36_0
[conda] pytorch 1.4.0 py3.6_cuda10.1.243_cudnn7.6.3_0 pytorch
[conda] torchvision 0.5.0 py36_cu101 pytorch