Source code for flame.engine

import os, logging, json, glob, gc
from time import time
from shutil import rmtree
from typing import Union, Any

import onnx
import onnxruntime as ort
from onnxruntime import InferenceSession
import numpy as np
from numpy.typing import NDArray
import mlflow
from mlflow import artifacts

from .image import FLAMEImage, is_FLAME_image
from .error import CAREInferenceError, FLAMEImageError
from .utils import min_max_norm, _float_or_float_array

PATCH_OVERLAP_MAP = {
    16: 4, 32: 8, 64: 16, 128: 32, 256: 64, 512: 128
}

[docs] class CAREInferenceSession(): """ Class CAREInference Session Attributes ---------- logger : logging.Logger Logger for the inference engine execution_providers : list[str] List of available execution providers for onnxruntime session for the inference engine model_config : Dict Model configuration information loaded from provided path dataset_config : Dict Dataset configuration information loaded from provided path input_name : str Name of the input layer in the ONNX/TRT engine input_shape : tuple Shape of the input tensor to the ONNX/TRT engine input_dtype : Data types of the input tensor to the ONNX/TRT engine inferenceSession : onnxruntime.InferenceSession onnxruntime InferenceSession object from_mlflow : bool Whether the CAREInferenceSession is loaded from an FLFlow run. mlflow_tracking_uri : str The MLFlow tracking URI mapping to the MLFlow tracking server used to initialize the CAREInferenceSession object. mlflow_run_id : str The MLFlow run ID corresponding to the MLFlow run used to initialized the CAREInferenceSession object. mlflow_run_name : str The name of the MLFlow run used to initialize the CAREInferenceSession object. """
[docs] def __init__( self, model_path: str, #onnx only, for now model_config_path: str, dataset_config_path: str, cpu_ok: bool=False, ) -> None: self.logger = logging.getLogger("ENGINE") self.execution_providers = None self._check_execution_providers(cpu_ok) self.model_config = self._load_json(model_config_path) self.input_name, self.input_shape, self.input_dtype = None, None, None self.inferenceSession = self._load_model(model_path) self.dataset_config = self._load_json(dataset_config_path) self.from_mlflow, self.mlflow_tracking_uri = False, None self.mlflow_run_id, self.mlflow_run_name = None, None try: self.input_min = _float_or_float_array(self.dataset_config['FLAME_Dataset']['input']['pixel_1pct']) self.input_max = _float_or_float_array(self.dataset_config['FLAME_Dataset']['input']['pixel_99pct']) self.logger.info(f"Found [{self.input_min}, {self.input_max}] for input normalization") self.output_min = _float_or_float_array(self.dataset_config['FLAME_Dataset']['output']['pixel_1pct']) self.output_max = _float_or_float_array(self.dataset_config['FLAME_Dataset']['output']['pixel_99pct']) self.logger.info(f"Found [{self.output_min}, {self.output_max}] for output normalization") except Exception as e: self.logger.error(f"Could not load normalization data from Dataset Config.\n{e.__class__.__name__}: {e}") raise CAREInferenceError(f"Could not load normalization data from Dataset Config.\n{e.__class__.__name__}: {e}") if not cpu_ok: assert 'CUDAExecutionProvider' in ort.get_available_providers()
@classmethod def from_mlflow_uri( cls, tracking_uri: str, run_id: str, model_artifact_path: str="model", config_artifact_path: str="model_config", model_name: str="model.onnx", json_name: str="model_config.json", cpu_ok: bool=False ): logger = logging.getLogger("ENGINE") temp_direc = os.path.join(os.getcwd(), "temp") os.makedirs(temp_direc, exist_ok=True) logger.info(f"Loading CAREInferenceSession from MLFlow tracking URI {tracking_uri} and run id {run_id}") mlflow.set_tracking_uri(tracking_uri) try: artifacts.download_artifacts( tracking_uri=tracking_uri, run_id=run_id, artifact_path=model_artifact_path, dst_path=temp_direc ) except Exception as e: logger.exception(f"Could not load '{model_artifact_path}' path from mlflow run of id {run_id}.\n{e.__class__.__name__}: {e}") raise CAREInferenceError(f"Could not load '{model_artifact_path}' path from mlflow run of id {run_id}.\n{e.__class__.__name__}: {e}") try: artifacts.download_artifacts( tracking_uri=tracking_uri, run_id=run_id, artifact_path=config_artifact_path, dst_path=temp_direc ) except Exception as e: logger.exception(f"Could not load '{config_artifact_path}' path from mlflow run of id {run_id}.\n{e.__class__.__name__}: {e}") raise CAREInferenceError(f"Could not load '{config_artifact_path}' path from mlflow run of id {run_id}.\n{e.__class__.__name__}: {e}") try: model_path = glob.glob(os.path.join(temp_direc, "**", model_name), recursive=True) json_path = glob.glob(os.path.join(temp_direc, "**", json_name), recursive=True) assert len(model_path) == 1 and len(json_path) == 1, f"Expected to found 1 model and one json, not {len(model_path)} and {len(json_path)}, respectively." model_path, json_path = model_path[0], json_path[0] obj=cls( model_path=model_path, model_config_path=json_path, dataset_config_path=json_path, cpu_ok=cpu_ok ) rmtree(temp_direc) setattr(obj, "from_mlflow", True) setattr(obj, "mlflow_tracking_uri", tracking_uri) setattr(obj, "mlflow_run_id", run_id) setattr(obj, "mlflow_run_name", mlflow.get_run(run_id).info.run_name) return obj except Exception as e: logger.exception(f"Could not initialize CAREInferenceSession object.\n{e.__class__.__name__}: {e}") raise CAREInferenceError(f"Could not initialize CAREInferenceSession object.\n{e.__class__.__name__}: {e}") def _check_execution_providers(self, cpu_ok: bool) -> None: """ Will check for available execution providers. Returns None, but sets self.execution_provider. """ try: providers = ort.get_available_providers() if not cpu_ok: assert "CUDAExecutionProvider" in providers, f"CUDA not available, and 'cpu_ok' is False." self.execution_providers = providers except Exception as e: self.logger.error(f"Could not validate available execution providers with 'cpu_ok' set to {cpu_ok}") raise CAREInferenceError(f"Could not validate available execution providers with 'cpu_ok' set to {cpu_ok}") def _load_json(self, json_path: str) -> dict: try: assert os.path.isfile(json_path), f"Provided path {json_path} is not a file." json_dict = json.load(open(json_path, 'r')) except Exception as e: self.logger.error(f"Could not load json from {json_path}.\n{e.__class__.__name__}: {e}") raise CAREInferenceError(f"Could not load json from {json_path}.\n{e.__class__.__name__}: {e}") return json_dict def _load_model(self, model_path: str) -> InferenceSession: """ Args: - model_path: absolute path to a ".onnx" file """ try: assert os.path.isfile(model_path), f"Provided path {model_path} is not a file" onnx_model = onnx.load(model_path) onnx.checker.check_model(onnx_model) del onnx_model ort_session = ort.InferenceSession( model_path, providers=self.execution_providers ) input_tensor = ort_session.get_inputs()[0] self.input_name, self.input_shape, self.input_dtype = input_tensor.name, input_tensor.shape, input_tensor.type self.logger.info(f"Model input: Name-{self.input_name} | Shape-{self.input_shape} | DType-{self.input_dtype}") except Exception as e: self.logger.error(f"Could not initialize Model Inference Session.\n{e.__class__.__name__}: {e}") raise CAREInferenceError(f"Could not initialize Model Inference Session.\n{e.__class__.__name__}: {e}") return ort_session def _validate_FLAME_images(self, inference_images: list) -> list: new_list = [] for image in inference_images: if is_FLAME_image(image): new_list.append(image) self.logger.info(f"Of {len(inference_images)} images provided, {len(new_list)} are valid for inference.") return new_list
[docs] def predict(self, arr: NDArray[Any]) -> NDArray: """ Assumes array input of shape NYXC. Will break Y and X dimension into patches necessary for inference by the ONNX model in this inference session. Args: - arr: numpy ndarray of shape NYXC Returns: Denoised image of shape NYXC """ assert arr.ndim == 4, f"Input array must have 4 dimensions, not {arr.ndim} dimensions of shape {arr.shape}" SINGLE_CHANNEL_INFER = True # In the beginning, assume that each channel will be inferred upon one-by-one. # if the ONNX input shape is not 1, that means the model was trained for a specific number of channels. if self.input_shape[-1] != 1: # type: ignore assert arr.shape[-1] == self.input_shape[-1], f"Array channel dim {arr.shape[-1]} does not match ONNX input channel dim {self.input_shape[-1]} (assumption: NYXC)." # type: ignore SINGLE_CHANNEL_INFER = False self.logger.info(f"Detected multiple channel inference. Inferring on all channels at the same time.") else: self.logger.info(f"Detected single channel inference. Running inference on each channel one-at-a-time.") input_dim = arr.shape input_dtype = arr.dtype if SINGLE_CHANNEL_INFER: initial_YXC = list(arr.shape[1:-1]) + [1] else: initial_YXC = list(arr.shape[1:]) """NORMALIZATION OPERATIONS""" # TODO: Decide what to do if the dataset statistics do not match the number of channels in the input image. arr = np.clip(arr, np.array(self.input_min), np.array(self.input_max)) arr = min_max_norm(arr, np.array(self.input_min), np.array(self.input_max)) """INFERENCE""" def run_on_patches(patches): try: t1 = time() output = self.inferenceSession.run(None, {'patch': patches})[0] t2 = time() self.logger.info(f"Inference on patches of shape {patches.shape} and dtype {patches.dtype} took {t2 - t1:.2f}s") return output except Exception as e: self.logger.error(f"Inference session failed to predict on patches of dim {patches.shape} and dtype {patches.dtype}") raise RuntimeError(f"Inference session failed to predict on patches of dim {patches.shape} and dtype {patches.dtype}") try: full_output = None for n in arr: # first looping through all images stacked in the first dimension (N, Y, X, C) channel_output = None if SINGLE_CHANNEL_INFER: for cdx in range(arr.shape[-1]): # now loopoing through the channel dimensions to predict 1 by 1. patches = self._get_patches(n[...,[cdx]]) # keep channel dimension while indexing output = run_on_patches(patches) output = self._stitch_patches(output, initial_YXC) if channel_output is None: channel_output = output else: channel_output = np.concat([channel_output, output], axis=-1) else: patches = self._get_patches(n) channel_output = run_on_patches(patches) channel_output = self._stitch_patches(channel_output, initial_YXC) if full_output is None: full_output = channel_output[np.newaxis,...] else: full_output = np.concat([full_output, channel_output[np.newaxis,...]], axis=0) except Exception as e: self.logger.error(f"Could not infer on array of shape {input_dim} and dtype {input_dtype}.\n{e.__class__.__name__}: {e}") raise CAREInferenceError(f"Could not infer on array of shape {input_dim} and dtype {input_dtype}.\n{e.__class__.__name__}: {e}") """RENORMALIZATION TO OUTPUT PIXEL DISTRIBUTION""" full_output = min_max_norm(full_output, mini=0, maxi=1) full_output = (full_output * (self.output_max - self.output_min)) + self.output_min """END""" return full_output
[docs] def predict_FLAME( self, image: FLAMEImage, input_frames: Union[int, None]=None, # input_min_override, # input_max_override, # output_min_override, # output_max_override ) -> NDArray: """ Takes FLAMEImage Object and infers on it using the ONNX engine. Will attempt to dynamically detect FLAMEImage dimensions (ZFCYX, CYX, etc...) and return corresponding denoised image. Args: - image (FLAMEImage): The FLAMEImage object to be denoised - input_frames (int): The number of frames to input into the denoising model. If none are provided, then all available frames will be used. Returns: Numpy NDArray with denoised FLAMEImage data. Will match dimensions of input FLAMEImage. """ """ENSURING: Frame and Channel dims exist, getting indicated frames.""" try: frames_idx = image.axes_shape.index("F") frames = image.get_frames((0, input_frames) if input_frames is not None else None) frame_dim_created = False except ValueError as e: # indicates frame dimension was not found. # if frame dimension was not found, ensure user did not asked for either 1 frame or all frames assert input_frames is None or input_frames == 1, f"User asked for 1 frame or all frames for inference, but no frame dimension was found in {image}" # create a frames dimension at the beginning frames = image.raw()[np.newaxis,...] image.axes_shape = "F" + image.axes_shape frame_dim_created = True # detect where channel dimension is try: channel_idx = image.axes_shape.index("C") channel_dim_created = False except ValueError as e: self.logger.info(f"Could not find channel dimension, so creating one...") frames = frames[...,np.newaxis] image.axes_shape = image.axes_shape + "C" channel_dim_created = True channel_idx = len(image.axes_shape-1) """RESHAPING OPERATIONS""" # transpose channel dimension to the end if channel_idx != len(image.axes_shape) - 1: # if channel_idx is already in the last position, no need to transpose transpose_shape = [] for idx in range(frames.ndim): if idx == channel_idx: continue transpose_shape += [idx] transpose_shape += [channel_idx] frames = np.transpose(frames, tuple(transpose_shape)) # Get the current shape. Will be ...,Y,X,C original_shape = frames.shape # Get the new shape. Will be Z*F*N,Y,X,C new_shape = tuple(np.cumprod(frames.shape[:-3])) + (frames.shape[-3:]) # if frames.ndim > 3 else frames.shape frames = np.reshape(frames, shape=new_shape) """INFERENCE""" output_image = self.predict(frames) """REVERSAL OF RESHAPING OPERATIONS""" # recreate the shape of the original frame object. output_image = np.reshape(output_image, shape = original_shape) # move channel dimension to original position new_shape = [] for adx in range(output_image.ndim - 1): if adx == channel_idx: new_shape.append(output_image.ndim - 1) new_shape.append(adx) output_image = np.transpose(output_image, axes=tuple(new_shape)) """REMOVE ADDED DIMENSIONS""" if frame_dim_created: image.axes_shape = image.axes_shape[1:] output_image = output_image[0,...] if channel_dim_created: image.axes_shape = image.axes_shape[:-1] output_image = output_image[...,0] """END""" return output_image
[docs] def inference_generator(self, inference_images: list[Union[FLAMEImage, NDArray]], FLAMEImage_input_frames: Union[int, None]=None): """ Will yield inferred-upon images one-by-one. Assumes 1-99 pcttile normalization. """ self.logger.info(f"Inference using 1-99 percentile normalization") length = len(inference_images) res = None for idx, image in enumerate(inference_images): if is_FLAME_image(image=image): try: self.logger.info(f"({idx}/{length}) - Inferring on FLAMEImage {image}...") res = self.predict_FLAME( image=image, # type: ignore input_frames=FLAMEImage_input_frames ) except Exception as e: self.logger.error(f"FLAMEImage detected, but inference failed.\n{e.__class__.__name__}: {e}") raise CAREInferenceError(f"FLAMEImage detected, but inference failed.\n{e.__class__.__name__}: {e}") else: try: self.logger.info(f"({idx}/{length}) - Inferring on array (shape: {image.shape} | dtype: {image.dtype})") # type: ignore res = self.predict( arr=image # type: ignore ) except Exception as e: self.logger.error(f"NDArray detected, but inference failed.\n{e.__class__.__name__}: {e}") raise CAREInferenceError(f"NDArray detected, but inference failed.\n{e.__class__.__name__}: {e}") if res is None: raise return res
def _get_patch_overlap(self, patch_dim: int) -> int: if patch_dim < 16: return PATCH_OVERLAP_MAP[16] if patch_dim > 512: return PATCH_OVERLAP_MAP[512] try: return PATCH_OVERLAP_MAP[patch_dim] except KeyError as e: key_list = list(PATCH_OVERLAP_MAP) for kdx in range(len(PATCH_OVERLAP_MAP)): if key_list[kdx] <= patch_dim: return key_list[kdx] else: raise def _get_patches(self, arr: NDArray[Any]) -> NDArray: """ Description: _get_patches will break down an input image (as an NDArray) into patches that can be inferred upon. Patch dimensions will be that of the listed self.model_config -> 'patch_size' -> "Patch_Config". Args: - arr: A numpy NDArray. Should have dimensions YXC Returns: - A numpy NDArray array of dimensions (N, patch_size, patch_size, C). """ try: patch_dim = self.model_config["Patch_Config"]['patch_size'] except Exception as e: self.logger.error(f"Could not retrieve patch dimensions from self.model_config.\n{e.__class__.__name__}: {e}") raise CAREInferenceError(f"Could not retrieve patch dimensions from self.model_config.\n{e.__class__.__name__}: {e}") try: po = self._get_patch_overlap(patch_dim=patch_dim) except Exception as e: self.logger.error(f"Could not retrieve patch overlap.\n{e.__class__.__name__}: {e}") raise CAREInferenceError(f"Could not retrieve patch overlap.\n{e.__class__.__name__}: {e}") try: assert len(arr.shape) == 3, f"Input dimensions must be of size 3 (YXC), not {len(arr.shape)}." input_y, input_x, input_c = arr.shape except Exception as e: self.logger.error(f"Cannot interpret input dimensions for patch extraction.\n{e.__class__.__name__}: {e}") raise CAREInferenceError(f"Cannot interpret input dimensions for patch extraction.\n{e.__class__.__name__}: {e}") # if input array is the size of the input patch, just return it with new batch dimension (a.k.a. N) if patch_dim == input_y and patch_dim == input_x: return arr[np.newaxis, ...] # NOTE: This will still break if ONE of input arary dimensions matches the patch dimension, but the other doesn't. # See issue #8 https://github.com/AlexSath/BaluLab-CARE/issues/8 output = None start_x = 0 start_y = 0 while start_y + patch_dim < input_y: while start_x + patch_dim < input_x: if start_y == 0 and start_x == 0: # top left corner this_patch = arr[start_y:start_y+patch_dim, start_x:start_x+patch_dim, :] assert this_patch.shape == (128, 128, 1), f"{this_patch.shape}" elif start_y == 0 and start_x != 0: # top of image this_patch = arr[start_y:start_y+patch_dim, start_x-po//2:start_x-po//2+patch_dim, :] assert this_patch.shape == (128, 128, 1), f"{this_patch.shape}" elif start_y != 0 and start_x == 0: # left side of image (leftmost column, any y) this_patch = arr[start_y-po//2:start_y+patch_dim-po//2, start_x:start_x+patch_dim, :] assert this_patch.shape == (128, 128, 1), f"{this_patch.shape}" else: # center of image this_patch = arr[start_y-po//2:start_y+patch_dim-po//2, start_x-po//2:start_x+patch_dim-po//2, :] assert this_patch.shape == (128, 128, 1), f"{this_patch.shape}" if output is None: output = [this_patch] else: output += [this_patch] start_x += patch_dim - po # if x's don't go evenly into input dimension, then run code to ensure right of image is denoised.. if input_x % patch_dim != 0: # rightmost column if start_y == 0: # top right corner this_patch = arr[start_y:start_y+patch_dim, -patch_dim:, :] assert this_patch.shape == (128, 128, 1), f"{this_patch.shape}" else: # right side this_patch = arr[start_y-po//2:start_y+patch_dim-po//2, -patch_dim:, :] assert this_patch.shape == (128, 128, 1), f"{this_patch.shape}" output += [this_patch] # type: ignore start_x = 0 start_y += patch_dim - po # If y's don't go evenly into input dimension, then run code to ensure bottom of image is denoised. if input_y % patch_dim != 0: while start_x + patch_dim < input_x: if start_x == 0: # bottom left corner this_patch = arr[-patch_dim:, start_x:start_x+patch_dim, :] assert this_patch.shape == (128, 128, 1), f"{this_patch.shape}" else: # bottom side this_patch = arr[-patch_dim:, start_x-po//2:start_x+patch_dim-po//2, :] assert this_patch.shape == (128, 128, 1), f"{this_patch.shape}" output += [this_patch] # type: ignore start_x += patch_dim - po if input_x % patch_dim != 0: # bottom right corner this_patch = arr[-patch_dim:, -patch_dim:, :] assert this_patch.shape == (128, 128, 1), f"{this_patch.shape}" output += [this_patch] # type: ignore try: assert output is not None, f"Output is NoneType. Not patches could be extracted. Check dimensions of image {arr.shape} and patch ({patch_dim}, {patch_dim}, C)" assert len(output) > 0, f"No patches could be extracted from input image of shape {arr.shape} and dtype {arr.dtype}" if len(output) == 1: return output[0][np.newaxis,...] else: return np.stack(output, axis=0) except Exception as e: self.logger.error(f"Could not output extracted patches.\n{e.__class__.__name__}: {e}") raise CAREInferenceError(f"Could not output extracted patches.\n{e.__class__.__name__}: {e}") def _stitch_patches(self, patches: NDArray[Any], final_dim: tuple[int, int, int]) -> NDArray: """ Description: _stitch_patches will take a patch array of shape (N, patch_y, patch_x, C) and stitch it back into a full-size image of shape 'final_dim'. Args: - patches: numpy NDArray of shape (N, patch_y, patch_x, C) - final_dim: final dimensions of the image. Should match axes YXC. C in final_dim should match C in the dimension of 'patches'. """ try: input_y, input_x, input_c = final_dim except Exception as e: self.logger.error(f"Could not inppack 'final_dim' of size {len(final_dim)} into Y,X,C.\n{e.__class__.__name__} {e}") raise CAREInferenceError(f"Could not inppack 'final_dim' of size {len(final_dim)} into Y,X,C.\n{e.__class__.__name__} {e}") # get patch dimension assert patches.shape[-1] == input_c, f"Channels in patch array and final_dim do not match ({patches.shape[-1]} vs. {input_c})" assert patches.shape[1] == patches.shape[2], f"Rectangular patch detected (assuming axes NYXC). Only square patches are supported" patch_dim = patches.shape[1] try: # get patch overlap po = self._get_patch_overlap(patch_dim=patch_dim) except Exception as e: self.logger.error(f"Could not retrieve patch overlap.\n{e.__class__.__name__}: {e}") raise CAREInferenceError(f"Could not retrieve patch overlap.\n{e.__class__.__name__}: {e}") try: patch_overlap = self._get_patch_overlap(patch_dim=patch_dim) except Exception as e: self.logger.error(f"Could not retrieve patch overlap.\n{e.__class__.__name__}: {e}") raise CAREInferenceError(f"Could not retrieve patch overlap.\n{e.__class__.__name__}: {e}") output = np.zeros(shape=final_dim, dtype=patches.dtype) this_y = 0 this_x = 0 patch_index = 0 while this_y + patch_dim < input_y: while this_x + patch_dim < input_x: # output[this_y:this_y+patch_dim, this_x:this_x+patch_dim, :] = patches[patch_index,...] if this_x == 0 and this_y == 0: # top left corner output[this_y:this_y+patch_dim-po//2, this_x:this_x+patch_dim-po//2] = patches[patch_index, :-po//2, :-po//2, :] # crop bottom and right of patch elif this_x == 0 and this_y != 0: # left side output[this_y:this_y+patch_dim-po, this_x:this_x+patch_dim-po//2] = patches[patch_index, po//2:-po//2, :-po//2, :] # crop top, bottom and right of patch elif this_x != 0 and this_y ==0: # top of image output[this_y:this_y+patch_dim-po//2, this_x:this_x+patch_dim-po] = patches[patch_index, :-po//2, po//2:-po//2, :] else: # middle of image in both axes output[this_y:this_y+patch_dim-po, this_x:this_x+patch_dim-po] = patches[patch_index, po//2:-po//2, po//2:-po//2, :] # crop all sides of patch patch_index += 1 this_x += patch_dim - po # if patch size doesn't go evenly into x axis, ensure that right edge of image is still stitched if input_x % patch_dim != 0: # output[this_y:this_y+patch_dim, -patch_dim:, :] = patches[patch_index,...] if this_y == 0: # top right corner output[this_y:this_y+patch_dim-po//2, -patch_dim+po//2:, :] = patches[patch_index,:-po//2,po//2:,:] # crop bottom and left of patch else: # right edge; middle of y axis (right side of image) output[this_y:this_y+patch_dim-po, -patch_dim+po//2:, :] = patches[patch_index,po//2:-po//2,po//2:] # crop top, bottom, and left of patch patch_index += 1 this_x = 0 this_y += patch_dim - po # bottom row (if patch size doesn't go evenly into y axis) if input_y % patch_dim != 0: while this_x + patch_dim < input_x: # output[-patch_dim:, this_x:this_x+patch_dim, :] = patches[patch_index,...] if this_x == 0: # bottom left corner output[-patch_dim+po//2:, this_x:this_x+patch_dim-po//2, :] = patches[patch_index,po//2:,:-po//2,:] # crop top and right of patch else: output[-patch_dim+po//2:, this_x:this_x+patch_dim-po, :] = patches[patch_index,po//2:,po//2:-po//2,:] # crop top, left, and right of patch patch_index += 1 this_x += patch_dim - po # last corner (if patch size doesn't go evenly into x axis) if input_x % patch_dim != 0: # output[-patch_dim:, -patch_dim:, :] = patches[patch_index,...] output[-patch_dim+po//2:, -patch_dim+po//2:, :] = patches[patch_index,po//2:,po//2:,:] # crop top and left of patch patch_index += 1 return output def __repr__(self): str = f"Obj CAREInferenceSession @{hex(id(self))}:\n" \ + f" - Input Dim: {self.input_shape}\n" \ + f" - Input DType: {self.input_dtype}\n" \ + f" - From MLFlow: {self.from_mlflow}\n" if self.from_mlflow: str += f"" \ + f" - MLFlow Tracking URI: {self.mlflow_tracking_uri}\n" \ + f" - MLFlow Run ID: {self.mlflow_run_id}\n" \ + f" - MLFlow Run Name: {self.mlflow_run_name}\n" return str def __str__(self): return repr(self)