error in constructing multioutput model in keras(在KERAS中构建多输出模型时出错)
本文介绍了在KERAS中构建多输出模型时出错的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!
问题描述
我正在尝试在KERAS中创建多输出模型。该模型从卷积开始,旨在将两个独立的致密层的结果叠加在一起。我为回归任务创建了一些随机数据,其中x1
是输入,df
是标签。df
包含三列。在定义了列车和测试拆分并形成模型后,我在拟合模型时收到错误。有人能帮我更正代码吗?
x1 = np.random.rand(500, 244, 244, 20)
df = pd.DataFrame(np.random.uniform(0,1,size=(500, 3)), columns=list('XYZ'))
x_train, x_test, y_train, y_test = train_test_split(x1,df ,test_size=0.2)
n1_y_train=y_train['X'].values
n1_y_test=y_test['X'].values
n2_y_train=y_train['Y'].values
n2_y_test=y_test['Y'].values
n3_y_train=y_train['Z'].values
n3_y_test=y_test['Z'].values
train_shape = x_train.shape
inputs = layers.Input(shape = train_shape[1:])
x = layers.Conv2D(16, (3,3), activation='relu', padding="same")(inputs)
x = layers.Flatten()(x)
# section1:
l1 = layers.Dense(16, activation='relu')(x)
l1 = layers.Dense(1)(l1)
# section2:
l2 = layers.Dense(32, activation='relu')(x)
l2 = layers.Dense(1)(l2)
output1 = tf.reduce_mean(tf.stack([l1, l2], axis=0), axis=0, name = "output1")
output2 = tf.reduce_mean(tf.stack([l1, l2], axis=0), axis=0, name = "output2")
output3 = tf.reduce_mean(tf.stack([l1, l2], axis=0), axis=0, name = "output3")
model = tf.keras.models.Model(inputs, [output1,output2,output3])
model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss= tf.keras.losses.mse,
metrics=tf.keras.metrics.RootMeanSquaredError(name="rmse"))
history = model.fit(x_train,{"output1": n1_y_train, "output2": n2_y_train, "output3": n3_y_train},
validation_data = (x_test,{"output1": n1_y_test, "output2": n2_y_test, "output3": n3_y_test}),
verbose=2,
epochs=100,
batch_size=32)
错误:
ValueError: Found unexpected losses or metrics that do not correspond to any Model output: dict_keys(['output1', 'output2', 'output3']). Valid mode output names: ['tf.math.reduce_mean', 'tf.math.reduce_mean_1', 'tf.math.reduce_mean_2']. Received struct is: {'output1': <tf.Tensor 'IteratorGetNext:1' shape=(None,) dtype=float32>, 'output2': <tf.Tensor 'IteratorGetNext:2' shape=(None,) dtype=float32>, 'output3': <tf.Tensor 'IteratorGetNext:3' shape=(None,) dtype=float32>}.
推荐答案
您应该用Lambda
层包装输出层,并使用tf.concat
而不是tf.stack
。通过使用Lambda
层,您可以显式设置输出的名称,这些名称将由您的模型捕获。以下是一个工作示例:
import pandas as pd
import numpy as np
import tensorflow as tf
from sklearn.model_selection import train_test_split
x1 = np.random.rand(10, 244, 244, 20)
df = pd.DataFrame(np.random.uniform(0,1,size=(10, 3)), columns=list('XYZ'))
x_train, x_test, y_train, y_test = train_test_split(x1,df ,test_size=0.2)
n1_y_train=y_train['X'].values
n1_y_test=y_test['X'].values
n2_y_train=y_train['Y'].values
n2_y_test=y_test['Y'].values
n3_y_train=y_train['Z'].values
n3_y_test=y_test['Z'].values
train_shape = x_train.shape
inputs = tf.keras.layers.Input(shape = train_shape[1:])
x = tf.keras.layers.Conv2D(16, (3,3), activation='relu', padding="same")(inputs)
x = tf.keras.layers.Flatten()(x)
l1 = tf.keras.layers.Dense(16, activation='relu')(x)
l1 = tf.keras.layers.Dense(1)(l1)
l2 = tf.keras.layers.Dense(32, activation='relu')(x)
l2 = tf.keras.layers.Dense(1)(l2)
output1 = tf.keras.layers.Lambda(lambda x: tf.reduce_mean(x, axis=1, keepdims=True), name='output1')(tf.concat([l1, l2], axis=1))
output2 = tf.keras.layers.Lambda(lambda x: tf.reduce_mean(x, axis=1, keepdims=True), name='output2')(tf.concat([l1, l2], axis=1))
output3 = tf.keras.layers.Lambda(lambda x: tf.reduce_mean(x, axis=1, keepdims=True), name='output3')(tf.concat([l1, l2], axis=1))
model = tf.keras.Model(inputs, [output1,output2,output3])
model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss= tf.keras.losses.mse,
metrics=tf.keras.metrics.RootMeanSquaredError(name="rmse"))
history = model.fit(x_train,{"output1": n1_y_train, "output2": n2_y_train, "output3": n3_y_train},
validation_data = (x_test,{"output1": n1_y_test, "output2": n2_y_test, "output3": n3_y_test}),
verbose=2,
epochs=100,
batch_size=2)
这篇关于在KERAS中构建多输出模型时出错的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持编程学习网!
沃梦达教程
本文标题为:在KERAS中构建多输出模型时出错
基础教程推荐
猜你喜欢
- 哪些 Python 包提供独立的事件系统? 2022-01-01
- 使用 Google App Engine (Python) 将文件上传到 Google Cloud Storage 2022-01-01
- 将 YAML 文件转换为 python dict 2022-01-01
- Python 的 List 是如何实现的? 2022-01-01
- 合并具有多索引的两个数据帧 2022-01-01
- 如何在Python中绘制多元函数? 2022-01-01
- 使 Python 脚本在 Windows 上运行而不指定“.py";延期 2022-01-01
- 症状类型错误:无法确定关系的真值 2022-01-01
- 使用Python匹配Stata加权xtil命令的确定方法? 2022-01-01
- 如何在 Python 中检测文件是否为二进制(非文本)文 2022-01-01