Source code for gammapy.detect.kernel

# Licensed under a 3-clause BSD style license - see LICENSE.rst
from __future__ import absolute_import, division, print_function, unicode_literals
from collections import OrderedDict
import logging
import numpy as np
import astropy.units as u
from astropy.convolution import Tophat2DKernel, CustomKernel
from ..image import SkyImage, SkyImageList
from .lima import compute_lima_image

log = logging.getLogger(__name__)

__all__ = [
    'KernelBackgroundEstimator',
]


[docs]class KernelBackgroundEstimator(object): """Estimate background and exclusion mask iteratively. Starting from an initial background estimate and exclusion mask (both provided, optionally) the algorithm works as follows: 1. Compute significance image 2. Create exclusion mask by thresholding significance image 3. Compute improved background estimate based on new exclusion mask The steps are executed repeatedly until the exclusion mask does not change anymore. For flexibility the algorithm takes arbitrary source and background kernels. Parameters ---------- kernel_src : `numpy.ndarray` Source kernel as a numpy array. kernel_bkg : `numpy.ndarray` Background convolution kernel as a numpy array. significance_threshold : float Significance threshold above which regions are excluded. mask_dilation_radius : `~astropy.units.Quantity` Radius by which mask is dilated with each iteration. delete_intermediate_results : bool Specify whether results of intermediate iterations should be deleted. (Otherwise, these are held in memory). Default True. save_intermediate_results : bool Specify whether to save intermediate results as FITS files to disk. Default False. base_dir : str (optional) Base of filenames if save_intermediate_results = True. Default 'temp'. See Also -------- gammapy.background.RingBackgroundEstimator, gammapy.background.AdaptiveRingBackgroundEstimator """ def __init__(self, kernel_src, kernel_bkg, significance_threshold=5, mask_dilation_radius=0.02 * u.deg, delete_intermediate_results=False, save_intermediate_results=False, base_dir='temp'): self.parameters = OrderedDict(significance_threshold=significance_threshold, mask_dilation_radius=mask_dilation_radius, save_intermediate_results=save_intermediate_results, delete_intermediate_results=delete_intermediate_results, base_dir=base_dir) self.kernel_src = kernel_src self.kernel_bkg = kernel_bkg self.images_stack = []
[docs] def run(self, images, niter_min=2, niter_max=10): """Run iterations until mask does not change (stopping condition). Parameters ---------- images : `~gammapy.image.SkyImageList` Input sky images. niter_min : int Minimum number of iterations, to prevent early termination of the algorithm. niter_max : int Maximum number of iterations after which the algorithm is terminated, if the termination condition (no change of mask between iterations) is not already satisfied. Returns ------- images : `~gammapy.image.SkyImageList` List of sky images containing 'background', 'exclusion' mask and 'significance' images. """ images.check_required(['counts']) p = self.parameters # initial mask, if not present if 'exclusion' not in images.names: images['exclusion'] = SkyImage.empty_like(images['counts'], fill=1) # initial background estimate, if not present if 'background' not in images.names: log.info('Estimating initial background.') images['background'] = self._estimate_background(images['counts'], images['exclusion']) images['significance'] = self._estimate_significance(images['counts'], images['background']) self.images_stack.append(images) for idx in range(niter_max): result_previous = self.images_stack.pop() result = self._run_iteration(result_previous) if p['delete_intermediate_results']: self.images_stack = [result] else: self.images_stack += [result_previous, result] if p['save_intermediate_results']: result.write(p['base_dir'] + 'ibe_iteration_{}.fits') if self._is_converged(result, result_previous) and (idx >= niter_min): log.info('Exclusion mask succesfully converged,' ' after {} iterations.'.format(idx)) break return result
def _is_converged(self, result, result_previous): """Check convergence. Criterion: exclusion masks unchanged in subsequent iterations. """ from scipy.ndimage.morphology import binary_fill_holes mask = result['exclusion'].data == result_previous['exclusion'].data # Because of pixel to pixel noise, the masks can still differ. # This is handled by removing structures of the scale of one pixel mask = binary_fill_holes(mask) return np.all(mask) # TODO: make more flexible, e.g. allow using adaptive ring etc. def _estimate_background(self, counts, exclusion): """ Estimate background by convolving the excluded counts image with the background kernel and renormalizing the image. """ wcs = counts.wcs.copy() # recompute background estimate counts_excluded = SkyImage(data=counts.data * exclusion.data, wcs=wcs) data = counts_excluded.convolve(self.kernel_bkg, mode='constant') norm = exclusion.convolve(self.kernel_bkg, mode='constant') return SkyImage(name='background', data=data.data / norm.data, wcs=wcs) # TODO: make more flexible, e.g. allow using TS images tec. def _estimate_significance(self, counts, background): kernel = CustomKernel(self.kernel_src) images_lima = compute_lima_image(counts, background, kernel=kernel) return images_lima['significance'] def _run_iteration(self, images): """Run one iteration. Parameters ---------- images : `gammapy.image.SkyImageList` Input sky images """ from scipy.ndimage import binary_erosion images.check_required(['counts', 'exclusion', 'background']) wcs = images['counts'].wcs.copy() p = self.parameters significance = self._estimate_significance(images['counts'], images['background']) # update exclusion mask radius = p['mask_dilation_radius'].to('deg') scale = images['counts'].wcs_pixel_scale()[0] structure = np.array(Tophat2DKernel((radius / scale).value)) mask = (significance.data < p['significance_threshold']) | np.isnan(significance) mask = binary_erosion(mask, structure, border_value=1) exclusion = SkyImage(name='exclusion', data=mask.astype('float'), wcs=wcs) background = self._estimate_background(images['counts'], exclusion) return SkyImageList([images['counts'], background, exclusion, significance])
[docs] def images_stack_show(self, dpi=120): """Show image stack. Parameters ---------- dpi : int Dots per inch to scale the image. """ import matplotlib.pyplot as plt niter_max = len(self.images_stack) wcs = self.images_stack[0]['background'].wcs height_pix, width_pix = self.images_stack[0]['background'].data.shape width = 2 * (width_pix / dpi + 1.) height = niter_max * (height_pix / dpi + .5) fig = plt.figure(figsize=(width, height)) for idx, images in enumerate(self.images_stack): ax_bkg = fig.add_subplot(niter_max + 1, 2, 2 * idx + 1, projection=wcs) bkg = images['background'] bkg.plot(ax=ax_bkg, vmin=0) ax_bkg.set_title('Background, N_iter = {}'.format(idx), fontsize='small') ax_sig = fig.add_subplot(niter_max + 1, 2, 2 * idx + 2, projection=wcs) sig = images['significance'] sig.plot(ax=ax_sig, vmin=0, vmax=20) ax_sig.set_title('Significance, N_Iter = {}'.format(idx), fontsize='small') mask = images['exclusion'].data ax_sig.contour(mask, levels=[0], linewidths=2, colors='green') if idx < (niter_max - 1): for ax in [ax_sig, ax_bkg]: ax.set_xlabel('') ax.coords['glon'].ticklabels.set_visible(False) ax_bkg.set_ylabel('') ax_sig.set_ylabel('') plt.tight_layout(pad=1.08, h_pad=1.5, w_pad=0.2, rect=[0, 0, 1, 0.98])