Keras for Binary Classification
So I didn’t get around to seriously (besides running a few examples) play with Keras (a powerful library for building fully-differentiable machine learning models aka neural networks) – until now. And I have been a bit surprised about how tricky it actually was for me to get a simple task running, despite (or maybe because of) all the docs available already.
The thing is, many of the “basic examples” gloss over exactly how the inputs and mainly outputs look like, and that’s important. Especially since for me, the archetypal simplest machine learning problem consists of binary classification, but in Keras the canonical task is categorical classification. Only after fumbling around for a few hours, I have realized this fundamental rift.
The examples (besides LSTM sequence classification) silently assume that you want to classify to categories (e.g. to predict words etc.), not do a binary 1/0 classification. The consequences are that if you naively copy the example MLP at first, before learning to think about it, your model will never learn anything and to add insult to injury, always show the accuracy as 1.0.
So, there are a few important things you need to do to perform binary classification:
- Pass
output_dim=1
to your finalDense
layer (this is the obvious one). - Use
sigmoid
activation instead ofsoftmax
– obviously, softmax on single output will always normalize whatever comes in to 1.0. - Pass
class_mode='binary'
tomodel.compile()
(this fixes the accuracy display, possibly more; you want to passshow_accuracy=True
tomodel.fit()
).
Other lessons learned:
- For some projects, my approach of first cobbling up an example from existing code and then thinking harder about it works great; for others, not so much…
- In IPython, do not forget to reinitialize
model = Sequential()
in some of your cells – a lot of confusion ensues otherwise. - Keras is pretty awesome and powerful. Conceptually, I think I like NNBlocks‘ usage philosophy more (regarding how you build the model), but sadly that library is still very early in its inception (I have created a bunch of gh issues).
(Edit: After a few hours, I toned down this post a bit. It wasn’t meant at all to be an attack at Keras, though it might be perceived by someone as such. Just as a word of caution to fellow Keras newbies. And it shouldn’t take much to improve the Keras docs.)
Recent Comments