1. 自定义模型
1.1 构建自定义模型的基本步骤
- 继承
keras.Model
类
- 在构造函数中创建层和变量
- 实现call方法来执行操作
- 实现get_config()
实现了get_config()就可以使用save()
方法保存模型并使用keras.models.load_model()
函数加载模型,使用save_weights()
和load_weights()
方法来保存和加载模型
其他功能和基本模型一样
1.2 基于模型内部的损失和指标
在call()方法计算损失,并使用add_loss()
方法将其添加到模型的损失函数中
2 使用自动微分计算梯度
- 首先定义两个变量W1和W2
- 创建一个tf.GradientTape上下文
- 要求该tape针对两个变量[W1,W2]计算z的梯度
w1, w2 = tf.Variable(5.), tf.Variable(3.)
with tf.GradientTape() as tape:
z = f(w1, w2)
gradients = tape.gradient(z, [w1, w2])
调用tape的gradient方法后,tape会立即被自动擦除。如果需要多次调用gradient(),必须使得tape具有持久性persistent=True
with tf.GradientTape(persistent=True) as tape:
z = f(w1, w2)
dz_dw1 = tape.gradient(z, w1)
dz_dw2 = tape.gradient(z, w2) # works now!
del tape
需要手动删除tape
2.1 跟踪对象
默认跟踪涉及变量的操作,但是可以强制观察你喜欢的任何张量
c1, c2 = tf.constant(5.), tf.constant(3.)
with tf.GradientTape() as tape:
tape.watch(c1)
tape.watch(c2)
z = f(c1, c2)
gradients = tape.gradient(z, [c1, c2])
print(gradients)
结果:
[<tf.Tensor: shape=(), dtype=float32, numpy=36.0>, <tf.Tensor: shape=(), dtype=float32, numpy=10.0>]
- 用处:实现正则化损失,从而在输入变化不大的时候惩罚那些变化很大的激活
2.2 梯度
一个梯度tape是用来计算单个值(通常是损失)相对于一组值(通常是模型参数)的梯度。一正一反获得所有梯度,可以调用jacobian()方法获取单独的梯度
with tf.GradientTape(persistent=True) as hessian_tape:
with tf.GradientTape() as jacobian_tape:
z = f(w1, w2)
jacobians = jacobian_tape.gradient(z, [w1, w2])
hessians = [hessian_tape.gradient(jacobian, [w1, w2])
for jacobian in jacobians]
del hessian_tape
print(jacobians)
print(hessians)
用来获得二阶偏导数
hessians = [hessian_tape.gradient(jacobian, [w1, w2])
for jacobian in jacobians]
2.3 阻止反向传播
使用tf.stop_gradient
def f(w1, w2):
return 3 * w1 ** 2 + tf.stop_gradient(2 * w1 * w2)
with tf.GradientTape() as tape:
z = f(w1, w2)
tape.gradient(z, [w1, w2])
2.4 返回值为nan
解决办法(1):重写函数,并使用 来修饰它并使它既返回其正常输出又返回计算导数的函数
@tf.custom_gradient
def my_better_softplus(z):
exp = tf.exp(z)
def my_softplus_gradients(grad):
return grad / (1 + 1 / exp)
return tf.math.log(exp + 1), my_softplus_gradients
解决办法(2):使用tf.where在较大输入时返回输入
def my_better_softplus(z):
return tf.where(z > 30., z, tf.math.log(tf.exp(z) + 1.))