TensorFlow 中的高效图像膨胀


我正在寻找一种有效的实施方式形态学图像膨胀 https://en.wikipedia.org/wiki/Dilation_(morphology)在 TensorFlow 中使用方形内核。正如 OpenCV 所示,与实际效果相比,显而易见的方法似乎效率极低。查看粘贴在底部的运行源代码的结果 - 即使最快的方法也比 OpenCV 慢 30 倍左右。这些来自配备 M1 芯片组的 MacBook Air。

Dilation of 640x480 image with a 25x25 kernel took: 
  0.61ms using opencv
  545.40ms using tf.nn.max_pool2d
  228.66ms using tf.nn.dilation2d naively
  17.63ms using tf.nn.dilation2d with row-col

Question:有谁知道一种使用 TensorFlow 进行图像膨胀的方法,而且效率不是极低?


import numpy as np
import cv2
import tensorflow as tf
import time

def tf_dilate(heatmap, width: int, method: str = 'rowcol'):
    """ Dilate the heatmap with a square kernel """
    if method=='maxpool':
        return tf.nn.max_pool2d(heatmap[None, :, :, None], ksize=width, padding='SAME', strides=(1, 1))[0, :, :, 0]
    elif method == 'naive_dilate':
        return tf.nn.dilation2d(heatmap[None, :, :, None], filters=tf.zeros((width, width, 1), dtype=heatmap.dtype),
                                        strides=(1, 1, 1, 1), padding="SAME", data_format="NHWC", dilations=(1, 1, 1, 1))[0, :, :, 0]
    elif method == 'rowcol_dilate':

        row_dilation = tf.nn.dilation2d(heatmap[None, :, :, None], filters=tf.zeros((1, width, 1), dtype=heatmap.dtype),
                                        strides=(1, 1, 1, 1), padding="SAME", data_format="NHWC", dilations=(1, 1, 1, 1))
        full_dilation = tf.nn.dilation2d(row_dilation, filters=tf.zeros((width, 1, 1), dtype=heatmap.dtype),
                                         strides=(1, 1, 1, 1), padding="SAME", data_format="NHWC", dilations=(1, 1, 1, 1))
        return full_dilation[0, :, :, 0]
        raise NotImplementedError(f'No method {method}')

def test_dilation_options(img_shape=(480, 640), kernel_size=25):

    img = np.random.randn(*img_shape).astype(np.float32)**2

    def get_result_and_time(version: str):

        tf_image = tf.constant(img, dtype=tf.float32)
        t_start = time.time()
        if version=='opencv':
            result = cv2.dilate(img, kernel=np.ones((kernel_size, kernel_size), dtype=np.float32))
            return time.time()-t_start, result
            result = tf_dilate(tf_image, width=kernel_size, method=version)
            return time.time()-t_start, result.numpy()

    t_opencv, result_opencv = get_result_and_time('opencv')
    t_maxpool, result_maxpool = get_result_and_time('maxpool')
    t_naive_dilate, result_naive_dilate = get_result_and_time('naive_dilate')
    t_rowcol_dilate, result_rowcol_dilate = get_result_and_time('rowcol_dilate')
    assert np.array_equal(result_opencv, result_maxpool), "Maxpool result did not match opencv result"
    assert np.array_equal(result_opencv, result_naive_dilate), "Naive dilation result did not match opencv result"
    assert np.array_equal(result_opencv, result_rowcol_dilate), "Row-col dilation result did not match opencv result"
    print(f'Dilation of {img_shape[1]}x{img_shape[0]} image with a {kernel_size}x{kernel_size} kernel took: '
          f'\n  {t_opencv*1000:.2f}ms using opencv'
          f'\n  {t_maxpool*1000:.2f}ms using tf.nn.max_pool2d'
          f'\n  {t_naive_dilate*1000:.2f}ms using tf.nn.dilation2d naively'
          f'\n  {t_rowcol_dilate*1000:.2f}ms using tf.nn.dilation2d with row-col'

if __name__ == '__main__':

好吧,如果你没问题的话近似解决方案中,总是存在“穷人的扩张”,它使用加权局部平均值(盒式滤波器)来近似扩张,其中通过对图像求幂来获取权重。它是O((H+K)*(W+K)) where W,H是图像的宽度、高度和K是内核大小。



TensorImage = NewType('TensorImage', tf.Tensor)  # A (height, width, n_colors) uint8 image
TensorFloatImage = NewType('TensorFloatImage', tf.Tensor)
TensorHeatmap = NewType('TensorHeatmap', tf.Tensor)  # A (height, width) heatmap

def tf_box_filter(image: Union[TensorImage, TensorFloatImage, TensorHeatmap], width: int, normalize: bool = True, weights: Optional[TensorHeatmap] = None,
                  weight_eps: float = 1e-6, norm_weights: bool = True):
    image = tf.cast(image, tf.float32) if image.dtype != tf.float64 else image
    if weights is not None:
        if norm_weights:
            weights = weights/(width**2)
        if len(image.shape) == 3:
            weights = weights[:, :, None]  # Lets us broadcast weights against image

        image = image * weights

    lwidth = width // 2 + 1
    rwidth = width - lwidth

    integral_image_container = tf.pad(image,
                                      paddings=[(lwidth, rwidth), (lwidth, rwidth)] + [(0, 0)] * (len(image.shape) - 2))
    integral_image_container = tf.cumsum(tf.cumsum(integral_image_container, axis=0), axis=1)
    box_image = integral_image_container[width:, width:] \
                - integral_image_container[width:, :-width] \
                - integral_image_container[:-width, width:] \
                + integral_image_container[:-width, :-width]

    if not normalize:
        return box_image if (weights is None or not norm_weights) else box_image*(width**2)
    elif weights is None:
        return box_image / (width ** 2)
        box_weights = tf_box_filter(weights, width=width, normalize=False)
        return (box_image + weight_eps) / (box_weights + weight_eps)

def tf_poor_mans_dilate(heatmap: TensorHeatmap, width: int, power: int = 4, cast_to_64 = False) -> TensorHeatmap:
    """ A 'poor man's' version of dilation, whise runtime is O((image_height+kernel_width), (image_width+kernel_width))"""
    if cast_to_64:
        heatmap = tf.cast(heatmap, tf.float64)
    return tf_box_filter(heatmap, width, weights=heatmap**power, weight_eps=1e-9)

测试表明它比问题中的解决方案快大约 3 倍(当内核很大时速度更快)。

def test_poor_mans_dilate(show=False):
    """ Can be faster for large images and kernels

    Dilating image of shape (1280, 720) with kernel of shape 40x40
        Real Dilate: Elapsed time is 0.09009s
        Poor Man's Dilate: Elapsed time is 0.02953s

    Dilating image of shape (640, 480) with kernel of shape 40x40
        Real Dilate: Elapsed time is 0.03089s
        Poor Man's Dilate: Elapsed time is 0.008736s

    Dilating image of shape (640, 480) with kernel of shape 20x20
        Real Dilate: Elapsed time is 0.01475s
        Poor Man's Dilate: Elapsed time is 0.009809s
    img = tf.random.Generator.from_seed(1234).normal((640, 480))**4
    width = 20
    print(f'Dilating image of shape {img.shape} with kernel of shape {width}x{width}')
    with profile_context('Real Dilate', print_result=True):
        dil_img = tf_dilate(img, width=width)
    with profile_context("Poor Man's Dilate", print_result=True):
        poor_dil_img = tf_poor_mans_dilate(img, width=width)

    assert np.allclose(dil_img.numpy().max(), poor_dil_img.numpy().max(), rtol=0.001)

