In the era of hyper-sophisticated machine learning models like ChatGPT, it is surprising how effective the classic decision tree model remains, especially when used in conjunction with other techniques, such as bagging, boosting and random forests. In this blog post we demonstrate how to build an effective decision tree model, and train this model on some sample data.

Introduction

In earlier articles we demonstrated how simple distance metrics and a nearest-neighbour recommendation system could be implemented in ruby. We will stick with this machine learning (ML) theme in the present article and discuss the implementation of a simple decision tree algorithm, again in ruby.

We will implement a decsion tree ML algorithm and we will then train our model with data for a classification task. Once trained, our decision tree model should be able to predict the classification for previously unseen data. We will use a simple dataset as we develop the algorithm, but we will also demonstrate how the algorithm can then be applied to a more challenging problem.

The algorithms implemented in this article are inspired by the python versions presented by Toby Segaran in Collective Intelligence (Amazon), which I highly recommend for anyone interested in learning the fundamentals of classic machine learning methods.

Decision Trees

The venerable decision tree has been at the basis of human classification systems from the beginning of time, and permeates many disciplines. Many of you will have come across these ideas in biology class as a mechanism for discriminating different living oganisms, for example:

Descision tree diagram showing boxes with decision nodes and branches for the true/false responses to each question. The questions are intended to classify a living organism into one of the five biological kingdoms of animal, plant, fungi, protista or moneta, so these form the terminating/leaf nodes in the diagram.
Diagram illustrating a simple decision tree to classify living organisms according to the five biological kingdoms. (Produced with reference to this article.)

This approach is characterised by a series of binary decision nodes, where each node represents a single question with two branches emanating from it: the true branch branch and the false branch. At the end of a series of decision nodes we should hopefully have a leaf node, which defines the categorization that should apply.

Once we have a decision tree, such as the one depicted, using it to classify a new case is pretty intuitive. The question is, how do we generate this decision tree?

In machine learning we want the algorithm to build this tree for us, using existing data. Let's introduce a simple dataset which we can use as a reference as we tease out the algorithm. We will take the same dataset presented in [1]:

    
module SimpleData
  DATA = [
    ['slashdot', 'USA', 'yes', 18, 'None'],
    ['google', 'France', 'yes', 23, 'Premium'],
    ['digg', 'USA', 'yes', 24, 'Basic'],
    ['kiwitobes', 'France', 'yes', 23, 'Basic'],
    ['google', 'UK', 'no', 21, 'Premium'],
    ['(direct)', 'New Zealand', 'no', 12, 'None'],
    ['(direct)', 'UK', 'no', 21, 'Basic'],
    ['google', 'USA', 'no', 24, 'Premium'],
    ['slashdot', 'France', 'yes', 19, 'None'],
    ['digg', 'USA', 'no', 18, 'None'],
    ['google', 'UK', 'no', 18, 'None'],
    ['kiwitobes', 'UK', 'no', 19, 'None'],
    ['digg', 'New Zealand', 'yes', 12, 'Basic'],
    ['slashdot', 'UK', 'no', 21, 'None'],
    ['google', 'UK', 'no', 18, 'Basic'],
    ['kiwitobes', 'France', 'yes', 19, 'Basic']
  ]
  …
end
    
  

Each row in the dataset represents a visitor to some particular website, and the columns of each row include the following data

  1. The referrering website, from which the visitor has arrived
  2. The geographic location of the visitor
  3. Whether the visitor has read the FAQ
  4. The number of pages viewed by the visitor
  5. The service to which the visitor has subscribed
The goal is to use the data we have to train a decision tree model, and then use the resulting model to predict if a visitor is likely to take out a Premium subscription on the site. In this regard, the final column of each row is the label which categorizes the row, and this is the label that we want our decision tree to predict. It can take one of three values None, Premium or Basic.

As a brief aside, you can see that we have wrapped our DATA in a SimpleData module; the only other thing exposed by this module is a singleton method get_train_and_test_data:

    
  def self.get_train_and_test_data(split: 0.1)
    train_data = DATA.dup
    # Let's remove :split of data for testing
    full_size = train_data.size
    test_data =(0..(full_size*split)).to_a.map do |i|
      train_data.delete_at(rand(full_size-i))
    end

    [train_data, test_data]
  end
  

This method provides an interface to return our dataset as two separate arrays, a training set and a test set. This a very important concept in building machine learning models, but I will not be able to do it justice in this post. In brief, we use our training set to train (or build) our decision tree model, and then we use our test set to evaluate the accuracy of the model we have produced. The method get_train_and_test_data will randomly pull rows out of the training data to populate the test set, in the ratio specified by split.

We aim to find a series of conditions which, when applied to our data, will split it into pure sets containing only a single label. As an example, let's consider splitting our dataset by the values in the fourth column, i.e. the number of pages viewed. Suppose our first node splits the data depending if number of pages viewed > 18, how would our data look if we applied this condition:

    
 SimpleData::DATA.partition{ |row| row[3] > 18}
=> 
[[["google", "France", "yes", 23, "Premium"],
  ["digg", "USA", "yes", 24, "Basic"],
  ["kiwitobes", "France", "yes", 23, "Basic"],
  ["google", "UK", "no", 21, "Premium"],
  ["(direct)", "UK", "no", 21, "Basic"],
  ["google", "USA", "no", 24, "Premium"],
  ["slashdot", "France", "yes", 19, "None"],
  ["kiwitobes", "UK", "no", 19, "None"],
  ["slashdot", "UK", "no", 21, "None"],
  ["kiwitobes", "France", "yes", 19, "Basic"]],

 [["slashdot", "USA", "yes", 18, "None"],
  ["(direct)", "New Zealand", "no", 12, "None"],
  ["digg", "USA", "no", 18, "None"],
  ["google", "UK", "no", 18, "None"],
  ["digg", "New Zealand", "yes", 12, "Basic"],
  ["google", "UK", "no", 18, "Basic"]]]
    
  

In this snippet we are using the a very neat Enumerable#partition method. Take a look at the docs and you will see that this little beauty is tailor-made for our decision tree implementation; it allows us to split an array in two, based on whether the given block evaluates to true of false.

Using this criterion, i.e num_pageviews > 18, we have gathered all of the Premium visitors into our first partition. This definitely feels like progress, like the two partitions are somehow purer. But is there anyway that we can evaluate this perceived improvement? This is where we need to introduce the idea of entropy.

Entropy a quantity that has been defined by humans because it turns out to be very useful to help us describe different physical systems. It is often described as a measure of disorder in a system, but in our case we will use it to tell us how mixed our arrays are. We want to apply criteria that will tend to reduce the average entropy in our partitions, i.e. will make them less mixed. The entropy in any given partition of our data can be defined to be: $$ \begin{aligned} U = - \sum_{i=1}^{N} p_{i} \log_{2} p_{i} \end{aligned} $$ where $p_{i}$ is the probability of obtaining label $i$ in the partition, and $N$ is 3, as we have 3 labels to sum over.

Things will be easier if we consider an example. In our original dataset (containing 16 rows) we had a mixture of labels: 6 visitors were Basic (b), 3 visitors were Premium (p) and 7 visitors were None (n). We can calculate the entropy of this set as follows:

$$ \begin{aligned} U_{\rm{orig}} &= - \sum_{i=\{\rm{b}, \rm{p}, \rm{n}\}} p_{i} \log_{2} p_{i} \\ &= - p_{\rm{b}} \log_{2} p_{\rm{b}} - p_{\rm{p}} \log_{2} p_{\rm{p}} - p_{\rm{n}} \log_{2} p_{\rm{n}} \\ &= - \frac{6}{16} \log_{2} \left( \frac{6}{16} \right) - \frac{3}{16} \log_{2} \left( \frac{3}{16} \right) - \frac{7}{16} \log_{2} \left( \frac{7}{16} \right) \\ &= 1.5052 \end{aligned} $$

We then use our condition, num_pageviews > 18, to split the data into two separate partitions of 10 and 6 rows. Each partition has its own entropy: $$ \begin{aligned} U_{1} &= - p_{\rm{b}} \log_{2} p_{\rm{b}} - p_{\rm{p}} \log_{2} p_{\rm{p}} - p_{\rm{n}} \log_{2} p_{\rm{n}} \\ &= - \frac{4}{10} \log_{2} \left( \frac{4}{10} \right) - \frac{3}{10} \log_{2} \left( \frac{3}{10} \right) - \frac{3}{10} \log_{2} \left( \frac{3}{10} \right) \\ &= 1.5709 \\ \\ U_{2} &= - p_{\rm{b}} \log_{2} p_{\rm{b}} - p_{\rm{p}} \log_{2} p_{\rm{p}} - p_{\rm{n}} \log_{2} p_{\rm{n}} \\ &= - \frac{2}{6} \log_{2} \left( \frac{2}{6} \right) - \frac{0}{6} \log_{2} \left( \frac{0}{6} \right) - \frac{4}{6} \log_{2} \left( \frac{4}{6} \right) \\ &= 0.9183 \end{aligned} $$

To compare the entropy before and after the split we need to take the weighted average of the entropy for each partition, i.e. $\frac{10}{16} U_{1} + \frac{6}{16} U_{2} = 1.3255$. So our split criterion has succeeded in reducing the entropy by $1.5052 - 1.3255 = 0.1797$. This decrease in entropy can also be referred to as the information gain.

Now that we understand what the entropy is, the implementation in ruby should be straightforward. Our entropy method will take a set of rows, it will group the rows according to the label in the last column, getting a count for each label. The count for a given label, divided by the total number of rows, will give the probablility for that label. We then combine these probabilities using the formula already presented:

    
  def entropy(rows)
    rows.group_by(&:last).reduce(0) do |entropy,(key, value)|
      prob = value.length.to_f / rows.size
      entropy -= prob * Math.log(prob, 2)
    end
  end
    
  

We will see this entropy enter our algorithm later in the article.

Before that, let's recall that our decision to split our data on this particular attribute and value was selected quasi-randomly. We could have chosen a different pageview limit in our criterion, or we could have split by a completely different attribute. Our decision tree algoritm must explore all the possible ways to split the data, and then choose the criterion which gives the best information gain before-after the split. Let's look at how we can write this algorithm in ruby.

Representing a Decision Tree

The basis of our decision tree is the individual node. Each node is either a decision node or a leaf node, we could represent these as separate classes, but for our implementation we will handle both cases within a single class:

    
class TreeNode
  attr_reader :column_index, :value, :true_branch, :false_branch, :results

  def initialize(column_index: nil, value: nil, true_branch: nil, false_branch: nil, results: nil)
    @column_index = column_index
    @value = value
    @true_branch = true_branch
    @false_branch = false_branch
    @results = results
  end

   …
end
    
  

A decision node will be defined by a column_index, for the column upon which the node is partitioning the data, along with the value of that column which defines the threshold for the split. The decision node will also hold a reference to the true_branch and false_branch; these are references to the next TreeNode to be applied to the to the true and false partitions, respectively. For a decision node the results attribute will be nil.

By contrast, if we are dealing with a leaf node then results will be populated with the training instances (or rows) that were grouped into this leaf. As the leaf node is at the end of a tree branch the other attributes of column_index, value, true_branch and false_branch will not be populated for such a node.

With a decision tree composed in this manner, if we are given the root node we should be able to classify a previously unseen item. We do this by recursively navigating our tree nodes, applying the decision criteria at each node to decide which branch should be followed next, until we reach a leaf node. This classification process for a new item (or row) is captured in the TreeNode#classify method, which leans on the classify_prob method:

    
class TreeNode
  …
  def classify(row)
    classify_prob(row).max_by{ |_,v| v }[0]
  end

  def classify_prob(row)
    return summary_results if results
    if value.is_a?(Numeric)
      if row[column_index] >= value
        true_branch.classify_prob(row)
      else
        false_branch.classify_prob(row)
      end
    elsif row[column_index] == value
      true_branch.classify_prob(row)
    else
      false_branch.classify_prob(row)
    end
  end

  private

  def summary_results
    results.group_by(&:last).reduce({}) do |memo, (key, value)|
      (memo[key] = value.size) && memo
    end
  end

  …
end
    
  

If the current node is a decision node (i.e. where results are not present), then the classify_prob method will take the new item (or row) and will apply the descision encapsulated in this current node. We apply a slightly different threshold check depending upon whether the TreeNode#value is numeric or non-numberic. However, the recursive pattern is the same in both cases, so let's focus on the case where the current TreeNode#value is a number:

    
    if value.is_a?(Numeric)
      if row[column_index] >= value
        true_branch.classify_prob(row)
      else
        false_branch.classify_prob(row)
      end
    else …
    
  

This decision node tells us that we are interested in the attribute in column column_index, so we extract this element from row. We then compare this extracted element with the threshold value defined by our decision node. If our row element equals or exceeds the threshold value we pass the row to classify_prob on the true_branch, otherwise we pass the row to false_branch.classify_prob. Thus we are recursively applying the classify_prob method to the row, as we navigate from node to node. This recursion will exit once we reach a leaf node, which has the results populated.

If the current node is a leaf node (i.e. with results present) then the method will return a summary of the trianing rows which have been stored in that leaf. This summary will give a count of the different classes within the leaf node's results, for instance if the leaf node has the following results stored:

    
  @results = [
    ["kiwitobes", "UK", "no", 19, "None"],
    ["slashdot", "UK", "no", 21, "None"],
    ["kiwitobes", "France", "yes", 19, "Basic"],
    ["slashdot", "USA", "yes", 18, "None"]
  ]
    
  

Then the summary_results will return a hash that looks like this: { "None" => 3, "Basic" => 1 }, reflecting that this leaf node has categorized 3 cases with a label of 'None' and one case of 'Basic'. If our row is categorized to this leaf node, we can conclude that our row should be classified as 'None', with a probability of 75%. The classify method is simply responsible for pulling out this label that has the maximum probability:

    
  def classify(row)
    classify_prob(row).max_by{ |_,v| v }[0]
  end
    
  

Algorithm Implementation

We have seen how the TreeNode can be used to represent our decision tree, and how we can use an existing tree to classify new data. We now introduce our DecisionTreeTrainer, which will use the training data to build our tree of nodes:

    
class DecisionTreeTrainer
  def self.train(rows, max_depth: nil)
    self.new(rows, max_depth: max_depth).train
  end

  attr_reader :all_rows, :total_rows, :num_attributes, :root_node, :max_depth
  def initialize(rows, max_depth: 10)
    @all_rows = rows
    @num_attributes = all_rows.first.size - 1 # Never split on last attribute
    @total_rows = all_rows.size
    @root_node = nil
    @max_depth = 10
  end

  # Returns a TreeNode, which is the root of the tree we have trained
  def train
    return @root_node if @root_node
    puts "Starting training ..."
    start_time = Time.now
    @root_node = build_decision_node(self.all_rows)
    puts "Completed training in #{Time.now - start_time} seconds"
    @root_node
  end
  …
  

The DecisionTreeTrainer is initialized with our training data, which is stored in the rows attribute and we can also specify a max_depth parameter. Let's summarize the instance variables on the DecisionTreeTrainer class:

  • rows: holds our training data. This is the data that will help us to build our decision tree model.
  • max_depth: prevents our recursive algorithm from building a tree that is too deep (defaults to a value of 10).
  • num_attributes: captures the number of attributes that makes up each item in our dataset , i.e. the number of columns in each row.
  • total_rows: total number of items in our training data.
  • root_node: the ultimate result of our training is the root node of the tree we have constructed. With this node we can traverse the tree to classify new items.

The DecisionTreeTrainer#train method will return this root_node if it is defined, otherwise it will calculate the root_node by calling build_decision_node passing the full dataset. This is where all the magic happens:

    
  …
  def build_decision_node(rows, depth = 0)
    max_info_gain = 0
    split_index = -1
    split_value = nil
    initial_entropy = entropy(rows)
    return TreeNode.new(results: rows) if (initial_entropy == 0 || depth >= max_depth)

    num_attributes.times do |i|
      rows.map{|r| r[i] }.uniq.each do |value|
        new_rows = divide_set(rows, i, value)
        new_entropy = new_rows.reduce(0) do |memo, branch|
          memo+=entropy(branch)*branch.size/rows.size
        end
        info_gain = (initial_entropy - new_entropy)
        if info_gain > max_info_gain
          max_info_gain = info_gain
          split_index = i
          split_value = value
        end
      end
    end

    # OK if we have an info gain, lets split according to best criteria found
    if max_info_gain <= 0
      TreeNode.new(results: rows)
    else
      true_rows, false_rows = divide_set(rows, split_index, split_value)

      TreeNode.new(
        column_index: split_index,
        value: split_value,
        true_branch: build_decision_node(true_rows, depth + 1),
        false_branch: build_decision_node(false_rows, depth + 1),
      )
    end
  end
    
  

The build_decision_node method, again, takes our training rows and an optional depth argument, which defaults to 0. We start by calculating the initial_entropy of our training set, using the entropy method defined earlier. If the initial_entropy evaluates to zero (meaning all the rows have the same label) then we just return a leaf node with the results set accordingly. We also terminate the process in the same way if the depth value exceeds the max_depth parameter which we previously defined.

Presuming that neither of these conditions have been met, we proceed to try and find a candidate column and value, which we can use to split our training dataset. We start by looping over each column index, i in turn. Then for each column we look at all the distinct values in that column, and try to split our training set on each value, using the divide_set method.

The splitting of the training dataset by a particular value in a particular column (column_index) can be achieved pretty neatly in ruby using the Enumerable#partition method as follows:

      
  def divide_set(rows, column_index, value)
    split_function = if value.is_a? Numeric
      lambda{ |row| row[column_index] >= value }
    else
      lambda{ |row| row[column_index] == value }
    end

    rows.partition(&split_function)
  end
      
    

If the value in the column is numeric then we check each item (or row) to see if the corresponding element is greater-than-or-equal to this threshold value. If it is, then the item/row will be returned within the true partition, otherwise the item/row will be returned in the false partition.

If the value is non-numeric then the true partition will include all rows where the matching element equals the value, otherwise the row will be returned in the false partition.

The resulting new_rows will contain our split dataset, with some items in the first array (or partition), new_rows[0] ,and the remaining items in the second partition, new_rows[1]. We calculate the average entropy of these partitions, new_entropy, and compare this to the original_entropy to determine the info_gain for this particular split.

By looping over each candidate value in each column, i, we eventually find the particular column and value which give us the greatest information gain These are represented by split_index and split_value, respectively. Having defined our best split criteria, we are now in a position to build our TreeNode:

    
    if max_info_gain <= 0
      TreeNode.new(results: rows)
    else
      true_rows, false_rows = divide_set(rows, split_index, split_value)

      TreeNode.new(
        column_index: split_index,
        value: split_value,
        true_branch: build_decision_node(true_rows, depth + 1),
        false_branch: build_decision_node(false_rows, depth + 1),
      )
    end
    
  

One special case which we need to handle here is when our best split doesn't actually improve the information gain. In this case no split can improve the separation of our examples, so we just return a leaf node with the results populated with the full set of rows.

Presuming our best split does, indeed, provide a positive information gain we then proceed to split our dataset using these optimal criteria, split_index and split_value, to retrieve the two new partitions: true_rows and false_rows. We then proceed to build our TreeNode with these optimal criteria, but how do we know the next node in the tree for the true_branch and false_branch?

The answer is that we don't know these nodes yet, we need to calculate them. We calculate these nodes by calling the build_decision_node method again, recursively. For the true_branch we need to invoke the method to be trained on the true_rows partition, whilst the false_branch should be trained on the false_rows partition. In each case we increment the depth parameter by 1.

If you are anything like me, this recursive invocation will hurt your head! But understanding this part is key to understanding the whole algorithm, so take your time to let these few lines sink in. As we initialize our root TreeNode, it will partition the training data using the best criteria it can find and then it will use the partitioned data to build two new nodes for the true and false branches. As those nodes are constructed they will also partition the data they have received and build subsequent nodes and so on. This will happen until we try to build a node for which one of the exiting criteria are triggered. These exiting criteria are

  • original_entropy for the rows is zero (i.e. they all have the same label),
  • we have reached our max_depth or
  • there is no split which will result in an information gain
If any of these stopping criteria are satisfied the recursion is terminated and we just build and return a leaf node, with the set of rows which we hold at that point.

Training and running the model

And with that we have all the pieces we need to train our decision tree model. We can initialize our training data using the SimpleData module and pass the data into DecisionTreeTrainer#train to build our model. With the model in hand we can then test how accurate it is against our test data:

    
require './decision_tree_trainer'
require './simple_data'

train_rows, test_rows = SimpleData.get_train_and_test_data(split: 0.1)

tree = DecisionTreeTrainer.train(train_rows)

tree.print
correct_count = test_rows.reduce(0) do |count, row|
  count+=1 if tree.classify(row) == row[-1]
  count
end
puts "Accuracy: #{correct_count/test_rows.size.to_f}"
    
  

You can see that we also make use of a print function which we have defined on the TreeNode. This gives a basic visual representation of our tree hierarchy, but I will refer you to the source to see how this is implemented. Running this script we see the following output:

    
Starting training ...
Completed training in 0.000835017 seconds
(column: 0, value: google)
  true:
    (column: 3, value: 21)
      true:
        ["google", "France", "yes", 23, "Premium"]
        ["google", "UK", "no", 21, "Premium"]
      false:
        ["google", "UK", "no", 18, "None"]
        ["google", "UK", "no", 18, "Basic"]
  false:
    (column: 2, value: yes)
      true:
        (column: 0, value: slashdot)
          true:
            ["slashdot", "France", "yes", 19, "None"]
          false:
            ["digg", "USA", "yes", 24, "Basic"]
            ["kiwitobes", "France", "yes", 23, "Basic"]
            ["digg", "New Zealand", "yes", 12, "Basic"]
            ["kiwitobes", "France", "yes", 19, "Basic"]
      false:
        (column: 0, value: (direct))
          true:
            (column: 1, value: New Zealand)
              true:
                ["(direct)", "New Zealand", "no", 12, "None"]
              false:
                ["(direct)", "UK", "no", 21, "Basic"]
          false:
            ["digg", "USA", "no", 18, "None"]
            ["kiwitobes", "UK", "no", 19, "None"]
            ["slashdot", "UK", "no", 21, "None"]
Accuracy: 1.0
    
  

For the run displayed above the accuracy was 1.0, but in other runs we are getting accuracy of 0.5 or 0.0. This is basically owing to the fact that our dummy dataset is very small. We can better evaluate our algorithm by applying it to a more meaningful dataset.

To this end, we will consider the mushroom dataset published by UC Irvine. This dataset contains thousands of samples of gilled mushrooms, with each sample described by a set of categorical features, along with a classification of whether the sample was poisonous (p) or edible (e). The first 10 rows from the dataset are show here:

     
>head mushroom/agaricus-lepiota.data 
p,x,s,n,t,p,f,c,n,k,e,e,s,s,w,w,p,w,o,p,k,s,u
e,x,s,y,t,a,f,c,b,k,e,c,s,s,w,w,p,w,o,p,n,n,g
e,b,s,w,t,l,f,c,b,n,e,c,s,s,w,w,p,w,o,p,n,n,m
p,x,y,w,t,p,f,c,n,n,e,e,s,s,w,w,p,w,o,p,k,s,u
e,x,s,g,f,n,f,w,b,k,t,e,s,s,w,w,p,w,o,e,n,a,g
e,x,y,y,t,a,f,c,b,n,e,c,s,s,w,w,p,w,o,p,k,n,g
e,b,s,w,t,a,f,c,b,g,e,c,s,s,w,w,p,w,o,p,k,n,m
e,b,y,w,t,l,f,c,b,n,e,c,s,s,w,w,p,w,o,p,n,s,m
p,x,y,w,t,p,f,c,n,p,e,e,s,s,w,w,p,w,o,p,k,v,g
e,b,s,y,t,a,f,c,b,g,e,c,s,s,w,w,p,w,o,p,k,s,m
    
  

The first column contains the classification, p or e. The subsequent columns represent different features for the sample. For example column 2 refers to the cap-shape and can take one of the values of bell (b), conical (c), convex (x), flat (f), knobbed (k) or sunken (s). The next column is the cap-surface, which can be fibrous (f), grooves (g), scaly (y) or smooth (s). The actual details of these features is not important to our decision tree algorithm, it will find the best way to separate these samples based on their values.

We will introduce a MushroomData module to take care of loading the data from the file and putting the data in the correct format expected by our algorithm, i.e. we need to place the label data (p or e) at the end of the row rather than the start. In addition, we will not train our model on the full set of 22 features for each sample, instead we will just use the first 7 features to build our model:

     
require 'csv'

module MushroomData
  def self.get_train_and_test_data(max_rows: 10_000, split: 0.1)
    raw_data = CSV.read('mushroom/agaricus-lepiota.data')[0..max_rows]
    classifications = []
    train_data = raw_data.map{ |row| row[1..7].append(row[0]) }

    # Let's remove :split of data for testing
    full_size = train_data.size
    test_data =(0..(full_size*split)).to_a.map do |i|
      train_data.delete_at(rand(full_size-i))
    end

    [train_data, test_data]
  end
end
    
  

With this we can adapt our existing script to load the mushroom data and train our decision tree classifier with the new dataset:

     
require './decision_tree_trainer'
require './simple_data'
require './mushroom_data'

train_rows, test_rows = if ARGV.first == "mushroom"
  MushroomData.get_train_and_test_data(max_rows: 8000, split: 0.1)
else
  SimpleData.get_train_and_test_data(split: 0.1)
end

tree = DecisionTreeTrainer.train(train_rows)
tree.print
correct_count = test_rows.reduce(0) do |count, row|
  count+=1 if tree.classify(row) == row[-1]
  count
end
puts "Accuracy: #{correct_count/test_rows.size.to_f}"
    
  

Running this script we will reuse our decision tree algorithm to train a model on this completely new, and unrelated dataset. The representation of the tree produced will be output to the terminal, but we truncate it here for brevity:

     
> ruby runner.rb mushroom
Starting training ...
Completed training in 0.385516171 seconds
(column: 4, value: n)
  true:
    (column: 2, value: y)
      true:
        {"p"=>22}
      false:
        (column: 0, value: b)
          true:
            (column: 3, value: t)
        …
  false:
    (column: 3, value: t)
      true:
        (column: 4, value: f)
          true:
            {"p"=>256}
          false:
            (column: 4, value: p)
              true:
                {"p"=>231}
              false:
                {"e"=>721}
      false:
        {"p"=>2900}
Accuracy: 0.9975031210986267
    
  

An accuracy of 99.75%, not too bad for about 200 lines of ruby!

Conclusion

We have discussed the fundamental ideas behind the decision tree model, along with an algorithm for training a decision tree using labelled data. We showed how this algorithm could be implemented in ruby and tested our implementation on a realistic data set.

Whilst there are other machine learning techniques which receive a lot more attention, the decision tree remains a simple but powerful technique. It is simple to implement and intuitive to interpret. It can deal with both numeric and categorical data and requires very little data preparation.

Notwithstanding these benefits, it is not all roses for the decision tree model. Unchecked, the model has a tendency to overfit the data, i.e. generate very specific rules that only apply to the precise training data, but do not apply more generally. In addition the technique can be quite computationally expensive and does not scale easily to very large datasets.

To address some of these shortcomings, the accuracy and performance of decision trees can often be improved by employing extensions such a pruning, random forests, boosting and bagging. These techniques warrant their own discussion so that will need to wait for a future post.

I hope you enjoyed this article and found it useful. f you would like to be kept up-to-date when we publish new articles please subscribe.

References

  1. Intro to the CART algorithm
  2. Collective Intelligence by Toby Sagaran, available on Amazon
  3. chapter 7: Modeling with Decision Trees from Collective Intelligence by Toby Sagaran
  4. Article on differentiation of the biological kingoms
  5. Blog post on the virtues of the decision tree model
  6. Docs on Ruby's Enumerable#partition method
  7. Machine Learning Mastery blog post covering some background theory on decision trees
  8. GitHub repo with the code presented in this blog post
  9. Ruby docs for Array#partition method
  10. Wiki entry on training and test data for machine learning models
  11. UCI mushroom dataset
  12. Wikipedia entry for entropy

Comments

There are no existing comments

Got your own view or feedback? Share it with us below …

×

Subscribe

Join our mailing list to hear when new content is published to the VectorLogic blog.
We promise not to spam you, and you can unsubscribe at any time.