Source code for megadetector.detection.run_md_and_speciesnet

"""

run_md_and_speciesnet.py

Script to run MegaDetector and SpeciesNet on a folder of images and/or videos.
Runs MD first, then runs SpeciesNet on every above-threshold crop.

"""

#%% Constants, imports, environment

import argparse
import json
import multiprocessing
import os
import sys
import time

from tqdm import tqdm
from multiprocessing import JoinableQueue, Process, Queue
from threading import Thread

import humanfriendly

from megadetector.detection import run_detector_batch
from megadetector.detection.video_utils import find_videos, run_callback_on_frames, is_video_file
from megadetector.detection.run_detector_batch import load_and_run_detector_batch
from megadetector.detection.run_detector import DEFAULT_OUTPUT_CONFIDENCE_THRESHOLD
from megadetector.detection.run_detector import CONF_DIGITS
from megadetector.detection.run_detector_batch import write_results_to_file
from megadetector.utils.ct_utils import round_float
from megadetector.utils.ct_utils import write_json
from megadetector.utils.ct_utils import make_temp_folder
from megadetector.utils.ct_utils import is_list_sorted
from megadetector.utils.ct_utils import is_sphinx_build
from megadetector.utils.ct_utils import args_to_object
from megadetector.utils.path_utils import find_images
from megadetector.utils.path_utils import test_file_write
from megadetector.utils.wi_taxonomy_utils import get_common_name_from_prediction_string
from megadetector.visualization import visualization_utils as vis_utils
from megadetector.postprocessing.validate_batch_results import \
    validate_batch_results, ValidateBatchResultsOptions
from megadetector.detection.process_video import \
    process_videos, ProcessVideoOptions
from megadetector.postprocessing.combine_batch_outputs import combine_batch_output_files

# We aren't taking an explicit dependency on the speciesnet package yet,
# so we wrap this in a try/except so sphinx can still document this module.
try:
    from speciesnet import SpeciesNetClassifier
    from speciesnet.utils import BBox
    from speciesnet.ensemble import SpeciesNetEnsemble
    from speciesnet.geofence_utils import roll_up_labels_to_first_matching_level
    from speciesnet.geofence_utils import geofence_animal_classification
except Exception:
    pass


#%% Constants

DEFAULT_DETECTOR_MODEL = 'MDV5A'
try:
    from speciesnet import DEFAULT_MODEL as DEFAULT_CLASSIFIER_MODEL
except Exception:
    DEFAULT_CLASSIFIER_MODEL = 'kaggle:google/speciesnet/pyTorch/v4.0.2a'
DEFAULT_DETECTION_CONFIDENCE_THRESHOLD_FOR_CLASSIFICATION = 0.1
DEFAULT_DETECTION_CONFIDENCE_THRESHOLD_FOR_OUTPUT = DEFAULT_OUTPUT_CONFIDENCE_THRESHOLD
DEFAULT_DETECTOR_BATCH_SIZE = 1
DEFAULT_CLASSIFIER_BATCH_SIZE = 8
DEFAULT_LOADER_WORKERS = 4
DEFAULT_WORKER_TYPE = 'thread'

# This determines the maximum number of image filenames that can be assigned to
# each of the producer workers before blocking.  The actual size of the queue
# will be MAX_IMAGE_QUEUE_SIZE_PER_WORKER * n_workers.  This is only used for
# the classification step.
MAX_IMAGE_QUEUE_SIZE_PER_WORKER = 30

# This determines the maximum number of crops that can accumulate in the queue
# used to communicate between the producers (which read and crop images) and the
# consumer (which runs the classifier).  This is only used for the classification step.
MAX_BATCH_QUEUE_SIZE = 300

# Default interval between frames we should process when processing video.
# This is only used for the detection step.
DEFAULT_SECONDS_PER_VIDEO_FRAME = 1.0

# Max number of classification scores to include per detection
DEFAULT_TOP_N_SCORES = 2

# Unless --norollup is specified, roll up taxonomic levels until the
# cumulative confidence is above this value.  Only relevant when
# geofencing is disabled, otherwise the default speciesnet library
# constants are used.
DEFAULT_ROLLUP_TARGET_CONFIDENCE = 0.65

# When the called supplies an existing MD results file, should we validate it before
# starting classification?
VALIDATE_DETECTION_FILE = False

verbose = False


#%% Main options class

[docs] class RunMDSpeciesNetOptions: """ Class controlling the behavior of run_md_and_speciesnet() """ def __init__(self): #: Folder containing images and/or videos to process self.source = None #: Output file for results (JSON format) self.output_file = None #: What to do if the output file exists ('overwrite', 'error', 'skip') self.overwrite_handling = 'overwrite' #: MegaDetector model identifier (MDv5a, MDv5b, MDv1000-redwood, etc.) self.detector_model = DEFAULT_DETECTOR_MODEL #: SpeciesNet classifier model identifier (e.g. kaggle:google/speciesnet/pyTorch/v4.0.2a) self.classification_model = DEFAULT_CLASSIFIER_MODEL #: Batch size for MegaDetector inference self.detector_batch_size = DEFAULT_DETECTOR_BATCH_SIZE #: Batch size for SpeciesNet classification self.classifier_batch_size = DEFAULT_CLASSIFIER_BATCH_SIZE #: Number of worker threads for preprocessing self.loader_workers = DEFAULT_LOADER_WORKERS #: Classify detections above this threshold self.detection_confidence_threshold_for_classification = \ DEFAULT_DETECTION_CONFIDENCE_THRESHOLD_FOR_CLASSIFICATION #: Include detections above this threshold in the output self.detection_confidence_threshold_for_output = \ DEFAULT_DETECTION_CONFIDENCE_THRESHOLD_FOR_OUTPUT #: Folder for intermediate files (default: system temp) self.intermediate_file_folder = None #: Keep intermediate files (e.g. detection-only results file) self.keep_intermediate_files = False #: Disable taxonomic rollup self.norollup = False #: Target confidence threshold for taxonomic rollup self.rollup_target_confidence = DEFAULT_ROLLUP_TARGET_CONFIDENCE #: Country code (ISO 3166-1 alpha-3) for geofencing (default None, no geoferencing) self.country = None #: Admin1 region/state code for geofencing self.admin1_region = None #: Path to existing MegaDetector output file (skips detection step) self.detections_file = None #: Ignore videos, only process images self.skip_video = False #: Ignore images, only process videos self.skip_images = False #: Sample every Nth frame from videos #: #: Mutually exclusive with time_sample self.frame_sample = None #: Sample frames every N seconds from videos #: #: Mutually exclusive with frame_sample self.time_sample = DEFAULT_SECONDS_PER_VIDEO_FRAME #: Enable additional debug output self.verbose = False #: Worker type for parallelization; should be "thread" or "process" self.worker_type = DEFAULT_WORKER_TYPE #: Include raw (pre-rollup/geofence) classification scores in output self.include_raw_classifications = False if self.time_sample is None and self.frame_sample is None: self.time_sample = DEFAULT_SECONDS_PER_VIDEO_FRAME
# ...class RunMDSpeciesNetOptions #%% Support classes
[docs] class CropMetadata: """ Metadata for a crop extracted from an image detection. """ def __init__(self, image_file: str, detection_index: int, bbox: list[float], original_width: int, original_height: int): """ Args: image_file (str): path to the original image file detection_index (int): index of this detection in the image bbox (List[float]): normalized bounding box [x_min, y_min, width, height] original_width (int): width of the original image original_height (int): height of the original image """ self.image_file = image_file self.detection_index = detection_index self.bbox = bbox self.original_width = original_width self.original_height = original_height
[docs] class CropBatch: """ A batch of crops with their metadata for classification. """ def __init__(self): #: List of preprocessed images self.crops = [] #: List of CropMetadata objects self.metadata = []
[docs] def add_crop(self, crop_data, metadata): """ Args: crop_data (PreprocessedImage): preprocessed image data from SpeciesNetClassifier.preprocess() metadata (CropMetadata): metadata for this crop """ self.crops.append(crop_data) self.metadata.append(metadata)
def __len__(self): return len(self.crops)
#%% Support functions for classification def _process_image_detections(file_path: str, absolute_file_path: str, detection_results: dict, classifier: 'SpeciesNetClassifier', detection_confidence_threshold: float, batch_queue: Queue): """ Process detections from a single image. Args: file_path (str): relative path to the image file absolute_file_path (str): absolute path to the image file detection_results (dict): detection results for this image classifier (SpeciesNetClassifier): classifier instance for preprocessing detection_confidence_threshold (float): classify detections above this threshold batch_queue (Queue): queue to send crops to """ detections = detection_results['detections'] # Don't bother loading images that have no above-threshold detections detections_above_threshold = \ [d for d in detections if d['conf'] >= detection_confidence_threshold] if len(detections_above_threshold) == 0: return # Load the image try: image = vis_utils.load_image(absolute_file_path) original_width, original_height = image.size except Exception as e: print('Warning: failed to load image {}: {}'.format(file_path, str(e))) # Send failure information to consumer failure_metadata = CropMetadata( image_file=file_path, detection_index=-1, # -1 indicates whole-image failure bbox=[], original_width=0, original_height=0 ) batch_queue.put(('failure', 'Failed to load image: {}'.format(str(e)), failure_metadata)) return # Process each detection above threshold # # detection_index needs to index into the original list of detections # (this is how classification results will be associated with detections # later), so iterate over "detections" here, rather than # "detections_above_threshold". for detection_index, detection in enumerate(detections): conf = detection['conf'] if conf < detection_confidence_threshold: continue bbox = detection['bbox'] assert len(bbox) == 4 # Convert normalized bbox to BBox object for SpeciesNet speciesnet_bbox = BBox( xmin=bbox[0], ymin=bbox[1], width=bbox[2], height=bbox[3] ) # Preprocess the crop try: preprocessed_crop = classifier.preprocess( image, bboxes=[speciesnet_bbox], resize=True ) if preprocessed_crop is not None: metadata = CropMetadata( image_file=file_path, detection_index=detection_index, bbox=bbox, original_width=original_width, original_height=original_height ) # Send individual crop to the consumer batch_queue.put(('crop', preprocessed_crop, metadata)) except Exception as e: print('Warning: failed to preprocess crop from {}, detection {}: {}'.format( file_path, detection_index, str(e))) # Send failure information to consumer failure_metadata = CropMetadata( image_file=file_path, detection_index=detection_index, bbox=bbox, original_width=original_width, original_height=original_height ) batch_queue.put(('failure', 'Failed to preprocess crop: {}'.format(str(e)), failure_metadata)) # ...try/except # ...for each detection in this image # ...def _process_image_detections(...) def _process_video_detections(file_path: str, absolute_file_path: str, detection_results: dict, classifier: 'SpeciesNetClassifier', detection_confidence_threshold: float, batch_queue: Queue): """ Process detections from a single video. Args: file_path (str): relative path to the video file absolute_file_path (str): absolute path to the video file detection_results (dict): detection results for this video classifier (SpeciesNetClassifier): classifier instance for preprocessing detection_confidence_threshold (float): classify detections above this threshold batch_queue (Queue): queue to send crops to """ detections = detection_results['detections'] # Find frames with above-threshold detections frames_with_detections = set() frame_to_detections = {} for detection_index, detection in enumerate(detections): conf = detection['conf'] if conf < detection_confidence_threshold: continue frame_number = detection['frame_number'] frames_with_detections.add(frame_number) if frame_number not in frame_to_detections: frame_to_detections[frame_number] = [] frame_to_detections[frame_number].append((detection_index, detection)) # ...for each detection in this video if len(frames_with_detections) == 0: return frames_to_process = sorted(list(frames_with_detections)) # Define callback for processing each frame def frame_callback(frame_array, frame_id): """ Callback to process a single frame. Args: frame_array (numpy.ndarray): frame data in PIL format frame_id (str): frame identifier like "frame0006.jpg" """ # Extract frame number from frame_id (e.g., "frame0006.jpg" -> 6) import re match = re.match(r'frame(\d+)\.jpg', frame_id) if not match: print('Warning: could not parse frame number from {}'.format(frame_id)) return frame_number = int(match.group(1)) # Only process frames for which we have detection results if frame_number not in frame_to_detections: return # Convert numpy array to PIL Image from PIL import Image if frame_array.dtype != 'uint8': frame_array = (frame_array * 255).astype('uint8') frame_image = Image.fromarray(frame_array) original_width, original_height = frame_image.size # Process each detection in this frame for detection_index, detection in frame_to_detections[frame_number]: bbox = detection['bbox'] assert len(bbox) == 4 # Convert normalized bbox to BBox object for SpeciesNet speciesnet_bbox = BBox( xmin=bbox[0], ymin=bbox[1], width=bbox[2], height=bbox[3] ) # Preprocess the crop try: preprocessed_crop = classifier.preprocess( frame_image, bboxes=[speciesnet_bbox], resize=True ) if preprocessed_crop is not None: metadata = CropMetadata( image_file=file_path, detection_index=detection_index, bbox=bbox, original_width=original_width, original_height=original_height ) # Send individual crop immediately to consumer batch_queue.put(('crop', preprocessed_crop, metadata)) except Exception as e: print('Warning: failed to preprocess crop from {}, detection {}: {}'.format( file_path, detection_index, str(e))) # Send failure information to consumer failure_metadata = CropMetadata( image_file=file_path, detection_index=detection_index, bbox=bbox, original_width=original_width, original_height=original_height ) batch_queue.put(('failure', 'Failed to preprocess crop: {}'.format(str(e)), failure_metadata)) # ...try/except # ...for each detection # ...def frame_callback(...) # Process the video frames try: run_callback_on_frames( input_video_file=absolute_file_path, frame_callback=frame_callback, frames_to_process=frames_to_process, verbose=verbose ) except Exception as e: print('Warning: failed to process video {}: {}'.format(file_path, str(e))) # Send failure information to consumer for the whole video failure_metadata = CropMetadata( image_file=file_path, detection_index=-1, # -1 indicates whole-file failure bbox=[], original_width=0, original_height=0 ) batch_queue.put(('failure', 'Failed to process video: {}'.format(str(e)), failure_metadata)) # ...try/except # ...def _process_video_detections(...) def _crop_producer_func(image_queue: JoinableQueue, batch_queue: Queue, classifier_model: str, detection_confidence_threshold: float, source_folder: str, producer_id: int = -1, preloaded_classifier: 'SpeciesNetClassifier' = None): """ Producer function for classification workers. Reads images and videos from [image_queue], crops detections above a threshold, preprocesses them, and sends individual crops to [batch_queue]. See the documentation of _crop_consumer_func to for the format of the tuples placed on batch_queue. Args: image_queue (JoinableQueue): queue containing detection_results dicts (for both images and videos) batch_queue (Queue): queue to put individual crops into classifier_model (str): classifier model identifier to load in this process detection_confidence_threshold (float): classify detections above this threshold source_folder (str): source folder to resolve relative paths producer_id (int, optional): identifier for this producer worker preloaded_classifier (SpeciesNetClassifier, optional): pre-loaded classifier instance (for thread-based workers, to avoid loading models in threads) """ if verbose: print('Classification producer starting: ID {}'.format(producer_id)) # Load classifier; this is just being used as a preprocessor, so we force device=cpu. # # When using threads, we pre-load the classifier in the main thread to avoid PyTorch FX # issues with loading models in worker threads. if preloaded_classifier is not None: classifier = preloaded_classifier if verbose: print('Classification producer {}: using pre-loaded classifier'.format(producer_id)) else: # There are a number of reasons loading the model might fail; note to self: *don't* # catch Exceptions here. This should be a catastrophic failure that stops the whole # process. classifier = SpeciesNetClassifier(classifier_model, device='cpu') if verbose: print('Classification producer {}: loaded classifier'.format(producer_id)) while True: # Pull an image of detection results from the queue detection_results = image_queue.get() # Pulling None from the queue indicates that this producer is done if detection_results is None: image_queue.task_done() break file_path = detection_results['file'] # Skip files that failed at the detection stage if 'failure' in detection_results: image_queue.task_done() continue # Skip files with no detections detections = detection_results['detections'] if len(detections) == 0: image_queue.task_done() continue # Determine if this is an image or video absolute_file_path = os.path.join(source_folder, file_path) is_video = is_video_file(file_path) if is_video: # Process video _process_video_detections( file_path=file_path, absolute_file_path=absolute_file_path, detection_results=detection_results, classifier=classifier, detection_confidence_threshold=detection_confidence_threshold, batch_queue=batch_queue ) else: # Process image _process_image_detections( file_path=file_path, absolute_file_path=absolute_file_path, detection_results=detection_results, classifier=classifier, detection_confidence_threshold=detection_confidence_threshold, batch_queue=batch_queue ) image_queue.task_done() # ...while(we still have items to process) # Send sentinel to indicate this producer is done batch_queue.put(None) if verbose: print('Classification producer {} finished'.format(producer_id)) # ...def _crop_producer_func(...) def _crop_consumer_func(batch_queue: Queue, results_queue: Queue, classifier_model: str, batch_size: int, num_producers: int, enable_rollup: bool, country: str = None, admin1_region: str = None, preloaded_classifier: 'SpeciesNetClassifier' = None, rollup_target_confidence: float = DEFAULT_ROLLUP_TARGET_CONFIDENCE): """ Consumer function for classification inference. Pulls individual crops from batch_queue, assembles them into batches, runs inference, and puts results into results_queue. Args: batch_queue (Queue): queue containing individual crop tuples or failures. Items on this queue are either None (to indicate that a producer finished) or tuples formatted as (type,image,metadata). [type] is a string (either "crop" or "failure"), [image] is a PreprocessedImage, and [metadata] is a CropMetadata object. results_queue (Queue): queue to put classification results into classifier_model (str): classifier model identifier to load batch_size (int): batch size for inference num_producers (int): number of producer workers enable_rollup (bool): whether to apply taxonomic rollup country (str, optional): country code for geofencing admin1_region (str, optional): admin1 region for geofencing preloaded_classifier (SpeciesNetClassifier, optional): pre-loaded classifier instance (for thread-based workers, to avoid loading models in threads) rollup_target_confidence (float, optional): target confidence threshold for taxonomic rollup. Ignored if enable_rollup is False. """ if verbose: print('Classification consumer starting') # Load classifier # When using threads, we pre-load the classifier in the main thread to avoid PyTorch FX # issues with loading models in worker threads. if preloaded_classifier is not None: classifier = preloaded_classifier if verbose: print('Classification consumer: using pre-loaded classifier') else: try: classifier = SpeciesNetClassifier(classifier_model) if verbose: print('Classification consumer: loaded classifier') except Exception as e: print('Classification consumer: failed to load classifier: {}'.format(str(e))) results_queue.put({}) return all_results = {} # image_file -> {detection_index -> classification_result} current_batch = CropBatch() producers_finished = 0 # Load ensemble metadata if rollup/geofencing is enabled taxonomy_map = {} geofence_map = {} if (enable_rollup) or (country is not None): # Note to self: there are a number of reasons loading the ensemble # could fail here; don't catch this exception, this should be a # catatstrophic failure. ensemble = SpeciesNetEnsemble( classifier_model, geofence=(country is not None)) taxonomy_map = ensemble.taxonomy_map geofence_map = ensemble.geofence_map # ...if we need to load ensemble components while True: # Pull an item from the queue item = batch_queue.get() # This indicates that a producer worker finished if item is None: producers_finished += 1 if producers_finished == num_producers: # Process any remaining images if len(current_batch) > 0: _process_classification_batch( current_batch, classifier, all_results, enable_rollup, taxonomy_map, geofence_map, country, admin1_region, rollup_target_confidence ) break continue # ...if a producer finished # If we got here, we know we have a crop to process, or # a failure to ignore. assert isinstance(item, tuple) and len(item) == 3 item_type, data, metadata = item if metadata.image_file not in all_results: all_results[metadata.image_file] = {} # We should never be processing the same detection twice assert metadata.detection_index not in all_results[metadata.image_file] if item_type == 'failure': all_results[metadata.image_file][metadata.detection_index] = { 'failure': 'Failure classification: {}'.format(data) } else: assert item_type == 'crop' current_batch.add_crop(data, metadata) assert len(current_batch) <= batch_size # Process batch if necessary if len(current_batch) == batch_size: _process_classification_batch( current_batch, classifier, all_results, enable_rollup, taxonomy_map, geofence_map, country, admin1_region, rollup_target_confidence ) current_batch = CropBatch() # ...was this item a failure or a crop? # ...while (we have items to process) # Send all the results at once back to the main process results_queue.put(all_results) if verbose: print('Classification consumer finished') # ...def _crop_consumer_func(...) def _process_classification_batch(batch: CropBatch, classifier: 'SpeciesNetClassifier', all_results: dict, enable_rollup: bool, taxonomy_map: dict, geofence_map: dict, country: str = None, admin1_region: str = None, rollup_target_confidence: float = DEFAULT_ROLLUP_TARGET_CONFIDENCE): """ Run a batch of crops through the classifier. Args: batch (CropBatch): batch of crops to process classifier (SpeciesNetClassifier): classifier instance all_results (dict): dictionary to store results in, modified in-place with format: {image_file: {detection_index: {'predictions': [[class_name, score], ...]} or {image_file: {detection_index: {'failure': error_message}}}. enable_rollup (bool): whether to apply rollup taxonomy_map (dict): taxonomy mapping for rollup geofence_map (dict): geofence mapping country (str, optional): country code for geofencing admin1_region (str, optional): admin1 region for geofencing rollup_target_confidence (float, optional): target confidence threshold for taxonomic rollup, ignored if enable_rollup is False """ if len(batch) == 0: print('Warning: _process_classification_batch received empty batch') return # Prepare batch for inference filepaths = [f"{metadata.image_file}_{metadata.detection_index}" for metadata in batch.metadata] # Run batch inference try: batch_results = classifier.batch_predict(filepaths, batch.crops) except Exception as e: print('*** Batch classification failed: {} ***'.format(str(e))) # Mark all crops in this batch as failed for metadata in batch.metadata: if metadata.image_file not in all_results: all_results[metadata.image_file] = {} all_results[metadata.image_file][metadata.detection_index] = { 'failure': 'Failure classification: {}'.format(str(e)) } return # Process results assert len(batch_results) == len(batch.metadata) assert len(batch_results) == len(filepaths) for i_result in range(0, len(batch_results)): filepath = filepaths[i_result] result = batch_results[i_result] metadata = batch.metadata[i_result] assert metadata.image_file in all_results, \ 'File {} not in results dict'.format(metadata.image_file) detection_index = metadata.detection_index # Handle classification failure if 'failures' in result: print('*** Classification failure for image: {} ***'.format( filepath)) all_results[metadata.image_file][detection_index] = { 'failure': 'Failure classification: SpeciesNet classifier failed' } continue # Extract classification results; this is a dict with keys "classes" # and "scores", each of which points to a list. classifications = result['classifications'] classes = classifications['classes'] scores = classifications['scores'] classification_was_geofenced = False predicted_class = classes[0] predicted_score = scores[0] # Possibly apply geofencing if country: geofence_result = geofence_animal_classification( labels=classes, scores=scores, country=country, admin1_region=admin1_region, taxonomy_map=taxonomy_map, geofence_map=geofence_map, enable_geofence=True ) geofenced_class, geofenced_score, prediction_source = geofence_result if prediction_source != 'classifier': classification_was_geofenced = True predicted_class = geofenced_class predicted_score = geofenced_score # ...if we might need to apply geofencing # Possibly apply rollup; this was already done if geofencing was applied if enable_rollup and (not classification_was_geofenced): rollup_result = roll_up_labels_to_first_matching_level( labels=classes, scores=scores, country=country, admin1_region=admin1_region, target_taxonomy_levels=['species','genus','family','order','class','kingdom'], non_blank_threshold=rollup_target_confidence, taxonomy_map=taxonomy_map, geofence_map=geofence_map, enable_geofence=(country is not None) ) if rollup_result is not None: rolled_up_class, rolled_up_score, prediction_source = rollup_result predicted_class = rolled_up_class predicted_score = rolled_up_score # ...if we might need to apply taxonomic rollup # For now, we'll store category names as strings; these will be assigned to integer # IDs before writing results to file later. classification = [predicted_class,predicted_score] # Also report raw model classifications raw_classifications = [] for i_class in range(0,len(classes)): raw_classifications.append([classes[i_class],scores[i_class]]) all_results[metadata.image_file][detection_index] = { 'classifications': [classification], 'raw_classifications': raw_classifications } # ...for each result in this batch # ...def _process_classification_batch(...) #%% Inference functions def _run_detection_step(source_folder: str, detector_output_file: str, detector_model: str = DEFAULT_DETECTOR_MODEL, detector_batch_size: int = DEFAULT_DETECTOR_BATCH_SIZE, detection_confidence_threshold: float = DEFAULT_OUTPUT_CONFIDENCE_THRESHOLD, detector_worker_threads: int = DEFAULT_LOADER_WORKERS, worker_type: str = DEFAULT_WORKER_TYPE, skip_images: bool = False, skip_video: bool = False, frame_sample: int = None, time_sample: float = None) -> str: """ Run MegaDetector on all images/videos in [source_folder]. Args: source_folder (str): folder containing images/videos detector_output_file (str): output .json file detector_model (str, optional): detector model identifier detector_batch_size (int, optional): batch size for detection detection_confidence_threshold (float, optional): confidence threshold for detections (to include in the output file) detector_worker_threads (int, optional): number of workers to use for preprocessing worker_type (str, optional): type of worker parallelization ("thread" or "process") skip_images (bool, optional): ignore images, only process videos skip_video (bool, optional): ignore videos, only process images frame_sample (int, optional): sample every Nth frame from videos time_sample (float, optional): sample frames every N seconds from videos """ print('Starting detection step...') # Validate arguments assert not (frame_sample is None and time_sample is None), \ 'Must specify either frame_sample or time_sample' # Find image and video files if not skip_images: image_files = find_images(source_folder, recursive=True, return_relative_paths=False) else: image_files = [] if not skip_video: video_files = find_videos(source_folder, recursive=True, return_relative_paths=False) else: video_files = [] if len(image_files) == 0 and len(video_files) == 0: raise ValueError( 'No images or videos found in {}'.format(source_folder)) print('Found {} images and {} videos'.format(len(image_files), len(video_files))) files_to_merge = [] # Process images if necessary if len(image_files) > 0: print('Running MegaDetector on {} images...'.format(len(image_files))) use_threads_for_queue = (worker_type == 'thread') image_results = load_and_run_detector_batch( model_file=detector_model, image_file_names=image_files, checkpoint_path=None, confidence_threshold=detection_confidence_threshold, checkpoint_frequency=-1, results=None, n_cores=0, use_image_queue=True, quiet=True, image_size=None, batch_size=detector_batch_size, include_image_size=False, include_image_timestamp=False, include_exif_tags=None, loader_workers=detector_worker_threads, preprocess_on_image_queue=True, use_threads_for_queue=use_threads_for_queue ) # Write image results to temporary file image_output_file = detector_output_file.replace('.json', '_images.json') write_results_to_file(image_results, image_output_file, relative_path_base=source_folder, detector_file=detector_model) print('Image detection results written to {}'.format(image_output_file)) files_to_merge.append(image_output_file) # ...if we had images to process # Process videos if necessary if len(video_files) > 0: print('Running MegaDetector on {} videos...'.format(len(video_files))) # Set up video processing options video_options = ProcessVideoOptions() video_options.model_file = detector_model video_options.input_video_file = source_folder video_options.output_json_file = detector_output_file.replace('.json', '_videos.json') video_options.json_confidence_threshold = detection_confidence_threshold video_options.frame_sample = frame_sample video_options.time_sample = time_sample video_options.recursive = True # Process videos process_videos(video_options) print('Video detection results written to {}'.format(video_options.output_json_file)) files_to_merge.append(video_options.output_json_file) # ...if we had videos to process # Merge results if we have both images and videos if len(files_to_merge) > 1: print('Merging image and video detection results...') combine_batch_output_files(files_to_merge, detector_output_file) print('Merged detection results written to {}'.format(detector_output_file)) elif len(files_to_merge) == 1: # Just rename the single file if files_to_merge[0] != detector_output_file: if os.path.isfile(detector_output_file): print('Detector file {} exists, over-writing'.format(detector_output_file)) os.remove(detector_output_file) os.rename(files_to_merge[0], detector_output_file) print('Detection results written to {}'.format(detector_output_file)) # ...def _run_detection_step(...) def _run_classification_step(detector_results_file: str, merged_results_file: str, source_folder: str, classifier_model: str = DEFAULT_CLASSIFIER_MODEL, classifier_batch_size: int = DEFAULT_CLASSIFIER_BATCH_SIZE, classifier_worker_threads: int = DEFAULT_LOADER_WORKERS, detection_confidence_threshold: float = \ DEFAULT_DETECTION_CONFIDENCE_THRESHOLD_FOR_CLASSIFICATION, enable_rollup: bool = True, country: str = None, admin1_region: str = None, top_n_scores: int = DEFAULT_TOP_N_SCORES, worker_type: str = DEFAULT_WORKER_TYPE, include_raw_classifications: bool = False, rollup_target_confidence: float = DEFAULT_ROLLUP_TARGET_CONFIDENCE): """ Run SpeciesNet classification on detections from MegaDetector results. Args: detector_results_file (str): path to MegaDetector output .json file merged_results_file (str): path to which we should write the merged results source_folder (str): source folder for resolving relative paths classifier_model (str, optional): classifier model identifier classifier_batch_size (int, optional): batch size for classification classifier_worker_threads (int, optional): number of worker threads detection_confidence_threshold (float, optional): classify detections above this threshold enable_rollup (bool, optional): whether to apply taxonomic rollup country (str, optional): country code for geofencing admin1_region (str, optional): admin1 region (typically a state code) for geofencing top_n_scores (int, optional): maximum number of scores to include for each detection worker_type (str, optional): type of worker parallelization ("thread" or "process") include_raw_classifications (bool, optional): include raw (pre-rollup/geofence) classification scores in output rollup_target_confidence (float, optional): target confidence threshold for taxonomic rollup. Ignored if enable_rollup is False. """ print('Starting classification step...') # Load MegaDetector results print('Reading detection results from {}'.format(detector_results_file)) with open(detector_results_file, 'r') as f: detector_results = json.load(f) print('Classification step loaded detection results for {} images'.format( len(detector_results['images']))) images = detector_results['images'] if len(images) == 0: raise ValueError('No images found in detector results') print('Using SpeciesNet classifier: {}'.format(classifier_model)) # Set multiprocessing start method to 'spawn' for CUDA compatibility if worker_type == 'process': original_start_method = multiprocessing.get_start_method() if original_start_method != 'spawn': multiprocessing.set_start_method('spawn', force=True) print('Set multiprocessing start method to spawn (was {})'.format( original_start_method)) ## Set up multiprocessing queues # This queue receives lists of image filenames (and associated detection results) # from the "main" thread (the one you're reading right now). Items are pulled off # of this queue by producer workers (on _crop_producer_func), where the corresponding # images are loaded from disk and preprocessed into crops. image_queue = JoinableQueue(maxsize= \ classifier_worker_threads * MAX_IMAGE_QUEUE_SIZE_PER_WORKER) # This queue receives cropped images from producers (on _crop_producer_func); those # crops are pulled off of this queue by the consumer (on _crop_consumer_func). batch_queue = Queue(maxsize=MAX_BATCH_QUEUE_SIZE) # This is not really used as a queue, rather it's just used to send all the results # at once from the consumer process to the main process (the one you're reading right # now). results_queue = Queue() WorkerClass = Thread if worker_type == 'thread' else Process # noqa # When using threads, pre-load classifiers in the main thread to avoid PyTorch FX issues # with loading models in worker threads. When using processes, pass None and let each # process load its own classifier. if worker_type == 'thread': # Producer classifier (CPU only, used for preprocessing) producer_classifier = SpeciesNetClassifier(classifier_model, device='cpu') # Consumer classifier (GPU if available, used for inference) consumer_classifier = SpeciesNetClassifier(classifier_model) else: producer_classifier = None consumer_classifier = None # Start producer workers producers = [] for i_worker in range(classifier_worker_threads): p = WorkerClass(target=_crop_producer_func, args=(image_queue, batch_queue, classifier_model, detection_confidence_threshold, source_folder, i_worker, producer_classifier)) p.start() producers.append(p) ## Start consumer worker consumer = WorkerClass(target=_crop_consumer_func, args=(batch_queue, results_queue, classifier_model, classifier_batch_size, classifier_worker_threads, enable_rollup, country, admin1_region, consumer_classifier, rollup_target_confidence)) consumer.start() # This will block every time the queue reaches its maximum depth, so for # very small jobs, this will not be a useful progress bar. with tqdm(total=len(images),desc='Classification') as pbar: for image_data in images: image_queue.put(image_data) pbar.update() # Send sentinel signals to producers for _ in range(classifier_worker_threads): image_queue.put(None) # Wait for all work to complete image_queue.join() print('Finished waiting for input queue') ## Wait for results classification_results = results_queue.get() ## Clean up processes for p in producers: p.join() consumer.join() print('Finished waiting for workers') ## Format results and write output class CategoryState: """ Helper class to manage classification category IDs. """ def __init__(self): self.next_category_id = 0 # Maps common name to string-int IDs self.common_name_to_id = {} # Maps string-ints to common names, as per format standard self.classification_categories = {} # Maps string-ints to latin taxonomy strings, as per format standard self.classification_category_descriptions = {} def _get_category_id(self, class_name): """ Get an integer-valued category ID for the 7-token string [class_name], creating a new one if necessary. """ # E.g.: # # "cb553c4e-42c9-4fe0-9bd0-da2d6ed5bfa1;mammalia;carnivora;canidae;urocyon;littoralis;island fox" tokens = class_name.split(';') assert len(tokens) == 7 common_name = get_common_name_from_prediction_string(class_name) if common_name not in self.common_name_to_id: self.common_name_to_id[common_name] = str(self.next_category_id) self.classification_categories[str(self.next_category_id)] = common_name self.classification_category_descriptions[str(self.next_category_id)] = class_name self.next_category_id += 1 category_id = self.common_name_to_id[common_name] return category_id # ...class CategoryState category_state = CategoryState() # Merge classification results back into detector results with proper category IDs for image_data in images: image_file = image_data['file'] if ('detections' not in image_data) or (image_data['detections'] is None): continue detections = image_data['detections'] if image_file not in classification_results: continue image_classifications = classification_results[image_file] for detection_index, detection in enumerate(detections): if detection_index in image_classifications: result = image_classifications[detection_index] if 'failure' in result: # Add failure to the image, not the detection if 'failure' not in image_data: image_data['failure'] = result['failure'] else: image_data['failure'] += ';' + result['failure'] else: # Convert class names to category IDs classification_pairs = [] raw_classification_pairs = [] scores = [x[1] for x in result['classifications']] assert is_list_sorted(scores, reverse=True) # Only report the requested number of scores per detection if len(result['classifications']) > top_n_scores: result['classifications'] = \ result['classifications'][0:top_n_scores] if len(result['raw_classifications']) > top_n_scores: result['raw_classifications'] = \ result['raw_classifications'][0:top_n_scores] for class_name, score in result['classifications']: category_id = category_state._get_category_id(class_name) score = round_float(score, precision=CONF_DIGITS) classification_pairs.append([category_id, score]) for class_name, score in result['raw_classifications']: category_id = category_state._get_category_id(class_name) score = round_float(score, precision=CONF_DIGITS) raw_classification_pairs.append([category_id, score]) # Add classifications to the detection detection['classifications'] = classification_pairs if include_raw_classifications: detection['raw_classifications'] = raw_classification_pairs # ...if this classification contains a failure # ...if this detection has classification information # ...for each detection # ...for each image # Update metadata in the results if 'info' not in detector_results: detector_results['info'] = {} detector_results['info']['classifier'] = classifier_model detector_results['info']['classification_completion_time'] = time.strftime( '%Y-%m-%d %H:%M:%S') # Add classification category mapping detector_results['classification_categories'] = \ category_state.classification_categories detector_results['classification_category_descriptions'] = \ category_state.classification_category_descriptions print('Writing output file') # Write results write_json(merged_results_file, detector_results) if verbose: print('Classification results written to {}'.format(merged_results_file)) # ...def _run_classification_step(...) #%% Main function
[docs] def run_md_and_speciesnet(options): """ Main entry point, runs MegaDetector and SpeciesNet on a folder. See RunMDSpeciesNetOptions for available arguments. Args: options (RunMDSpeciesNetOptions): options controlling MD and SN inference """ # Set global verbose flag global verbose verbose = options.verbose # Also set the run_detector_batch verbose flag run_detector_batch.verbose = verbose # Validate arguments if not os.path.isdir(options.source): raise ValueError( 'Source folder does not exist: {}'.format(options.source)) if (options.admin1_region is not None) and (options.country is None): raise ValueError('--admin1_region requires --country to be specified') if options.skip_images and options.skip_video: raise ValueError('Cannot skip both images and videos') if (options.frame_sample is not None) and (options.time_sample is not None): raise ValueError('--frame_sample and --time_sample are mutually exclusive') if (options.frame_sample is None) and (options.time_sample is None): options.time_sample = DEFAULT_SECONDS_PER_VIDEO_FRAME if options.worker_type not in ('thread','process'): raise ValueError('Unknown worker type {}'.format(options.worker_type)) # Set up intermediate file folder if options.intermediate_file_folder: temp_folder = options.intermediate_file_folder os.makedirs(temp_folder, exist_ok=True) else: temp_folder = make_temp_folder(subfolder='run_md_and_speciesnet') start_time = time.time() print('Processing folder: {}'.format(options.source)) print('Output file: {}'.format(options.output_file)) print('Intermediate file folder: {}'.format(temp_folder)) assert options.overwrite_handling in ('overwrite','error','skip'), \ 'Unknown overwrite_handling value {}'.format(options.overwrite_handling) if os.path.isdir(options.output_file): raise ValueError('Output file {} exists, but is a directory'.format( options.output_file)) if os.path.isfile(options.output_file): if options.overwrite_handling == 'overwrite': print('Over-writing existing output file {}'.format(options.output_file)) elif options.overwrite_handling == 'error': raise ValueError('Output file {} exists, and overwrite_handling is "error"'.format( options.output_file)) elif options.ovwrite_handling == 'skip': print('Bypassing proecssing: output file {} exists, and overwrite_handling is "skip"'.format( options.output_file)) return # Verify that we can create the output file test_file_write(options.output_file) # Determine detector output file path if options.detections_file is not None: detector_output_file = options.detections_file if VALIDATE_DETECTION_FILE: print('Using existing detections file: {}'.format(detector_output_file)) validation_options = ValidateBatchResultsOptions() validation_options.check_image_existence = True validation_options.relative_path_base = options.source validation_options.raise_errors = True validate_batch_results(detector_output_file,options=validation_options) print('Validated detections file') else: print('Bypassing validation of {}'.format(options.detections_file)) else: detector_output_file = os.path.join(temp_folder, 'detector_output.json') # Run MegaDetector _run_detection_step( source_folder=options.source, detector_output_file=detector_output_file, detector_model=options.detector_model, detector_batch_size=options.detector_batch_size, detection_confidence_threshold=options.detection_confidence_threshold_for_output, detector_worker_threads=options.loader_workers, skip_images=options.skip_images, skip_video=options.skip_video, frame_sample=options.frame_sample, time_sample=options.time_sample, worker_type=options.worker_type ) # Run SpeciesNet _run_classification_step( detector_results_file=detector_output_file, merged_results_file=options.output_file, source_folder=options.source, classifier_model=options.classification_model, classifier_batch_size=options.classifier_batch_size, classifier_worker_threads=options.loader_workers, detection_confidence_threshold=options.detection_confidence_threshold_for_classification, enable_rollup=(not options.norollup), country=options.country, admin1_region=options.admin1_region, worker_type=options.worker_type, include_raw_classifications=options.include_raw_classifications, rollup_target_confidence=options.rollup_target_confidence ) elapsed_time = time.time() - start_time print( 'Processing complete in {}'.format(humanfriendly.format_timespan(elapsed_time))) print('Results written to: {}'.format(options.output_file)) # Clean up intermediate files if requested if (not options.keep_intermediate_files) and \ (not options.intermediate_file_folder) and \ (not options.detections_file): try: os.remove(detector_output_file) except Exception as e: print('Warning: error removing temporary output file {}: {}'.format( detector_output_file, str(e)))
# ...def run_md_and_speciesnet(...) #%% Command-line driver def main(): """ Command-line driver for run_md_and_speciesnet.py """ if 'speciesnet' not in sys.modules: print('It looks like the speciesnet package is not available, try "pip install speciesnet"') if not is_sphinx_build(): sys.exit(-1) parser = argparse.ArgumentParser( description='Run MegaDetector and SpeciesNet on a folder of images/videos', formatter_class=argparse.ArgumentDefaultsHelpFormatter ) # Required arguments parser.add_argument('source', help='Folder containing images and/or videos to process') parser.add_argument('output_file', help='Output file for results (JSON format)') # Optional arguments parser.add_argument('--detector_model', default=DEFAULT_DETECTOR_MODEL, help='MegaDetector model identifier') parser.add_argument('--classification_model', default=DEFAULT_CLASSIFIER_MODEL, help='SpeciesNet classifier model identifier') parser.add_argument('--detector_batch_size', type=int, default=DEFAULT_DETECTOR_BATCH_SIZE, help='Batch size for MegaDetector inference') parser.add_argument('--classifier_batch_size', type=int, default=DEFAULT_CLASSIFIER_BATCH_SIZE, help='Batch size for SpeciesNet classification') parser.add_argument('--loader_workers', type=int, default=DEFAULT_LOADER_WORKERS, help='Number of worker threads for preprocessing') parser.add_argument('--detection_confidence_threshold_for_classification', type=float, default=DEFAULT_DETECTION_CONFIDENCE_THRESHOLD_FOR_CLASSIFICATION, help='Classify detections above this threshold') parser.add_argument('--detection_confidence_threshold_for_output', type=float, default=DEFAULT_DETECTION_CONFIDENCE_THRESHOLD_FOR_OUTPUT, help='Include detections above this threshold in the output') parser.add_argument('--intermediate_file_folder', default=None, help='Folder for intermediate files (default: system temp)') parser.add_argument('--keep_intermediate_files', action='store_true', help='Keep intermediate files (e.g. detection-only results file)') parser.add_argument('--norollup', action='store_true', help='Disable taxonomic rollup') parser.add_argument('--rollup_target_confidence', type=float, default=DEFAULT_ROLLUP_TARGET_CONFIDENCE, help='Target confidence threshold for taxonomic rollup ' + \ f'(default {DEFAULT_ROLLUP_TARGET_CONFIDENCE}), only ' + \ 'used when geofencing is disabled') parser.add_argument('--country', default=None, help='Country code (ISO 3166-1 alpha-3) for geofencing') parser.add_argument('--admin1_region', '--state', default=None, help='Admin1 region/state code for geofencing') parser.add_argument('--detections_file', default=None, help='Path to existing MegaDetector output file (skips detection step)') parser.add_argument('--skip_video', action='store_true', help='Ignore videos, only process images') parser.add_argument('--skip_images', action='store_true', help='Ignore images, only process videos') parser.add_argument('--frame_sample', type=int, default=None, help='Sample every Nth frame from videos (mutually exclusive with --time_sample)') parser.add_argument('--time_sample', type=float, default=None, help='Sample frames every N seconds from videos (default {})'.\ format(DEFAULT_SECONDS_PER_VIDEO_FRAME) + \ ' (mutually exclusive with --frame_sample)') parser.add_argument('--verbose', action='store_true', help='Enable additional debug output') parser.add_argument('--include_raw_classifications', action='store_true', help='Include raw (pre-rollup/geofence) classification scores in output') if len(sys.argv[1:]) == 0: parser.print_help() parser.exit() args = parser.parse_args() options = RunMDSpeciesNetOptions() args_to_object(args,options) run_md_and_speciesnet(options) # ...def main(...) if __name__ == '__main__': main()