是否可以使用 Keras 从自定义损失函数中调用/使用实例属性或全局变量?

2023-12-21

我想定义一个如下的损失函数:

def custom_loss_function(y_true, y_pred):
    calculate loss based on y_true, y_pred and self.list_of_values

其中变量 self.list_of_values 每次迭代都会在此函数之外进行修改,因此每次“调用”custom_loss_function 时都会有不同的值。我知道从这个帖子 https://stackoverflow.com/questions/55653015/how-do-i-call-a-global-variable-in-keras-custom-loss-function-to-change-the-ret损失函数仅被调用一次,然后“会话迭代地评估损失”。

我的疑问是是否可以使用损失函数中的全局/外部变量(具有动态值),然后像这样使用:

model.compile(loss=custom_loss_function, optimizer=Adam(lr=LEARNING_RATE), metrics=['accuracy'])

在此指定解决方案(答案部分),即使它存在于注释部分中,对于社区的利益.

变量,list_of_values可以被认为是一个Input Variable like

list_of_values = Input(shape=(1,), name='list_of_values')并定义Custom Loss function如下所示:

def sample_loss( y_true, y_pred, list_of_values ) :
    return list_of_values * categorical_crossentropy( y_true, y_pred ) 

还有,同样Global Variable可以作为输入传递给模型,例如:

model = Model( inputs=[x, y_true, list_of_values], outputs=y_pred, name='train_only' )

示例的完整代码如下所示:

from keras.layers import Input, Dense, Conv2D, MaxPool2D, Flatten
from keras.models import Model
from keras.losses import categorical_crossentropy

def sample_loss( y_true, y_pred, list_of_values ) :
    return list_of_values * categorical_crossentropy( y_true, y_pred ) 

x = Input(shape=(32,32,3), name='image_in')
y_true = Input( shape=(10,), name='y_true' )
list_of_values = Input(shape=(1,), name='list_of_values')
f = Conv2D(16,(3,3),padding='same')(x)
f = MaxPool2D((2,2),padding='same')(f)
f = Conv2D(32,(3,3),padding='same')(f)
f = MaxPool2D((2,2),padding='same')(f)
f = Conv2D(64,(3,3),padding='same')(f)
f = MaxPool2D((2,2),padding='same')(f)
f = Flatten()(f)
y_pred = Dense(10, activation='softmax', name='y_pred' )(f)
model = Model( inputs=[x, y_true, list_of_values], outputs=y_pred, name='train_only' )
model.add_loss( sample_loss( y_true, y_pred, list_of_values ) )
model.compile( loss=None, optimizer='sgd' )
print model.summary()

欲了解更多信息,请参阅此堆栈溢出答案 https://stackoverflow.com/a/50127646/13465258.

希望这可以帮助。快乐学习!

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

是否可以使用 Keras 从自定义损失函数中调用/使用实例属性或全局变量? 的相关文章

随机推荐