我正在构建一个简单的 CNN 模型用于多类分类。训练和测试数据位于data_path
根据所需的类子目录flow_from_directory
的函数ImageDataGenerator
.
这是我根据数据构建和训练模型的代码:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dropout, Flatten, Dense, Conv2D, MaxPooling2D
from tensorflow.keras.preprocessing.image import ImageDataGenerator
# Build Model
model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(40, 24, 1)))
model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
model.add(Conv2D(64, kernel_size=(3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
model.add(Conv2D(64, kernel_size=(3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(12, activation='softmax'))
model.compile('binary_crossentropy', 'SGD', ['accuracy'])
# Init Generators
generator = ImageDataGenerator(rescale=1./255,
horizontal_flip=True,
fill_mode='nearest',
validation_split=0.2)
def get_train_images():
train_images = generator.flow_from_directory(os.path.join(data_path, 'train'),
target_size=(40, 24, 1),
batch_size=32,
color_mode='grayscale',
class_mode='categorical',
subset='training',
shuffle=True)
def get_validation_images():
validation_images = generator.flow_from_directory(os.path.join(data_path, 'train'),
target_size=(40, 24, 1),
batch_size=32,
color_mode='grayscale',
class_mode='categorical',
subset='validation',
shuffle=True)
# Train Model
model.fit(get_train_images, validation_data=get_validation_images, epochs=20)
拟合函数给出以下错误:
File "C:\Python38\lib\site-packages\tensorflow\python\keras\engine\training.py", line 108, in _method_wrapper
return method(self, *args, **kwargs)
File "C:\Python38\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1049, in fit
data_handler = data_adapter.DataHandler(
File "C:\Python38\lib\site-packages\tensorflow\python\keras\engine\data_adapter.py", line 1104, in __init__
adapter_cls = select_data_adapter(x, y)
File "C:\Python38\lib\site-packages\tensorflow\python\keras\engine\data_adapter.py", line 968, in select_data_adapter
raise ValueError(
ValueError: Failed to find data adapter that can handle input: <class 'method'>, <class 'NoneType'>
看起来是某种兼容性问题。我正在使用张量流版本 2.3.1。有人可以指出我做错了什么并帮助我解决这个问题吗?
Thanks!