Source code for megadetector.data_management.mewc_to_md

"""

mewc_to_md.py

Converts the output of the MEWC inference scripts to the MD output format.

"""

#%% Imports and constants

import os
import json
import pandas as pd
import sys
import argparse

from copy import deepcopy
from collections import defaultdict
from megadetector.utils.ct_utils import sort_list_of_dicts_by_key, invert_dictionary # noqa
from megadetector.utils.path_utils import recursive_file_list

from megadetector.postprocessing.validate_batch_results import \
    ValidateBatchResultsOptions, validate_batch_results

default_mewc_mount_prefix = '/images/'
default_mewc_category_name_column = 'class_id'


#%% Functions

[docs] def mewc_to_md(mewc_input_folder, output_file=None, mount_prefix=default_mewc_mount_prefix, category_name_column=default_mewc_category_name_column, mewc_out_filename='mewc_out.csv', md_out_filename='md_out.json'): """ Converts the output of the MEWC inference scripts to the MD output format. Args: mewc_input_folder (str): the folder we'll search for MEWC output files output_file (str, optional): .json file to write with class information mount_prefix (str, optional): string to remove from all filenames in the MD .json file, typically the prefix used to mount the image folder. category_name_column (str, optional): column in the MEWC results .csv to use for category naming. mewc_out_filename (str, optional): MEWC-formatted .csv file that should be in [mewc_input_folder] md_out_filename (str, optional): MD-formatted .json file (without classification information) that should be in [mewc_input_folder] Returns: dict: an MD-formatted dict, the same as what's written to [output_file] """ ##%% Read input files assert os.path.isdir(mewc_input_folder), \ 'Could not find folder {}'.format(mewc_input_folder) ##%% Find MEWC output files relative_path_to_mewc_info = {} print('Listing files in folder {}'.format(mewc_input_folder)) all_files_relative = set(recursive_file_list(mewc_input_folder,return_relative_paths=True)) for fn_relative in all_files_relative: if fn_relative.endswith(mewc_out_filename): folder_relative = '/'.join(fn_relative.split('/')[:-1]) assert folder_relative not in relative_path_to_mewc_info md_output_file_relative = os.path.join(folder_relative,md_out_filename).replace('\\','/') assert md_output_file_relative in all_files_relative, \ 'Could not find MD output file {} to match to {}'.format( md_output_file_relative,fn_relative) relative_path_to_mewc_info[folder_relative] = \ {'mewc_predict_file':fn_relative,'md_file':md_output_file_relative} del folder_relative print('Found {} MEWC results files'.format(len(relative_path_to_mewc_info))) ##%% Prepare to loop over results files md_results_all = {} md_results_all['images'] = [] md_results_all['detection_categories'] = {} md_results_all['classification_categories'] = {} md_results_all['info'] = None classification_category_name_to_id = {} ##%% Loop over results files # relative_folder = next(iter(relative_path_to_mewc_info.keys())) for relative_folder in relative_path_to_mewc_info: ##%% mewc_info = relative_path_to_mewc_info[relative_folder] mewc_csv_fn_abs = os.path.join(mewc_input_folder,mewc_info['mewc_predict_file']) mewc_md_fn_abs = os.path.join(mewc_input_folder,mewc_info['md_file']) mewc_classification_info = pd.read_csv(mewc_csv_fn_abs) mewc_classification_info = mewc_classification_info.to_dict('records') assert os.path.isfile(mewc_md_fn_abs), \ 'Could not find file {}'.format(mewc_md_fn_abs) with open(mewc_md_fn_abs,'r') as f: md_results = json.load(f) ##%% Remove the mount prefix from MD files if necessary if mount_prefix is not None and len(mount_prefix) > 0: n_files_without_mount_prefix = 0 # im = md_results['images'][0] for im in md_results['images']: if not im['file'].startswith(mount_prefix): n_files_without_mount_prefix += 1 else: im['file'] = im['file'].replace(mount_prefix,'',1) if n_files_without_mount_prefix > 0: print('Warning {} of {} files in the MD results did not include the mount prefix {}'.format( n_files_without_mount_prefix,len(md_results['images']),mount_prefix)) ##%% Convert MEWC snip IDs to image files # r = mewc_classification_info[0] for r in mewc_classification_info: # E.g. "IMG0-0.jpg" snip_file = r['filename'] # E.g. "IMG0-0" snip_file_no_ext = os.path.splitext(snip_file)[0] ext = os.path.splitext(snip_file)[1] # noqa tokens = snip_file_no_ext.split('-') if len(tokens) == 1: print('Warning: in folder {}, detection ID not found in snip filename {}, skipping'.format( relative_folder,snip_file_no_ext)) r['image_filename_without_extension'] = snip_file_no_ext r['snip_id'] = None continue filename_without_snip_id = '-'.join(tokens[0:-1]) snip_id = int(tokens[-1]) image_filename_without_extension = filename_without_snip_id r['image_filename_without_extension'] = image_filename_without_extension r['snip_id'] = snip_id # ...for each MEWC result record ##%% Make sure MD results and MEWC results refer to the same files images_in_md_results_no_extension = \ set([os.path.splitext(im['file'])[0] for im in md_results['images']]) images_in_mewc_results_no_extension = set(r['image_filename_without_extension'] \ for r in mewc_classification_info) # All files with classification results should also have detection results for fn in images_in_mewc_results_no_extension: assert fn in images_in_md_results_no_extension, \ 'Error: file {} is present in mewc-predict results, but not in MD results'.format(fn) # This is just a note to self: no classification results are present for empty images if False: for fn in images_in_md_results_no_extension: if fn not in images_in_mewc_results_no_extension: print('Warning: file {}/{} is present in MD results, but not in mewc-predict results'.format( relative_folder,fn)) ##%% Validate images for im in md_results['images']: fn_relative = im['file'] fn_abs = os.path.join(mewc_input_folder,relative_folder,fn_relative) if not os.path.isfile(fn_abs): print('Warning: image file {} does not exist'.format(fn_abs)) ##%% Map filenames to MEWC results image_id_to_mewc_records = defaultdict(list) for r in mewc_classification_info: image_id_to_mewc_records[r['image_filename_without_extension']].append(r) ##%% Add classification info to MD results # im = md_results['images'][0] for im in md_results['images']: if ('detections' not in im) or (im['detections'] is None) or (len(im['detections']) == 0): continue detections = im['detections'] # *Don't* sort by confidence, it looks like snip IDs use the original sort order # detections = sort_list_of_dicts_by_key(detections,'conf',reverse=True) # This is just a debug assist, so I can run this cell more than once for det in detections: det['classifications'] = [] image_id = os.path.splitext(im['file'])[0] mewc_records_this_image = image_id_to_mewc_records[image_id] # r = mewc_records_this_image[0] for r in mewc_records_this_image: if r['snip_id'] is None: continue category_name = r[category_name_column] # This is a *global* list of category mappings, across all mewc .csv files if category_name not in classification_category_name_to_id: category_id = str(len(classification_category_name_to_id)) classification_category_name_to_id[category_name] = category_id else: category_id = classification_category_name_to_id[category_name] snip_id = r['snip_id'] if snip_id >= len(detections): print('Warning: image {} has a classified snip ID of {}, but only {} detections are present'.format( image_id,snip_id,len(detections))) continue det = detections[snip_id] if 'classifications' not in det: det['classifications'] = [] det['classifications'].append([category_id,r['prob']]) # ...for each classification in this image # ...for each image ##%% Map MD results to the global level if md_results_all['info'] is None: md_results_all['info'] = md_results['info'] for category_id in md_results['detection_categories']: if category_id not in md_results_all['detection_categories']: md_results_all['detection_categories'][category_id] = \ md_results['detection_categories'][category_id] else: assert md_results_all['detection_categories'][category_id] == \ md_results['detection_categories'][category_id], \ 'MD results present with incompatible detection categories' # im = md_results['images'][0] for im in md_results['images']: im_copy = deepcopy(im) im_copy['file'] = os.path.join(relative_folder,im['file']).replace('\\','/') md_results_all['images'].append(im_copy) # ...for each folder that contains MEWC results del md_results ##%% Write output md_results_all['classification_categories'] = invert_dictionary(classification_category_name_to_id) if output_file is not None: output_dir = os.path.dirname(output_file) os.makedirs(output_dir,exist_ok=True) with open(output_file,'w') as f: json.dump(md_results_all,f,indent=1) validation_options = ValidateBatchResultsOptions() validation_options.check_image_existence = True validation_options.relative_path_base = mewc_input_folder validation_options.raise_errors = True validation_results = validate_batch_results(output_file,validation_options) # noqa
# ...def mewc_to_md(...) #%% Interactive driver if False: pass #%% mewc_input_folder = r'G:\temp\mewc-test' mount_prefix = '/images/' output_file = os.path.join(mewc_input_folder,'results_with_classes.json') _ = mewc_to_md(mewc_input_folder=mewc_input_folder, output_file=output_file, mount_prefix=mount_prefix, category_name_column='class_id') #%% Command-line driver def main(): # noqa parser = argparse.ArgumentParser() parser.add_argument( 'input_folder',type=str, help='Folder containing images and MEWC .json/.csv files') parser.add_argument( 'output_file',type=str, help='.json file where output will be written') parser.add_argument( '--mount_prefix',type=str,default=default_mewc_mount_prefix, help='prefix to remove from each filename in MEWC results, typically the Docker mount point') parser.add_argument( '--category_name_column',type=str,default=default_mewc_category_name_column, help='column in the MEWC .csv file to use for category names') if len(sys.argv[1:]) == 0: parser.print_help() parser.exit() args = parser.parse_args() _ = mewc_to_md(mewc_input_folder=args.input_folder, output_file=args.output_file, mount_prefix=args.mount_prefix, category_name_column=args.category_name_column) if __name__ == '__main__': main()