Pruning Machine Learning Models in TensorFlow

Learn how to make your models smaller via pruning

In a previous article, we reviewed some of the pre-eminent literature on pruning neural networks. We learned that pruning is a model optimization technique that involves eliminating unnecessary values in the weight tensor. This results in smaller models with accuracy very close to the baseline model.

In this article, we’ll work through an example as we apply pruning and view the effect on the final model size and prediction errors.

Import the Usual Suspects

Our first step is to get a couple of imports out of the way:

  • Os and Zipfile will help us in assessing the size of the models.
  • tensorflow_model_optimization for model pruning.
  • load_model for loading a saved model.
  • and of course tensorflow and keras.

Finally, we initialize TensorBoard so that we’ll able to visualize the models:

Dataset Generation

For this experiment, we’ll generate a regression dataset using scikit-learn. Thereafter, we split the dataset into a training and test set:

Model Without Pruning

We’ll create a simple neural network to predict the target variable y. We’ll then check the mean squared error. After this, we’ll compare this with the entire model pruned, and then with just the Dense layer pruned.

def setup_model():
    model = keras.Sequential([
    keras.layers.Dense(units = 128, activation='relu',input_shape=(X_train.shape[1],)),
    keras.layers.Dense(units=1, activation='relu')])
    return model

Next, we step up a callback to stop training the model once it stops improving, after 30 epochs.

Let’s print a summary of the model so that we can compare it with the summary of the pruned models.

Let’s compile the model and train it.

Since it’s a regression problem, we’re monitoring the mean absolute error and the mean squared error.

Here’s the model plotted to an image. The input is 10 since the dataset we generated has 10 features.

Let’s now check the mean squared error. We can move on to the next section and see how this error changes when we prune the entire model.

Pruning the Entire Model with a ConstantSparsity Pruning Schedule

Let’s compared the above MSE with the one obtained upon pruning the entire model. The first step is to define the pruning parameters. The weight pruning is magnitude-based. This means that some weights are converted to zeros during the training process. The model becomes sparse, hence making it easier to compress. Sparse models also make inferencing faster since the zeros can be skipped.

The parameters expected are the pruning schedule, the block size, and the block pooling type.

  • In this case, we’re setting a 50% sparsity, meaning that 50% of the weights will be zeroed.
  • block_size — The dimensions (height, weight) for the block
    sparse pattern in matrix weight tensors.
  • block_pooling_type — The function to use to pool weights in the
    block. Must be AVG or MAX.
from tensorflow_model_optimization.sparsity.keras import ConstantSparsity
pruning_params = {
    'pruning_schedule': ConstantSparsity(0.5, 0),
    'block_size': (1, 1),
    'block_pooling_type': 'AVG'

We can now prune the entire model by applying our pruning parameters.

from tensorflow_model_optimization.sparsity.keras import prune_low_magnitude
model_to_prune = prune_low_magnitude(
        tf.keras.layers.Dense(128, activation='relu', input_shape=(X_train.shape[1],)),
        tf.keras.layers.Dense(1, activation='relu')
    ]), **pruning_params)

Let’s check the model summary. Compare this with the summary of the unpruned model. From the image below we can see that the entire model has been pruned—we’ll see the difference shortly with the summary obtained after pruning one dense layer.

We have to compile the model before we can fit it to the training and testing set.

Since we’re applying pruning, we have to define a couple of pruning callbacks in addition to the early stopping callback. We define the folder to log the model, then create a list with the callbacks.

tfmot.sparsity.keras.UpdatePruningStep() updates pruning wrappers with the optimizer step. Failure to specify it will result in an error.

tfmot.sparsity.keras.PruningSummaries() adds pruning summaries to the Tensorboard.

With that out of the way, we can now fit the model to the training set.

Upon checking the mean squared error for this model, we notice that it’s slightly higher than the one for the unpruned model.

Pruning the Dense Layer Only with PolynomialDecay Pruning Schedule

Let’s now implement the same model—but this time, we’ll prune the dense layer only. Notice the use of the PolynomialDecay function in the pruning schedule.

from tensorflow_model_optimization.sparsity.keras import PolynomialDecay
layer_pruning_params = {
    'pruning_schedule': PolynomialDecay(initial_sparsity=0.2,
        final_sparsity=0.8, begin_step=1000, end_step=2000),
    'block_size': (2, 3),
    'block_pooling_type': 'MAX'

model_layer_prunning = keras.Sequential([
    prune_low_magnitude(tf.keras.layers.Dense(128, activation='relu',input_shape=(X_train.shape[1],)),
    tf.keras.layers.Dense(1, activation='relu')

From the summary, we can see that only the first dense layer will be pruned.

We then compile and fit the model.

Now, let’s check the mean squared error.

We can’t compare the MSE obtained here with the previous one since we’ve used different pruning parameters. If you’d like to compare them, then ensure that the pruning parameters are similar. Upon testing, layer_pruning_params gave a lower error than the pruning_params for this specific case. Comparing the MSE obtained from different pruning parameters is useful so that you can settle for the one that doesn’t make the model’s performance worse.

Comparing Model Sizes

Let’s now compare the sizes of the models with and without pruning. We start by training and saving the model weights for later use.

def train_save_weights():
    model = setup_model()
              metrics=['mae', 'mse']),y_train,epochs=300,validation_split=0.2,callbacks=callbacks,verbose=0)

We’ll set up our base model and load the saved weights. We then prune the entire model. We compile, fit the model, and visualize the results on Tensorboard.

base_model = setup_model()
base_model.load_weights('.models/friedman_model_weights.h5') # optional but recommended for model accuracy
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)

      metrics=['mae', 'mse']
    validation_split = 0.2,

%tensorboard --logdir={log_dir}

Here’s a single snapshot of the pruning summaries from TensorBoard.

The other pruning summaries can also be viewed on Tensorboard.

Let’s now define a function to compute the sizes of the models.

def get_gzipped_model_size(model,mode_name,zip_name):
    # Returns size of gzipped model, in bytes., include_optimizer=False)

    with zipfile.ZipFile(zip_name, 'w', compression=zipfile.ZIP_DEFLATED) as f:

    return os.path.getsize(zip_name)

And now we define the model for export and then compute the sizes.

For a pruned model, tfmot.sparsity.keras.strip_pruning() is used to restore the original model with the sparse weights. Notice the difference in size for the stripped and unstripped models.

print("Size of gzipped pruned model without stripping: %.2f bytes" % (get_gzipped_model_size(model_for_pruning,'.models/model_for_pruning.h5','.models/')))
print("Size of gzipped pruned model with stripping: %.2f bytes" % (get_gzipped_model_size(model_for_export,'.models/model_for_export.h5','.models/')))

Running predictions on both models, we see that they have the same mean squared error.

model_for_prunning_predictions = model_for_pruning.predict(X_test)
print('Model for Prunning Error %.4f' % mean_squared_error(y_test,model_for_prunning_predictions.reshape(3300,)))
model_for_export_predictions = model_for_export.predict(X_test)
print('Model for Export Error  %.4f' %  mean_squared_error(y_test,model_for_export_predictions.reshape(3300,)))

Final Thoughts

You can go ahead and test how different pruning schedules affect the size of the model. Obviously, the observations made here are not universal. You’ll have to try different pruning parameters and learn how they affect your model size, prediction error, and/or accuracy depending on your problem.

To optimize the model even more, you could quantize it. If you’d like to explore that and more, check the repo and the resources below.


Our team has been at the forefront of Artificial Intelligence and Machine Learning research for more than 15 years and we're using our collective intelligence to help others learn, understand and grow using these new technologies in ethical and sustainable ways.

Comments 0 Responses

Leave a Reply

Your email address will not be published. Required fields are marked *

wix banner square