Thursday, October 24, 2019

Model Distillation

Model Distillation is the process of taking a big model or ensemble of models and producing a smaller model that captures most of the performance of the original bigger model. It could be also better be described as a blind model replication method.

The reasons for doing so are:
  1. improved run-time performance (FLOP operations)
  2. (maybe) better generalization because of the model simplicity
  3. you don't have access to the training of the original model.
  4. you have access to a remotely deployed model and you want to replicate it (it happens more than you can imagine)
  5. original model maybe is too complicated
  6. insights that may arise from the process itself

How it works

Assume a MNIST classifier \(F_{MNIST}\) composed of an ensemble of \(N\) convolutional deep neural networks that produces a logit \(z_i\) which is then converted to a probability of an input image, \(x_i\), for each of the possible labels \(C_{0}-C_{9}\).

The distillation process will give us an  \(F_{MNIST_{distilled}}\) composed of a single deep neural network that will approximate the classification results of the bigger ensemble of models.

In distillation, knowledge is transferred from the teacher model to the student by minimizing a loss function in which the target is the distribution of class probabilities predicted by the teacher model. That is - the output of a softmax function on the teacher model's logits.

Logits \(z_j\) are converted to probabilities \(P(C_i|x)\) using the softmax layer:

 \(
p_i = \frac
{exp(z_i)}
{\sum_{j}exp(z_j)}
\)

However, in many cases, this probability distribution has the correct class at a very high probability, with all other class probabilities very close to 0. As such, it doesn't provide much information beyond the ground truth labels already provided in the dataset.

To tackle this issue, Hinton et al., 2015 introduced the concept of "softmax temperature". The probability \(q_i\) is computer by the logit \(z_i\) for the scalar softmax temperature \(T\):

 \(
q_i = \frac
{exp(\frac{z_i}{T})}
{\sum_{j}exp(\frac{z_j}{T})}
\)

where T is a temperature that is normally set to 1. Using a higher value for T produces a softer probability distribution over classes. Softer probability distribution means that the values are somewhat diffused and a 0.999 probability may become 0.9 and the rest spread to the other classes.

In the simplest form of distillation, knowledge is transferred to the distilled model by training it on a transfer set and using a soft target distribution for each case in the transfer set that is produced by using the cumbersome model with a high temperature in its softmax. The same high temperature is
used when training the distilled model, but after it has been trained it uses a temperature of 1. When the correct labels are known for all or some of the transfer set, this method can be significantly improved by also training the distilled model to produce the correct labels. One way to do this is to use the correct labels to modify the soft targets, but we found that a better way is to simply use a weighted average of two different objective functions.
  1. The first objective function is the cross entropy with the soft targets and this cross entropy is computed using the same high temperature in the softmax of the distilled model as was used for generating the soft targets from the cumbersome model. 
  2. The second objective function is the cross entropy with the correct labels. This is computed using exactly the same logits in softmax of the distilled model but at a temperature of 1. 
This very simple operation can have a multitude of knobs and parameters to adjust but the core essence is very simple and works quite well.