Ensemble of Narrow DNN Chains

[pdf] Manuscript (Course Essay for Machine Learning at the University of Oxford), 2021, 2021

Our paper available at: “Ensemble of Narrow DNN Chains” (my Machine Learning course essay at Oxford).

Our code is publicly available at https://github.com/vtu81/ENDC.

We propose the Ensemble of Narrow DNN Chains (ENDC) framework:

  1. first train such narrow DNN chains that perform well on one-vs-all binary classification tasks,
  2. then aggregate them together by voting to predict for the multiclassification task.

Our ensemble framework could:

  • utilize the abstract interpretability of DNNs,
  • outperform traditional ML significantly on CIFAR-10,
  • while being 2-4 orders of magnitude smaller than normal DNN and 6+ times smaller than traditional ML models,
  • furthermore compatible with full parallelism in both the training and deployment stage.

Our empirical study shows that a narrow DNN chain could learn binary classifications well. Moreover, our experiments on three MNIST, Fashion-MNIST, CIFAR-10 confirm the potential power of ENDC. Compared with traditional ML models, ENDC, with the smallest parameter number, could achieve similar accuracy on MNIST and Fashion-MNIST, and significantly better accuracy on CIFAR-10.

Results

Overall Accuracy

DatasetAccuracyArch#Param
MNIST93.40%1-channel1300
Fashion-MNIST80.39%1-channel1300
CIFAR-1047.72%2-channel4930
  • Each binary classifier’s parameter number is even smaller than the input entry (130 < 28x28 for MNIST and Fashion-MNIST, 493 < 3x32x32 for CIFAR-10)!

Comparison

We compare ENDC with traditional ML models:

  • Logistic Regression (LR)
  • Support Vector Classifier (SVC)

and normal DNNs. Their results are referenced from internet, see our paper for sources and details.

MNIST

MethodAccuracy (%)# Param
ENDC (ours)93.41.3K
LR91.77.7K+
SVC97.87.7K+
Normal DNN (LeNet)99.30.41M

Fashion-MNIST

MethodAccuracy (%)# Param
ENDC (ours)80.41.3K
LR84.27.7K+
SVC89.77.7K+
Normal DNN (VGG-16)93.526M

CIFAR-10

MethodAccuracy (%)# Param
ENDC (ours)47.74.8K
LR39.930.0K+
SVC (PCA)40.20.44M+
Normal DNN (VGG-16-BN)93.915M

Per-class Accuracy

Dataset#0 (%)#1 (%)#2 (%)#3 (%)#4 (%)#5 (%)#6 (%)#7 (%)#8 (%)#9 (%)
MNIST97.0497.5396.5188.9195.5292.3890.2994.5588.7191.67
Fashion-MNIST80.6092.9077.6077.6075.5092.3040.7081.3090.0095.50
CIFAR-1048.9055.7043.5031.8041.0045.4061.9042.0049.9057.10