Source code for megadetector.data_management.threshold_coco_dataset

"""

threshold_coco_dataset.py

Given a COCO-formatted dataset that stores confidence in the semi-standard "score"
field, remove annotations below a threshold.

"""

#%% Imports and constants

import json
import sys
import argparse

from tqdm import tqdm

from megadetector.utils.ct_utils import write_json


#%% Functions

[docs] def threshold_coco_dataset(input_filename, confidence_threshold=0.0, output_filename=None, confidence_field='score', missing_confidence_handling='error'): """ Given a COCO-formatted dataset that stores confidence in the semi-standard "score" field, remove annotations below a threshold. Args: input_filename (str): the (input) COCO-formatted .json file containing annotations confidence_threshold (float, optional): discard annotations below this confidence value output_filename (str, optional): write the thresholded output here confidence_field (str, optional): the field within annotations that represents confidence values missing_confidence_handling: what to do if a confidence value is missing (should be 'error' or 'warning') Returns: dict: the thresholded COCO database """ # Validate arguments assert missing_confidence_handling in ('error','warning'), \ f'Illegal missing confidence handling {missing_confidence_handling}' # Read input data print('Reading file {}'.format(input_filename)) with open(input_filename,'r') as f: d = json.load(f) annotations_to_keep = [] annotations = d['annotations'] for ann in tqdm(annotations): if confidence_field not in ann: ann_id = 'unknown' if 'id' in ann: ann_id = ann['id'] s = 'annotation {} is missing field {}'.format(ann_id,confidence_field) if missing_confidence_handling == 'error': raise ValueError(s) else: assert missing_confidence_handling == 'warning' print('Warning: {}'.format(s)) continue confidence = ann[confidence_field] if confidence >= confidence_threshold: annotations_to_keep.append(ann) # ...for each annotation print('Keeping {} of {} annotations'.format( len(annotations_to_keep),len(d['annotations']))) d['annotations'] = annotations_to_keep if output_filename is not None: write_json(output_filename,d) return d
# ...def threshold_coco_dataset(...) #%% Command-line driver def main(): """ Command-line driver for threshold_coco_dataset """ parser = argparse.ArgumentParser( description='Threshold a COCO dataset, write the results to a new file' ) parser.add_argument( 'input_filename', type=str, help='Path to the input COCO .json file' ) parser.add_argument( 'output_filename', type=str, help='Path to the .json file where thresholded data will be saved' ) parser.add_argument( 'confidence_threshold', type=float, help='Confidence threshold' ) parser.add_argument( '--confidence_field', type=str, default='score', help='Field to use for confidence values, default "score"' ) parser.add_argument( '--missing_confidence_handling', type=str, default='error', choices=['error', 'warning'], help='Whether to error on annotations that are missing the confidence field' ) if len(sys.argv[1:]) == 0: parser.print_help() parser.exit() args = parser.parse_args() threshold_coco_dataset( input_filename=args.input_filename, confidence_threshold=args.confidence_threshold, output_filename=args.output_filename, confidence_field=args.confidence_field, missing_confidence_handling=args.missing_confidence_handling ) if __name__ == '__main__': main()