使用 Keras 预测进行 Python 多处理

2024-03-20

Context

Keras 模型 (链接在这里 https://drive.google.com/file/d/1f0WGCv11uObPziySE2wl6hXYKfyjqXBQ/view?usp=sharing,为了 MWE)需要并行预测大量测试数据。

我定义一个cube作为 3Dnumpy.ndarray of uint。它的每个垂直切片都是一个column,即npixels= 128 高度,nbins= 128 深度。

每个预测都会变换去噪列(相同大小)中的一列。

我提供了三种方法:单线程、多处理和pathos包多处理。这两种多线程方法都不起作用,我不明白原因。

Code

import keras
import numpy as np
import threading
import pathos.multiprocessing
import multiprocessing


def __res_sum_squares(y_true, y_pred):
    squared_diff = (y_true - y_pred) ** 2
    return keras.backend.sum(squared_diff)


__npixels, __nbins = 128, 128
__shape_col = (__npixels, __nbins)
__shape_nn = (1, __npixels, __nbins, 1)
__model = keras.models.load_model('./model.h5', compile=True, custom_objects={'res_sum_squares': __res_sum_squares})

__max_parallel_predictions = 4
__sema = threading.BoundedSemaphore(value=__max_parallel_predictions)


def __mt_pathos_manager(col_ratio):
    return __denoise(col_ratio[0], col_ratio[1])


def __denoise_frame_mt_pathos(frame_ratios):
    results = pathos.multiprocessing.ProcessingPool().map(__mt_pathos_manager, frame_ratios)
    return results


def __denoise_frame_mt_multiprocessing(frame_ratios):
    pool = multiprocessing.Pool()
    results = pool.map(__denoise, map(lambda col_ratio: col_ratio, frame_ratios))
    pool.close()
    pool.join()
    return results


def __denoise(col, ratio=None):
    """
        :param col: the source column
        :param ratio: logging purposes
        :return: the denoised column
    """
    really_predict = True
    if type(col) is tuple:
        col, ratio = col[0], col[1]
    col_denoise = np.reshape(col, __shape_nn)

    print("{} acquiring".format(ratio))
    __sema.acquire()
    print("{} acquired".format(ratio))
    #  ~    ~  ~  ~  ~  ~  ~  ~  ~  ~ CRITICAL SECTION START ~  ~  ~  ~  ~  ~  ~  ~  ~  ~
    col_denoise = __model.predict(col_denoise) if really_predict else col_denoise
    #  ~    ~  ~  ~  ~  ~  ~  ~  ~  ~ CRITICAL SECTION END   ~  ~  ~  ~  ~  ~  ~  ~  ~  ~
    print("{} releasing".format(ratio))
    __sema.release()
    print("{} released".format(ratio))

    return np.reshape(col_denoise, __shape_col)


def denoise_cube(cube, mp=False, mp_pathos=False):
    """
        :param cube: a numpy 3D array of ncols * npixels * nbins
        :param mp: use multiprocessing
        :param mp_pathos: use pathos multiprocessing
        :return: the denoised cube
    """
    ncol = cube.shape[0]
    ratios = [(ic * 100.0) / ncol for ic in range(0, ncol)]
    frame_ratios = zip(cube, ratios)

    if mp:
        if mp_pathos:
            l_cols_denoised = __denoise_frame_mt_pathos(frame_ratios)
        else:
            l_cols_denoised = __denoise_frame_mt_multiprocessing(frame_ratios)
    else:
        l_cols_denoised = [__denoise(col, ratio) for col, ratio in frame_ratios]
    return l_cols_denoised


if __name__ == "__main__":

    test_cube = np.random.rand(1000, __npixels, __nbins)

    # Single threaded impl: works fine
    denoise_cube(test_cube, mp=False)
    # Multiprocessing Pool: blocks at the eighth "acquired" print
    denoise_cube(test_cube, mp=True, mp_pathos=False)
    # Pathos multiprocessing Pool: blocks at the eighth "acquired" print
    denoise_cube(test_cube, mp=True, mp_pathos=True)

Analysis

我首先想到的是,不知何故,急于__model.predict()在 8 次调用后阻塞(= 测试机器上的 CPU 核心数)。 所以我放置了一个threading.BoundedSemaphore访问次数少于 8 次。什么都不起作用。

单线程按预期工作:

0.0 acquiring
0.0 acquired
0.0 releasing
0.0 released
< ............ >
99.9 acquiring
99.9 acquired
99.9 releasing
99.9 released

多重处理(两个版本)都没有。

0.0 acquiring
0.0 acquired
3.2 acquiring
3.2 acquired
6.4 acquiring
6.4 acquired
9.6 acquiring
9.6 acquired
12.8 acquiring
12.8 acquired
16.0 acquiring
16.0 acquired
19.2 acquiring
19.2 acquired
22.4 acquiring
22.4 acquired
< hangs >

等等,哪里有release印刷?似乎信号量没有被触及,或者每次调用都被复制,并且总是重新初始化。唔。

那么我们来寻找一下really_predict = True并交换其值:predict()以这种方式永远不会接通电话。

....这效果很好,太棒了!所以问题并不能完全解决multiprocessing,而不是之间的奇怪链接keras预测和multiprocessing汇集。有什么建议吗?


None

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用 Keras 预测进行 Python 多处理 的相关文章

随机推荐