使用keras实现一个多输入多输出的网络

结构图

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import keras
from keras.layers import Input, Dense
from keras.models import Model

input1 = Input(shape=(784,), name="input1")
input2 = Input(shape=(10,), name="input2")

hidden = Dense(1, activation='relu')(input1)
output1 = Dense(10, activation='relu', name="output1")(hidden)

hidden_input2 = keras.layers.concatenate([hidden, input2])
output2 = Dense(10, activation='relu', name="output2")(hidden_input2)

model = Model(inputs=[input1, input2], outputs=[output1, output2])

model.compile(loss={'output1': ... , 'output2': ...}, optimizer=..., loss_weights= [1, 0.4], metrics=['accuracy'])

model.fit([train_X1, train_X2], [train_y1, train_y2], batch_size=None, epochs=1,
validation_data=([test_X1, test_X2], [test_y1, test_y2]))