Model class to store pytorch model.
%load_ext autoreload
%autoreload 2
model = Sequential()
model.add(Dense(2, x.shape[1], activation='relu'))
model.add(Dense(2, activation='relu'))
model.add(Dense(len(set(y))))
model.add(Activation('softmax'))
model.compile(ce4softmax)
bs = 50
model.lr_find(x, y, bs=bs)
model.fit(x, y, bs, epochs=3, lr=1e-1)
preds = model.predict(x[:2])
preds
As can be seen the sum of the probabilities are 1.
preds.sum(axis=-1, keepdims=True)