TensorFlow 中的高效图像膨胀

2024-01-03

我正在寻找一种有效的实施方式形态学图像膨胀 https://en.wikipedia.org/wiki/Dilation_(morphology)在 TensorFlow 中使用方形内核。正如 OpenCV 所示,与实际效果相比,显而易见的方法似乎效率极低。查看粘贴在底部的运行源代码的结果 - 即使最快的方法也比 OpenCV 慢 30 倍左右。这些来自配备 M1 芯片组的 MacBook Air。

Dilation of 640x480 image with a 25x25 kernel took: 
  0.61ms using opencv
  545.40ms using tf.nn.max_pool2d
  228.66ms using tf.nn.dilation2d naively
  17.63ms using tf.nn.dilation2d with row-col

Question:有谁知道一种使用 TensorFlow 进行图像膨胀的方法,而且效率不是极低?

当前解决方案的源代码:

import numpy as np
import cv2
import tensorflow as tf
import time


def tf_dilate(heatmap, width: int, method: str = 'rowcol'):
    """ Dilate the heatmap with a square kernel """
    if method=='maxpool':
        return tf.nn.max_pool2d(heatmap[None, :, :, None], ksize=width, padding='SAME', strides=(1, 1))[0, :, :, 0]
    elif method == 'naive_dilate':
        return tf.nn.dilation2d(heatmap[None, :, :, None], filters=tf.zeros((width, width, 1), dtype=heatmap.dtype),
                                        strides=(1, 1, 1, 1), padding="SAME", data_format="NHWC", dilations=(1, 1, 1, 1))[0, :, :, 0]
    elif method == 'rowcol_dilate':

        row_dilation = tf.nn.dilation2d(heatmap[None, :, :, None], filters=tf.zeros((1, width, 1), dtype=heatmap.dtype),
                                        strides=(1, 1, 1, 1), padding="SAME", data_format="NHWC", dilations=(1, 1, 1, 1))
        full_dilation = tf.nn.dilation2d(row_dilation, filters=tf.zeros((width, 1, 1), dtype=heatmap.dtype),
                                         strides=(1, 1, 1, 1), padding="SAME", data_format="NHWC", dilations=(1, 1, 1, 1))
        return full_dilation[0, :, :, 0]
    else:
        raise NotImplementedError(f'No method {method}')


def test_dilation_options(img_shape=(480, 640), kernel_size=25):

    img = np.random.randn(*img_shape).astype(np.float32)**2

    def get_result_and_time(version: str):

        tf_image = tf.constant(img, dtype=tf.float32)
        t_start = time.time()
        if version=='opencv':
            result = cv2.dilate(img, kernel=np.ones((kernel_size, kernel_size), dtype=np.float32))
            return time.time()-t_start, result
        else:
            result = tf_dilate(tf_image, width=kernel_size, method=version)
            return time.time()-t_start, result.numpy()

    t_opencv, result_opencv = get_result_and_time('opencv')
    t_maxpool, result_maxpool = get_result_and_time('maxpool')
    t_naive_dilate, result_naive_dilate = get_result_and_time('naive_dilate')
    t_rowcol_dilate, result_rowcol_dilate = get_result_and_time('rowcol_dilate')
    assert np.array_equal(result_opencv, result_maxpool), "Maxpool result did not match opencv result"
    assert np.array_equal(result_opencv, result_naive_dilate), "Naive dilation result did not match opencv result"
    assert np.array_equal(result_opencv, result_rowcol_dilate), "Row-col dilation result did not match opencv result"
    print(f'Dilation of {img_shape[1]}x{img_shape[0]} image with a {kernel_size}x{kernel_size} kernel took: '
          f'\n  {t_opencv*1000:.2f}ms using opencv'
          f'\n  {t_maxpool*1000:.2f}ms using tf.nn.max_pool2d'
          f'\n  {t_naive_dilate*1000:.2f}ms using tf.nn.dilation2d naively'
          f'\n  {t_rowcol_dilate*1000:.2f}ms using tf.nn.dilation2d with row-col'
          )


if __name__ == '__main__':
    test_dilation_options()

好吧,如果你没问题的话近似解决方案中,总是存在“穷人的扩张”,它使用加权局部平均值(盒式滤波器)来近似扩张,其中通过对图像求幂来获取权重。它是O((H+K)*(W+K)) where W,H是图像的宽度、高度和K是内核大小。

它还具有以下优点:梯度不仅流过局部最大值,还流过竞争者直至抛出。

参见代码:

TensorImage = NewType('TensorImage', tf.Tensor)  # A (height, width, n_colors) uint8 image
TensorFloatImage = NewType('TensorFloatImage', tf.Tensor)
TensorHeatmap = NewType('TensorHeatmap', tf.Tensor)  # A (height, width) heatmap

def tf_box_filter(image: Union[TensorImage, TensorFloatImage, TensorHeatmap], width: int, normalize: bool = True, weights: Optional[TensorHeatmap] = None,
                  weight_eps: float = 1e-6, norm_weights: bool = True):
    image = tf.cast(image, tf.float32) if image.dtype != tf.float64 else image
    if weights is not None:
        if norm_weights:
            weights = weights/(width**2)
        if len(image.shape) == 3:
            weights = weights[:, :, None]  # Lets us broadcast weights against image

        image = image * weights

    lwidth = width // 2 + 1
    rwidth = width - lwidth

    integral_image_container = tf.pad(image,
                                      paddings=[(lwidth, rwidth), (lwidth, rwidth)] + [(0, 0)] * (len(image.shape) - 2))
    integral_image_container = tf.cumsum(tf.cumsum(integral_image_container, axis=0), axis=1)
    box_image = integral_image_container[width:, width:] \
                - integral_image_container[width:, :-width] \
                - integral_image_container[:-width, width:] \
                + integral_image_container[:-width, :-width]

    if not normalize:
        return box_image if (weights is None or not norm_weights) else box_image*(width**2)
    elif weights is None:
        return box_image / (width ** 2)
    else:
        box_weights = tf_box_filter(weights, width=width, normalize=False)
        return (box_image + weight_eps) / (box_weights + weight_eps)


def tf_poor_mans_dilate(heatmap: TensorHeatmap, width: int, power: int = 4, cast_to_64 = False) -> TensorHeatmap:
    """ A 'poor man's' version of dilation, whise runtime is O((image_height+kernel_width), (image_width+kernel_width))"""
    if cast_to_64:
        heatmap = tf.cast(heatmap, tf.float64)
    return tf_box_filter(heatmap, width, weights=heatmap**power, weight_eps=1e-9)


测试表明它比问题中的解决方案快大约 3 倍(当内核很大时速度更快)。


def test_poor_mans_dilate(show=False):
    """ Can be faster for large images and kernels

    Dilating image of shape (1280, 720) with kernel of shape 40x40
        Real Dilate: Elapsed time is 0.09009s
        Poor Man's Dilate: Elapsed time is 0.02953s

    Dilating image of shape (640, 480) with kernel of shape 40x40
        Real Dilate: Elapsed time is 0.03089s
        Poor Man's Dilate: Elapsed time is 0.008736s

    Dilating image of shape (640, 480) with kernel of shape 20x20
        Real Dilate: Elapsed time is 0.01475s
        Poor Man's Dilate: Elapsed time is 0.009809s
    """
    img = tf.random.Generator.from_seed(1234).normal((640, 480))**4
    width = 20
    print(f'Dilating image of shape {img.shape} with kernel of shape {width}x{width}')
    with profile_context('Real Dilate', print_result=True):
        dil_img = tf_dilate(img, width=width)
    with profile_context("Poor Man's Dilate", print_result=True):
        poor_dil_img = tf_poor_mans_dilate(img, width=width)

    assert np.allclose(dil_img.numpy().max(), poor_dil_img.numpy().max(), rtol=0.001)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

TensorFlow 中的高效图像膨胀 的相关文章

  • Python - StatsModels、OLS 置信区间

    在 Statsmodels 中 我可以使用以下方法拟合我的模型 import statsmodels api as sm X np array 22000 13400 47600 7400 12000 32000 28000 31000 6
  • 计数物体和更好的填充孔的方法

    我是 OpenCV 新手 正在尝试计算物体的数量在图像中 我在使用 MATLAB 图像处理工具箱之前已经完成了此操作 并在 OpenCV Android 中也采用了相同的方法 第一步是将图像转换为灰度 然后对其进行阈值计算 然后计算斑点的数
  • 如何使用 Ansible playbook 中的 service_facts 模块检查服务是否存在且未安装在服务器中?

    我用过service facts检查服务是否正在运行并启用 在某些服务器中 未安装特定的软件包 现在 我如何知道这个特定的软件包没有安装在该特定的服务器上service facts module 在 Ansible 剧本中 它显示以下错误
  • 测试 python Counter 是否包含在另一个 Counter 中

    如何测试是否是pythonCounter https docs python org 2 library collections html collections Counter is 包含在另一个中使用以下定义 柜台a包含在计数器中b当且
  • 使用 Tkinter 显示 numpy 数组中的图像

    我对 Python 缺乏经验 第一次使用 Tkinter 制作一个 UI 显示我的数字分类程序与 mnist 数据集的结果 当图像来自 numpy 数组而不是我的 PC 上的文件路径时 我有一个关于在 Tkinter 中显示图像的问题 我为
  • OpenCV 无法从 MacBook Pro iSight 捕获

    几天后 我无法再从 opencv 应用程序内部打开我的 iSight 相机 cap cv2 VideoCapture 0 返回 并且cap isOpened 回报true 然而 cap grab 刚刚返回false 有任何想法吗 示例代码
  • 从 Flask 访问 Heroku 变量

    我已经使用以下命令在 Heroku 配置中设置了数据库变量 heroku config add server xxx xxx xxx xxx heroku config add user userName heroku config add
  • BeautifulSoup 中的嵌套标签 - Python

    我在网站和 stackoverflow 上查看了许多示例 但找不到解决我的问题的通用解决方案 我正在处理一个非常混乱的网站 我想抓取一些数据 标记看起来像这样 table tbody tr tr tr td td td table tr t
  • Python 的“zip”内置函数的 Ruby 等价物是什么?

    Ruby 是否有与 Python 内置函数等效的东西zip功能 如果不是 做同样事情的简洁方法是什么 一些背景信息 当我试图找到一种干净的方法来进行涉及两个数组的检查时 出现了这个问题 如果我有zip 我可以写这样的东西 zip a b a
  • 如何使用Python创建历史时间线

    So I ve seen a few answers on here that helped a bit but my dataset is larger than the ones that have been answered prev
  • Pygame:有没有简单的方法可以找到按下的任何字母数字的字母/数字?

    我目前正在开发的游戏需要让人们以自己的名义在高分板上计时 我对如何处理按键有点熟悉 但我只处理过寻找特定的按键 有没有一种简单的方法可以按下任意键的字母 而不必执行以下操作 for event in pygame event get if
  • 无法在 Python 3 中导入 cProfile

    我试图将 cProfile 模块导入 Python 3 3 0 但出现以下错误 Traceback most recent call last File
  • 使用 \r 并打印一些文本后如何清除控制台中的一行?

    对于我当前的项目 有一些代码很慢并且我无法使其更快 为了获得一些关于已完成 必须完成多少的反馈 我创建了一个进度片段 您可以在下面看到 当你看到最后一行时 sys stdout write r100 80 n I use 80覆盖最终剩余的
  • 如何在Python中对类别进行加权随机抽样

    给定一个元组列表 其中每个元组都包含一个概率和一个项目 我想根据其概率对项目进行采样 例如 给出列表 3 a 4 b 3 c 我想在 40 的时间内对 b 进行采样 在 python 中执行此操作的规范方法是什么 我查看了 random 模
  • 每个 X 具有多个 Y 值的 Python 散点图

    我正在尝试使用 Python 创建一个散点图 其中包含两个 X 类别 cat1 cat2 每个类别都有多个 Y 值 如果每个 X 值的 Y 值的数量相同 我可以使用以下代码使其工作 import numpy as np import mat
  • 为字典中的一个键附加多个值[重复]

    这个问题在这里已经有答案了 我是 python 新手 我有每年的年份和值列表 我想要做的是检查字典中是否已存在该年份 如果存在 则将该值附加到特定键的值列表中 例如 我有一个年份列表 并且每年都有一个值 2010 2 2009 4 1989
  • 如何计算 pandas 数据帧上的连续有序值

    我试图从给定的数据帧中获取连续 0 值的最大计数 其中包含来自 pandas 数据帧的 id date value 列 如下所示 id date value 354 2019 03 01 0 354 2019 03 02 0 354 201
  • Scrapy:如何使用元在方法之间传递项目

    我是 scrapy 和 python 的新手 我试图将 parse quotes 中的项目 item author 传递给下一个解析方法 parse bio 我尝试了 request meta 和 response meta 方法 如 sc
  • Rocket UniData/UniVerse:ODBC 无法分配足够的内存

    每当我尝试使用pyodbc连接到 Rocket UniData UniVerse 数据时我不断遇到错误 pyodbc Error 00000 00000 Rocket U2 U2ODBC 0302810 Unable to allocate
  • 从列表指向字典变量

    假设你有一个清单 a 3 4 1 我想用这些信息来指向字典 b 3 4 1 现在 我需要的是一个常规 看到该值后 在 b 的位置内读写一个值 我不喜欢复制变量 我想直接改变变量b的内容 假设b是一个嵌套字典 你可以这样做 reduce di

随机推荐

  • 后台进程的 cy.exec 超时

    我正在尝试使用启动服务器cy exec并像这样后台处理 cy exec nohup python m my module arg 1 failOnNonZeroExit false then result gt if result code
  • 如何防止密码和其他敏感信息出现在 ASP.NET 转储中?

    如何防止在 IIS ASP NET 转储文件中向 ASP NET 网页提交和接收密码和其他敏感数据 重现步骤 使用 Visual Studio 2010 创建 ASP NET MVC 3 Intranet 应用程序 将其配置为使用 IIS
  • Spring嵌套事务

    在我的 Spring Boot 项目中 我实现了以下服务方法 Transactional public boolean validateBoard Board board boolean result false if inProgress
  • 更新更改 svn 时出错

    我安装了 PHPStorm 并使用 SVN 打开包含 PHP 项目的目录 在 更改 的 SVN 选项卡下 我遇到以下错误 Error updating changes svn E155021 The client is too old to
  • Spring JPA Repository - 在服务器重启时保留数据

    我目前正在尝试学习如何使用 Spring Boot 但遇到一个问题 我不确定如何解决 我已经按照使用 JPA 访问数据 http spring io guides gs accessing data jpa 指导 一切正常 但是 如果我重新
  • Pandas 和 Matplotlib - fill_ Between() 与 datetime64

    有一个 Pandas 数据框
  • ggplot 中的热图,每组不同的颜色

    我正在尝试在 ggplot 中生成热图 我希望每个组都有不同的颜色渐变 但不知道该怎么做 我当前的代码如下所示 dummy data data lt data frame group sample c Direct Patient Care
  • OL3:强制重绘图层

    我目前正在将 OpenLayers 客户端版本 2 13 1 升级为新版本的 OpenLayers OL3 我的设置包括作为 WMS 映射服务器的 Mapserver 和前面提到的 OpenLayers 客户端 在旧系统中 我支持用户交互
  • R 中百分比格式表

    我想获取一个百分比表 将值格式化为百分比并以良好的格式显示它们 如果重要的话 我正在使用 RStudio 并编织为 PDF 我看过其他关于此的帖子 但它们看起来都不干净 而且效果不佳 例如 下面的 apply 语句确实采用百分比格式 但是
  • 检索两个字符之间的子字符串

    我有这样的字符串 var str it itA itB et etA etB etC etD 如何检索 和 之间的元素 截至目前 我正在用新行分割文本 但无法解决这个问题 请帮我解决这个问题 请使用这个小提琴http jsfiddle ne
  • IronPython - JSON 选择

    在 IronPython 2 0 1 中处理 JSON 的最佳方法是什么 原生 Python 标准库 json 看起来尚未实现 如果我想使用 Newtonsoft Json NET 库 我该怎么做 我可以将程序集添加到 GAC 但我还有其他
  • 如何使用 php 渲染远程图像?

    这是一个 jpg https i stack imgur com PIFN0 jpg 假设我希望这个渲染自 img php file name PIFN0 jpg 以下是我尝试完成这项工作的方法 样本 php p Here s my ima
  • UICollectionView 启用取消选择单元格,同时禁用 allowedMultipleSelection

    When collectionView allowsMultipleSelection YES 我可以取消选择已选择的单元格 when collectionView allowsMultipleSelection NO 我无法取消选择已选择
  • Fortran 中不提升数组的标量参数

    为什么 Fortran 会将标量表达式提升为数组表达 但不作为过程的参数 特别是 为什么标准机构做出这样的设计决定 仅仅是因为含糊不清 程序就应该超载吗 在这种情况下 错误消息是否可以作为替代方法 例如 在下面的代码中 最后一条语句 x f
  • Jsoup,在执行表单POST之前获取值

    这是我用来提交表单的代码 Connection Response res Jsoup connect http example com data id myID data username myUsername data code MyAu
  • iPhone:cocos2d 中相机跟随玩家

    我正在用 cocos2d 制作 iPhone 游戏 我想知道如何使相机 视图遵循特定的精灵 我会使用 CCCamera 类吗 是的 CCCamera 可以工作 然而 它有一些缺点 使其不适合某些用途 相对于该精灵移动图层以及所有其他对象可能
  • 在 StructureMap 中注册一个默认实例

    我有一堂课 MyService 具有静态属性 MyService Context 代表当前上下文 特定于当前登录的用户 因此它会发生变化 我想要实现的目标 ObjectFactory Initialize x gt x For
  • 在 WPF 中,我们如何将 Duration 定义为资源?

    我在许多动画中使用了一个持续时间 0 0 0 5 并且我想仅在一个位置定义该数字 我可以将双精度定义为
  • 在 Win32 API 中绘制格式化文本的最快方法是什么?

    我正在使用普通 Win32 API 在 C 中实现一个文本编辑器 并且我正在尝试找到实现语法突出显示的最佳方法 我知道有像 scintilla 这样的现有控件 但我这样做是为了好玩 所以我想自己完成大部分工作 我还希望它又快又轻 从我到目前
  • TensorFlow 中的高效图像膨胀

    我正在寻找一种有效的实施方式形态学图像膨胀 https en wikipedia org wiki Dilation morphology 在 TensorFlow 中使用方形内核 正如 OpenCV 所示 与实际效果相比 显而易见的方法似