Source code for megadetector.utils.split_locations_into_train_val

"""

split_locations_into_train_val.py

Splits a list of location IDs into training and validation, targeting a specific
train/val split for each category, but allowing some categories to be tighter or looser
than others.  Does nothing particularly clever, just randomly splits locations into
train/val lots of times using the target val fraction, and picks the one that meets the
specified constraints and minimizes weighted error, where "error" is defined as the
sum of each class's absolute divergence from the target val fraction.

"""

#%% Imports/constants

import random
import numpy as np

from collections import defaultdict
from megadetector.utils.ct_utils import sort_dictionary_by_value
from tqdm import tqdm


#%% Main function

[docs] def split_locations_into_train_val(location_to_category_counts, n_random_seeds=10000, target_val_fraction=0.15, category_to_max_allowable_error=None, category_to_error_weight=None, default_max_allowable_error=0.1, require_complete_coverage=True): """ Splits a list of location IDs into training and validation, targeting a specific train/val split for each category, but allowing some categories to be tighter or looser than others. Does nothing particularly clever, just randomly splits locations into train/val lots of times using the target val fraction, and picks the one that meets the specified constraints and minimizes weighted error, where "error" is defined as the sum of each class's absolute divergence from the target val fraction. Args: location_to_category_counts (dict): a dict mapping location IDs to dicts, with each dict mapping a category name to a count. Any categories not present in a particular dict are assumed to have a count of zero for that location. For example: .. code-block:: none {'location-000': {'bear':4,'wolf':10}, 'location-001': {'bear':12,'elk':20}} n_random_seeds (int, optional): number of random seeds to try, always starting from zero target_val_fraction (float, optional): fraction of images containing each species we'd like to put in the val split category_to_max_allowable_error (dict, optional): a dict mapping category names to maximum allowable errors. These are hard constraints (i.e., we will error if we can't meet them). Does not need to include all categories; categories not included will be assigned a maximum error according to [default_max_allowable_error]. If this is None, no hard constraints are applied. category_to_error_weight (dict, optional): a dict mapping category names to error weights. You can specify a subset of categories; categories not included here have a weight of 1.0. If None, all categories have the same weight. default_max_allowable_error (float, optional): the maximum allowable error for categories not present in [category_to_max_allowable_error]. Set to None (or >= 1.0) to disable hard constraints for categories not present in [category_to_max_allowable_error] require_complete_coverage (bool, optional): require that every category appear in both train and val Returns: tuple: A two-element tuple: - list of location IDs in the val split - a dict mapping category names to the fraction of images in the val split """ location_ids = list(location_to_category_counts.keys()) n_val_locations = int(target_val_fraction*len(location_ids)) if category_to_max_allowable_error is None: category_to_max_allowable_error = {} if category_to_error_weight is None: category_to_error_weight = {} # category ID to total count; the total count is used only for printouts category_id_to_count = {} for location_id in location_to_category_counts: for category_id in location_to_category_counts[location_id].keys(): if category_id not in category_id_to_count: category_id_to_count[category_id] = 0 category_id_to_count[category_id] += \ location_to_category_counts[location_id][category_id] category_ids = set(category_id_to_count.keys()) print('Splitting {} categories over {} locations'.format( len(category_ids),len(location_ids))) # random_seed = 0 def compute_seed_errors(random_seed): """ Computes the per-category error for a specific random seed. returns weighted_average_error,category_to_val_fraction """ # Randomly split into train/val random.seed(random_seed) val_locations = random.sample(location_ids,k=n_val_locations) val_locations_set = set(val_locations) # For each category, measure the % of images that went into the val set category_to_val_fraction = defaultdict(float) for category_id in category_ids: category_val_count = 0 category_train_count = 0 for location_id in location_to_category_counts: if category_id not in location_to_category_counts[location_id]: location_category_count = 0 else: location_category_count = location_to_category_counts[location_id][category_id] if location_id in val_locations_set: category_val_count += location_category_count else: category_train_count += location_category_count category_val_fraction = category_val_count / (category_val_count + category_train_count) category_to_val_fraction[category_id] = category_val_fraction # Absolute deviation from the target val fraction for each category category_errors = {} weighted_category_errors = {} # category = next(iter(category_to_val_fraction)) for category in category_to_val_fraction: category_val_fraction = category_to_val_fraction[category] category_error = abs(category_val_fraction-target_val_fraction) category_errors[category] = category_error category_weight = 1.0 if category in category_to_error_weight: category_weight = category_to_error_weight[category] weighted_category_error = category_error * category_weight weighted_category_errors[category] = weighted_category_error weighted_average_error = np.mean(list(weighted_category_errors.values())) return weighted_average_error,weighted_category_errors,category_to_val_fraction # ... def compute_seed_errors(...) # This will only include random seeds that satisfy the hard constraints random_seed_to_weighted_average_error = {} # random_seed = 0 for random_seed in tqdm(range(0,n_random_seeds)): weighted_average_error,weighted_category_errors,category_to_val_fraction = \ compute_seed_errors(random_seed) seed_satisfies_hard_constraints = True for category in category_to_val_fraction: if category in category_to_max_allowable_error: max_allowable_error = category_to_max_allowable_error[category] else: if default_max_allowable_error is None: continue max_allowable_error = default_max_allowable_error val_fraction = category_to_val_fraction[category] # If necessary, verify that this category doesn't *only* appear in train or val if require_complete_coverage: if (val_fraction == 0.0) or (val_fraction == 1.0): seed_satisfies_hard_constraints = False break # Check whether this category exceeds the hard maximum deviation category_error = abs(val_fraction - target_val_fraction) if category_error > max_allowable_error: seed_satisfies_hard_constraints = False break # ...for each category if seed_satisfies_hard_constraints: random_seed_to_weighted_average_error[random_seed] = weighted_average_error # ...for each random seed assert len(random_seed_to_weighted_average_error) > 0, \ 'No random seed met all the hard constraints' print('\n{} of {} random seeds satisfied hard constraints'.format( len(random_seed_to_weighted_average_error),n_random_seeds)) min_error = None min_error_seed = None for random_seed in random_seed_to_weighted_average_error.keys(): error_metric = random_seed_to_weighted_average_error[random_seed] if min_error is None or error_metric < min_error: min_error = error_metric min_error_seed = random_seed random.seed(min_error_seed) val_locations = random.sample(location_ids,k=n_val_locations) train_locations = [] for location_id in location_ids: if location_id not in val_locations: train_locations.append(location_id) print('\nVal locations:\n') for loc in val_locations: print('{}'.format(loc)) print('') weighted_average_error,weighted_category_errors,category_to_val_fraction = \ compute_seed_errors(min_error_seed) category_to_val_fraction = sort_dictionary_by_value(category_to_val_fraction, sort_values=category_id_to_count, reverse=True) print('Val fractions by category:\n') for category in category_to_val_fraction: print('{} ({}) {:.2f}'.format( category,category_id_to_count[category], category_to_val_fraction[category])) return val_locations,category_to_val_fraction
# ...def split_locations_into_train_val(...)