Skip to content

Commit 0c5bf53

Browse files
authored
Merge pull request #49 from Scyfer/added_dropout_switch
Added a switch variable which allows it to disable dropout
2 parents f77c420 + 998d853 commit 0c5bf53

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

kaffe/tensorflow/network.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ def __init__(self, inputs, trainable=True):
4040
self.layers = dict(inputs)
4141
# If true, the resulting variables are set as trainable
4242
self.trainable = trainable
43+
# Switch variable for dropout
44+
self.use_dropout = tf.placeholder_with_default(tf.constant(1.0),
45+
shape=[],
46+
name='use_dropout')
4347
self.setup()
4448

4549
def setup(self):
@@ -236,4 +240,5 @@ def batch_normalization(self, input, name, scale_offset=True, relu=False):
236240

237241
@layer
238242
def dropout(self, input, keep_prob, name):
239-
return tf.nn.dropout(input, keep_prob, name=name)
243+
keep = 1 - self.use_dropout + (self.use_dropout * keep_prob)
244+
return tf.nn.dropout(input, keep, name=name)

0 commit comments

Comments
 (0)