"""
pytorch_detector.py
Module to run YOLO-based MegaDetector models.
"""
#%% Imports and constants
import os
import sys
import math
import zipfile
import tempfile
import shutil
import uuid
import json
import inspect
import cv2
import torch
import numpy as np
from megadetector.detection.run_detector import \
CONF_DIGITS, COORD_DIGITS, FAILURE_INFER, FAILURE_IMAGE_OPEN, \
get_detector_version_from_model_file, \
known_models
from megadetector.utils.ct_utils import parse_bool_string
from megadetector.utils.ct_utils import is_running_in_gha
from megadetector.utils import ct_utils
import torchvision
# We support a few ways of accessing the YOLOv5 dependencies:
#
# * The standard configuration as of 2023.09 expects that the YOLOv5 repo is checked
# out and on the PYTHONPATH (import utils)
#
# * Supported but non-default (used for PyPI packaging):
#
# pip install ultralytics-yolov5
#
# * Works, but not supported:
#
# pip install yolov5
#
# * Unfinished:
#
# pip install ultralytics
yolo_model_type_imported = None
def _get_model_type_for_model(model_file,
prefer_model_type_source='table',
default_model_type='yolov5',
verbose=False):
"""
Determine the model type (i.e., the inference library we need to use) for a .pt file.
Args:
model_file (str): the model file to read
prefer_model_type_source (str, optional): how should we handle the (very unlikely)
case where the metadata in the file indicates one model type, but the global model
type table says something else. Should be "table" (trust the table) or "file"
(trust the file).
default_model_type (str, optional): return value for the case where we can't find
appropriate metadata in the file or in the global table.
verbose (bool, optional): enable additional debug output
Returns:
str: the model type indicated for this model
"""
model_info = read_metadata_from_megadetector_model_file(model_file)
# Check whether the model file itself specified a model type
model_type_from_model_file_metadata = None
if model_info is not None and 'model_type' in model_info:
model_type_from_model_file_metadata = model_info['model_type']
if verbose:
print('Parsed model type {} from model {}'.format(
model_type_from_model_file_metadata,
model_file))
model_type_from_model_version = None
# Check whether this is a known model version with a specific model type
model_version_from_file = get_detector_version_from_model_file(model_file)
if model_version_from_file is not None and model_version_from_file in known_models:
model_info = known_models[model_version_from_file]
if 'model_type' in model_info:
model_type_from_model_version = model_info['model_type']
if verbose:
print('Parsed model type {} from global metadata'.format(model_type_from_model_version))
else:
model_type_from_model_version = None
if model_type_from_model_file_metadata is None and \
model_type_from_model_version is None:
if verbose:
print('Could not determine model type for {}, assuming {}'.format(
model_file,default_model_type))
model_type = default_model_type
elif model_type_from_model_file_metadata is not None and \
model_type_from_model_version is not None:
if model_type_from_model_version == model_type_from_model_file_metadata:
model_type = model_type_from_model_file_metadata
else:
print('Warning: model type from model version is {}, from file metadata is {}'.format(
model_type_from_model_version,model_type_from_model_file_metadata))
if prefer_model_type_source == 'table':
model_type = model_type_from_model_version
else:
model_type = model_type_from_model_file_metadata
elif model_type_from_model_file_metadata is not None:
model_type = model_type_from_model_file_metadata
elif model_type_from_model_version is not None:
model_type = model_type_from_model_version
return model_type
# ...def _get_model_type_for_model(...)
def _initialize_yolo_imports_for_model(model_file,
prefer_model_type_source='table',
default_model_type='yolov5',
detector_options=None,
verbose=False):
"""
Initialize the appropriate YOLO imports for a model file.
Args:
model_file (str): The model file for which we're loading support
prefer_model_type_source (str, optional): how should we handle the (very unlikely)
case where the metadata in the file indicates one model type, but the global model
type table says something else. Should be "table" (trust the table) or "file"
(trust the file).
default_model_type (str, optional): return value for the case where we can't find
appropriate metadata in the file or in the global table.
detector_options (dict, optional): dictionary of detector options that mean
different things to different models
verbose (bool, optional): enable additional debug output
Returns:
str: the model type for which we initialized support
"""
global yolo_model_type_imported
if detector_options is not None and 'model_type' in detector_options:
model_type = detector_options['model_type']
print('Model type {} provided in detector options'.format(model_type))
else:
model_type = _get_model_type_for_model(model_file,
prefer_model_type_source=prefer_model_type_source,
default_model_type=default_model_type)
if yolo_model_type_imported is not None:
if model_type == yolo_model_type_imported:
print('Bypassing imports for model type {}'.format(model_type))
return
else:
print('Previously set up imports for model type {}, re-importing as {}'.format(
yolo_model_type_imported,model_type))
_initialize_yolo_imports(model_type,verbose=verbose)
return model_type
def _clean_yolo_imports(verbose=False, aggressive_cleanup=False):
"""
Remove all YOLO-related imports from sys.modules and sys.path, to allow a clean re-import
of another YOLO library version. The reason we jump through all these hoops, rather than
just, e.g., handling different libraries in different modules, is that we need to make sure
*pickle* sees the right version of modules during module loading, including modules we don't
load directly (i.e., every module loaded within a YOLO library), and the only way I know to
do that is to remove all the "wrong" versions from sys.modules and sys.path.
Args:
verbose (bool, optional): enable additional debug output
aggressive_cleanup (bool, optional): err on the side of removing modules,
at least by ignoring whether they are/aren't in a site-packages folder.
By default, only modules in a folder that includes "site-packages" will
be considered for unloading.
"""
modules_to_delete = []
for module_name in sys.modules.keys():
module = sys.modules[module_name]
if not hasattr(module,'__file__') or (module.__file__ is None):
continue
try:
module_file = module.__file__.replace('\\','/')
if not aggressive_cleanup:
if 'site-packages' not in module_file:
continue
tokens = module_file.split('/')
# For local path imports, a module filename that should be unloaded might
# look like:
#
# c:/git/yolov9/models/common.py
#
# For pip imports, a module filename that should be unloaded might look like:
#
# c:/users/user/miniforge3/envs/megadetector/lib/site-packages/yolov9/utils/__init__.py
first_token_to_check = len(tokens) - 4
for i_token,token in enumerate(tokens):
if i_token < first_token_to_check:
continue
# Don't remove anything based on the environment name, which
# always follows "envs" in the path
if (i_token > 1) and (tokens[i_token-1] == 'envs'):
continue
if ('yolov5' in token) or ('yolov9' in token) or ('ultralytics' in token):
if verbose:
print('Module {} ({}) looks deletable'.format(module_name,module_file))
modules_to_delete.append(module_name)
break
except Exception as e:
if verbose:
print('Exception during module review: {}'.format(str(e)))
pass
# ...for each module in the global namespace
for module_name in modules_to_delete:
if module_name in sys.modules.keys():
if verbose:
try:
module = sys.modules[module_name]
module_file = module.__file__.replace('\\','/')
print('clean_yolo_imports: deleting module {}: {}'.format(module_name,module_file))
except Exception:
pass
del sys.modules[module_name]
# ...for each module we want to remove from the global namespace
paths_to_delete = []
for p in sys.path:
if p.endswith('yolov5') or p.endswith('yolov9') or p.endswith('ultralytics'):
print('clean_yolo_imports: removing {} from path'.format(p))
paths_to_delete.append(p)
for p in paths_to_delete:
sys.path.remove(p)
# ...def _clean_yolo_imports(...)
def _initialize_yolo_imports(model_type='yolov5',
allow_fallback_import=True,
force_reimport=False,
verbose=False):
"""
Imports required functions from one or more yolo libraries (yolov5, yolov9,
ultralytics, targeting support for [model_type]).
Args:
model_type (str): The model type for which we're loading support
allow_fallback_import (bool, optional): If we can't import from the package for
which we're trying to load support, fall back to "import utils". This is
typically used when the right support library is on the current PYTHONPATH.
force_reimport (bool, optional): import the appropriate libraries even if the
requested model type matches the current initialization state
verbose (bool, optional): include additional debug output
Returns:
str: the model type for which we initialized support
"""
# When running in pytest, the megadetector 'utils' module is put in the global
# namespace, which creates conflicts with yolov5; remove it from the global
# namespsace.
if ('PYTEST_CURRENT_TEST' in os.environ):
print('*** pytest detected ***')
if ('utils' in sys.modules):
utils_module = sys.modules['utils']
if hasattr(utils_module, '__file__') and 'megadetector' in str(utils_module.__file__):
print(f"Removing conflicting utils module: {utils_module.__file__}")
sys.modules.pop('utils', None)
# Also remove any submodules
to_remove = [name for name in sys.modules if name.startswith('utils.')]
for name in to_remove:
sys.modules.pop(name, None)
global yolo_model_type_imported
if model_type is None:
model_type = 'yolov5'
# The point of this function is to make the appropriate version
# of the following functions available at module scope
global non_max_suppression
global xyxy2xywh
global letterbox
global scale_coords
if yolo_model_type_imported is not None:
if (yolo_model_type_imported == model_type) and (not force_reimport):
print('Bypassing imports for YOLO model type {}'.format(model_type))
return
else:
print('Starting yolo import cleanup')
_clean_yolo_imports()
try_yolov5_import = (model_type == 'yolov5')
try_yolov9_import = (model_type == 'yolov9')
try_ultralytics_import = (model_type == 'ultralytics')
utils_imported = False
# First try importing from the yolov5 package; this is how the pip
# package finds YOLOv5 utilities.
if try_yolov5_import and not utils_imported:
try:
# from yolov5.utils.general import non_max_suppression # type: ignore
from yolov5.utils.general import xyxy2xywh # noqa
from yolov5.utils.augmentations import letterbox # noqa
try:
from yolov5.utils.general import scale_boxes as scale_coords
except Exception:
from yolov5.utils.general import scale_coords
utils_imported = True
if verbose:
print('Imported utils from YOLOv5 package')
except Exception as e: # noqa
# print('yolov5 module import failed: {}'.format(e))
# print(traceback.format_exc())
pass
# Next try importing from the yolov9 package
if try_yolov9_import and not utils_imported:
try:
# from yolov9.utils.general import non_max_suppression # noqa
from yolov9.utils.general import xyxy2xywh # noqa
from yolov9.utils.augmentations import letterbox # noqa
from yolov9.utils.general import scale_boxes as scale_coords # noqa
utils_imported = True
if verbose:
print('Imported utils from YOLOv9 package')
except Exception as e: # noqa
# print('yolov9 module import failed: {}'.format(e))
# print(traceback.format_exc())
print('\n\n*****\nIt looks like you are trying to run a model that requires the yolov9pip package, '
'but the yolov9pip package is not installed. Nothing is wrong, this package is '
'just not installed by default with the MegaDetector Python package. Run '
'"pip install yolov9pip" to install it, and try again.\n*****\n\n')
raise
# If we haven't succeeded yet, import from the ultralytics package
if try_ultralytics_import and not utils_imported:
try:
import ultralytics # type: ignore # noqa
except Exception:
print('\n\n*****\nIt looks like you are trying to run a model that requires the ultralytics package, '
'but the ultralytics package is not installed. Nothing is wrong, this package is '
'just not installed by default with the MegaDetector Python package. Run '
'"pip install ultralytics" to install it, and try again.\n*****\n\n')
raise
try:
# The non_max_suppression() function moved from the ops module to the nms module
# in mid-2025
try:
from ultralytics.utils.ops import non_max_suppression # type: ignore # noqa
except Exception:
from ultralytics.utils.nms import non_max_suppression # type: ignore # noqa
from ultralytics.utils.ops import xyxy2xywh # type: ignore # noqa
# In the ultralytics package, scale_boxes and scale_coords both exist;
# we want scale_boxes.
#
# from ultralytics.utils.ops import scale_coords # noqa
from ultralytics.utils.ops import scale_boxes as scale_coords # type: ignore # noqa
from ultralytics.data.augment import LetterBox # type: ignore # noqa
# letterbox() became a LetterBox class in the ultralytics package. Create a
# backwards-compatible letterbox function wrapper that wraps the class up.
def letterbox(img,new_shape,auto=False,scaleFill=False, #noqa
scaleup=True,center=True,stride=32):
# Ultralytics changed the "scaleFill" parameter to "scale_fill", we want to support
# both conventions.
use_old_scalefill_arg = False
try:
sig = inspect.signature(LetterBox.__init__)
if 'scaleFill' in sig.parameters:
use_old_scalefill_arg = True
except Exception:
pass
if use_old_scalefill_arg:
if verbose:
print('Using old scaleFill calling convention')
letterbox_transformer = LetterBox(new_shape,auto=auto,scaleFill=scaleFill,
scaleup=scaleup,center=center,stride=stride)
else:
letterbox_transformer = LetterBox(new_shape,auto=auto,scale_fill=scaleFill,
scaleup=scaleup,center=center,stride=stride)
letterbox_result = letterbox_transformer(image=img)
if isinstance(new_shape,int):
new_shape = [new_shape,new_shape]
# The letterboxing is done, we just need to reverse-engineer what it did
shape = img.shape[:2]
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
if not scaleup:
r = min(r, 1.0)
ratio = r, r
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]
if auto:
dw, dh = np.mod(dw, stride), np.mod(dh, stride)
elif scaleFill:
dw, dh = 0.0, 0.0
new_unpad = (new_shape[1], new_shape[0])
ratio = (new_shape[1] / shape[1], new_shape[0] / shape[0])
dw /= 2
dh /= 2
pad = (dw,dh)
return [letterbox_result,ratio,pad]
utils_imported = True
if verbose:
print('Imported utils from ultralytics package')
except Exception as e:
print('Ultralytics module import failed: {}'.format(str(e)))
pass
# If we haven't succeeded yet, assume the YOLOv5 repo is on our PYTHONPATH.
if (not utils_imported) and allow_fallback_import:
try:
# import pre- and post-processing functions from the YOLOv5 repo
# from utils.general import non_max_suppression # type: ignore
from utils.general import xyxy2xywh # type: ignore
from utils.augmentations import letterbox # type: ignore
# scale_coords() is scale_boxes() in some YOLOv5 versions
try:
from utils.general import scale_coords # type: ignore
except ImportError:
from utils.general import scale_boxes as scale_coords # type: ignore
utils_imported = True
imported_file = sys.modules[scale_coords.__module__].__file__
if verbose:
print('Imported utils from {}'.format(imported_file))
except ModuleNotFoundError as e:
raise ModuleNotFoundError('Could not import YOLOv5 functions:\n{}'.format(str(e)))
assert utils_imported, 'YOLO utils import error'
yolo_model_type_imported = model_type
if verbose:
print('Prepared YOLO imports for model type {}'.format(model_type))
return model_type
# ...def _initialize_yolo_imports(...)
#%% NMS
[docs]
def nms(prediction, conf_thres=0.25, iou_thres=0.45, max_det=300):
"""
Non-maximum suppression (a wrapper around torchvision.ops.nms())
Args:
prediction (torch.Tensor): Model predictions with shape [batch_size, num_anchors, num_classes + 5]
Format: [x_center, y_center, width, height, objectness, class1_conf, class2_conf, ...]
Coordinates are normalized to input image size.
conf_thres (float): Confidence threshold for filtering detections
iou_thres (float): IoU threshold for NMS
max_det (int): Maximum number of detections per image
Returns:
list: List of tensors, one per image in batch. Each tensor has shape [N, 6] where:
- N is the number of detections for that image
- Columns are [x1, y1, x2, y2, confidence, class_id]
- Coordinates are in absolute pixels relative to input image size
- class_id is the integer class index (0-based)
"""
batch_size = prediction.shape[0]
num_classes = prediction.shape[2] - 5 # noqa
output = []
# Process each image in the batch
for img_idx in range(batch_size):
x = prediction[img_idx] # Shape: [num_anchors, num_classes + 5]
# Filter by objectness confidence
obj_conf = x[:, 4]
valid_detections = obj_conf > conf_thres
x = x[valid_detections]
if x.shape[0] == 0:
# No detections for this image
output.append(torch.zeros((0, 6), device=prediction.device))
continue
# Convert box coordinates from [x_center, y_center, w, h] to [x1, y1, x2, y2]
box = x[:, :4].clone()
box[:, 0] = x[:, 0] - x[:, 2] / 2.0 # x1 = center_x - width/2
box[:, 1] = x[:, 1] - x[:, 3] / 2.0 # y1 = center_y - height/2
box[:, 2] = x[:, 0] + x[:, 2] / 2.0 # x2 = center_x + width/2
box[:, 3] = x[:, 1] + x[:, 3] / 2.0 # y2 = center_y + height/2
# Get class predictions: multiply objectness by class probabilities
class_conf = x[:, 5:] * x[:, 4:5] # shape: [N, num_classes]
# For each detection, take the class with highest confidence (single-label)
best_class_conf, best_class_idx = class_conf.max(1, keepdim=True)
# Filter by class confidence threshold
conf_mask = best_class_conf.view(-1) > conf_thres
if conf_mask.sum() == 0:
# No detections pass confidence threshold
output.append(torch.zeros((0, 6), device=prediction.device))
continue
box = box[conf_mask]
best_class_conf = best_class_conf[conf_mask]
best_class_idx = best_class_idx[conf_mask]
# Prepare for NMS: group detections by class
unique_classes = best_class_idx.unique()
final_detections = []
for class_id in unique_classes:
class_mask = (best_class_idx == class_id).view(-1)
class_boxes = box[class_mask]
class_scores = best_class_conf[class_mask].view(-1)
if class_boxes.shape[0] == 0:
continue
# Apply NMS for this class
keep_indices = torchvision.ops.nms(class_boxes, class_scores, iou_thres)
if len(keep_indices) > 0:
kept_boxes = class_boxes[keep_indices]
kept_scores = class_scores[keep_indices]
kept_classes = torch.full((len(keep_indices), 1), class_id.item(),
device=prediction.device, dtype=torch.float)
# Combine: [x1, y1, x2, y2, conf, class]
class_detections = torch.cat([kept_boxes, kept_scores.unsqueeze(1), kept_classes], 1)
final_detections.append(class_detections)
# ...for each category
if final_detections:
# Combine all classes and sort by confidence
all_detections = torch.cat(final_detections, 0)
conf_sort_indices = all_detections[:, 4].argsort(descending=True)
all_detections = all_detections[conf_sort_indices]
# Limit to max_det
if all_detections.shape[0] > max_det:
all_detections = all_detections[:max_det]
output.append(all_detections)
else:
output.append(torch.zeros((0, 6), device=prediction.device))
# ...for each image in the batch
return output
# ...def nms(...)
#%% Model metadata functions
# ...def add_metadata_to_megadetector_model_file(...)
# ...with zipfile.Zipfile(...)
# ...def read_metadata_from_megadetector_model_file(...)
#%% Inference classes
default_compatibility_mode = 'classic'
# This is a useful hack when I want to verify that my test driver (md_tests.py) is
# correctly forcing a specific compatibility mode (I use "classic-test" in that case)
require_non_default_compatibility_mode = False
[docs]
class PTDetector:
"""
Class that runs a PyTorch-based MegaDetector model. Also used as a preprocessor
for images that will later be run through an instance of PTDetector.
"""
def __init__(self, model_path, detector_options=None, verbose=False):
"""
PTDetector constructor. If detector_options['preprocess_only'] exists and is
True, this instance is being used as a preprocessor, so we don't load model weights.
"""
if verbose:
print('Initializing PTDetector (verbose)')
# Set up the import environment for this model, unloading previous
# YOLO library versions if necessary.
_initialize_yolo_imports_for_model(model_path,
detector_options=detector_options,
verbose=verbose)
# Parse options specific to this detector family
force_cpu = False
use_model_native_classes = False
compatibility_mode = default_compatibility_mode
if detector_options is not None:
if 'force_cpu' in detector_options:
force_cpu = parse_bool_string(detector_options['force_cpu'])
if 'use_model_native_classes' in detector_options:
use_model_native_classes = parse_bool_string(detector_options['use_model_native_classes'])
if 'compatibility_mode' in detector_options:
if detector_options['compatibility_mode'] is None:
compatibility_mode = default_compatibility_mode
else:
compatibility_mode = detector_options['compatibility_mode']
# This is a global option used only during testing, to make sure I'm hitting
# the cases where we are not using "classic" preprocessing.
if require_non_default_compatibility_mode:
print('### DEBUG: requiring non-default compatibility mode ###')
assert compatibility_mode != 'classic'
assert compatibility_mode != 'default'
preprocess_only = False
if (detector_options is not None) and \
('preprocess_only' in detector_options) and \
(detector_options['preprocess_only']):
preprocess_only = True
if verbose or (not preprocess_only):
print('Loading PT detector with compatibility mode {}'.format(compatibility_mode))
self.model_metadata = read_metadata_from_megadetector_model_file(model_path)
#: Image size passed to the letterbox() function; 1280 means "1280 on the long side,
#: preserving aspect ratio".
if self.model_metadata is not None and 'image_size' in self.model_metadata:
self.default_image_size = self.model_metadata['image_size']
print('Loaded image size {} from model metadata'.format(self.default_image_size))
else:
# This is not the default for most YOLO models, but most of the time, if someone
# is loading a model here that does not have metadata, it's MDv5[ab].0.0
print('No image size available in model metadata, defaulting to 1280')
self.default_image_size = 1280
#: Either a string ('cpu','cuda:0') or a torch.device()
self.device = 'cpu'
#: Have we already printed a warning about using a non-standard image size?
#:
#: :meta private:
self.printed_image_size_warning = False
#: If this is False, we assume the underlying model is producing class indices in the
#: set (0,1,2) (and we assert() on this), and we add 1 to get to the backwards-compatible
#: MD classes (1,2,3) before generating output. If this is True, we use whatever
#: indices the model provides
self.use_model_native_classes = use_model_native_classes
#: This allows us to maintain backwards compatibility across a set of changes to the
#: way this class does inference. Currently should start with either "default" or
#: "classic".
self.compatibility_mode = compatibility_mode
#: Stride size passed to the YOLO letterbox() function
self.letterbox_stride = 32
# This is a convenient heuristic to determine the stride size without actually loading
# the model: the only models in the YOLO family with a stride size of 64 are the
# YOLOv5*6 and YOLOv5*6u models, which are 1280px models.
#
# See:
#
# github.com/ultralytics/ultralytics/issues/21544
#
# Note to self, though, if I decide later to require loading the model on preprocessing
# workers so I can more reliably choose a stride, this is the right way to determine the
# stride:
#
# self.letterbox_stride = int(self.model.stride.max())
if self.default_image_size == 1280:
self.letterbox_stride = 64
print('Using model stride: {}'.format(self.letterbox_stride))
#: Use half-precision inference... fixed by the model, generally don't mess with this
self.half_precision = False
if preprocess_only:
return
if (detector_options is not None) and \
('device' in detector_options) and \
detector_options['device'] is not None:
print('Using caller-specified device: {}'.format(detector_options['device']))
self.device = detector_options['device']
else:
if not force_cpu:
if torch.cuda.is_available():
self.device = torch.device('cuda:0')
try:
if torch.backends.mps.is_built and torch.backends.mps.is_available():
# MPS inference fails on GitHub runners as of 2025.08. This is
# independent of model size. So, we disable MPS when running in GHA.
if is_running_in_gha():
print('GitHub actions detected, bypassing MPS backend')
else:
print('Using MPS device')
self.device = 'mps'
except AttributeError:
pass
if self.device == 'cpu':
try:
import torch_directml
self.device = torch_directml.device()
print('Using DirectML device')
except ImportError:
pass
# AddaxAI depends on this printout, don't remove it
print('PTDetector using device {}'.format(str(self.device).lower()))
try:
self.model = PTDetector._load_model(model_path,
device=self.device,
compatibility_mode=self.compatibility_mode)
except Exception as e:
# In a very esoteric scenario where an old version of YOLOv5 is used to run
# newer models, we run into an issue because the "Model" class became
# "DetectionModel". New YOLOv5 code handles this case by just setting them
# to be the same, so doing that externally doesn't seem *that* rude.
if "Can't get attribute 'DetectionModel'" in str(e):
print('Forward-compatibility issue detected, patching')
from models import yolo # type: ignore
yolo.DetectionModel = yolo.Model
self.model = PTDetector._load_model(model_path,
device=self.device,
compatibility_mode=self.compatibility_mode,
verbose=verbose)
else:
raise
if (self.device != 'cpu'):
if verbose:
print('Sending model to GPU')
self.model.to(self.device)
@staticmethod
def _load_model(model_pt_path, device, compatibility_mode='', verbose=False):
if verbose:
print(f'Using PyTorch version {torch.__version__}')
# I get quirky errors when loading YOLOv5 models on MPS hardware using
# map_location, but this is the recommended method, so I'm using it everywhere
# other than MPS devices.
use_map_location = (device != 'mps')
# DirectML (privateuseone) doesn't support map_location; load to CPU, .to(device) handles the rest
safe_map_location = 'cpu' if 'privateuseone' in str(device) else device
if use_map_location:
try:
checkpoint = torch.load(model_pt_path, map_location=safe_map_location, weights_only=False)
# For a transitional period, we want to support torch 1.1x, where the weights_only
# parameter doesn't exist
except Exception as e:
if "'weights_only' is an invalid keyword" in str(e):
checkpoint = torch.load(model_pt_path, map_location=safe_map_location)
else:
raise
else:
try:
checkpoint = torch.load(model_pt_path, weights_only=False)
# For a transitional period, we want to support torch 1.1x, where the weights_only
# parameter doesn't exist
except Exception as e:
if "'weights_only' is an invalid keyword" in str(e):
checkpoint = torch.load(model_pt_path)
else:
raise
# Compatibility fix that allows us to load older YOLOv5 models with
# newer versions of YOLOv5/PT
for m in checkpoint['model'].modules():
t = type(m)
if t is torch.nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
m.recompute_scale_factor = None
# Calling .to(device) should no longer be necessary now that we're using map_location=device
# model = checkpoint['model'].float().fuse().eval().to(device)
model = checkpoint['model'].float().fuse().eval()
return model
# ...def _load_model(...)
[docs]
def preprocess_image(self,
img_original,
image_id='unknown',
image_size=None,
verbose=False):
"""
Prepare an image for detection, including scaling and letterboxing.
Args:
img_original (Image or np.array): the image on which we should run the detector, with
EXIF rotation already handled
image_id (str, optional): a path to identify the image; will be in the "file" field
of the output object
detection_threshold (float, optional): only detections above this confidence threshold
will be included in the return value
image_size (int, optional): image size (long side) to use for inference, or None to
use the default size specified at the time the model was loaded
verbose (bool, optional): enable additional debug output
Returns:
dict: dict with fields:
- file (filename)
- img (the preprocessed np.array)
- img_original (the input image before preprocessing, as an np.array)
- img_original_pil (the input image before preprocessing, as a PIL Image)
- target_shape (the 2D shape to which the image was resized during preprocessing)
- scaling_shape (the 2D original size, for normalizing coordinates later)
- letterbox_ratio (letterbox parameter used for normalizing coordinates later)
- letterbox_pad (letterbox parameter used for normalizing coordinates later)
"""
# Prepare return dict
result = {'file': image_id }
# Store the PIL version of the original image, the caller may want to use
# it for metadata extraction later.
img_original_pil = None
# If we were given a PIL image, rather than a numpy array
if not isinstance(img_original,np.ndarray):
img_original_pil = img_original
img_original = np.asarray(img_original)
# PIL images are RGB already
# img_original = img_original[:, :, ::-1]
# Save the original shape for scaling boxes later
scaling_shape = img_original.shape
# If the caller is requesting a specific target size...
if image_size is not None:
assert isinstance(image_size,int)
if not self.printed_image_size_warning:
print('Using user-supplied image size {}'.format(image_size))
self.printed_image_size_warning = True
# Otherwise resize to self.default_image_size
else:
image_size = self.default_image_size
self.printed_image_size_warning = False
# ...if the caller has specified an image size
# In "classic mode", we only do the letterboxing resize, we don't do an
# additional initial resizing operation
if 'classic' in self.compatibility_mode:
resize_ratio = 1.0
# Resize the image so the long side matches the target image size. This is not
# letterboxing (i.e., padding) yet, just resizing.
else:
use_ceil_for_resize = ('use_ceil_for_resize' in self.compatibility_mode)
h,w = img_original.shape[:2]
resize_ratio = image_size / max(h,w)
# Only resize if we have to
if resize_ratio != 1:
# Match what yolov5 does: use linear interpolation for upsizing;
# area interpolation for downsizing
if resize_ratio > 1:
interpolation_method = cv2.INTER_LINEAR
else:
interpolation_method = cv2.INTER_AREA
if use_ceil_for_resize:
target_w = math.ceil(w * resize_ratio)
target_h = math.ceil(h * resize_ratio)
else:
target_w = int(w * resize_ratio)
target_h = int(h * resize_ratio)
img_original = cv2.resize(
img_original, (target_w, target_h),
interpolation=interpolation_method)
if 'classic' in self.compatibility_mode:
letterbox_auto = True
letterbox_scaleup = True
target_shape = image_size
else:
letterbox_auto = False
letterbox_scaleup = False
# The padding to apply as a fraction of the stride size
pad = 0.5
# Resize to a multiple of the model stride
#
# This is how we would determine the stride if we knew the model had been loaded:
#
# model_stride = int(self.model.stride.max())
#
# ...but because we do this on preprocessing workers now, we try to avoid loading the model
# just for preprocessing, and we assume the stride was determined at the time the PTDetector
# object was created.
try:
model_stride = int(self.model.stride.max())
if model_stride != self.letterbox_stride:
print('*** Warning: model stride is {}, stride at construction time was {} ***'.format(
model_stride,self.letterbox_stride
))
except Exception:
pass
model_stride = self.letterbox_stride
max_dimension = max(img_original.shape)
normalized_shape = [img_original.shape[0] / max_dimension,
img_original.shape[1] / max_dimension]
target_shape = np.ceil(((np.array(normalized_shape) * image_size) / model_stride) + \
pad).astype(int) * model_stride
# Now we letterbox, which is just padding, since we've already resized
img,letterbox_ratio,letterbox_pad = letterbox(img_original,
new_shape=target_shape,
stride=self.letterbox_stride,
auto=letterbox_auto,
scaleFill=False,
scaleup=letterbox_scaleup)
result['img_processed'] = img
result['img_original'] = img_original
result['img_original_pil'] = img_original_pil
result['target_shape'] = target_shape
result['scaling_shape'] = scaling_shape
result['letterbox_ratio'] = letterbox_ratio
result['letterbox_pad'] = letterbox_pad
return result
# ...def preprocess_image(...)
[docs]
def generate_detections_one_batch(self,
img_original,
image_id=None,
detection_threshold=0.00001,
image_size=None,
augment=False,
verbose=False):
"""
Run a detector on a batch of images.
Args:
img_original (list): list of images (Image, np.array, or dict) on which we should run the detector, with
EXIF rotation already handled, or dicts representing preprocessed images with associated
letterbox parameters
image_id (list or None): list of paths to identify the images; will be in the "file" field
of the output objects. Will be ignored when img_original contains preprocessed dicts.
detection_threshold (float, optional): only detections above this confidence threshold
will be included in the return value
image_size (int, optional): image size (long side) to use for inference, or None to
use the default size specified at the time the model was loaded
augment (bool, optional): enable (implementation-specific) image augmentation
verbose (bool, optional): enable additional debug output
Returns:
list: a list of dictionaries, each with the following fields:
- 'file' (filename, always present)
- 'max_detection_conf' (removed from MegaDetector output files by default, but generated here)
- 'detections' (a list of detection objects containing keys 'category', 'conf', and 'bbox')
- 'failure' (a failure string, or None if everything went fine)
"""
# Validate inputs
if not isinstance(img_original, list):
raise ValueError('img_original must be a list for batch processing')
if len(img_original) == 0:
return []
# Check input consistency
if isinstance(img_original[0], dict):
# All items in img_original should be preprocessed dicts
for i, img in enumerate(img_original):
if not isinstance(img, dict):
raise ValueError(f'Mixed input types in batch: item {i} is not a dict, but item 0 is a dict')
else:
# All items in img_original should be PIL/numpy images, and image_id should be a list of strings
if image_id is None:
raise ValueError('image_id must be a list when img_original contains PIL/numpy images')
if not isinstance(image_id, list):
raise ValueError('image_id must be a list for batch processing')
if len(image_id) != len(img_original):
raise ValueError(
'Length mismatch: img_original has {} items, image_id has {} items'.format(
len(img_original),len(image_id)))
for i_img, img in enumerate(img_original):
if isinstance(img, dict):
raise ValueError(
'Mixed input types in batch: item {} is a dict, but item 0 is not a dict'.format(
i_img))
if detection_threshold is None:
detection_threshold = 0.0
batch_size = len(img_original)
results = [None] * batch_size
# Preprocess all images, handling failures
preprocessed_images = []
preprocessing_failed_indices = set()
for i_img, img in enumerate(img_original):
try:
if isinstance(img, dict):
# Already preprocessed
image_info = img
current_image_id = image_info['file']
else:
# Need to preprocess
current_image_id = image_id[i_img]
image_info = self.preprocess_image(
img_original=img,
image_id=current_image_id,
image_size=image_size,
verbose=verbose)
preprocessed_images.append((i_img, image_info, current_image_id))
except Exception as e:
print('Warning: preprocessing failed for image {}: {}'.format(
image_id[i_img] if image_id else f'index_{i_img}', str(e)))
preprocessing_failed_indices.add(i_img)
current_image_id = image_id[i_img] if image_id else f'index_{i_img}'
results[i_img] = {
'file': current_image_id,
'detections': None,
'failure': FAILURE_IMAGE_OPEN
}
# ...for each image in this batch
# Group preprocessed images by actual processed image shape for batching
shape_groups = {}
for original_idx, image_info, current_image_id in preprocessed_images:
# Use the actual processed image shape for grouping, not target_shape
actual_shape = tuple(image_info['img_processed'].shape)
if actual_shape not in shape_groups:
shape_groups[actual_shape] = []
shape_groups[actual_shape].append((original_idx, image_info, current_image_id))
# Process each shape group as a batch
for target_shape, group_items in shape_groups.items():
try:
self._process_batch_group(group_items, results, detection_threshold, augment, verbose)
except Exception as e:
# If inference fails for the entire batch, mark all images in this batch as failed
print('Warning: batch inference failed for shape {}: {}'.format(target_shape, str(e)))
for original_idx, image_info, current_image_id in group_items:
results[original_idx] = {
'file': current_image_id,
'detections': None,
'failure': FAILURE_INFER
}
# ...for each shape group
return results
# ...def generate_detections_one_batch(...)
def _process_batch_group(self, group_items, results, detection_threshold, augment, verbose):
"""
Process a group of images with the same target shape as a single batch.
Args:
group_items (list): List of (original_idx, image_info, current_image_id) tuples
results (list): Results list to populate (modified in place)
detection_threshold (float): Detection confidence threshold
augment (bool): Enable augmentation
verbose (bool): Enable verbose output
Returns:
list of dict: list of dictionaries the same length as group_items, with fields 'file',
'detections', 'max_detection_conf'.
"""
if len(group_items) == 0:
return
# Extract batch data
batch_images = []
batch_metadata = []
# For each image in this batch...
for original_idx, image_info, current_image_id in group_items:
img = image_info['img_processed']
# Convert HWC to CHW and prepare tensor
img_tensor = img.transpose((2, 0, 1))
img_tensor = np.ascontiguousarray(img_tensor)
img_tensor = torch.from_numpy(img_tensor)
batch_images.append(img_tensor)
metadata = {
'original_idx': original_idx,
'current_image_id': current_image_id,
'scaling_shape': image_info['scaling_shape'],
'letterbox_pad': image_info['letterbox_pad'],
'img_original': image_info['img_original']
}
batch_metadata.append(metadata)
# ...for each image in this batch
# Stack images into a batch tensor
batch_tensor = torch.stack(batch_images)
batch_tensor = batch_tensor.float()
batch_tensor /= 255.0
batch_tensor = batch_tensor.to(self.device)
if self.half_precision:
batch_tensor = batch_tensor.half()
# Run the model on the batch
pred = self.model(batch_tensor, augment=augment)[0]
if 'privateuseone' in str(self.device):
pred = pred.cpu()
# Configure NMS parameters
if 'classic' in self.compatibility_mode:
nms_iou_thres = 0.45
else:
nms_iou_thres = 0.6
use_library_nms = False
# Model output format changed in recent ultralytics packages, and the nms implementation
# in this module hasn't been updated to handle that format yet.
if (yolo_model_type_imported is not None) and (yolo_model_type_imported == 'ultralytics'):
use_library_nms = True
if use_library_nms:
pred = non_max_suppression(prediction=pred,
conf_thres=detection_threshold,
iou_thres=nms_iou_thres,
agnostic=False,
multi_label=False)
else:
pred = nms(prediction=pred,
conf_thres=detection_threshold,
iou_thres=nms_iou_thres)
assert isinstance(pred, list)
assert len(pred) == len(batch_metadata), \
'Mismatch between prediction length {} and batch size {}'.format(
len(pred),len(batch_metadata))
# Process each image's detections
for i_image, det in enumerate(pred):
metadata = batch_metadata[i_image]
original_idx = metadata['original_idx']
current_image_id = metadata['current_image_id']
scaling_shape = metadata['scaling_shape']
letterbox_pad = metadata['letterbox_pad']
img_original = metadata['img_original']
detections = []
max_conf = 0.0
if len(det) > 0:
# Prepare scaling parameters
gn = torch.tensor(scaling_shape)[[1, 0, 1, 0]]
if 'classic' in self.compatibility_mode:
ratio = None
ratio_pad = None
else:
ratio = (img_original.shape[0]/scaling_shape[0],
img_original.shape[1]/scaling_shape[1])
ratio_pad = (ratio, letterbox_pad)
# Rescale boxes
if 'classic' in self.compatibility_mode:
det[:, :4] = scale_coords(batch_tensor.shape[2:], det[:, :4], img_original.shape).round()
else:
det[:, :4] = scale_coords(batch_tensor.shape[2:], det[:, :4], scaling_shape, ratio_pad).round()
# Process each detection
for *xyxy, conf, cls in reversed(det):
if conf < detection_threshold:
continue
# Convert to YOLO format then to MD format
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()
api_box = ct_utils.convert_yolo_to_xywh(xywh)
if 'classic' in self.compatibility_mode:
api_box = ct_utils.truncate_float_array(api_box, precision=COORD_DIGITS)
conf = ct_utils.truncate_float(conf.tolist(), precision=CONF_DIGITS)
else:
api_box = ct_utils.round_float_array(api_box, precision=COORD_DIGITS)
conf = ct_utils.round_float(conf.tolist(), precision=CONF_DIGITS)
if not self.use_model_native_classes:
cls = int(cls.tolist()) + 1
if cls not in (1, 2, 3):
raise KeyError(f'{cls} is not a valid class.')
else:
cls = int(cls.tolist())
detections.append({
'category': str(cls),
'conf': conf,
'bbox': api_box
})
max_conf = max(max_conf, conf)
# ...for each detection
# ...if there are > 0 detections
# Store result for this image
results[original_idx] = {
'file': current_image_id,
'detections': detections,
'max_detection_conf': max_conf
}
# ...for each image
# ...def _process_batch_group(...)
[docs]
def generate_detections_one_image(self,
img_original,
image_id='unknown',
detection_threshold=0.00001,
image_size=None,
augment=False,
verbose=False):
"""
Run a detector on an image (wrapper around batch function).
Args:
img_original (Image, np.array, or dict): the image on which we should run the detector, with
EXIF rotation already handled, or a dict representing a preprocessed image with associated
letterbox parameters
image_id (str, optional): a path to identify the image; will be in the "file" field
of the output object
detection_threshold (float, optional): only detections above this confidence threshold
will be included in the return value
image_size (int, optional): image size (long side) to use for inference, or None to
use the default size specified at the time the model was loaded
augment (bool, optional): enable (implementation-specific) image augmentation
verbose (bool, optional): enable additional debug output
Returns:
dict: a dictionary with the following fields:
- 'file' (filename, always present)
- 'max_detection_conf' (removed from MegaDetector output files by default, but generated here)
- 'detections' (a list of detection objects containing keys 'category', 'conf', and 'bbox')
- 'failure' (a failure string, or None if everything went fine)
"""
# Prepare batch inputs
if isinstance(img_original, dict):
batch_results = self.generate_detections_one_batch(
img_original=[img_original],
image_id=None,
detection_threshold=detection_threshold,
image_size=image_size,
augment=augment,
verbose=verbose)
else:
batch_results = self.generate_detections_one_batch(
img_original=[img_original],
image_id=[image_id],
detection_threshold=detection_threshold,
image_size=image_size,
augment=augment,
verbose=verbose)
# Return the single result
return batch_results[0]
# ...def generate_detections_one_image(...)
# ...class PTDetector