Gaussian NB In Android, Not Python!

Diving deep into the code and mathematics

Mobile Machine Learning

Whether you are a beginner or a seasoned ML developer, you’ve probably heard about (and most likely implemented) the Naive Bayes classifier. These classifiers are very helpful when you have smaller datasets and lesser computational power.

You might be aware of the world-famous scikit-learn implementation commonly used in Python:

We’ll get our hands dirty with Gaussian Naïve Bayes, its math, and the Android implementation. We’ll use our algorithm on the Iris dataset.

Contents

The Android app (written in Kotlin) can be found in the below GitHub repo:

Motivation

Neural Networks might be the kings, but kingdoms are not only made up of kings, right?

Although Neural Networks (i.e. Artificial Neural Networks) outperform other classifiers in a variety of problems, they have their limitations too.

  • Neural networks are computationally expensive algorithms. So, if you are working on an edge device like an Android smartphone, you can’t train a NN on it, as training a NN requires much more computation.
  • The size, as well as the inference speed, of NN might not be accepted. A huge NN, like the InceptionV3, has millions of parameters and would also require a few seconds to make an inference on an edge device. This would be quite slow in use-cases where real-time classifications need to be made.
  • NNs cannot adapt themselves to changing data. They first need to be trained on a dataset, then we freeze their variables and use them on edge devices (as they cannot be trained on edge devices). If the source of the data is changing over a period of time, we again need to fine-tune or retrain the NN.

To overcome these problems, we might use non-NN classifiers like Decision Trees, Random Forests, or logistic regression (with a threshold). These can be implemented natively on the edge platform, like on Android or iOS, and can provide faster results without consuming much of the memory. In fact, we can make some custom changes in the algorithms, if the researcher in you gets alive!

In this article, we’ll implement a Gaussian version of the Naive Bayes classifier, which can deal with numerical features (as opposed to the traditional Naïve Bayes classifier, which requires categorical features) and predict the class of a given sample.

Overview

You may observe the functioning of the app in the GIF below. We’ll try to create such an application, but the internal classes could be used for any other problem as well.

Before starting with Gaussian Naïve Bayes, I encourage you to read the following excerpt from Wikipedia for context.

And the following video from the amazing YouTube channel StatQuest with Josh Starmer can also be helpful.

Loading the Data in Android

Our Android app expects data given in the CSV format. Observe that the CSV file is kept in the assets folder of the Android app.

In order to understand how we’ll load the CSV data into a nice and clean array, we assume that the Iris dataset is given to us.

A meaningful approach would be to read the data in columns, or what we call *feature columns. Each feature column will hold a particular feature value for all samples, present in the CSV file. For instance, we’ll parse the CSV data into four feature columns [ sepal_length , sepal_width , petal_length , petal_width ] .

*feature columns: The term should not be confused with Feature Columns used in TensorFlow. We’ll use this term fluidly and in general corresponds to a column which holds values for a specific feature.

We implement a class, FeatureColumn, to hold the data for a particular feature.

// Holds data for a particular feature.
class FeatureColumn( var name : String , var data : FloatArray ) {

    // Mean of given `data`
    var featureMean : Float

    // Variance of given `data`
    var featureVariance : Float

    // Standard deviation of given `data`
    var featureStdDev : Float

    init {

        featureMean = computeMean()
        featureVariance = computeVariance()

        // Compute the standard deviation of `data` = sqrt( variance )
        featureStdDev = featureVariance.pow( 0.5f )
    }

    override fun toString(): String {
        return "{ name = $name , mean = $featureMean , stddev = $featureStdDev }"
    }

    // Compute the mean of `data`.
    private fun computeMean() : Float {
        return data.average().toFloat()
    }

    // Compute the variance of `data`.
    private fun computeVariance() : Float {
        val mean = data.average().toFloat()
        val meanSquaredSum = data.map{ xi -> ( xi - mean ).pow( 2 ) }.sum()
        return meanSquaredSum / data.size
    }

}

Why is it a meaningful approach to read the CSV data column (feature column)-wise?

In the above snippet, observe the methods computeMean() and comoputeVariance(). We’ll require the mean and standard deviation (remember which is the square root of variance) of each feature column to construct a Gaussian distribution. We’ll discuss more on this later.

With this, we conclude that the data given to the Gaussian Naive Bayes algorithm will be in the form Array<FeatureColumn>.

Wait a minute. How do we read a CSV file in Android? We don’t have pandas there to save our lives, so we head to Csvreader, a Java library to parse CSV files. We’ll need a separate class which, when given the CSV file, returns Array<FeatuerColumn>, called DataFrame.

// Helper class to load data from given CSV file and transform it into a Array<FeatureColumn>
class DataFrame( context: Context , assetsFileName : String ) {

    // HashMap which stores the CSV data in the form ( Column_Name , Float[] ). Where Float[] holds
    // the feature value for all samples.
    private var featureColumnData = HashMap<String,ArrayList<Float>>()

    // Variable to store the parsed CSV file from `CSVReader`.
    private var rawData : List<Array<String>>

    // Variable to store column names which are rawData[0].
    private var columnNames : Array<String>

    ...

    init {
        // Create a CSVReader object
        val csvReader = CSVReader( InputStreamReader( context.assets.open( assetsFileName ) ) )

        // Call csvReader.readAll() which outputs the contents of the CSV in the form List<String[]>.
        // Every String[] corresponds to a single row.
        rawData = csvReader.readAll()

        // Get the column names.
        columnNames = getColumnNames()

        // Initialize variables like `numFeatures` and `numSamples`.
        // Also initialize `featureColumnData` by setting the keys as column names of this Hashmap.
        initTable()
    }

The constructor of the DataFrame class has two arguments, context and assetsFileName. The assetsFileName variable is the name of the CSV file located in the assets folder.

  • The featureColumnData HashMap is important, as it will store the data in the form of columns. In this HashMap, the keys (of type String) will represent the column name whereas the values (of type ArrayList<Float>) will represent the values of a particular feature. This HashMap will later be transformed to Array<FeatureColumn> which woukd act as an endpoint for the data processing part. Further, the Gaussian NB algorithm will access the data from this variable.
  • At line no. 18, we initialize the csvReader object with the assetsFileName .
  • We call csvReader.readAll() which returns an List<Array<String>> object. Note, each element in this list is a row and Array<String> represents the columnar data.
  • We call getColumnNames() which returns the column names in the CSV file.
 // Get column names from `rawData`.
private fun getColumnNames() : Array<String> {
    val columnNames = rawData[ 0 ]
    return columnNames
}
  • In Snippet 2, you’ll also find the initTable() method, where we add ArrayList<Float> to the featureColumnData HashMap, and also initialize variables numFeatures and numLabels.
...

// Number of features in the dataset. This number equals ( num_cols - 1 ) where num_cols is the number of
// columns in the CSV file.
// ( Note: We assume that the file has the labels column as the last column ).
var numFeatures = 0

// Number of samples in the dataset. This number equals ( num_rows - 1 ) where num_rows is the number of
// rows in the CSV file.
// Note: We assume that the CSV file has its first row as the column names.
var numSamples = 0

...

// Initialize the `featureColumnData`
private fun initTable() {
    // `numFeatures` = num_cols - 1
    numFeatures = rawData[ 0 ].size - 1

    // `numSamples` = num_rows - 1
    numSamples = rawData.size - 1

    // For each entry in `featureColumnData` initialize an empty ArrayList with key=column_name.
    for ( i in 0 until numFeatures ) {
        featureColumnData[ columnNames[ i ] ] = ArrayList()
    }
}

Our final step is to transform rawData into Array<FeatureColumn> so we implement a new method, populateColumns().

// Set data from `rawData` to `featureColumnData`
private fun populateColumns() {
    // Create an empty ArrayList to store the labels
    val labels = ArrayList<String>()

    // Iterate through `rawData` starting from index=1 ( as index=0 refers to the column names )
    for ( strSample in rawData.subList( 1 , rawData.size ) ) {

        // Append the label which is the last element of `strSample`.
        labels.add( strSample[ numFeatures ] )

        // Convert rest of the elements of `strSample` to FloatArray.
        val floatSample = convertStringArrayToFloatArray( strSample.sliceArray( IntRange( 0 , numFeatures - 1 ) ) )

        // Append each of the elements in above FloatArray to `featureColumnData`.
        floatSample.forEachIndexed { index, fl ->
            featureColumnData[ columnNames[ index ] ]!!.add( fl )
        }
    }

    // Transform `featureColumnData` HashMap into a Array<FeatureColumn>.
    val featureColumnsList = ArrayList<FeatureColumn>()
    for ( ( name , data ) in featureColumnData ) {
        featureColumnsList.add(
                FeatureColumn( name , data.toFloatArray() )
        )
    }
    featureColumns = featureColumnsList.toTypedArray()

    // Initialize `labels` here by getting the distinct elements from the labels column.
    this.labels = labels.toTypedArray()
    // num_classes = num_labels.
    numClasses = this.labels.distinct().size

    // We're done, call onCSVProcessed with `featureColumns` and `labels`.
    readCSVCallback.onCSVProcessed( featureColumns, this.labels )
}
  • The convertStringArrayToFloatArray is a utility method to convert Array<String> to FloatArray>.
// Convert the given String[] to a Float[]
// Note: The `CSVReader` returns a row as String[] where each String is a number. We parse this String and convert it to
// a float.
fun convertStringArrayToFloatArray( strArray : Array<String> ) : FloatArray {
    val out = strArray.map { si -> si.toFloat() }.toFloatArray()
    return out
}
  • The ReadCSVCallback is used to check whether the given CSV file has been processed, so the user is shown a progress bar while the CSV file is being processed.

So, we’ve finished with the data processing part. We shall now move towards implementing the Gaussian Naïve Bayes algorithm. You’ll see the code for this algorithm in GaussianNB.kt file.

Code + Math Walkthrough

We’ll start by defining the classifier in a simple mathematical form:

Don’t worry if you haven’t dealt with argmax notation before, it simply means that y_hat is a class cₖ where k is the argument that maximizes the given quantity p( cₖ| x ). In the above expression:

  • x is the given sample, which can be interpreted as a vector holding N features.
  • p( cₖ| x ) is the probability of x belonging to a class cₖ.
  • y_hat is the value of k, which gives the maximum value of p( cₖ| x ).

Let’s start with an eye-friendly expression for Bayes Theorem,

I would like to introduce you to some basic terminology here:

In our case, taking into consideration the naive assumption that all features are independent, we’ll use a simpler version of the Bayes Theorem to determine the class of a given sample.

To avoid *underflow, we take the log of p( xᵢ | Cₖ ).

*underflow: The term arithmetic underflow is a condition in a computer program where the result of a calculation is a number of smaller absolute value than the computer can actually represent in memory on its central processing unit (Wikipedia).

Which is equivalent to (applying basic rules of logarithms):

Our first step will be to calculate the prior probabilities for each of the classes.

The prior probability (which may be provided by the user itself, just like scikit-learn, but we’ll omit that for simplicity) for a particular class cₖ is:

To do this, we implement the computePriorProbabilities() method in GaussianNB class.

// Class to implement Gaussian Naive Bayes
class GaussianNB( private var dataFrame : DataFrame ) {

    ...

    // Prior probabilities stored in a HashMap of form ( column_name , prior_prob )
    private var priorProbabilities : HashMap<String,Float>

    ...

    // Compute the prior probabilities.
    // These probabilities are p( class=some_class ) which are calculated as
    // p( class=apple ) = num_samples_label_as_apple / num_samples_in_ds
    private fun computePriorProbabilities( labels : Array<String> ) : HashMap<String,Float> {
        // Get the count ( freq ) of each unique class in labels.
        val labelCountMap = labels.groupingBy { it }.eachCount()
        // The prior probabilties are stored in a HashMap of form ( column_name , prob )
        val out = HashMap<String,Float>()
        for ( ( label , count ) in labelCountMap ) {
            // Append the prob with key=column_name
            out[ label ] = count.toFloat() / dataFrame.numSamples.toFloat()
        }
        return out
    
    ...

}

Now comes the interesting part, how do we calculate p( xᵢ | Cₖ ) as in equation ( 5 )? Well, as we are implementing a Gaussian Naïve Bayes, there must be a Gaussian Distribution involved. Remember that the Gaussian Distribution is parameterized by μ and σ, which are the mean and the standard deviation respectively.

As we need to calculate p( xᵢ | Cₖ ) for every feature xᵢ, we create a Gaussian distribution for feature xᵢ, and to implement this in code, we create an inner class GaussianDistribution in GaussianNB.kt.

// A Gaussian Distribution which will be constructed for every feature, given the mean and standard deviation.
class GaussianDistribution( var mean : Float , var stdDev : Float ) {

    private val p = 1f / ( sqrt( 2.0 * Math.PI ) ).toFloat()

    // Get the likelihood for given x.
    fun getProb( x : Float ) : Float {
        val exp = exp( -0.5 * ((x - mean) / stdDev ).pow( 2 ) ).toFloat()
        return ( 1 / stdDev ) * p * exp
    }

    // Get the log of the likelihood for given x.
    fun getLogProb( x : Float ) : Float {
        return log10( getProb( x ) )
    }

}

We’ll create an ArrayList<GaussianDistribution> to store the GDs for each feature in given featureColumns which is an instance of Array<FeatureColumn>. While implementing the FeatureColumn class, we were computing the mean and standard deviation, which will be used here.

// Class to implement Gaussian Naive Bayes
class GaussianNB( private var dataFrame : DataFrame ) {

    // Array to store Gaussian Distributions for each feature.
    private var featureDistributions : Array<GaussianDistribution>

    // Prior probabilities stored in a HashMap of form ( column_name , prior_prob )
    private var priorProbabilities : HashMap<String,Float>

    private var resultCallback : ResultCallback? = null

    init {

        // Construct Gaussian Distributions for each feature and append them to `featureDistributions`.
        val GDs = ArrayList<GaussianDistribution>()
        for ( featureColumn in dataFrame.featureColumns ) {
            GDs.add( GaussianDistribution( featureColumn.featureMean , featureColumn.featureStdDev ))
        }
        featureDistributions = GDs.toTypedArray()

        // Compute prior probabilities and store them in `priorProbabilities`.
        priorProbabilities = computePriorProbabilities( dataFrame.labels )

    }

Note the uses of featureColumn.featureMean and featureColumn.featureStdDev in the above code snippet. Next, we create a public method predictLabel() which, when given a FloatArray, returns the label, via ResultCallback.

// Predict the label for the given sample.
private fun predictLabel( sample : FloatArray ) {
    val probArray = FloatArray( dataFrame.numClasses )
    for ( ( i , priorProb ) in priorProbabilities.values.withIndex()) {
        // We take the log probabilities so as to avoid underflow.
        var p = log10( priorProb )
        // While we take log, the product is transformed into a sum
        // log( a . b ) = log(a) + log(b)
        for ( j in 0 until dataFrame.numFeatures ) {
            p += featureDistributions[ j ].getLogProb( sample[ i ] )
        }
        probArray[ i ] = p
    }
    // Get the label with highest probability.
    val label = priorProbabilities.keys.toTypedArray()[ probArray.indexOf( probArray.max()!!) ]
    resultCallback?.onPredictionResult( label , probArray )
}
  • We initialize a probArray to store the probabilities p( Cₖ | xᵢ ) for all k classes.
  • For each class, we add the likelihood for each feature, by calling featureDistribution[ j ].getLobProb().
  • Finally, at line no. 15, we perform the arg max operation and get the index of the class which has the highest probability.

The Android App

We’ve implemented all classes for our Gaussian NB algorithm. We’ll now create some EditText , Button , and TextView to take a sample as input and display the predicted label to the user.

class MainActivity : AppCompatActivity() {

    // View elements to take the inputs and display it to the user.
    private lateinit var sampleInputEditText : TextInputEditText
    private lateinit var outputTextView : TextView

    // DataFrame object which will hold the data.
    private lateinit var dataFrame: DataFrame

    // GaussianNB object to perform the calculations and return
    // the predictions.
    private lateinit var gaussianNB: GaussianNB


    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        setContentView( R.layout.activity_main )

        // Initialize the View elements
        sampleInputEditText = findViewById( R.id.sample_input_editText )
        outputTextView = findViewById( R.id.output_textView )

        // Initialize the DataFrame object.
        // The file iris_ds.csv is kept in the assets folder of the app.
        dataFrame = DataFrame( this , "iris_ds.csv" )

    }

Adding onClick methods for the two buttons and showing the final output to the user via outputTextView :

...

// Called when predict_button is clicked. ( See activity_main.xml ).
fun onPredictButtonClick( view : View ) {
    // Split the String by ","
    var strX = sampleInputEditText.text.toString().split( "," ).toTypedArray()
    strX = strX.map{ xi -> xi.trim() }.toTypedArray()
    // Convert the String[] to float[]
    val x = dataFrame.convertStringArrayToFloatArray( strX )
    // Predict the class with GaussianNB.
    gaussianNB.predict( x , resultCallback )
}

// Called when load_data_button is clicked. ( See activity_main.xml ).
fun onLoadCSVButtonClick( view : View ) {
    // Load the CSV file.
    dataFrame.readCSV( readCSVCallback )
}

...

// Called when all calculations are done and the results could be processed
// further.
private val resultCallback = object : GaussianNB.ResultCallback {

    override fun onPredictionResult(label: String, probs: FloatArray) {
        // Display the output to the user via outputTextView.
        outputTextView.text = "Label : $label n Prob : ${probs.contentToString()}"
    }

}

That’s all! We’ve just implemented the Gaussian Naïve Bayes algorithm in an Android app!

The End

I hope you liked the story. Feel free to express your thoughts in the comments below or at [email protected]. Thank you!

Avatar photo

Fritz

Our team has been at the forefront of Artificial Intelligence and Machine Learning research for more than 15 years and we're using our collective intelligence to help others learn, understand and grow using these new technologies in ethical and sustainable ways.

Comments 0 Responses

Leave a Reply

Your email address will not be published. Required fields are marked *

wix banner square