"""
run_detector_batch.py
Module to run MegaDetector on lots of images, writing the results
to a file in the MegaDetector results format.
https://lila.science/megadetector-output-format
This enables the results to be used in our post-processing pipeline; see postprocess_batch_results.py.
This script can save results to checkpoints intermittently, in case disaster
strikes. To enable this, set --checkpoint_frequency to n > 0, and results
will be saved as a checkpoint every n images. Checkpoints will be written
to a file in the same directory as the output_file, and after all images
are processed and final results file written to output_file, the temporary
checkpoint file will be deleted. If you want to resume from a checkpoint, set
the checkpoint file's path using --resume_from_checkpoint.
Has multiprocessing support for CPUs only; if a GPU is available, it will
use the GPU instead of CPUs, and the --ncores option will be ignored. Checkpointing
is not supported when using a GPU.
The lack of GPU multiprocessing support might sound annoying, but in practice we
run a gazillion MegaDetector images on multiple GPUs using this script, we just only use
one GPU *per invocation of this script*. Dividing a list of images into one chunk
per GPU happens outside of this script.
Does not have a command-line option to bind the process to a particular GPU, but you can
prepend with "CUDA_VISIBLE_DEVICES=0 ", for example, to bind to GPU 0, e.g.:
CUDA_VISIBLE_DEVICES=0 python detection/run_detector_batch.py md_v4.1.0.pb ~/data ~/mdv4test.json
You can disable GPU processing entirely by setting CUDA_VISIBLE_DEVICES=''.
"""
#%% Constants, imports, environment
import argparse
import json
import os
import sys
import time
import copy
import shutil
import random
import warnings
import itertools
import humanfriendly
from datetime import datetime
from functools import partial
from copy import deepcopy
from tqdm import tqdm
import multiprocessing
from threading import Thread
from multiprocessing import Process, Manager
# This pool is used for multi-CPU parallelization, not for data loading workers
# from multiprocessing.pool import ThreadPool as workerpool
from multiprocessing.pool import Pool as workerpool
from megadetector.detection import run_detector
from megadetector.detection.run_detector import \
is_gpu_available,\
load_detector,\
try_download_known_detector,\
get_detector_version_from_filename,\
get_detector_metadata_from_version_string
from megadetector.utils import path_utils
from megadetector.utils import ct_utils
from megadetector.utils.ct_utils import parse_kvp_list
from megadetector.utils.ct_utils import split_list_into_n_chunks
from megadetector.utils.ct_utils import sort_list_of_dicts_by_key
from megadetector.visualization import visualization_utils as vis_utils
from megadetector.data_management import read_exif
from megadetector.data_management.yolo_output_to_md_output import read_classes_from_yolo_dataset_file
# Numpy FutureWarnings from tensorflow import
warnings.filterwarnings('ignore', category=FutureWarning)
# Default number of loaders to use when --image_queue is set
default_loaders = 4
# Should we do preprocessing on the image queue?
default_preprocess_on_image_queue = False
# Number of images to pre-fetch per worker
max_queue_size = 10
# How often should we print progress when using the image queue?
n_queue_print = 1000
# Only used if --include_exif_tags or --include_image_timestamp are supplied
exif_options_base = read_exif.ReadExifOptions()
exif_options_base.processing_library = 'pil'
exif_options_base.byte_handling = 'convert_to_string'
# Only relevant when we're running our test harness; because bugs in batch
# inference are dependent on batch grouping, we randomize batch grouping
# during testing to maximize the probability that latent bugs come up
# eventually.
randomize_batch_order_during_testing = True
# TODO: it's a little sloppy that the following are module-level globals, but in practice it
# doesn't really matter, so I'm not in a big rush to move these to options until I do
# a larger cleanup of all the long argument lists in this module.
# Should the consumer loop run on its own process, or here in the main process?
run_separate_consumer_process = False
# Enable additional debug output
verbose = False
# File format version
current_format_version = '1.6'
#%% Support functions for multiprocessing
def _producer_func(q,
image_files,
producer_id=-1,
preprocessor=None,
detector_options=None,
verbose=False,
image_size=None,
augment=None):
"""
Producer function; only used when using the (optional) image queue.
Reads images from disk and puts, optionally preprocesses them (depending on whether "preprocessor"
is None, then puts them on the blocking queue for processing. Each image is queued as a tuple of
[filename,Image]. Sends "None" to the queue when finished.
The "detector" argument is only used for preprocessing.
Args:
q (Queue): multiprocessing queue to put loaded/preprocessed images into
image_files (list): list of image file paths to process
producer_id (int, optional): identifier for this producer worker (for logging)
preprocessor (str, optional): model file path/identifier for preprocessing, or None to skip preprocessing
detector_options (dict, optional): key/value pairs that are interpreted differently
by different detectors
verbose (bool, optional): enable additional debug output
image_size (int, optional): image size to use for preprocessing
augment (bool, optional): enable image augmentation during preprocessing
"""
if verbose:
print('Producer starting: ID {}, preprocessor {}'.format(producer_id,preprocessor))
sys.stdout.flush()
if preprocessor is not None:
assert isinstance(preprocessor,str)
detector_options = deepcopy(detector_options)
# Tell the detector object it's being loaded as a preprocessor, so it
# shouldn't actually load model weights.
detector_options['preprocess_only'] = True
preprocessor = load_detector(preprocessor,
detector_options=detector_options,
verbose=verbose)
for im_file in image_files:
try:
image = vis_utils.load_image(im_file)
if preprocessor is not None:
image_info = preprocessor.preprocess_image(image,
image_id=im_file,
image_size=image_size,
verbose=verbose)
if 'failure' in image_info:
assert image_info['failure'] == run_detector.FAILURE_INFER
raise
image = image_info
except Exception as e:
print('Producer process: image {} cannot be loaded:\n{}'.format(im_file,str(e)))
image = run_detector.FAILURE_IMAGE_OPEN
q.put([im_file,image,producer_id])
# ...for each image
# This is a signal to the consumer function that a worker has finished
q.put(None)
if verbose:
print('Loader worker {} finished'.format(producer_id))
sys.stdout.flush()
# ...def _producer_func(...)
def _consumer_func(q,
return_queue,
model_file,
confidence_threshold,
loader_workers,
image_size=None,
include_image_size=False,
include_image_timestamp=False,
include_exif_tags=None,
augment=False,
detector_options=None,
preprocess_on_image_queue=default_preprocess_on_image_queue,
n_total_images=None,
batch_size=1,
checkpoint_path=None,
checkpoint_frequency=-1
):
"""
Consumer function; only used when using the (optional) image queue.
Pulls images from a blocking queue and processes them. Returns when "None" has
been read from each loader's queue.
Args:
q (Queue): multiprocessing queue to pull images from
return_queue (Queue): queue to put final results into
model_file (str or detector object): model file path/identifier or pre-loaded detector
confidence_threshold (float): only detections above this threshold are returned
loader_workers (int): number of producer workers (used to know when all are finished)
image_size (int, optional): image size to use for inference
include_image_size (bool, optional): include image dimensions in output
include_image_timestamp (bool, optional): include image timestamps in output
include_exif_tags (str, optional): comma-separated list of EXIF tags to include in output
augment (bool, optional): enable image augmentation
detector_options (dict, optional): key/value pairs that are interpreted differently
by different detectors
preprocess_on_image_queue (bool, optional): whether images are already preprocessed on
the queue
n_total_images (int, optional): total number of images expected (for progress bar)
batch_size (int, optional): batch size for GPU inference
checkpoint_path (str, optional): path to write checkpoint files, None disables
checkpointing
checkpoint_frequency (int, optional): write checkpoint every N images, -1 disables
checkpointing
"""
if verbose:
print('Consumer starting'); sys.stdout.flush()
start_time = time.time()
if isinstance(model_file,str):
detector = load_detector(model_file,
detector_options=detector_options,
verbose=verbose)
elapsed = time.time() - start_time
print('Loaded model (before queueing) in {}, printing updates every {} images'.format(
humanfriendly.format_timespan(elapsed),n_queue_print))
sys.stdout.flush()
else:
detector = model_file
print('Detector of type {} passed to consumer function'.format(type(detector)))
results = []
n_images_processed = 0
n_queues_finished = 0
last_checkpoint_count = 0
def _should_write_checkpoint():
"""
Check whether we should write a checkpoint. Returns True if we've crossed a
checkpoint boundary.
"""
if (checkpoint_frequency <= 0) or (checkpoint_path is None):
return False
# Calculate the checkpoint threshold we should have crossed
current_checkpoint_threshold = \
(n_images_processed // checkpoint_frequency) * checkpoint_frequency
last_checkpoint_threshold = \
(last_checkpoint_count // checkpoint_frequency) * checkpoint_frequency
# We should write a checkpoint if we've crossed into a new checkpoint interval
return (current_checkpoint_threshold > last_checkpoint_threshold)
pbar = None
if n_total_images is not None:
# TODO: in principle I should close this pbar
pbar = tqdm(total=n_total_images)
# Batch processing state
if batch_size > 1:
current_batch_items = []
while True:
r = q.get()
# Is this the last image in one of the producer queues?
if r is None:
n_queues_finished += 1
q.task_done()
if verbose:
print('Consumer thread: {} of {} queues finished'.format(
n_queues_finished,loader_workers))
# Was this the last worker to finish?
if n_queues_finished == loader_workers:
# Do we have any leftover images?
if (batch_size > 1) and (len(current_batch_items) > 0):
# We should never have more than one batch of work left to do, so this loop
# not strictly necessary; it's a bit of future-proofing.
leftover_batches = _group_into_batches(current_batch_items, batch_size)
if len(leftover_batches) > 1:
print('Warning: after all producer queues finished, '
'{} images were left for processing, which is more than'
'the batch size of {}'.format(len(current_batch_items),batch_size))
for leftover_batch in leftover_batches:
batch_results = _process_batch(leftover_batch,
detector,
confidence_threshold,
quiet=True,
image_size=image_size,
include_image_size=include_image_size,
include_image_timestamp=include_image_timestamp,
include_exif_tags=include_exif_tags,
augment=augment)
results.extend(batch_results)
if pbar is not None:
pbar.update(len(leftover_batch))
n_images_processed += len(leftover_batch)
# In theory we could write a checkpoint here, but because we're basically
# done at this point, there's not much upside to writing another checkpoint,
# so for simplicity, I'm skipping it.
# ...for each batch we have left to process
return_queue.put(results)
return
else:
continue
# ...if we pulled the sentinel signal (None) telling us that a worker finished
# At this point, we have a real image (i.e., not a sentinel indicating that a worker finished)
#
# "r" is always a tuple of (filename,image,producer_id)
#
# Image can be a PIL image (if the loader wasn't doing preprocessing) or a dict with
# a preprocessed image and associated metadata.
im_file = r[0]
image = r[1]
# Handle failed images immediately (don't batch them)
#
# Loader workers communicate failures by passing a string to
# the consumer, rather than an image.
if isinstance(image,str):
results.append({'file': im_file,
'failure': image})
n_images_processed += 1
if pbar is not None:
pbar.update(1)
# This is a catastrophic internal failure; preprocessing workers should
# be passing the consumer dicts that represent processed images
elif preprocess_on_image_queue and (not isinstance(image,dict)):
print('Expected a dict, received an image of type {}'.format(type(image)))
results.append({'file': im_file,
'failure': 'illegal image type'})
n_images_processed += 1
if pbar is not None:
pbar.update(1)
else:
# At this point, "image" is either an image (if the producer workers are only
# doing loading) or a dict (if the producer workers are doing preprocessing)
if batch_size > 1:
# Add to current batch
current_batch_items.append([im_file, image, r[2]])
# Process batch when full
if len(current_batch_items) >= batch_size:
batch_results = _process_batch(current_batch_items,
detector,
confidence_threshold,
quiet=True,
image_size=image_size,
include_image_size=include_image_size,
include_image_timestamp=include_image_timestamp,
include_exif_tags=include_exif_tags,
augment=augment)
results.extend(batch_results)
if pbar is not None:
pbar.update(len(current_batch_items))
n_images_processed += len(current_batch_items)
current_batch_items = []
else:
# Process single image
result = _process_image(im_file=im_file,
detector=detector,
confidence_threshold=confidence_threshold,
image=image,
quiet=True,
image_size=image_size,
include_image_size=include_image_size,
include_image_timestamp=include_image_timestamp,
include_exif_tags=include_exif_tags,
augment=augment)
results.append(result)
n_images_processed += 1
if pbar is not None:
pbar.update(1)
# ...if we are/aren't doing batch processing
# Write checkpoint if necessary
if _should_write_checkpoint():
print('Consumer: writing checkpoint after {} images'.format(
n_images_processed))
write_checkpoint(checkpoint_path, results)
last_checkpoint_count = n_images_processed
# ...whether we received a string (indicating failure) or an image from the loader worker
q.task_done()
# ...while True (consumer loop)
# ...def _consumer_func(...)
def _run_detector_with_image_queue(image_files,
model_file,
confidence_threshold,
quiet=False,
image_size=None,
include_image_size=False,
include_image_timestamp=False,
include_exif_tags=None,
augment=False,
detector_options=None,
loader_workers=default_loaders,
preprocess_on_image_queue=default_preprocess_on_image_queue,
batch_size=1,
checkpoint_path=None,
checkpoint_frequency=-1,
use_threads_for_queue=False):
"""
Driver function for the (optional) multiprocessing-based image queue. Spawns workers to read and
preprocess images, runs the consumer function in the calling process.
Args:
image_files (str): list of absolute paths to images
model_file (str): filename or model identifier (e.g. "MDV5A")
confidence_threshold (float): minimum confidence detection to include in
output
quiet (bool, optional): suppress per-image console printouts
image_size (int, optional): image size to use for inference, only mess with this
if (a) you're using a model other than MegaDetector or (b) you know what you're
doing
include_image_size (bool, optional): should we include image size in the output for each image?
include_image_timestamp (bool, optional): should we include image timestamps in the output for each image?
include_exif_tags (str, optional): comma-separated list of EXIF tags to include in output
augment (bool, optional): enable image augmentation
detector_options (dict, optional): key/value pairs that are interpreted differently
by different detectors
loader_workers (int, optional): number of loaders to use
preprocess_on_image_queue (bool, optional): if the image queue is enabled, should it handle
image loading and preprocessing (True), or just image loading (False)?
batch_size (int, optional): batch size for GPU processing
checkpoint_path (str, optional): path to write checkpoint files, None disables checkpointing
checkpoint_frequency (int, optional): write checkpoint every N images, -1 disables checkpointing
use_threads_for_queue (bool, optional): use threads (rather than processes) for the data
loading workers
Returns:
list: list of dicts in the format returned by process_image()
"""
# Validate inputs
assert isinstance(model_file,str)
if loader_workers <= 0:
loader_workers = 1
if detector_options is None:
detector_options = {}
q = multiprocessing.JoinableQueue(max_queue_size)
return_queue = multiprocessing.Queue(1)
producers = []
worker_string = 'thread' if use_threads_for_queue else 'process'
print('Starting a {} pool with {} workers'.format(worker_string,loader_workers))
preprocessor = None
if preprocess_on_image_queue:
print('Enabling image queue preprocessing')
preprocessor = model_file
n_total_images = len(image_files)
chunks = split_list_into_n_chunks(image_files, loader_workers, chunk_strategy='greedy')
for i_chunk,chunk in enumerate(chunks):
if use_threads_for_queue:
producer = Thread(target=_producer_func,args=(q,
chunk,
i_chunk,preprocessor,
detector_options,
verbose,
image_size,
augment))
else:
producer = Process(target=_producer_func,args=(q,
chunk,
i_chunk,
preprocessor,
detector_options,
verbose,
image_size,
augment))
producers.append(producer)
for producer in producers:
producer.daemon = False
producer.start()
if run_separate_consumer_process:
if use_threads_for_queue:
consumer = Thread(target=_consumer_func,args=(q,
return_queue,
model_file,
confidence_threshold,
loader_workers,
image_size,
include_image_size,
include_image_timestamp,
include_exif_tags,
augment,
detector_options,
preprocess_on_image_queue,
n_total_images,
batch_size,
checkpoint_path,
checkpoint_frequency))
else:
consumer = Process(target=_consumer_func,args=(q,
return_queue,
model_file,
confidence_threshold,
loader_workers,
image_size,
include_image_size,
include_image_timestamp,
include_exif_tags,
augment,
detector_options,
preprocess_on_image_queue,
n_total_images,
batch_size,
checkpoint_path,
checkpoint_frequency))
consumer.daemon = True
consumer.start()
else:
_consumer_func(q,
return_queue,
model_file,
confidence_threshold,
loader_workers,
image_size,
include_image_size,
include_image_timestamp,
include_exif_tags,
augment,
detector_options,
preprocess_on_image_queue,
n_total_images,
batch_size,
checkpoint_path,
checkpoint_frequency)
for i_producer,producer in enumerate(producers):
producer.join()
if verbose:
print('Producer {} finished'.format(i_producer))
if verbose:
print('All producers finished')
if run_separate_consumer_process:
consumer.join()
if verbose:
print('Consumer loop finished')
q.join()
if verbose:
print('Queue joined')
results = return_queue.get()
return results
# ...def _run_detector_with_image_queue(...)
#%% Other support functions
def _chunks_by_number_of_chunks(ls, n):
"""
Splits a list into n even chunks.
External callers should use ct_utils.split_list_into_n_chunks().
Args:
ls (list): list to break up into chunks
n (int): number of chunks
"""
for i in range(0, n):
yield ls[i::n]
#%% Batch processing helper functions
def _group_into_batches(items, batch_size):
"""
Group items into batches.
Args:
items (list): items to group into batches
batch_size (int): size of each batch
Returns:
list: list of batches, where each batch is a list of items
"""
if batch_size <= 0:
raise ValueError('Batch size must be positive')
batches = []
for i_item in range(0, len(items), batch_size):
batch = items[i_item:i_item + batch_size]
batches.append(batch)
return batches
def _process_batch(image_items_batch,
detector,
confidence_threshold,
quiet=False,
image_size=None,
include_image_size=False,
include_image_timestamp=False,
include_exif_tags=None,
augment=False):
"""
Process a batch of images using generate_detections_one_batch(). Does not necessarily return
results in the same order in which they were supplied; in particular, images that fail preprocessing
will be returned out of order.
Args:
image_items_batch (list): list of image file paths (strings) or list of tuples [filename, image, producer_id]
detector: loaded detector object
confidence_threshold (float): confidence threshold for detections
quiet (bool, optional): suppress per-image output
image_size (int, optional): image size override
include_image_size (bool, optional): include image dimensions in results
include_image_timestamp (bool, optional): include image timestamps in results
include_exif_tags (str, optional): comma-separated list of EXIF tags to include in output
augment (bool, optional): whether to use image augmentation
Returns:
list of dict: list of results for each image in the batch
"""
# This will be the set of items we send for inference; it may be
# smaller than the input list (image_items_batch) if some images
# fail to load. [valid_images] will be either a list of PIL Image
# objects or a list of dicts containing preprocessed images.
valid_images = []
valid_image_filenames = []
batch_results = []
for i_image, item in enumerate(image_items_batch):
# Handle both filename strings and tuples
if isinstance(item, str):
im_file = item
try:
image = vis_utils.load_image(im_file)
except Exception as e:
print('Image {} cannot be loaded: {}'.format(im_file,str(e)))
failed_result = {
'file': im_file,
'failure': run_detector.FAILURE_IMAGE_OPEN
}
batch_results.append(failed_result)
continue
else:
assert len(item) == 3
im_file, image, producer_id = item
valid_images.append(image)
valid_image_filenames.append(im_file)
# ...for each image in the batch
assert len(valid_images) == len(valid_image_filenames)
valid_batch_results = []
# Process the batch if we have any valid images
if len(valid_images) > 0:
try:
batch_detections = \
detector.generate_detections_one_batch(valid_images,
valid_image_filenames,
verbose=verbose)
assert len(batch_detections) == len(valid_images)
# Apply confidence threshold and add metadata
for i_valid_image,image_result in enumerate(batch_detections):
assert valid_image_filenames[i_valid_image] == image_result['file']
if 'failure' not in image_result:
# Apply confidence threshold
image_result['detections'] = \
[det for det in image_result['detections'] if det['conf'] >= confidence_threshold]
if include_image_size or include_image_timestamp or (include_exif_tags is not None):
image = valid_images[i_valid_image]
# If this was preprocessed by the producer thread, pull out the PIL version
if isinstance(image,dict):
image = image['img_original_pil']
if include_image_size:
image_result['width'] = image.width
image_result['height'] = image.height
if include_image_timestamp:
image_result['datetime'] = get_image_datetime(image)
if include_exif_tags is not None:
exif_options = copy.copy(exif_options_base)
exif_options.tags_to_include = include_exif_tags
image_result['exif_metadata'] = read_exif.read_pil_exif(
image,exif_options)
# ...if we need to store metadata
# ...if this image succeeded
# Failures here should be very rare; there's almost no reason an image would fail
# within a batch once it's been loaded
else:
print('Warning: within-batch processing failure for image {}'.format(
image_result['file']))
# Add to the list of results for the batch whether or not it succeeded
valid_batch_results.append(image_result)
# ...for each image in this batch
except Exception as e:
print('Batch processing failure for {} images: {}'.format(len(valid_images),str(e)))
# Throw out any successful results for this batch, this should almost never happen
valid_batch_results = []
for image_id in valid_image_filenames:
r = {'file':image_id,'failure': run_detector.FAILURE_INFER}
valid_batch_results.append(r)
# ...try/except
assert len(valid_batch_results) == len(valid_images)
# ...if we have valid images in this batch
batch_results.extend(valid_batch_results)
return batch_results
# ...def _process_batch(...)
#%% Image processing functions
def _process_images(im_files,
detector,
confidence_threshold,
use_image_queue=False,
quiet=False,
image_size=None,
checkpoint_queue=None,
include_image_size=False,
include_image_timestamp=False,
include_exif_tags=None,
augment=False,
detector_options=None,
loader_workers=default_loaders,
preprocess_on_image_queue=default_preprocess_on_image_queue,
use_threads_for_queue=False):
"""
Runs a detector (typically MegaDetector) over a list of image files, possibly using multiple
image loading workers, but not using multiple inference workers.
Args:
im_files (list): paths to image files
detector (str or detector object): loaded model or str; if this is a string, it can be a
path to a .pb/.pt model file or a known model identifier (e.g. "MDV5A")
confidence_threshold (float): only detections above this threshold are returned
use_image_queue (bool, optional): separate image loading onto a dedicated worker process
quiet (bool, optional): suppress per-image printouts
image_size (int, optional): image size to use for inference, only mess with this
if (a) you're using a model other than MegaDetector or (b) you know what you're
doing
checkpoint_queue (Queue, optional): internal parameter used to pass image queues around
include_image_size (bool, optional): should we include image size in the output for each image?
include_image_timestamp (bool, optional): should we include image timestamps in the output for each image?
include_exif_tags (str, optional): comma-separated list of EXIF tags to include in output
augment (bool, optional): enable image augmentation
detector_options (dict, optional): key/value pairs that are interpreted differently
by different detectors
loader_workers (int, optional): number of loaders to use (only relevant when using image queue)
preprocess_on_image_queue (bool, optional): if the image queue is enabled, should it handle
image loading and preprocessing (True), or just image loading (False)?
use_threads_for_queue (bool, optional): use threads (rather than processes) for the data
loading workers
Returns:
list: list of dicts, in which each dict represents detections on one image,
see the 'images' key in https://lila.science/megadetector-output-format
"""
if isinstance(detector, str):
start_time = time.time()
detector = load_detector(detector,
detector_options=detector_options,
verbose=verbose)
elapsed = time.time() - start_time
print('Loaded model (process_images) in {}'.format(humanfriendly.format_timespan(elapsed)))
if detector_options is None:
detector_options = {}
if use_image_queue:
results = _run_detector_with_image_queue(im_files,
detector,
confidence_threshold,
quiet=quiet,
image_size=image_size,
include_image_size=include_image_size,
include_image_timestamp=include_image_timestamp,
include_exif_tags=include_exif_tags,
augment=augment,
detector_options=detector_options,
loader_workers=loader_workers,
preprocess_on_image_queue=preprocess_on_image_queue,
use_threads_for_queue=use_threads_for_queue)
return results
else:
results = []
for im_file in im_files:
result = _process_image(im_file,
detector,
confidence_threshold,
quiet=quiet,
image_size=image_size,
include_image_size=include_image_size,
include_image_timestamp=include_image_timestamp,
include_exif_tags=include_exif_tags,
augment=augment)
if checkpoint_queue is not None:
checkpoint_queue.put(result)
results.append(result)
return results
# ...if we are/aren't using the image queue
# ...def _process_images(...)
def _process_image(im_file,
detector,
confidence_threshold,
image=None,
quiet=False,
image_size=None,
include_image_size=False,
include_image_timestamp=False,
include_exif_tags=False,
augment=False):
"""
Runs a detector (typically MegaDetector) on a single image file.
Args:
im_file (str): path to image file
detector (detector object): loaded model, this can no longer be a string by the time
you get this far down the pipeline
confidence_threshold (float): only detections above this threshold are returned
image (Image or dict, optional): previously-loaded image, if available, used when a worker
thread is handling image loading (and possibly preprocessing)
quiet (bool, optional): suppress per-image printouts
image_size (int, optional): image size to use for inference, only mess with this
if (a) you're using a model other than MegaDetector or (b) you know what you're
doing
include_image_size (bool, optional): should we include image size in the output for each image?
include_image_timestamp (bool, optional): should we include image timestamps in the output for each image?
include_exif_tags (str, optional): comma-separated list of EXIF tags to include in output
augment (bool, optional): enable image augmentation
Returns:
dict: dict representing detections on one image,
see the 'images' key in https://lila.science/megadetector-output-format
"""
if not quiet:
print('Processing image {}'.format(im_file))
if image is None:
try:
image = vis_utils.load_image(im_file)
except Exception as e:
if not quiet:
print('Image {} cannot be loaded. Exception: {}'.format(im_file, e))
result = {
'file': im_file,
'failure': run_detector.FAILURE_IMAGE_OPEN
}
return result
try:
result = detector.generate_detections_one_image(
image,
im_file,
detection_threshold=confidence_threshold,
image_size=image_size,
augment=augment,
verbose=verbose)
except Exception as e:
if not quiet:
print('Image {} cannot be processed. Exception: {}'.format(im_file, e))
result = {
'file': im_file,
'failure': run_detector.FAILURE_INFER
}
return result
# If this image has already been preprocessed
if isinstance(image,dict):
image = image['img_original_pil']
if include_image_size:
result['width'] = image.width
result['height'] = image.height
if include_image_timestamp:
result['datetime'] = get_image_datetime(image)
if include_exif_tags is not None:
exif_options = copy.copy(exif_options_base)
exif_options.tags_to_include = include_exif_tags
result['exif_metadata'] = read_exif.read_pil_exif(image,exif_options)
return result
# ...def _process_image(...)
def _load_custom_class_mapping(class_mapping_filename):
"""
Allows the use of non-MD models, disables the code that enforces MD-like class lists.
Args:
class_mapping_filename (str): .json file that maps int-strings to strings, or a YOLOv5
dataset.yaml file.
Returns:
dict: maps class IDs (int-strings) to class names
"""
if class_mapping_filename is None:
return
run_detector.USE_MODEL_NATIVE_CLASSES = True
if class_mapping_filename.endswith('.json'):
with open(class_mapping_filename,'r') as f:
class_mapping = json.load(f)
elif (class_mapping_filename.endswith('.yml') or class_mapping_filename.endswith('.yaml')):
class_mapping = read_classes_from_yolo_dataset_file(class_mapping_filename)
# convert from ints to int-strings
class_mapping = {str(k):v for k,v in class_mapping.items()}
else:
raise ValueError('Unrecognized class mapping file {}'.format(class_mapping_filename))
print('Loaded custom class mapping:')
print(class_mapping)
run_detector.DEFAULT_DETECTOR_LABEL_MAP = class_mapping
return class_mapping
#%% Main function
[docs]
def load_and_run_detector_batch(model_file,
image_file_names,
checkpoint_path=None,
confidence_threshold=run_detector.DEFAULT_OUTPUT_CONFIDENCE_THRESHOLD,
checkpoint_frequency=-1,
results=None,
n_cores=1,
use_image_queue=False,
quiet=False,
image_size=None,
class_mapping_filename=None,
include_image_size=False,
include_image_timestamp=False,
include_exif_tags=None,
augment=False,
force_model_download=False,
detector_options=None,
loader_workers=default_loaders,
preprocess_on_image_queue=default_preprocess_on_image_queue,
batch_size=1,
verbose_output=False,
use_threads_for_queue=False):
"""
Load a model file and run it on a list of images.
Args:
model_file (str): path to model file, or supported model string (e.g. "MDV5A")
image_file_names (list or str): list of strings (image filenames), a single image filename,
a folder to recursively search for images in, or a .json or .txt file containing a list
of images.
checkpoint_path (str, optional): path to use for checkpoints (if None, checkpointing
is disabled)
confidence_threshold (float, optional): only detections above this threshold are returned
checkpoint_frequency (int, optional): int, write results to JSON checkpoint file every N
images, -1 disabled checkpointing
results (list, optional): list of dicts, existing results loaded from checkpoint; generally
not useful if you're using this function outside of the CLI
n_cores (int, optional): number of parallel worker to use, ignored if we're running on a GPU
use_image_queue (bool, optional): use a dedicated worker for image loading
quiet (bool, optional): disable per-image console output
image_size (int, optional): image size to use for inference, only mess with this
if (a) you're using a model other than MegaDetector or (b) you know what you're
doing
class_mapping_filename (str, optional): use a non-default class mapping supplied in a .json
file or YOLOv5 dataset.yaml file
include_image_size (bool, optional): should we include image size in the output for each image?
include_image_timestamp (bool, optional): should we include image timestamps in the output for each image?
include_exif_tags (str, optional): comma-separated list of EXIF tags to include in output
augment (bool, optional): enable image augmentation
force_model_download (bool, optional): force downloading the model file if
a named model (e.g. "MDV5A") is supplied, even if the local file already
exists
detector_options (dict, optional): key/value pairs that are interpreted differently
by different detectors. Can also be a list of k=v pairs, or a comma-delimited
string containing a list of k=v pairs.
loader_workers (int, optional): number of loaders to use, only relevant when use_image_queue is True
preprocess_on_image_queue (bool, optional): if the image queue is enabled, should it handle
image loading and preprocessing (True), or just image loading (False)?
batch_size (int, optional): batch size for GPU processing, automatically set to 1 for CPU processing
verbose_output (bool, optional): enable additional debug output
use_threads_for_queue (bool, optional): use threads (rather than processes) for the data
loading workers
Returns:
results: list of dicts; each dict represents detections on one image
"""
# Validate input arguments
if (n_cores is None) or (n_cores <= 0):
n_cores = 1
if detector_options is None:
detector_options = {}
elif isinstance(detector_options,list):
detector_options = parse_kvp_list(detector_options)
elif isinstance(detector_options,str):
detector_options = parse_kvp_list(detector_options.split(','))
assert isinstance(detector_options,dict)
if confidence_threshold is None:
confidence_threshold=run_detector.DEFAULT_OUTPUT_CONFIDENCE_THRESHOLD
# Disable checkpointing if checkpoint_path is None
if checkpoint_frequency is None or checkpoint_path is None:
checkpoint_frequency = -1
if class_mapping_filename is not None:
_load_custom_class_mapping(class_mapping_filename)
global verbose
if verbose_output:
print('Enabling verbose output')
verbose = True
# Handle the case where image_file_names is not yet actually a list
if isinstance(image_file_names,str):
# Find the images to score; images can be a directory, may need to recurse
if os.path.isdir(image_file_names):
image_dir = image_file_names
image_file_names = path_utils.find_images(image_dir, True)
print('{} image files found in folder {}'.format(len(image_file_names),image_dir))
# A single file, or a list of image paths
elif os.path.isfile(image_file_names):
list_file = image_file_names
if image_file_names.endswith('.json'):
with open(list_file,'r') as f:
image_file_names = json.load(f)
print('Loaded {} image filenames from .json list file {}'.format(
len(image_file_names),list_file))
elif image_file_names.endswith('.txt'):
with open(list_file,'r') as f:
image_file_names = f.readlines()
image_file_names = [s.strip() for s in image_file_names if len(s.strip()) > 0]
print('Loaded {} image filenames from .txt list file {}'.format(
len(image_file_names),list_file))
elif path_utils.is_image_file(image_file_names):
image_file_names = [image_file_names]
print('Processing image {}'.format(image_file_names[0]))
else:
raise ValueError(
'File {} supplied as [image_file_names] argument, but extension is neither .json nor .txt'\
.format(
list_file))
else:
raise ValueError(
'{} supplied as [image_file_names] argument, but it does not appear to be a file or folder'.format(
image_file_names))
if results is None:
results = []
already_processed = set([i['file'] for i in results])
model_file = try_download_known_detector(model_file,
force_download=force_model_download,
verbose=verbose)
gpu_available = is_gpu_available(model_file)
print('GPU available: {}'.format(gpu_available))
if (n_cores > 1) and gpu_available:
print('Warning: multiple cores requested, but a GPU is available; parallelization across ' + \
'GPUs is not currently supported, defaulting to one GPU')
n_cores = 1
if (n_cores > 1) and use_image_queue:
print('Warning: multiple cores requested, but the image queue is enabled; parallelization ' + \
'with the image queue is not currently supported, defaulting to one worker')
n_cores = 1
if use_image_queue:
assert n_cores <= 1
# Filter out already processed images
images_to_process = [im_file for im_file in image_file_names
if im_file not in already_processed]
if len(images_to_process) != len(image_file_names):
print('Bypassing {} images that have already been processed'.format(
len(image_file_names) - len(images_to_process)))
new_results = _run_detector_with_image_queue(images_to_process,
model_file,
confidence_threshold,
quiet,
image_size=image_size,
include_image_size=include_image_size,
include_image_timestamp=include_image_timestamp,
include_exif_tags=include_exif_tags,
augment=augment,
detector_options=detector_options,
loader_workers=loader_workers,
preprocess_on_image_queue=preprocess_on_image_queue,
batch_size=batch_size,
checkpoint_path=checkpoint_path,
checkpoint_frequency=checkpoint_frequency,
use_threads_for_queue=use_threads_for_queue)
# Merge new results with existing results from checkpoint
results.extend(new_results)
elif n_cores <= 1:
# Single-threaded processing, no image queue
# Load the detector
start_time = time.time()
detector = load_detector(model_file,
detector_options=detector_options,
verbose=verbose)
elapsed = time.time() - start_time
print('Loaded model in {}'.format(humanfriendly.format_timespan(elapsed)))
if (batch_size > 1) and (not gpu_available):
print('Batch size of {} requested, but no GPU is available, using batch size 1'.format(
batch_size))
batch_size = 1
# Filter out already processed images
images_to_process = [im_file for im_file in image_file_names
if im_file not in already_processed]
if len(images_to_process) != len(image_file_names):
print('Bypassing {} images that have already been processed'.format(
len(image_file_names) - len(images_to_process)))
image_count = 0
if (batch_size > 1):
# During testing, randomize the order of images_to_process to help detect
# non-deterministic batching issues
if randomize_batch_order_during_testing and ('PYTEST_CURRENT_TEST' in os.environ):
print('PyTest detected: randomizing batch order')
random.seed(int(time.time()))
debug_seed = random.randint(0, 2**31 - 1)
print('Debug seed: {}'.format(debug_seed))
random.seed(debug_seed)
random.shuffle(images_to_process)
# Use batch processing
image_batches = _group_into_batches(images_to_process, batch_size)
for batch in tqdm(image_batches):
batch_results = _process_batch(batch,
detector,
confidence_threshold,
quiet=quiet,
image_size=image_size,
include_image_size=include_image_size,
include_image_timestamp=include_image_timestamp,
include_exif_tags=include_exif_tags,
augment=augment)
results.extend(batch_results)
image_count += len(batch)
# Write a checkpoint if necessary
if (checkpoint_frequency != -1) and ((image_count % checkpoint_frequency) == 0):
print('Writing a new checkpoint after having processed {} images since '
'last restart'.format(image_count))
write_checkpoint(checkpoint_path, results)
else:
# Use non-batch processing
for im_file in tqdm(images_to_process):
image_count += 1
result = _process_image(im_file,
detector,
confidence_threshold,
quiet=quiet,
image_size=image_size,
include_image_size=include_image_size,
include_image_timestamp=include_image_timestamp,
include_exif_tags=include_exif_tags,
augment=augment)
results.append(result)
# Write a checkpoint if necessary
if (checkpoint_frequency != -1) and ((image_count % checkpoint_frequency) == 0):
print('Writing a new checkpoint after having processed {} images since '
'last restart'.format(image_count))
write_checkpoint(checkpoint_path, results)
# ...if the batch size is > 1
else:
# Multiprocessing is enabled at this point
# When using multiprocessing, tell the workers to load the model on each
# process, by passing the model_file string as the "model" argument to
# process_images.
detector = model_file
print('Creating worker pool with {} cores'.format(n_cores))
if len(already_processed) > 0:
n_images_all = len(image_file_names)
image_file_names = [fn for fn in image_file_names if fn not in already_processed]
print('Loaded {} of {} images from checkpoint'.format(
len(already_processed),n_images_all))
# Divide images into chunks; we'll send one chunk to each worker process
image_chunks = list(_chunks_by_number_of_chunks(image_file_names, n_cores))
pool = None
try:
pool = workerpool(n_cores)
if checkpoint_path is not None:
# Multiprocessing and checkpointing are both enabled at this point
checkpoint_queue = Manager().Queue()
# Pass the "results" array (which may already contain images loaded from an
# existing checkpoint) to the checkpoint queue handler function, which will
# append results to the list as they become available.
checkpoint_thread = Thread(target=_checkpoint_queue_handler,
args=(checkpoint_path, checkpoint_frequency,
checkpoint_queue, results), daemon=True)
checkpoint_thread.start()
pool.map(partial(_process_images,
detector=detector,
confidence_threshold=confidence_threshold,
use_image_queue=False,
quiet=quiet,
image_size=image_size,
checkpoint_queue=checkpoint_queue,
include_image_size=include_image_size,
include_image_timestamp=include_image_timestamp,
include_exif_tags=include_exif_tags,
augment=augment,
detector_options=detector_options,
use_threads_for_queue=use_threads_for_queue),
image_chunks)
checkpoint_queue.put(None)
else:
# Multprocessing is enabled, but checkpointing is not
new_results = pool.map(partial(_process_images,
detector=detector,
confidence_threshold=confidence_threshold,
use_image_queue=False,
quiet=quiet,
checkpoint_queue=None,
image_size=image_size,
include_image_size=include_image_size,
include_image_timestamp=include_image_timestamp,
include_exif_tags=include_exif_tags,
augment=augment,
detector_options=detector_options,
use_threads_for_queue=use_threads_for_queue),
image_chunks)
new_results = list(itertools.chain.from_iterable(new_results))
# Append the results we just computed to "results", which is *usually* empty, but will
# be non-empty if we resumed from a checkpoint
results.extend(new_results)
# ...if checkpointing is/isn't enabled
finally:
if pool is not None:
pool.close()
pool.join()
print('Pool closed and joined for multi-core inference')
# ...if we're running (1) with image queue, (2) on one core, or (3) on multiple cores
# 'results' may have been modified in place, but we also return it for
# backwards-compatibility.
return results
# ...def load_and_run_detector_batch(...)
def _checkpoint_queue_handler(checkpoint_path, checkpoint_frequency, checkpoint_queue, results):
"""
Thread function to accumulate results and write checkpoints when checkpointing and
multiprocessing are both enabled.
"""
result_count = 0
while True:
result = checkpoint_queue.get()
if result is None:
break
result_count +=1
results.append(result)
if (checkpoint_frequency != -1) and (result_count % checkpoint_frequency == 0):
print('Writing a new checkpoint after having processed {} images since '
'last restart'.format(result_count))
write_checkpoint(checkpoint_path, results)
[docs]
def write_checkpoint(checkpoint_path, results):
"""
Writes the object in [results] to a json checkpoint file, as a dict with the
key "checkpoint". First backs up the checkpoint file if it exists, in case we
crash while writing the file.
Args:
checkpoint_path (str): the file to write the checkpoint to
results (object): the object we should write
"""
assert checkpoint_path is not None
# Back up any previous checkpoints, to protect against crashes while we're writing
# the checkpoint file.
checkpoint_tmp_path = None
if os.path.isfile(checkpoint_path):
checkpoint_tmp_path = checkpoint_path + '_tmp'
shutil.copyfile(checkpoint_path,checkpoint_tmp_path)
# Write the new checkpoint
ct_utils.write_json(checkpoint_path, {'checkpoint': results}, force_str=True)
# Remove the backup checkpoint if it exists
if checkpoint_tmp_path is not None:
try:
os.remove(checkpoint_tmp_path)
except Exception as e:
print('Warning: error removing backup checkpoint file {}:\n{}'.format(
checkpoint_tmp_path,str(e)))
[docs]
def load_checkpoint(checkpoint_path):
"""
Loads results from a checkpoint file. A checkpoint file is always a dict
with the key "checkpoint".
Args:
checkpoint_path (str): the .json file to load
Returns:
object: object retrieved from the checkpoint, typically a list of results
"""
print('Loading previous results from checkpoint file {}'.format(checkpoint_path))
with open(checkpoint_path, 'r') as f:
checkpoint_data = json.load(f)
if 'checkpoint' not in checkpoint_data:
raise ValueError('Checkpoint file {} is missing "checkpoint" field'.format(checkpoint_path))
results = checkpoint_data['checkpoint']
print('Restored {} entries from the checkpoint {}'.format(len(results),checkpoint_path))
return results
[docs]
def get_image_datetime(image):
"""
Reads EXIF datetime from a PIL Image object.
Args:
image (Image): the PIL Image object from which we should read datetime information
Returns:
str: the EXIF datetime from [image] (a PIL Image object), if available, as a string;
returns None if EXIF datetime is not available.
"""
exif_tags = read_exif.read_pil_exif(image,exif_options_base)
try:
datetime_str = exif_tags['DateTimeOriginal']
_ = time.strptime(datetime_str, '%Y:%m:%d %H:%M:%S')
return datetime_str
except Exception:
return None
[docs]
def write_results_to_file(results,
output_file,
relative_path_base=None,
detector_file=None,
info=None,
include_max_conf=False,
custom_metadata=None,
force_forward_slashes=True):
"""
Writes list of detection results to JSON output file. Format matches:
https://lila.science/megadetector-output-format
Args:
results (list): list of dict, each dict represents detections on one image
output_file (str): path to JSON output file, should end in '.json'
relative_path_base (str, optional): path to a directory as the base for relative paths, can
be None if the paths in [results] are absolute
detector_file (str, optional): filename of the detector used to generate these results, only
used to pull out a version number for the "info" field
info (dict, optional): dictionary to put in the results file instead of the default "info" field
include_max_conf (bool, optional): old files (version 1.2 and earlier) included a "max_conf" field
in each image; this was removed in version 1.3. Set this flag to force the inclusion
of this field.
custom_metadata (object, optional): additional data to include as info['custom_metadata']; typically
a dictionary, but no type/format checks are performed
force_forward_slashes (bool, optional): convert all slashes in filenames within [results] to
forward slashes
Returns:
dict: the MD-formatted dictionary that was written to [output_file]
"""
if relative_path_base is not None:
results_relative = []
for r in results:
r_relative = copy.copy(r)
r_relative['file'] = os.path.relpath(r_relative['file'], start=relative_path_base)
results_relative.append(r_relative)
results = results_relative
if force_forward_slashes:
results_converted = []
for r in results:
r_converted = copy.copy(r)
r_converted['file'] = r_converted['file'].replace('\\','/')
results_converted.append(r_converted)
results = results_converted
# The typical case: we need to build the 'info' struct
if info is None:
info = {
'detection_completion_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
'format_version': current_format_version
}
if detector_file is not None:
detector_filename = os.path.basename(detector_file)
detector_version = get_detector_version_from_filename(detector_filename,verbose=True)
detector_metadata = get_detector_metadata_from_version_string(detector_version)
info['detector'] = detector_filename
info['detector_metadata'] = detector_metadata
else:
info['detector'] = 'unknown'
info['detector_metadata'] = get_detector_metadata_from_version_string('unknown')
# If the caller supplied the entire "info" struct
else:
if detector_file is not None:
print('Warning (write_results_to_file): info struct and detector file ' + \
'supplied, ignoring detector file')
if custom_metadata is not None:
info['custom_metadata'] = custom_metadata
# The 'max_detection_conf' field used to be included by default, and it caused all kinds
# of headaches, so it's no longer included unless the user explicitly requests it.
if not include_max_conf:
for im in results:
if 'max_detection_conf' in im:
del im['max_detection_conf']
# Sort results by filename; not required by the format, but convenient for consistency
results = sort_list_of_dicts_by_key(results,'file')
# Sort detections in descending order by confidence; not required by the format, but
# convenient for consistency
for im in results:
if ('detections' in im) and (im['detections'] is not None):
im['detections'] = sort_list_of_dicts_by_key(im['detections'], 'conf', reverse=True)
for im in results:
if 'failure' in im:
if 'detections' in im:
assert im['detections'] is None, 'Illegal failure/detection combination'
else:
im['detections'] = None
final_output = {
'images': results,
'detection_categories': run_detector.DEFAULT_DETECTOR_LABEL_MAP,
'info': info
}
# Create the folder where the output file belongs; this will fail if
# this is a relative path with no folder component
try:
os.makedirs(os.path.dirname(output_file),exist_ok=True)
except Exception:
pass
ct_utils.write_json(output_file, final_output, force_str=True)
print('Output file saved at {}'.format(output_file))
return final_output
# ...def write_results_to_file(...)
#%% Interactive driver
if False:
pass
#%%
model_file = 'MDV5A'
image_dir = r'g:\camera_traps\camera_trap_images'
output_file = r'g:\temp\md-test.json'
recursive = True
output_relative_filenames = True
include_max_conf = False
quiet = True
image_size = None
use_image_queue = False
confidence_threshold = 0.0001
checkpoint_frequency = 5
checkpoint_path = None
resume_from_checkpoint = 'auto'
allow_checkpoint_overwrite = False
ncores = 1
class_mapping_filename = None
include_image_size = True
include_image_timestamp = True
include_exif_tags = None
overwrite_handling = None
# Generate a command line
cmd = 'python run_detector_batch.py "{}" "{}" "{}"'.format(
model_file,image_dir,output_file)
if recursive:
cmd += ' --recursive'
if output_relative_filenames:
cmd += ' --output_relative_filenames'
if include_max_conf:
cmd += ' --include_max_conf'
if image_size is not None:
cmd += ' --image_size {}'.format(image_size)
if use_image_queue:
cmd += ' --use_image_queue'
if confidence_threshold is not None:
cmd += ' --threshold {}'.format(confidence_threshold)
if checkpoint_frequency is not None:
cmd += ' --checkpoint_frequency {}'.format(checkpoint_frequency)
if checkpoint_path is not None:
cmd += ' --checkpoint_path "{}"'.format(checkpoint_path)
if resume_from_checkpoint is not None:
cmd += ' --resume_from_checkpoint "{}"'.format(resume_from_checkpoint)
if allow_checkpoint_overwrite:
cmd += ' --allow_checkpoint_overwrite'
if ncores is not None:
cmd += ' --ncores {}'.format(ncores)
if class_mapping_filename is not None:
cmd += ' --class_mapping_filename "{}"'.format(class_mapping_filename)
if include_image_size:
cmd += ' --include_image_size'
if include_image_timestamp:
cmd += ' --include_image_timestamp'
if include_exif_tags is not None:
cmd += ' --include_exif_tags "{}"'.format(include_exif_tags)
if overwrite_handling is not None:
cmd += ' --overwrite_handling {}'.format(overwrite_handling)
print(cmd)
import clipboard; clipboard.copy(cmd)
#%% Run inference interactively
image_file_names = path_utils.find_images(image_dir, recursive=False)
results = None
start_time = time.time()
results = load_and_run_detector_batch(model_file=model_file,
image_file_names=image_file_names,
checkpoint_path=checkpoint_path,
confidence_threshold=confidence_threshold,
checkpoint_frequency=checkpoint_frequency,
results=results,
n_cores=ncores,
use_image_queue=use_image_queue,
quiet=quiet,
image_size=image_size)
elapsed = time.time() - start_time
print('Finished inference in {}'.format(humanfriendly.format_timespan(elapsed)))
#%% Command-line driver
def main(): # noqa
parser = argparse.ArgumentParser(
description='Module to run a TF/PT animal detection model on lots of images')
parser.add_argument(
'detector_file',
help='Path to detector model file (.pb or .pt). Can also be the strings "MDV4", ' + \
'"MDV5A", or "MDV5B" to request automatic download.')
parser.add_argument(
'image_file',
help=\
'Path to a single image file, a .json or .txt file containing a list of paths to images, or a directory')
parser.add_argument(
'output_file',
help='Path to output JSON results file, should end with a .json extension')
parser.add_argument(
'--recursive',
action='store_true',
help='Recurse into directories, only meaningful if image_file points to a directory')
parser.add_argument(
'--output_relative_filenames',
action='store_true',
help='Output relative file names, only meaningful if image_file points to a directory')
parser.add_argument(
'--include_max_conf',
action='store_true',
help='Include the "max_detection_conf" field in the output')
parser.add_argument(
'--verbose',
action='store_true',
help='Enable additional debug output')
parser.add_argument(
'--image_size',
type=int,
default=None,
help=('Force image resizing to a specific integer size on the long axis (not recommended to change this)'))
parser.add_argument(
'--augment',
action='store_true',
help='Enable image augmentation'
)
parser.add_argument(
'--use_image_queue',
action='store_true',
help='Pre-load images, may help keep your GPU busy; does not currently support ' + \
'checkpointing. Useful if you have a very fast GPU and a very slow disk.')
parser.add_argument(
'--preprocess_on_image_queue',
action='store_true',
help='Whether to do image resizing on the image queue (PyTorch detectors only)')
parser.add_argument(
'--use_threads_for_queue',
action='store_true',
help='Use threads (rather than processes) for the image queue; only relevant if --use_image_queue is set')
parser.add_argument(
'--threshold',
type=float,
default=run_detector.DEFAULT_OUTPUT_CONFIDENCE_THRESHOLD,
help="Confidence threshold between 0 and 1.0, don't include boxes below this " + \
"confidence in the output file. Default is {}".format(
run_detector.DEFAULT_OUTPUT_CONFIDENCE_THRESHOLD))
parser.add_argument(
'--checkpoint_frequency',
type=int,
default=-1,
help='Write results to a temporary file every N images; default is -1, which ' + \
'disables this feature')
parser.add_argument(
'--checkpoint_path',
type=str,
default=None,
help='File name to which checkpoints will be written if checkpoint_frequency is > 0, ' + \
'defaults to md_checkpoint_[date].json in the same folder as the output file')
parser.add_argument(
'--resume_from_checkpoint',
type=str,
default=None,
help='Path to a JSON checkpoint file to resume from, or "auto" to ' + \
'find the most recent checkpoint in the same folder as the output file. "auto" uses' + \
'checkpoint_path (rather than searching the output folder) if checkpoint_path is specified.')
parser.add_argument(
'--allow_checkpoint_overwrite',
action='store_true',
help='By default, this script will bail if the specified checkpoint file ' + \
'already exists; this option allows it to overwrite existing checkpoints')
parser.add_argument(
'--ncores',
type=int,
default=1,
help='Number of cores to use for inference; only applies to CPU-based inference (default 1)')
parser.add_argument(
'--loader_workers',
type=int,
default=default_loaders,
help='Number of image loader workers to use; only relevant when --use_image_queue ' + \
'is set (default {})'.format(default_loaders))
parser.add_argument(
'--class_mapping_filename',
type=str,
default=None,
help='Use a non-default class mapping, supplied in a .json file with a dictionary mapping' + \
'int-strings to strings. This will also disable the addition of "1" to all category ' + \
'IDs, so your class mapping should start at zero. Can also be a YOLOv5 dataset.yaml file.')
parser.add_argument(
'--include_image_size',
action='store_true',
help='Include image dimensions in output file'
)
parser.add_argument(
'--include_image_timestamp',
action='store_true',
help='Include image datetime (if available) in output file'
)
parser.add_argument(
'--include_exif_tags',
type=str,
default=None,
help='Command-separated list of EXIF tags to include in output, or "all" to include all tags'
)
parser.add_argument(
'--overwrite_handling',
type=str,
default='overwrite',
help='What should we do if the output file exists? overwrite/skip/error (default overwrite)'
)
parser.add_argument(
'--force_model_download',
action='store_true',
help=('If a named model (e.g. "MDV5A") is supplied, force a download of that model even if the ' +\
'local file already exists.'))
parser.add_argument(
'--previous_results_file',
type=str,
default=None,
help=('If supplied, this should point to a previous .json results file; any results in that ' +\
'file will be transferred to the output file without reprocessing those images. Useful ' +\
'for "updating" a set of results when you may have added new images to a folder you\'ve ' +\
'already processed. Only supported when using relative paths.'))
parser.add_argument(
'--detector_options',
nargs='*',
metavar='KEY=VALUE',
default='',
help='Detector-specific options, as a space-separated list of key-value pairs')
parser.add_argument(
'--batch_size',
type=int,
default=1,
help='Batch size for GPU inference (default 1). CPU inference will ignore this and use batch_size=1.')
# This argument is deprecated, we always use what was formerly "quiet mode"
parser.add_argument(
'--quiet',
action='store_true',
help=argparse.SUPPRESS)
# This argument is deprecated in favor use --include_exif_tags
parser.add_argument(
'--include_exif_data',
action='store_true',
help=argparse.SUPPRESS)
if len(sys.argv[1:]) == 0:
parser.print_help()
parser.exit()
args = parser.parse_args()
# Support the legacy --include_exif_data flag
if args.include_exif_data and (args.include_exif_tags is None):
args.include_exif_tags = 'all'
detector_options = parse_kvp_list(args.detector_options)
# If the specified detector file is really the name of a known model, find
# (and possibly download) that model
args.detector_file = try_download_known_detector(args.detector_file,
force_download=args.force_model_download,
verbose=verbose)
assert os.path.exists(args.detector_file), \
'detector file {} does not exist'.format(args.detector_file)
assert 0.0 <= args.threshold <= 1.0, 'Confidence threshold needs to be between 0 and 1'
assert args.output_file.endswith('.json'), 'output_file specified needs to end with .json'
if args.checkpoint_frequency != -1:
assert args.checkpoint_frequency > 0, 'Checkpoint_frequency needs to be > 0 or == -1'
if args.output_relative_filenames:
assert os.path.isdir(args.image_file), \
f'Could not find folder {args.image_file}, must supply a folder when ' + \
'--output_relative_filenames is set'
if args.previous_results_file is not None:
assert os.path.isdir(args.image_file) and args.output_relative_filenames, \
"Can only process previous results when using relative paths"
if os.path.exists(args.output_file):
if args.overwrite_handling == 'overwrite':
print('Warning: output file {} already exists and will be overwritten'.format(
args.output_file))
elif args.overwrite_handling == 'skip':
print('Output file {} exists, returning'.format(
args.output_file))
return
elif args.overwrite_handling == 'error':
raise Exception('Output file {} exists'.format(args.output_file))
else:
raise ValueError('Illegal overwrite handling string {}'.format(args.overwrite_handling))
output_dir = os.path.dirname(args.output_file)
if len(output_dir) > 0:
os.makedirs(output_dir,exist_ok=True)
assert not os.path.isdir(args.output_file), 'Specified output file is a directory'
if args.class_mapping_filename is not None:
_load_custom_class_mapping(args.class_mapping_filename)
# Load the checkpoint if available
#
# File paths in the checkpoint are always absolute paths; conversion to relative paths
# (if requested) happens at the time results are exported at the end of a job.
if args.resume_from_checkpoint is not None:
if args.resume_from_checkpoint == 'auto':
checkpoint_files = os.listdir(output_dir)
checkpoint_files = [fn for fn in checkpoint_files if \
(fn.startswith('md_checkpoint') and fn.endswith('.json'))]
if len(checkpoint_files) == 0:
raise ValueError('resume_from_checkpoint set to "auto", but no checkpoints found in {}'.format(
output_dir))
else:
if len(checkpoint_files) > 1:
print('Warning: found {} checkpoints in {}, using the latest'.format(
len(checkpoint_files),output_dir))
checkpoint_files = sorted(checkpoint_files)
checkpoint_file_relative = checkpoint_files[-1]
checkpoint_file = os.path.join(output_dir,checkpoint_file_relative)
else:
checkpoint_file = args.resume_from_checkpoint
results = load_checkpoint(checkpoint_file)
else:
results = []
# Find the images to process; images can be a directory, may need to recurse
if os.path.isdir(args.image_file):
image_file_names = path_utils.find_images(args.image_file, args.recursive)
if len(image_file_names) > 0:
print('{} image files found in the input directory'.format(len(image_file_names)))
else:
if (args.recursive):
print('No image files found in directory {}, exiting'.format(args.image_file))
else:
print('No image files found in directory {}, did you mean to specify '
'--recursive?'.format(
args.image_file))
return
# A json list of image paths
elif os.path.isfile(args.image_file) and args.image_file.endswith('.json'):
with open(args.image_file) as f:
image_file_names = json.load(f)
print('Loaded {} image filenames from .json list file {}'.format(
len(image_file_names),args.image_file))
# A text list of image paths
elif os.path.isfile(args.image_file) and args.image_file.endswith('.txt'):
with open(args.image_file) as f:
image_file_names = f.readlines()
image_file_names = [fn.strip() for fn in image_file_names if len(fn.strip()) > 0]
print('Loaded {} image filenames from .txt list file {}'.format(
len(image_file_names),args.image_file))
# A single image file
elif os.path.isfile(args.image_file) and path_utils.is_image_file(args.image_file):
image_file_names = [args.image_file]
print('Processing image {}'.format(args.image_file))
else:
raise ValueError('image_file specified is not a directory, a json list, or an image file, '
'(or does not have recognizable extensions).')
# At this point, regardless of how they were specified, [image_file_names] is a list of
# absolute image paths.
assert len(image_file_names) > 0, 'Specified image_file does not point to valid image files'
# Convert to forward slashes to facilitate comparison with previous results
image_file_names = [fn.replace('\\','/') for fn in image_file_names]
# We can head off many problems related to incorrect command line formulation if we confirm
# that one image exists before proceeding. The use of the first image for this test is
# arbitrary.
assert os.path.exists(image_file_names[0]), \
'The first image to be processed does not exist at {}'.format(image_file_names[0])
# Possibly load results from a previous pass
previous_results = None
if args.previous_results_file is not None:
assert os.path.isfile(args.previous_results_file), \
'Could not find previous results file {}'.format(args.previous_results_file)
with open(args.previous_results_file,'r') as f:
previous_results = json.load(f)
assert previous_results['detection_categories'] == run_detector.DEFAULT_DETECTOR_LABEL_MAP, \
"Can't merge previous results when those results use a different set of detection categories"
print('Loaded previous results for {} images from {}'.format(
len(previous_results['images']), args.previous_results_file))
# Convert previous result filenames to absolute paths if necessary
#
# We asserted above to make sure that we are using relative paths and processing a
# folder, but just to be super-clear...
assert os.path.isdir(args.image_file)
previous_image_files_set = set()
for im in previous_results['images']:
assert not os.path.isabs(im['file']), \
"When processing previous results, relative paths are required"
fn_abs = os.path.join(args.image_file,im['file']).replace('\\','/')
# Absolute paths are expected at the final output stage below
im['file'] = fn_abs
previous_image_files_set.add(fn_abs)
image_file_names_to_keep = []
for fn_abs in image_file_names:
if fn_abs not in previous_image_files_set:
image_file_names_to_keep.append(fn_abs)
print('Based on previous results file, processing {} of {} images'.format(
len(image_file_names_to_keep), len(image_file_names)))
image_file_names = image_file_names_to_keep
# ...if we're handling previous results
# Test that we can write to the output_file's dir if checkpointing requested
if args.checkpoint_frequency != -1:
if args.checkpoint_path is not None:
checkpoint_path = args.checkpoint_path
else:
checkpoint_path = os.path.join(output_dir,
'md_checkpoint_{}.json'.format(
datetime.now().strftime("%Y%m%d%H%M%S")))
# Don't overwrite existing checkpoint files, this is a sure-fire way to eventually
# erase someone's checkpoint.
if (checkpoint_path is not None) and (not args.allow_checkpoint_overwrite) \
and (args.resume_from_checkpoint is None):
assert not os.path.isfile(checkpoint_path), \
f'Checkpoint path {checkpoint_path} already exists, delete or move it before ' + \
're-using the same checkpoint path, or specify --allow_checkpoint_overwrite'
print('The checkpoint file will be written to {}'.format(checkpoint_path))
else:
if args.checkpoint_path is not None:
print('Warning: checkpointing disabled because checkpoint_frequency is -1, ' + \
'but a checkpoint path was specified')
checkpoint_path = None
start_time = time.time()
results = load_and_run_detector_batch(model_file=args.detector_file,
image_file_names=image_file_names,
checkpoint_path=checkpoint_path,
confidence_threshold=args.threshold,
checkpoint_frequency=args.checkpoint_frequency,
results=results,
n_cores=args.ncores,
use_image_queue=args.use_image_queue,
quiet=True,
image_size=args.image_size,
class_mapping_filename=args.class_mapping_filename,
include_image_size=args.include_image_size,
include_image_timestamp=args.include_image_timestamp,
include_exif_tags=args.include_exif_tags,
augment=args.augment,
# Don't download the model *again*
force_model_download=False,
detector_options=detector_options,
loader_workers=args.loader_workers,
preprocess_on_image_queue=args.preprocess_on_image_queue,
batch_size=args.batch_size,
verbose_output=args.verbose,
use_threads_for_queue=args.use_threads_for_queue)
elapsed = time.time() - start_time
images_per_second = len(results) / elapsed
print('Finished inference for {} images in {} ({:.2f} images per second)'.format(
len(results),humanfriendly.format_timespan(elapsed),images_per_second))
relative_path_base = None
# We asserted above to make sure that if output_relative_filenames is set,
# args.image_file is a folder, but we'll double-check for clarity.
if args.output_relative_filenames:
assert os.path.isdir(args.image_file)
relative_path_base = args.image_file
# Merge results from a previous file if necessary
if previous_results is not None:
previous_filenames_set = set([im['file'] for im in previous_results['images']])
new_filenames_set = set([im['file'] for im in results])
assert len(previous_filenames_set.intersection(new_filenames_set)) == 0, \
'Previous results handling error: redundant image filenames'
results.extend(previous_results['images'])
write_results_to_file(results,
args.output_file,
relative_path_base=relative_path_base,
detector_file=args.detector_file,
include_max_conf=args.include_max_conf)
if checkpoint_path and os.path.isfile(checkpoint_path):
os.remove(checkpoint_path)
print('Deleted checkpoint file {}'.format(checkpoint_path))
print('Done, thanks for MegaDetect\'ing!')
if __name__ == '__main__':
main()