Skip to content

Commit ad0ccd4

Browse files
author
Saumitro Dasgupta
committed
Squeeze out singleton dimensions for softmax
1 parent e7295b8 commit ad0ccd4

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

kaffe/tensorflow/network.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,15 @@ def fc(self, input, num_out, name, relu=True):
193193

194194
@layer
195195
def softmax(self, input, name):
196+
input_shape = map(lambda v: v.value, input.get_shape())
197+
if len(input_shape)>2:
198+
# For certain models (like NiN), the singleton spatial dimensions
199+
# need to be explicitly squeezed, since they're not broadcast-able
200+
# in TensorFlow's NHWC ordering (unlike Caffe's NCHW).
201+
if input_shape[1]==1 and input_shape[2]==1:
202+
input = tf.squeeze(input, squeeze_dims=[1, 2])
203+
else:
204+
raise ValueError('Rank 2 tensor input expected for softmax!')
196205
return tf.nn.softmax(input, name)
197206

198207
@layer

0 commit comments

Comments
 (0)