我认为没有任何方法可以将“每次运行”参数传递给GridSearchCV
。也许最简单的方法是子类化KerasRegressor
做你想做的事。
class KerasRegressorTB(KerasRegressor):
def __init__(self, *args, **kwargs):
super(KerasRegressorTB, self).__init__(*args, **kwargs)
def fit(self, x, y, log_dir=None, **kwargs):
cbs = None
if log_dir is not None:
params = self.get_params()
conf = ",".join("{}={}".format(k, params[k])
for k in sorted(params))
conf_dir = os.path.join(log_dir, conf)
cbs = [TensorBoard(log_dir=conf_dir, histogram_freq=0,
write_graph=True, write_images=False)]
super(KerasRegressorTB, self).fit(x, y, callbacks=cbs, **kwargs)
你会像这样使用它:
# ...
estimator = KerasRegressorTB(build_fn=create_3_layers_model,
input_dim=input_dim, output_dim=output_dim)
#...
grid = GridSearchCV(estimator=estimator, param_grid=param_grid,
n_jobs=1, scoring=bug_fix_score,
cv=2, verbose=0, fit_params={'log_dir': './Graph'})
grid_result = grid.fit(x.as_matrix(), y.as_matrix())
Update:
Since GridSearchCV
由于交叉验证,多次运行相同的模型(即相同的参数配置),之前的代码最终将在每次运行中放置多个跟踪。看源码(here and here),似乎没有办法检索“当前 split id”。同时,您不应该只检查现有文件夹并根据需要添加子修复,因为作业是并行运行的(至少可能是这样,尽管我不确定 Keras/TF 是否属于这种情况)。你可以尝试这样的事情:
import itertools
import os
class KerasRegressorTB(KerasRegressor):
def __init__(self, *args, **kwargs):
super(KerasRegressorTB, self).__init__(*args, **kwargs)
def fit(self, x, y, log_dir=None, **kwargs):
cbs = None
if log_dir is not None:
# Make sure the base log directory exists
try:
os.makedirs(log_dir)
except OSError:
pass
params = self.get_params()
conf = ",".join("{}={}".format(k, params[k])
for k in sorted(params))
conf_dir_base = os.path.join(log_dir, conf)
# Find a new directory to place the logs
for i in itertools.count():
try:
conf_dir = "{}_split-{}".format(conf_dir_base, i)
os.makedirs(conf_dir)
break
except OSError:
pass
cbs = [TensorBoard(log_dir=conf_dir, histogram_freq=0,
write_graph=True, write_images=False)]
super(KerasRegressorTB, self).fit(x, y, callbacks=cbs, **kwargs)
我在用着os
需要 Python 2 兼容性,但如果您使用的是 Python 3,您可能会考虑更好的pathlib module用于路径和目录处理。
注意:我忘了之前提到过,但为了以防万一,请注意通过write_graph=True
将记录一个图表per run,根据您的模型,这可能意味着很多(相对而言)这个空间。这同样适用于write_images
,虽然我不知道该功能需要的空间。