Skip to content

Commit 2d8649c

Browse files
author
Saumitro Dasgupta
committed
Rewrote the image loader+processor
The previous one was (insanely) sub-optimal
1 parent 13c7e9a commit 2d8649c

File tree

4 files changed

+180
-71
lines changed

4 files changed

+180
-71
lines changed

examples/imagenet/classify.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,18 +38,30 @@ def classify(model_data_path, image_paths):
3838
# Construct the network
3939
net = models.GoogleNet({'data': input_node})
4040

41+
# Create an image producer (loads and processes images in parallel)
42+
image_producer = dataset.ImageProducer(image_paths=image_paths, data_spec=spec)
43+
4144
with tf.Session() as sesh:
45+
# Start the image processing workers
46+
coordinator = tf.train.Coordinator()
47+
threads = image_producer.start(session=sesh, coordinator=coordinator)
48+
4249
# Load the converted parameters
4350
print('Loading the model')
4451
net.load(model_data_path, sesh)
52+
4553
# Load the input image
4654
print('Loading the images')
47-
input_images = dataset.load_images(image_paths, spec).eval()
55+
indices, input_images = image_producer.get(sesh)
56+
4857
# Perform a forward pass through the network to get the class probabilities
4958
print('Classifying')
5059
probs = sesh.run(net.get_output(), feed_dict={input_node: input_images})
51-
display_results(image_paths, probs)
60+
display_results([image_paths[i] for i in indices], probs)
5261

62+
# Stop the worker threads
63+
coordinator.request_stop()
64+
coordinator.join(threads, stop_grace_period_secs=2)
5365

5466
def main():
5567
# Parse arguments

examples/imagenet/dataset.py

Lines changed: 134 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -5,36 +5,14 @@
55
import tensorflow as tf
66

77

8-
def read_image(path, to_bgr=True):
9-
'''Returns the image at the given path as a tensor.'''
10-
# Read the file
11-
file_data = tf.read_file(path)
12-
# Figure out the image format from the extension
13-
ext = osp.splitext(path)[-1].lower()
14-
if ext == '.png':
15-
decoder = tf.image.decode_png
16-
elif ext in ('.jpg', '.jpeg'):
17-
decoder = tf.image.decode_jpeg
18-
else:
19-
raise ValueError('Unsupported image extension: {}'.format(ext))
20-
img = decoder(file_data, channels=3)
21-
if to_bgr:
22-
# Convert from RGB channel ordering to BGR
23-
# This matches, for instance, how OpenCV orders the channels.
24-
img = tf.reverse(img, [False, False, True])
25-
return img
26-
27-
28-
def _load_image(path, scale, isotropic, crop, mean):
29-
'''Loads and pre-processes the image at the given path.
8+
def process_image(img, scale, isotropic, crop, mean):
9+
'''Crops, scales, and normalizes the given image.
3010
scale : The image wil be first scaled to this size.
3111
If isotropic is true, the smaller side is rescaled to this,
3212
preserving the aspect ratio.
3313
crop : After scaling, a central crop of this size is taken.
3414
mean : Subtracted from the image
3515
'''
36-
# Read in the image
37-
img = read_image(path)
3816
# Rescale
3917
if isotropic:
4018
img_shape = tf.to_float(tf.shape(img)[:2])
@@ -52,22 +30,136 @@ def _load_image(path, scale, isotropic, crop, mean):
5230
return tf.to_float(img) - mean
5331

5432

55-
def load_image(path, spec):
56-
'''Load a single image, processed based on the given spec.'''
57-
return _load_image(path=path,
58-
scale=spec.scale_size,
59-
isotropic=spec.isotropic,
60-
crop=spec.crop_size,
61-
mean=spec.mean)
33+
class ImageProducer(object):
34+
'''
35+
Loads and processes batches of images in parallel.
36+
'''
37+
38+
def __init__(self, image_paths, data_spec, num_concurrent=4, batch_size=None, labels=None):
39+
# The data specifications describe how to process the image
40+
self.data_spec = data_spec
41+
# A list of full image paths
42+
self.image_paths = image_paths
43+
# An optional list of labels corresponding to each image path
44+
self.labels = labels
45+
# A boolean flag per image indicating whether its a JPEG or PNG
46+
self.extension_mask = self.create_extension_mask(self.image_paths)
47+
# Create the loading and processing operations
48+
self.setup(batch_size=batch_size, num_concurrent=num_concurrent)
49+
50+
def setup(self, batch_size, num_concurrent):
51+
# Validate the batch size
52+
num_images = len(self.image_paths)
53+
batch_size = min(num_images, batch_size or self.data_spec.batch_size)
54+
if num_images % batch_size != 0:
55+
raise ValueError(
56+
'The total number of images ({}) must be divisible by the batch size ({}).'.format(
57+
num_images, batch_size))
58+
self.num_batches = num_images / batch_size
59+
60+
# Create a queue that will contain image paths (and their indices and extension indicator)
61+
self.path_queue = tf.FIFOQueue(capacity=num_images,
62+
dtypes=[tf.int32, tf.bool, tf.string],
63+
name='path_queue')
64+
65+
# Enqueue all image paths, along with their indices
66+
indices = tf.range(num_images)
67+
self.enqueue_paths_op = self.path_queue.enqueue_many([indices, self.extension_mask,
68+
self.image_paths])
69+
# Close the path queue (no more additions)
70+
self.close_path_queue_op = self.path_queue.close()
71+
72+
# Create an operation that dequeues a single path and returns a processed image
73+
(idx, processed_image) = self.process()
74+
75+
# Create a queue that will contain the processed images (and their indices)
76+
image_shape = (self.data_spec.crop_size, self.data_spec.crop_size, self.data_spec.channels)
77+
processed_queue = tf.FIFOQueue(capacity=int(np.ceil(num_images / float(num_concurrent))),
78+
dtypes=[tf.int32, tf.float32],
79+
shapes=[(), image_shape],
80+
name='processed_queue')
81+
82+
# Enqueue the processed image and path
83+
enqueue_processed_op = processed_queue.enqueue([idx, processed_image])
84+
85+
# Create a dequeue op that fetches a batch of processed images off the queue
86+
self.dequeue_op = processed_queue.dequeue_many(batch_size)
6287

88+
# Create a queue runner to perform the processing operations in parallel
89+
num_concurrent = min(num_concurrent, num_images)
90+
self.queue_runner = tf.train.QueueRunner(processed_queue,
91+
[enqueue_processed_op] * num_concurrent)
6392

64-
def load_images(paths, spec):
65-
'''Load multiple images, processed based on the given spec.'''
66-
return tf.pack([load_image(path, spec) for path in paths])
93+
def start(self, session, coordinator, num_concurrent=4):
94+
'''Start the processing worker threads.'''
95+
# Queue all paths
96+
session.run(self.enqueue_paths_op)
97+
# Close the path queue
98+
session.run(self.close_path_queue_op)
99+
# Start the queue runner and return the created threads
100+
return self.queue_runner.create_threads(session, coord=coordinator, start=True)
67101

102+
def get(self, session):
103+
'''
104+
Get a single batch of images along with their indices. If a set of labels were provided,
105+
the corresponding labels are returned instead of the indices.
106+
'''
107+
(indices, images) = session.run(self.dequeue_op)
108+
if self.labels is not None:
109+
labels = [self.labels[idx] for idx in indices]
110+
return (labels, images)
111+
return (indices, images)
68112

69-
class ImageNet(object):
70-
'''Iterates over the ImageNet validation set.'''
113+
def batches(self, session):
114+
'''Yield a batch until no more images are left.'''
115+
for _ in xrange(self.num_batches):
116+
yield self.get(session=session)
117+
118+
def load_image(self, image_path, is_jpeg):
119+
# Read the file
120+
file_data = tf.read_file(image_path)
121+
# Decode the image data
122+
img = tf.cond(is_jpeg,
123+
lambda: tf.image.decode_jpeg(file_data, channels=3),
124+
lambda: tf.image.decode_png(file_data, channels=3))
125+
if self.data_spec.expects_bgr:
126+
# Convert from RGB channel ordering to BGR
127+
# This matches, for instance, how OpenCV orders the channels.
128+
img = tf.reverse(img, [False, False, True])
129+
return img
130+
131+
def process(self):
132+
# Dequeue a single image path
133+
idx, is_jpeg, image_path = self.path_queue.dequeue()
134+
# Load the image
135+
img = self.load_image(image_path, is_jpeg)
136+
# Process the image
137+
processed_img = process_image(img=img,
138+
scale=self.data_spec.scale_size,
139+
isotropic=self.data_spec.isotropic,
140+
crop=self.data_spec.crop_size,
141+
mean=self.data_spec.mean)
142+
# Return the processed image, along with its index
143+
return (idx, processed_img)
144+
145+
@staticmethod
146+
def create_extension_mask(paths):
147+
148+
def is_jpeg(path):
149+
extension = osp.splitext(path)[-1].lower()
150+
if extension in ('.jpg', '.jpeg'):
151+
return True
152+
if not extension == '.png':
153+
raise ValueError('Unsupported image format: {}'.format(extension))
154+
return False
155+
156+
return [is_jpeg(p) for p in paths]
157+
158+
def __len__(self):
159+
return len(self.image_paths)
160+
161+
162+
class ImageNetProducer(ImageProducer):
71163

72164
def __init__(self, val_path, data_path, data_spec):
73165
# Read in the ground truth labels for the validation set
@@ -76,19 +168,10 @@ def __init__(self, val_path, data_path, data_spec):
76168
gt_pairs = [line.split() for line in gt_lines]
77169
# Get the full image paths
78170
# You will need a copy of the ImageNet validation set for this.
79-
self.image_paths = [osp.join(data_path, p[0]) for p in gt_pairs]
171+
image_paths = [osp.join(data_path, p[0]) for p in gt_pairs]
80172
# The corresponding ground truth labels
81-
self.labels = np.array([int(p[1]) for p in gt_pairs])
82-
# The data specifications for the model being validated (for preprocessing)
83-
self.data_spec = data_spec
84-
85-
def batches(self, n):
86-
'''Yields a batch of up to n preprocessed image tensors and their ground truth labels.'''
87-
for i in xrange(0, len(self.image_paths), n):
88-
images = load_images(self.image_paths[i:i + n], self.data_spec)
89-
labels = self.labels[i:i + n]
90-
yield (images, labels)
91-
92-
def __len__(self):
93-
'''Returns the number of instances in the validation set.'''
94-
return len(self.labels)
173+
labels = np.array([int(p[1]) for p in gt_pairs])
174+
# Initialize base
175+
super(ImageNetProducer, self).__init__(image_paths=image_paths,
176+
data_spec=data_spec,
177+
labels=labels)

examples/imagenet/models/helper.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,18 @@
1212
from nin import NiN
1313
from resnet import ResNet50, ResNet101, ResNet152
1414

15+
1516
class DataSpec(object):
1617
'''Input data specifications for an ImageNet model.'''
1718

18-
def __init__(self, batch_size, scale_size, crop_size, isotropic, channels=3, mean=None):
19+
def __init__(self,
20+
batch_size,
21+
scale_size,
22+
crop_size,
23+
isotropic,
24+
channels=3,
25+
mean=None,
26+
bgr=True):
1927
# The recommended batch size for this model
2028
self.batch_size = batch_size
2129
# The image should be scaled to this size first during preprocessing
@@ -31,11 +39,15 @@ def __init__(self, batch_size, scale_size, crop_size, isotropic, channels=3, mea
3139
# Some of the earlier models (like AlexNet) used a spatial three-channeled mean.
3240
# However, using just the per-channel mean values instead doesn't affect things too much.
3341
self.mean = mean if mean is not None else np.array([104., 117., 124.])
42+
# Whether this model expects images to be in BGR order
43+
self.expects_bgr = True
44+
3445

3546
def alexnet_spec(batch_size=500):
3647
'''Parameters used by AlexNet and its variants.'''
3748
return DataSpec(batch_size=batch_size, scale_size=256, crop_size=227, isotropic=False)
3849

50+
3951
def std_spec(batch_size, isotropic=True):
4052
'''Parameters commonly used by "post-AlexNet" architectures.'''
4153
return DataSpec(batch_size=batch_size, scale_size=256, crop_size=224, isotropic=isotropic)
@@ -47,21 +59,13 @@ def std_spec(batch_size, isotropic=True):
4759
# These specifications are based on how the models were trained.
4860
# The recommended batch size is based on a Titan X (12GB).
4961
MODEL_DATA_SPECS = {
50-
5162
AlexNet: alexnet_spec(),
52-
5363
CaffeNet: alexnet_spec(),
54-
5564
GoogleNet: std_spec(batch_size=200, isotropic=False),
56-
5765
ResNet50: std_spec(batch_size=25),
58-
5966
ResNet101: std_spec(batch_size=25),
60-
6167
ResNet152: std_spec(batch_size=25),
62-
6368
NiN: std_spec(batch_size=500),
64-
6569
VGG16: std_spec(batch_size=224)
6670
}
6771

examples/imagenet/validate.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,14 @@ def load_model(name):
3434
return NetClass({'data': data_node})
3535

3636

37-
def validate(net, model_path, images, top_k=5):
37+
def validate(net, model_path, image_producer, top_k=5):
3838
'''Compute the top_k classification accuracy for the given network and images.'''
3939
# Get the data specifications for given network
4040
spec = models.get_data_spec(model_instance=net)
4141
# Get the input node for feeding in the images
4242
input_node = net.inputs['data']
4343
# Create a placeholder for the ground truth labels
44-
label_node = tf.placeholder(tf.int32, shape=(spec.batch_size,))
44+
label_node = tf.placeholder(tf.int32)
4545
# Get the output of the network (class probabilities)
4646
probs = net.get_output()
4747
# Create a top_k accuracy node
@@ -51,21 +51,29 @@ def validate(net, model_path, images, top_k=5):
5151
# The number of correctly classified images
5252
correct = 0
5353
# The total number of images
54-
total = len(images)
54+
total = len(image_producer)
55+
5556
with tf.Session() as sesh:
57+
coordinator = tf.train.Coordinator()
5658
# Load the converted parameters
57-
net.load(model_path, sesh)
59+
net.load(data_path=model_path, session=sesh)
60+
# Start the image processing workers
61+
threads = image_producer.start(session=sesh, coordinator=coordinator)
5862
# Iterate over and classify mini-batches
59-
for idx, (images, labels) in enumerate(images.batches(spec.batch_size)):
63+
for (labels, images) in image_producer.batches(sesh):
6064
correct += np.sum(sesh.run(top_k_op,
61-
feed_dict={input_node: images.eval(),
65+
feed_dict={input_node: images,
6266
label_node: labels}))
63-
count += images.get_shape()[0].value
67+
count += len(labels)
6468
cur_accuracy = float(correct) * 100 / count
6569
print('{:>6}/{:<6} {:>6.2f}%'.format(count, total, cur_accuracy))
70+
# Stop the worker threads
71+
coordinator.request_stop()
72+
coordinator.join(threads, stop_grace_period_secs=2)
6673
print('Top {} Accuracy: {}'.format(top_k, float(correct) / total))
6774

6875

76+
6977
def main():
7078
# Parse arguments
7179
parser = argparse.ArgumentParser()
@@ -82,10 +90,12 @@ def main():
8290

8391
# Load the dataset
8492
data_spec = models.get_data_spec(model_instance=net)
85-
images = dataset.ImageNet(args.val_gt, args.imagenet_data_dir, data_spec)
93+
image_producer = dataset.ImageNetProducer(val_path=args.val_gt,
94+
data_path=args.imagenet_data_dir,
95+
data_spec=data_spec)
8696

8797
# Evaluate its performance on the ILSVRC12 validation set
88-
validate(net, args.model_path, images)
98+
validate(net, args.model_path, image_producer)
8999

90100

91101
if __name__ == '__main__':

0 commit comments

Comments
 (0)