Source code for megadetector.detection.tf_detector

"""

tf_detector.py

Module containing the class TFDetector, for loading and running a TensorFlow detection model.

"""

#%% Imports and constants

import numpy as np

from megadetector.detection.run_detector import \
    CONF_DIGITS, COORD_DIGITS, FAILURE_INFER
from megadetector.utils.ct_utils import truncate_float

import tensorflow.compat.v1 as tf

print('TensorFlow version:', tf.__version__)
print('Is GPU available? tf.test.is_gpu_available:', tf.test.is_gpu_available())


#%% Classes

[docs] class TFDetector: """ A detector model loaded at the time of initialization. It is intended to be used with TensorFlow-based versions of MegaDetector (v2, v3, or v4). If someone can find v1, I suppose you could use this class for v1 also. """ #: TF versions of MD were trained with batch size of 1, and the resizing function is a #: part of the inference graph, so this is fixed. #: #: :meta private: BATCH_SIZE = 1 def __init__(self, model_path, detector_options=None): """ Loads a model from [model_path] and starts a tf.Session with this graph. Obtains input and output tensor handles. Args: model_path (str): path to .pb file detector_options (dict, optional): key-value pairs that control detector options; currently not used by TFDetector """ detection_graph = TFDetector.__load_model(model_path) self.tf_session = tf.Session(graph=detection_graph) self.image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') self.box_tensor = detection_graph.get_tensor_by_name('detection_boxes:0') self.score_tensor = detection_graph.get_tensor_by_name('detection_scores:0') self.class_tensor = detection_graph.get_tensor_by_name('detection_classes:0') @staticmethod def __round_and_make_float(d, precision=4): return truncate_float(float(d), precision=precision) @staticmethod def __convert_coords(tf_coords): """ Converts coordinates from the model's output format [y1, x1, y2, x2] to the format used by our API and MegaDB: [x1, y1, width, height]. All coordinates (including model outputs) are normalized in the range [0, 1]. Args: tf_coords: np.array of predicted bounding box coordinates from the TF detector, has format [y1, x1, y2, x2] Returns: list of Python float, predicted bounding box coordinates [x1, y1, width, height] """ # change from [y1, x1, y2, x2] to [x1, y1, width, height] width = tf_coords[3] - tf_coords[1] height = tf_coords[2] - tf_coords[0] new = [tf_coords[1], tf_coords[0], width, height] # must be a list instead of np.array # convert numpy floats to Python floats for i, d in enumerate(new): new[i] = TFDetector.__round_and_make_float(d, precision=COORD_DIGITS) return new @staticmethod def __load_model(model_path): """ Loads a detection model (i.e., create a graph) from a .pb file. Args: model_path: .pb file of the model. Returns: the loaded graph. """ print('TFDetector: Loading graph...') detection_graph = tf.Graph() with detection_graph.as_default(): od_graph_def = tf.GraphDef() with tf.gfile.GFile(model_path, 'rb') as fid: serialized_graph = fid.read() od_graph_def.ParseFromString(serialized_graph) tf.import_graph_def(od_graph_def, name='') print('TFDetector: Detection graph loaded.') return detection_graph def _generate_detections_one_image(self, image): """ Runs the detector on a single image. """ if isinstance(image,np.ndarray): np_im = image else: np_im = np.asarray(image, np.uint8) im_w_batch_dim = np.expand_dims(np_im, axis=0) # need to change the above line to the following if supporting a batch size > 1 and resizing to the same size # np_images = [np.asarray(image, np.uint8) for image in images] # images_stacked = np.stack(np_images, axis=0) if len(images) > 1 else np.expand_dims(np_images[0], axis=0) # performs inference (box_tensor_out, score_tensor_out, class_tensor_out) = self.tf_session.run( [self.box_tensor, self.score_tensor, self.class_tensor], feed_dict={self.image_tensor: im_w_batch_dim}) return box_tensor_out, score_tensor_out, class_tensor_out
[docs] def generate_detections_one_image(self, image, image_id, detection_threshold, image_size=None, augment=False, verbose=False): """ Runs the detector on an image. Args: image (Image): the PIL Image object (or numpy array) on which we should run the detector, with EXIF rotation already handled. image_id (str): a path to identify the image; will be in the "file" field of the output object detection_threshold (float): only detections above this threshold will be included in the return value image_size (tuple, optional): image size to use for inference, only mess with this if (a) you're using a model other than MegaDetector or (b) you know what you're doing augment (bool, optional): enable image augmentation. Not currently supported, but included here for compatibility with PTDetector. verbose (bool, optional): enable additional debug output Returns: dict: a dictionary with the following fields: - 'file' (filename, always present) - 'max_detection_conf' (removed from MegaDetector output files by default, but generated here) - 'detections' (a list of detection objects containing keys 'category', 'conf', and 'bbox') - 'failure' (a failure string, or None if everything went fine) """ assert image_size is None, 'Image sizing not supported for TF detectors' assert not augment, 'Image augmentation is not supported for TF detectors' if detection_threshold is None: detection_threshold = 0 result = { 'file': image_id } try: b_box, b_score, b_class = self._generate_detections_one_image(image) # our batch size is 1; need to loop the batch dim if supporting batch size > 1 boxes, scores, classes = b_box[0], b_score[0], b_class[0] detections_cur_image = [] # will be empty for an image with no confident detections max_detection_conf = 0.0 for b, s, c in zip(boxes, scores, classes): #noqa if s > detection_threshold: detection_entry = { 'category': str(int(c)), # use string type for the numerical class label, not int 'conf': truncate_float(float(s), # cast to float for json serialization precision=CONF_DIGITS), 'bbox': TFDetector.__convert_coords(b) } detections_cur_image.append(detection_entry) if s > max_detection_conf: max_detection_conf = s result['max_detection_conf'] = truncate_float(float(max_detection_conf), precision=CONF_DIGITS) result['detections'] = detections_cur_image except Exception as e: result['failure'] = FAILURE_INFER print('TFDetector: image {} failed during inference: {}'.format(image_id, str(e))) return result
# ...def generate_detections_one_image(...) # ...class TFDetector