5
5
import tensorflow as tf
6
6
7
7
8
- def read_image (path , to_bgr = True ):
9
- '''Returns the image at the given path as a tensor.'''
10
- # Read the file
11
- file_data = tf .read_file (path )
12
- # Figure out the image format from the extension
13
- ext = osp .splitext (path )[- 1 ].lower ()
14
- if ext == '.png' :
15
- decoder = tf .image .decode_png
16
- elif ext in ('.jpg' , '.jpeg' ):
17
- decoder = tf .image .decode_jpeg
18
- else :
19
- raise ValueError ('Unsupported image extension: {}' .format (ext ))
20
- img = decoder (file_data , channels = 3 )
21
- if to_bgr :
22
- # Convert from RGB channel ordering to BGR
23
- # This matches, for instance, how OpenCV orders the channels.
24
- img = tf .reverse (img , [False , False , True ])
25
- return img
26
-
27
-
28
- def _load_image (path , scale , isotropic , crop , mean ):
29
- '''Loads and pre-processes the image at the given path.
8
+ def process_image (img , scale , isotropic , crop , mean ):
9
+ '''Crops, scales, and normalizes the given image.
30
10
scale : The image wil be first scaled to this size.
31
11
If isotropic is true, the smaller side is rescaled to this,
32
12
preserving the aspect ratio.
33
13
crop : After scaling, a central crop of this size is taken.
34
14
mean : Subtracted from the image
35
15
'''
36
- # Read in the image
37
- img = read_image (path )
38
16
# Rescale
39
17
if isotropic :
40
18
img_shape = tf .to_float (tf .shape (img )[:2 ])
@@ -52,22 +30,136 @@ def _load_image(path, scale, isotropic, crop, mean):
52
30
return tf .to_float (img ) - mean
53
31
54
32
55
- def load_image (path , spec ):
56
- '''Load a single image, processed based on the given spec.'''
57
- return _load_image (path = path ,
58
- scale = spec .scale_size ,
59
- isotropic = spec .isotropic ,
60
- crop = spec .crop_size ,
61
- mean = spec .mean )
33
+ class ImageProducer (object ):
34
+ '''
35
+ Loads and processes batches of images in parallel.
36
+ '''
37
+
38
+ def __init__ (self , image_paths , data_spec , num_concurrent = 4 , batch_size = None , labels = None ):
39
+ # The data specifications describe how to process the image
40
+ self .data_spec = data_spec
41
+ # A list of full image paths
42
+ self .image_paths = image_paths
43
+ # An optional list of labels corresponding to each image path
44
+ self .labels = labels
45
+ # A boolean flag per image indicating whether its a JPEG or PNG
46
+ self .extension_mask = self .create_extension_mask (self .image_paths )
47
+ # Create the loading and processing operations
48
+ self .setup (batch_size = batch_size , num_concurrent = num_concurrent )
49
+
50
+ def setup (self , batch_size , num_concurrent ):
51
+ # Validate the batch size
52
+ num_images = len (self .image_paths )
53
+ batch_size = min (num_images , batch_size or self .data_spec .batch_size )
54
+ if num_images % batch_size != 0 :
55
+ raise ValueError (
56
+ 'The total number of images ({}) must be divisible by the batch size ({}).' .format (
57
+ num_images , batch_size ))
58
+ self .num_batches = num_images / batch_size
59
+
60
+ # Create a queue that will contain image paths (and their indices and extension indicator)
61
+ self .path_queue = tf .FIFOQueue (capacity = num_images ,
62
+ dtypes = [tf .int32 , tf .bool , tf .string ],
63
+ name = 'path_queue' )
64
+
65
+ # Enqueue all image paths, along with their indices
66
+ indices = tf .range (num_images )
67
+ self .enqueue_paths_op = self .path_queue .enqueue_many ([indices , self .extension_mask ,
68
+ self .image_paths ])
69
+ # Close the path queue (no more additions)
70
+ self .close_path_queue_op = self .path_queue .close ()
71
+
72
+ # Create an operation that dequeues a single path and returns a processed image
73
+ (idx , processed_image ) = self .process ()
74
+
75
+ # Create a queue that will contain the processed images (and their indices)
76
+ image_shape = (self .data_spec .crop_size , self .data_spec .crop_size , self .data_spec .channels )
77
+ processed_queue = tf .FIFOQueue (capacity = int (np .ceil (num_images / float (num_concurrent ))),
78
+ dtypes = [tf .int32 , tf .float32 ],
79
+ shapes = [(), image_shape ],
80
+ name = 'processed_queue' )
81
+
82
+ # Enqueue the processed image and path
83
+ enqueue_processed_op = processed_queue .enqueue ([idx , processed_image ])
84
+
85
+ # Create a dequeue op that fetches a batch of processed images off the queue
86
+ self .dequeue_op = processed_queue .dequeue_many (batch_size )
62
87
88
+ # Create a queue runner to perform the processing operations in parallel
89
+ num_concurrent = min (num_concurrent , num_images )
90
+ self .queue_runner = tf .train .QueueRunner (processed_queue ,
91
+ [enqueue_processed_op ] * num_concurrent )
63
92
64
- def load_images (paths , spec ):
65
- '''Load multiple images, processed based on the given spec.'''
66
- return tf .pack ([load_image (path , spec ) for path in paths ])
93
+ def start (self , session , coordinator , num_concurrent = 4 ):
94
+ '''Start the processing worker threads.'''
95
+ # Queue all paths
96
+ session .run (self .enqueue_paths_op )
97
+ # Close the path queue
98
+ session .run (self .close_path_queue_op )
99
+ # Start the queue runner and return the created threads
100
+ return self .queue_runner .create_threads (session , coord = coordinator , start = True )
67
101
102
+ def get (self , session ):
103
+ '''
104
+ Get a single batch of images along with their indices. If a set of labels were provided,
105
+ the corresponding labels are returned instead of the indices.
106
+ '''
107
+ (indices , images ) = session .run (self .dequeue_op )
108
+ if self .labels is not None :
109
+ labels = [self .labels [idx ] for idx in indices ]
110
+ return (labels , images )
111
+ return (indices , images )
68
112
69
- class ImageNet (object ):
70
- '''Iterates over the ImageNet validation set.'''
113
+ def batches (self , session ):
114
+ '''Yield a batch until no more images are left.'''
115
+ for _ in xrange (self .num_batches ):
116
+ yield self .get (session = session )
117
+
118
+ def load_image (self , image_path , is_jpeg ):
119
+ # Read the file
120
+ file_data = tf .read_file (image_path )
121
+ # Decode the image data
122
+ img = tf .cond (is_jpeg ,
123
+ lambda : tf .image .decode_jpeg (file_data , channels = 3 ),
124
+ lambda : tf .image .decode_png (file_data , channels = 3 ))
125
+ if self .data_spec .expects_bgr :
126
+ # Convert from RGB channel ordering to BGR
127
+ # This matches, for instance, how OpenCV orders the channels.
128
+ img = tf .reverse (img , [False , False , True ])
129
+ return img
130
+
131
+ def process (self ):
132
+ # Dequeue a single image path
133
+ idx , is_jpeg , image_path = self .path_queue .dequeue ()
134
+ # Load the image
135
+ img = self .load_image (image_path , is_jpeg )
136
+ # Process the image
137
+ processed_img = process_image (img = img ,
138
+ scale = self .data_spec .scale_size ,
139
+ isotropic = self .data_spec .isotropic ,
140
+ crop = self .data_spec .crop_size ,
141
+ mean = self .data_spec .mean )
142
+ # Return the processed image, along with its index
143
+ return (idx , processed_img )
144
+
145
+ @staticmethod
146
+ def create_extension_mask (paths ):
147
+
148
+ def is_jpeg (path ):
149
+ extension = osp .splitext (path )[- 1 ].lower ()
150
+ if extension in ('.jpg' , '.jpeg' ):
151
+ return True
152
+ if not extension == '.png' :
153
+ raise ValueError ('Unsupported image format: {}' .format (extension ))
154
+ return False
155
+
156
+ return [is_jpeg (p ) for p in paths ]
157
+
158
+ def __len__ (self ):
159
+ return len (self .image_paths )
160
+
161
+
162
+ class ImageNetProducer (ImageProducer ):
71
163
72
164
def __init__ (self , val_path , data_path , data_spec ):
73
165
# Read in the ground truth labels for the validation set
@@ -76,19 +168,10 @@ def __init__(self, val_path, data_path, data_spec):
76
168
gt_pairs = [line .split () for line in gt_lines ]
77
169
# Get the full image paths
78
170
# You will need a copy of the ImageNet validation set for this.
79
- self . image_paths = [osp .join (data_path , p [0 ]) for p in gt_pairs ]
171
+ image_paths = [osp .join (data_path , p [0 ]) for p in gt_pairs ]
80
172
# The corresponding ground truth labels
81
- self .labels = np .array ([int (p [1 ]) for p in gt_pairs ])
82
- # The data specifications for the model being validated (for preprocessing)
83
- self .data_spec = data_spec
84
-
85
- def batches (self , n ):
86
- '''Yields a batch of up to n preprocessed image tensors and their ground truth labels.'''
87
- for i in xrange (0 , len (self .image_paths ), n ):
88
- images = load_images (self .image_paths [i :i + n ], self .data_spec )
89
- labels = self .labels [i :i + n ]
90
- yield (images , labels )
91
-
92
- def __len__ (self ):
93
- '''Returns the number of instances in the validation set.'''
94
- return len (self .labels )
173
+ labels = np .array ([int (p [1 ]) for p in gt_pairs ])
174
+ # Initialize base
175
+ super (ImageNetProducer , self ).__init__ (image_paths = image_paths ,
176
+ data_spec = data_spec ,
177
+ labels = labels )
0 commit comments