decision-forests/documentation/tutorials/dtreeviz_colab.ipynb

1340 lines
47 KiB
Plaintext

{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"source": [
"##### Copyright 2022 The TensorFlow Authors."
],
"metadata": {
"id": "kgMvP3SF-w_X"
}
},
{
"cell_type": "code",
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
],
"metadata": {
"cellView": "form",
"id": "yhrhZl5t-yUe"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Visualizing TensorFlow Decision Forest Trees with dtreeviz\n",
"\n",
"<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://www.tensorflow.org/decision_forests/tutorials/dtreeviz_colab\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
" </td>\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/decision-forests/blob/main/documentation/tutorials/dtreeviz_colab.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
" </td>\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://github.com/tensorflow/decision-forests/blob/main/documentation/tutorials/dtreeviz_colab.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View on GitHub</a>\n",
" </td>\n",
" <td>\n",
" <a href=\"https://storage.googleapis.com/tensorflow_docs/decision-forests/documentation/tutorials/dtreeviz_colab.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
" </td>\n",
"</table>\n"
],
"metadata": {
"id": "ZuQzYr8B1J1K"
}
},
{
"cell_type": "markdown",
"source": [
"## Introduction\n",
"\n",
"The [beginner tutorial](https://www.tensorflow.org/decision_forests/tutorials/beginner_colab) demonstrates how to prepare data, train, and evaluate (Random Forest, Gradient Boosted Trees and CART) classifiers and regressors using TensorFlow's Decision Forests. (We'll abbreviate TensorFlow Decision Forests *TF-DF*.) You also learned how to visualize trees using the builtin `plot_model_in_colab()` function and to display feature importance measures.\n",
"\n",
"The goal of this tutorial is to dig deeper into the interpretation of classifier and regressor decision trees through visualization. We'll look at detailed tree structure illustrations and also depictions of how decision trees partition feature space to make decisions. Tree structure plots help us understand the behavior of our model and feature space plots help us understand our data by surfacing the relationship between features and target variables.\n",
"\n",
"The visualization library we'll use is called [dtreeviz](https://github.com/parrt/dtreeviz) and, for consistency, we'll reuse the penguin and abalone data from the beginner tutorial. (To learn more about dtreeviz and the visualization of decision trees, see the [YouTube video](https://www.youtube.com/watch?v=4FC1D9SuDBc) or the article on the [design of dtreeviz](https://explained.ai/decision-tree-viz/index.html)).\n",
"\n",
"In this tutorial, you'll learn how to\n",
"\n",
"* display the structure of decision trees from a TF-DF forest\n",
"* alter the size and style of dtreeviz tree structure plots\n",
"* plot leaf information, such as the number of instances per leaf, the distribution of target values in each leaf, and various statistics about leaves\n",
"* trace a tree's interpretation for a specific instance and show the path from the root to the leaf that makes the prediction\n",
"* print an English interpretation of how the tree interprets an instance\n",
"* view one and two dimensional feature spaces to see how the model partitions them into regions of similar instances"
],
"metadata": {
"id": "WROGgnC0_X4n"
}
},
{
"cell_type": "markdown",
"source": [
"## Setup"
],
"metadata": {
"id": "gl0tV6RueIb3"
}
},
{
"cell_type": "markdown",
"source": [
"### Install TF-DF and dtreeviz"
],
"metadata": {
"id": "oC6TkF60jjRV"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rWVR82cI2XBD"
},
"outputs": [],
"source": [
"!pip install -q -U tensorflow_decision_forests==1.9.2 # Warning: dtreeviz is not compatible with TF-DF >= 1.10.0"
]
},
{
"cell_type": "code",
"source": [
"!pip install -q -U dtreeviz"
],
"metadata": {
"id": "cM2_M_KY2jjr"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### Import libraries"
],
"metadata": {
"id": "am88uNiGjpQN"
}
},
{
"cell_type": "code",
"source": [
"import tensorflow_decision_forests as tfdf\n",
"\n",
"import tensorflow as tf\n",
"\n",
"import os\n",
"import numpy as np\n",
"import pandas as pd\n",
"import tensorflow as tf\n",
"import math\n",
"\n",
"import dtreeviz\n",
"\n",
"from matplotlib import pyplot as plt\n",
"from IPython import display\n",
"\n",
"# avoid \"Arial font not found warnings\"\n",
"import logging\n",
"logging.getLogger('matplotlib.font_manager').setLevel(level=logging.CRITICAL)\n",
"\n",
"display.set_matplotlib_formats('retina') # generate hires plots\n",
"\n",
"np.random.seed(1234) # reproducible plots/data for explanatory reasons"
],
"metadata": {
"id": "UU8lDr622ZWi"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Let's check the versions:\n",
"tfdf.__version__, dtreeviz.__version__ # want dtreeviz >= 2.2.0"
],
"metadata": {
"id": "tqu6EVKMbX3N"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"It'll be handy to have a function to split a data set into training and test sets so let's define one:"
],
"metadata": {
"id": "VUXskgdymAPe"
}
},
{
"cell_type": "code",
"source": [
"def split_dataset(dataset, test_ratio=0.30, seed=1234):\n",
" \"\"\"\n",
" Splits a panda dataframe in two, usually for train/test sets.\n",
" Using the same random seed ensures we get the same split so\n",
" that the description in this tutorial line up with generated images.\n",
" \"\"\"\n",
" np.random.seed(seed)\n",
" test_indices = np.random.rand(len(dataset)) < test_ratio\n",
" return dataset[~test_indices], dataset[test_indices]"
],
"metadata": {
"id": "Ia3YOR8mmCtI"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Visualizing Classifier Trees\n",
"\n",
"<img align=\"right\" src=\"https://allisonhorst.github.io/palmerpenguins/reference/figures/lter_penguins.png\" width=\"150\"/>Using the penguin data, let's build a classifier to predict the `species` (`Adelie`, `Gentoo`, or `Chinstrap`) from the other 7 columns. Then, we can use dtreeviz to display the tree and interrogate the model to learn more about how it makes decisions and to learn more about our data."
],
"metadata": {
"id": "dfFjKqUeTpxe"
}
},
{
"cell_type": "markdown",
"source": [
"### Load, clean, and prep data\n",
"\n",
"As we did in the beginner tutorial, let's start by downloading the penguin data and get it into a pandas dataframe."
],
"metadata": {
"id": "7wDQTQ3mU50S"
}
},
{
"cell_type": "code",
"source": [
"# Download the Penguins dataset\n",
"!wget -q https://storage.googleapis.com/download.tensorflow.org/data/palmer_penguins/penguins.csv -O /tmp/penguins.csv\n",
"\n",
"# Load a dataset into a Pandas Dataframe.\n",
"df_penguins = pd.read_csv(\"/tmp/penguins.csv\")\n",
"df_penguins.head(3)"
],
"metadata": {
"id": "zxy8Z70z4gf-"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"A quick check shows that there are missing values in the data set:"
],
"metadata": {
"id": "_f7uchI4m8o4"
}
},
{
"cell_type": "code",
"source": [
"df_penguins.columns[df_penguins.isna().any()].tolist()"
],
"metadata": {
"id": "9ezd79LBnM53"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Rather than impute missing values, let's just drop incomplete rows to focus on visualization for this tutorial:"
],
"metadata": {
"id": "XeR5el2An8bS"
}
},
{
"cell_type": "code",
"source": [
"df_penguins = df_penguins.dropna() # E.g., 19 rows have missing sex etc..."
],
"metadata": {
"id": "jrNi7JspAre2"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"TF-DF requires classification labels to be integers in [0, num_labels), so let's convert the label column `species` from strings to integers.\n",
"\n",
"**Note:** TF-DF supports categorical string input features. You don't need to encode any feature values."
],
"metadata": {
"id": "_1WLYlZloLcC"
}
},
{
"cell_type": "code",
"source": [
"penguin_label = \"species\" # Name of the classification target label\n",
"classes = list(df_penguins[penguin_label].unique())\n",
"df_penguins[penguin_label] = df_penguins[penguin_label].map(classes.index)\n",
"\n",
"print(f\"Target '{penguin_label}'' classes: {classes}\")\n",
"df_penguins.head(3)"
],
"metadata": {
"id": "RBYJMGmn4lgu"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Now, let's get a 70-30 split for training and testing using our convenience function defined above, and then convert those dataframes into tensorflow data sets."
],
"metadata": {
"id": "5d81SeuHq_2i"
}
},
{
"cell_type": "markdown",
"source": [
"### Split train/test set and train model"
],
"metadata": {
"id": "4jljZTvCVCR9"
}
},
{
"cell_type": "code",
"source": [
"# Split into training and test sets\n",
"train_ds_pd, test_ds_pd = split_dataset(df_penguins)\n",
"print(f\"{len(train_ds_pd)} examples in training, {len(test_ds_pd)} examples for testing.\")\n",
"\n",
"# Convert to tensorflow data sets\n",
"train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_ds_pd, label=penguin_label)\n",
"test_ds = tfdf.keras.pd_dataframe_to_tf_dataset(test_ds_pd, label=penguin_label)"
],
"metadata": {
"id": "fU49bP6C5dlJ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### Train a random forest classifier"
],
"metadata": {
"id": "D6Dfhu3mrutT"
}
},
{
"cell_type": "code",
"source": [
"# Train a Random Forest model\n",
"cmodel = tfdf.keras.RandomForestModel(verbose=0, random_seed=1234)\n",
"cmodel.fit(train_ds)"
],
"metadata": {
"id": "-Veh05HX5hU2"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Just to verify that everything is working properly, let's check the accuracy of the model, which should be about 99%:"
],
"metadata": {
"id": "DcUeypEMsaFR"
}
},
{
"cell_type": "code",
"source": [
"cmodel.compile(metrics=[\"accuracy\"])\n",
"cmodel.evaluate(test_ds, return_dict=True, verbose=0)"
],
"metadata": {
"id": "7BNmaeLqsJOb"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Yep, the model is accurate on the test set."
],
"metadata": {
"id": "rWMj6nSgBqGe"
}
},
{
"cell_type": "markdown",
"source": [
"### Display decision tree\n",
"\n",
"Now that we have a model, let's pick one of the trees in the random forest and take a look at its structure. The dtreeviz library asks us to bundle up the TF-DF model with the associated training data, which it can then use to repeatedly interrogate the model.\n",
"\n"
],
"metadata": {
"id": "mK7OsBZuVStM"
}
},
{
"cell_type": "code",
"source": [
"# Tell dtreeviz about training data and model\n",
"penguin_features = [f.name for f in cmodel.make_inspector().features()]\n",
"viz_cmodel = dtreeviz.model(cmodel,\n",
" tree_index=3,\n",
" X_train=train_ds_pd[penguin_features],\n",
" y_train=train_ds_pd[penguin_label],\n",
" feature_names=penguin_features,\n",
" target_name=penguin_label,\n",
" class_names=classes)"
],
"metadata": {
"id": "-p6DsbVZ5yFF"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"The most common dtreeviz API function is `view()`, which displays the structure of the tree as well as the feature distributions for the instances associated with each decision node."
],
"metadata": {
"id": "VOw7d4SawtIo"
}
},
{
"cell_type": "code",
"source": [
"viz_cmodel.view(scale=1.2)"
],
"metadata": {
"id": "nYUeb1js5o8J"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"The root of the decision tree indicates that classification begins by testing the `flipper_length_mm` feature with a split value of 206. If a test instance's `flipper_length_mm` feature value is less than 206, the decision tree descends the left child. If it is larger or equal to 206, classification proceeds by descending the right child.\n",
"\n",
"To see why the model chose to split the training data at `flipper_length_mm`=206, let's zoom in on the root node:"
],
"metadata": {
"id": "x-MsxE4-jYLW"
}
},
{
"cell_type": "code",
"source": [
"viz_cmodel.view(depth_range_to_display=[0,0], scale=1.5)"
],
"metadata": {
"id": "E_RCrfXtBMB1"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"It's clear to the human eye that almost all instances to the right of 206 are blue (`Gentoo` Penguins). So, with a single feature comparison, the model can split the training data into a fairly pure `Gentoo` group and a mixed group. (The model will further purify the subgroups with future splits below the root.)\n",
"\n",
"The decision tree also has a categorical decision node, which can test category subsets rather than simple numeric splits. For example, let's take a look at the second level of the tree:\n",
"\n",
"<!-- <img src=\"images/dtreeviz-class-catvars.png\" width=\"400\"> -->"
],
"metadata": {
"id": "MaLbmDVuBVaJ"
}
},
{
"cell_type": "code",
"source": [
"viz_cmodel.view(depth_range_to_display=[1,1], scale=1.5)"
],
"metadata": {
"id": "FIwp-ooxCsMw"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"The node (on the left) tests feature `island` and, if a test instance has `island==Dream`, classification proceeds down it's right child. For the other two categories, `Torgersen` and `Biscoe`, classification proceeds down it's left child. (The `bill_length_mm` node on the right in this plot is not relevant to this discussion of categorical decision nodes.)\n",
"\n",
"This splitting behavior highlights that decision trees partition feature space into regions with the goal of increasing target value purity. We'll look at feature space in more detail below.\n",
"\n",
"Decision trees can get very large and it's not always useful to plot them in their entirety. But, we can look at simpler versions of the tree, portions of the tree, the number of training instances in the various leaves (where predictions are made), etc... Here's an example where we turn off the fancy decision node distribution illustrations and scale the whole image down to 75%:"
],
"metadata": {
"id": "ctndFMe8CqqR"
}
},
{
"cell_type": "code",
"source": [
"viz_cmodel.view(fancy=False, scale=.75)"
],
"metadata": {
"id": "Mj-LiHi4KP-h"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"We can also use a left-to-right orientation, which sometimes results in a smaller plot:"
],
"metadata": {
"id": "x9IYKl5VzDTD"
}
},
{
"cell_type": "code",
"source": [
"viz_cmodel.view(orientation='LR', scale=.75)"
],
"metadata": {
"id": "OrARL5XmJrak"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"If you're not a big fan of the pie charts, you can also get bar charts."
],
"metadata": {
"id": "3ubba6J_bGCG"
}
},
{
"cell_type": "code",
"source": [
"viz_cmodel.view(leaftype='barh', scale=.75)"
],
"metadata": {
"id": "V8qMliA-bN38"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### Examining leaf stats\n",
"\n",
"Decision trees make decisions at the leaf nodes and so it is sometimes useful to zoom in on those, particularly if the entire graph is too large to see all at once. Here is how to examine the number of training data instances that are grouped into each leaf node:"
],
"metadata": {
"id": "RyyMHTolg9VN"
}
},
{
"cell_type": "code",
"source": [
"viz_cmodel.leaf_sizes(figsize=(5,1.5))"
],
"metadata": {
"id": "mpMYp3vqOXpB"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"A perhaps more interesting graph is one that shows the proportion of each kind of training instance in the various leaves. The goal of training is to have leaves with a single color because it represents \"pure\" nodes that can predict that class with high confidence."
],
"metadata": {
"id": "NhKTJzE9z-jD"
}
},
{
"cell_type": "code",
"source": [
"viz_cmodel.ctree_leaf_distributions(figsize=(5,1.5))"
],
"metadata": {
"id": "rjCAJHubex8E"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"We can also zoom in on a specific leaf node to look at some stats of the various instance features. For example, leaf node 5 contains 31 instances, 24 of which have unique `bill_length_mm` values:"
],
"metadata": {
"id": "4S53gu2x0YWY"
}
},
{
"cell_type": "code",
"source": [
"viz_cmodel.node_stats(node_id=5)"
],
"metadata": {
"id": "EnNMnQvyhLoR"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### How decision trees classify an instance\n",
"\n",
"Now that we've looked at the structure and contents of a decision tree, let's figure out how the classifier makes a decision for a specific instance. By passing in an instance (a feature vector) as argument `x`, the `view()` function will highlight the path from the root to the leaf pursued by the classifier to make the prediction for that instance:"
],
"metadata": {
"id": "yyUmFRLIURN6"
}
},
{
"cell_type": "code",
"source": [
"x = train_ds_pd[penguin_features].iloc[20]\n",
"viz_cmodel.view(x=x, scale=.75)"
],
"metadata": {
"id": "N2wEfuctKZyY"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"The illustration highlights the tree path and the instance features that were tested (`island`, `bill_length_mm`, and `flipper_length_mm`).\n",
"\n",
"For a very large tree, you can also ask to see just the path through the tree, and not the entire tree, by using the `show_just_path` parameter:"
],
"metadata": {
"id": "e8RUIBz0385V"
}
},
{
"cell_type": "code",
"source": [
"viz_cmodel.view(x=x, show_just_path=True, scale=.75)"
],
"metadata": {
"id": "LGtTtt5bXb5n"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"To obtain the English interpretation for the classification of an instance, the smallest possible representation, use `explain_prediction_path()`:"
],
"metadata": {
"id": "Vs_qb1Y_4nBN"
}
},
{
"cell_type": "code",
"source": [
"print(viz_cmodel.explain_prediction_path(x=x))"
],
"metadata": {
"id": "t5D0VZr6fimw"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"The model tests `x`'s `bill_length_mm`, `flipper_length_mm`, and `island` features to reach the leaf, which in this case, predicts `Adelie`."
],
"metadata": {
"id": "-LOuAvrJEvC-"
}
},
{
"cell_type": "markdown",
"source": [
"### Feature space partitioning\n",
"\n",
"So far we've looked at the structure of trees and how trees interpret instances to make decisions, but what exactly are the decision nodes doing? Decision trees partition feature space into groups of observations that share similar target values. Each leaf represents the partitioning resulting from the sequence of feature splitting performed from the root down to that leaf. For classification, the goal is to get partitions to share the same or mostly the same target class value.\n",
"\n",
"If we look back at the tree structure, we see that variable `flipper_length_mm` is tested by three nodes in the tree. The corresponding decision node split values are 189, 206, and 210.5, which means that the decision tree is splitting `flipper_length_mm` into four regions, which we can illustrate using `ctree_feature_space()`:\n",
"\n"
],
"metadata": {
"id": "cE1FtBAyzcOu"
}
},
{
"cell_type": "code",
"source": [
"viz_cmodel.ctree_feature_space(features=['flipper_length_mm'], show={'splits','legend'}, figsize=(5,1.5))"
],
"metadata": {
"id": "9NExSEZw6Wp6"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"(The vertical axis is not meaningful in this single-feature case. To increase visibility, that vertical axis just separates the dots representing different target classes into different elevations with some noise added.)\n",
"\n",
"The first split at 206 (tested at the root) separates the training data into an overlapping region of Adelie/Gentoo Penguins and a fairly region of Chinstrap Penguins. The subsequent split at 210.5 further isolates a region of pure Chinstrap (above 210.5 flipper length). The decision tree also splits at 189, but the resulting regions are still impure. The tree relies on splitting by other variables to separate the \"confused\" clumps of `Adelie`/`Gentoo` Penguins. Because we have passed in a single feature name, no splits are shown for other features.\n",
"\n",
"Let's look at another feature that has more splits, `bill_length_mm`. There are four nodes in the decision tree that test that feature and so we get a feature space split into five regions. Notice how the model can split off a pure region of `Adelie` by testing for `bill_length_mm` less than 40:"
],
"metadata": {
"id": "Ij_ypRl27cti"
}
},
{
"cell_type": "code",
"source": [
"viz_cmodel.ctree_feature_space(features=['bill_length_mm'], show={'splits','legend'},\n",
" figsize=(5,1.5))"
],
"metadata": {
"id": "Jl0qjk_UKuJ6"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"We can also examine the how the tree partitions feature space for two features at once, such as `flipper_length_mm` and `bill_length_mm`:"
],
"metadata": {
"id": "Ylq5vqHa8lJI"
}
},
{
"cell_type": "code",
"source": [
"viz_cmodel.ctree_feature_space(features=['flipper_length_mm','bill_length_mm'],\n",
" show={'splits','legend'}, figsize=(5,5))"
],
"metadata": {
"id": "5ggvD7wqM6Y8"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"The color of the region indicates the color of the classification for test instances whose features fall in that region.\n",
"\n",
"By considering two variables at once, the decision tree can create much more pure (rectangular) regions, leading to more accurate predictions. For example, the upper left region encapsulates purely `Chinstrap` penguins.\n",
"\n",
"Depending on the variables we choose, the regions will be more or less pure. Here is another 2D feature space partition for features `bill_depth_mm` and `bill_length_mm`, where shades indicate uncertainty."
],
"metadata": {
"id": "dVZzVtuwTYpF"
}
},
{
"cell_type": "code",
"source": [
"viz_cmodel.ctree_feature_space(features=['body_mass_g','bill_length_mm'],\n",
" show={'splits','legend'}, figsize=(5,5))"
],
"metadata": {
"id": "kbSAHR9q5pGP"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Only the `Adelie` region is fairly pure. The tree relies on other variables to get a better partition, as we just saw with `flipper_length_mm` vs `bill_length_mm` space.\n",
"\n",
"The dtreeviz library cannot visualize more than two feature dimensions for classification at this time.\n",
"\n",
"At this point, you've got a good handle on how to visualize the structure of decision trees, how trees partition feature space, and how trees classify test instances. Let's turn now to regression and see how dtreeviz visualizes regression trees."
],
"metadata": {
"id": "dl50QFXBUhIF"
}
},
{
"cell_type": "markdown",
"source": [
"## Visualizing Regressor Trees\n",
"\n",
"<img align=\"right\" src=\"https://upload.wikimedia.org/wikipedia/commons/8/87/Abalone_%28PSF%29.png\" width=\"120\">Let's use the [abalone dataset](https://storage.googleapis.com/download.tensorflow.org/data/abalone_raw.csv) used in the beginner tutorial to explore the structure of regression trees. As we did for classification above, we start by loading and preparing data for training. Given 8 variables, we'd like to predict the number of rings in an abalone's shell.\n"
],
"metadata": {
"id": "_lmlW71CzP9v"
}
},
{
"cell_type": "markdown",
"source": [
"### Load, clean, and prep data\n",
"\n",
"Using the following code snippet, we can see that the features are all numeric except for the `Type` (sex) variable."
],
"metadata": {
"id": "7qIzJwYsbpmo"
}
},
{
"cell_type": "code",
"source": [
"# Download the dataset.\n",
"!wget -q https://storage.googleapis.com/download.tensorflow.org/data/abalone_raw.csv -O /tmp/abalone.csv\n",
"\n",
"df_abalone = pd.read_csv(\"/tmp/abalone.csv\")\n",
"df_abalone.head(3)"
],
"metadata": {
"id": "CzFEh3sOzRbK"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Fortunately, there's no missing data to deal with:"
],
"metadata": {
"id": "L6RmA-2YXWYr"
}
},
{
"cell_type": "code",
"source": [
"df_abalone.isna().any()"
],
"metadata": {
"id": "rmvDHSKpX715"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### Split train/test set and train model"
],
"metadata": {
"id": "FvmbvEfCbubB"
}
},
{
"cell_type": "code",
"source": [
"abalone_label = \"Rings\" # Name of the classification target label\n",
"\n",
"# Split into training and test sets 70/30\n",
"df_train_abalone, df_test_abalone = split_dataset(df_abalone)\n",
"print(f\"{len(df_train_abalone)} examples in training, {len(df_test_abalone)} examples for testing.\")\n",
"\n",
"# Convert to tensorflow data sets\n",
"train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(df_train_abalone, label=abalone_label, task=tfdf.keras.Task.REGRESSION)\n",
"test_ds = tfdf.keras.pd_dataframe_to_tf_dataset(df_test_abalone, label=abalone_label, task=tfdf.keras.Task.REGRESSION)"
],
"metadata": {
"id": "hiERDOZg1X_p"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### Train a random forest regressor\n",
"\n",
"Now that we have training and test sets, let's train a random forest regressor. Because of the nature of the data, we need to artificially restrict the height of the tree in order to visualize it. (Restricting the tree depth is also a form of regularization to prevent overfitting.) A max depth of 5 is deep enough to be fairly accurate but small enough to visualize."
],
"metadata": {
"id": "KhVF32ZorzcE"
}
},
{
"cell_type": "code",
"source": [
"rmodel = tfdf.keras.RandomForestModel(task=tfdf.keras.Task.REGRESSION,\n",
" max_depth=5, # don't let the tree get too big\n",
" random_seed=1234, # create same tree every time\n",
" verbose=0)\n",
"rmodel.fit(x=train_ds)"
],
"metadata": {
"id": "7RZQNlzc1n3T"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Let's check the accuracy of the model using MAE and MSE. The range of `Rings` is 1-27, so an MAE of 1.66 on the test set is not great but it's OK for our demonstration purposes."
],
"metadata": {
"id": "4moHNWt_ZM_W"
}
},
{
"cell_type": "code",
"source": [
"# Evaluate the model on the test dataset.\n",
"rmodel.compile(metrics=[\"mae\",\"mse\"])\n",
"evaluation = rmodel.evaluate(test_ds, return_dict=True, verbose=0)\n",
"\n",
"print(f\"MSE: {evaluation['mse']}\")\n",
"print(f\"MAE: {evaluation['mae']}\")\n",
"print(f\"RMSE: {math.sqrt(evaluation['mse'])}\")"
],
"metadata": {
"id": "3O3vZ4qx18FJ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### Display decision tree\n",
"\n",
"To use dtreeviz, we need to bundle up the model and the training data. We also have to choose a particular tree from the random forest to display; let's choose tree 3 as we did for classification."
],
"metadata": {
"id": "bWZrhDfTfPQ4"
}
},
{
"cell_type": "code",
"source": [
"abalone_features = [f.name for f in rmodel.make_inspector().features()]\n",
"viz_rmodel = dtreeviz.model(rmodel, tree_index=3,\n",
" X_train=df_train_abalone[abalone_features],\n",
" y_train=df_train_abalone[abalone_label],\n",
" feature_names=abalone_features,\n",
" target_name='Rings')"
],
"metadata": {
"id": "mUqC6IGB2VWU"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Function `view()` displays the structure of the tree, but now the decision nodes are scatterplots not stacked bar charts. Each decision node shows a marginal plot of the indicated variable versus the target (`Rings`):"
],
"metadata": {
"id": "MMOM77IXatzF"
}
},
{
"cell_type": "code",
"source": [
"viz_rmodel.view(scale=1.2)"
],
"metadata": {
"id": "n0aIEGkm2qOd"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"As with classification, regression proceeds from the root of the tree towards a specific leaf, which ultimately makes the prediction for a specific test instance. The nodes on the path to the leaf test numeric or categorical variables, directing the regressor into a specific region of feature space that (hopefully) has very similar target values.\n",
"\n",
"The leaves are strip plots that show the target variable `Rings` values for all instances in the leaf. The horizontal parameter is not meaningful and is just a bit of noise to separate the dots so we can see where the density lies. Consider the lower left leaf with n=10, Rings=3.30. That indicates that the average `Rings` value for the 10 instances in that leaf is 3.30, which is then the prediction from the decision tree for any test instance that reaches that leaf.\n",
"\n",
"Let's zoom in on the root of the tree to see how the regressor splits on variable `ShellWeight`:\n"
],
"metadata": {
"id": "OvTpoqiVcTES"
}
},
{
"cell_type": "code",
"source": [
"viz_rmodel.view(depth_range_to_display=[0,0], scale=2)"
],
"metadata": {
"id": "AmG_4hkHRDqM"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"For a test instance with `ShellWeight<0.164`, the regressor proceeds down the left child of the root; otherwise it proceeds down the right child. The horizontal dashed lines indicate the average `Rings` value associated with instances whose `ShellWeight` is above or below 0.164.\n",
"\n",
"Decision nodes for categorical variables, on the other hand, test subsets of categories since categories are unordered. In the fourth level of the tree, there are two decision nodes that test categorical variable `Type`:\n",
"\n",
"<!-- <img src=\"images/dtreeviz-regr-catvars.png\" width=\"800\"> -->"
],
"metadata": {
"id": "m_WduAzdRlLK"
}
},
{
"cell_type": "code",
"source": [
"viz_rmodel.view(depth_range_to_display=[3,3], scale=1.5)"
],
"metadata": {
"id": "U_X6pekdSGbT"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Regressor nodes that test categoricals use color to indicate subsets. For example, the decision node on the left at the fourth level directs the regressor to descend to the left if the test instance it has `Type=I` or `Type=F`; otherwise the regressor descends to the right. The yellow and blue colors indicate the two categorical value subsets associated with left and right branches. The horizontal dashed lines indicate the average `Rings` target value for instances with the associated categorical value(s).\n",
"\n",
"To display large trees, you can use the orientation parameter to get a left to right version of the tree, although it is fairly tall so using scale to shrink it is a good idea. Using a screen zoom-in feature on your machine, you can zoom in on areas of interest."
],
"metadata": {
"id": "dS28snShRRTX"
}
},
{
"cell_type": "code",
"source": [
"viz_rmodel.view(orientation='LR', scale=.5)"
],
"metadata": {
"id": "oLFUeZbhcJvc"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"We can save space with the non-fancy plot. It still shows the decision node split variables and split points; it's just not as pretty."
],
"metadata": {
"id": "sHkVF5fxbjUs"
}
},
{
"cell_type": "code",
"source": [
"viz_rmodel.view(fancy=False, scale=.75)"
],
"metadata": {
"id": "5rpKz3C-hhTI"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### Examining leaf stats\n",
"\n",
"When graphs get very large, it's sometimes better to focus on the leaves. Function `leaf_sizes()` indicates the number of instances found in each leaf:"
],
"metadata": {
"id": "QuA-7t5fgy1D"
}
},
{
"cell_type": "code",
"source": [
"viz_rmodel.leaf_sizes(figsize=(5,1.5))"
],
"metadata": {
"id": "3GtflGu3gYeY"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"We can also look at the distribution of instances in the leaves (`Rings` values). The vertical axis has a \"row\" for each leaf and the horizontal axis shows the distribution of `Rings` values for instances in each leaf. The column on the right shows the average target value for each leaf."
],
"metadata": {
"id": "fecr-p8jeHPj"
}
},
{
"cell_type": "code",
"source": [
"viz_rmodel.rtree_leaf_distributions(figsize=(5,5))"
],
"metadata": {
"id": "RqoC8NDAdYnH"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Alternatively, we can get information on the features of the instances in a particular node. For example, here's how to get information on features in leaf id 29, the leaf with the most instances:"
],
"metadata": {
"id": "lQp_iRmkef04"
}
},
{
"cell_type": "code",
"source": [
"viz_rmodel.node_stats(node_id=29)"
],
"metadata": {
"id": "nJN9Zol1ekh9"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### How decision trees predict a value for an instance\n",
"\n",
"To make a prediction for a specific instance, the decision tree weaves its way from the root down to a specific leaf, according to the feature values in the test instance. The prediction of the individual tree is just the average of the `Rings` values from instances (from the training set) residing in that leaf. The dtreeviz library can illustrate this process if we provide a test instance via parameter `x`."
],
"metadata": {
"id": "UmVioRJsfXR4"
}
},
{
"cell_type": "code",
"source": [
"x = df_abalone[abalone_features].iloc[1234]\n",
"viz_rmodel.view(x=x, scale=.75)"
],
"metadata": {
"id": "B7S2xj-Z2wpf"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"If that visualization is too large, we can cut down the plot to just the path from the root to the leaf that is actually traversed:"
],
"metadata": {
"id": "ZaYbnEc9ganl"
}
},
{
"cell_type": "code",
"source": [
"viz_rmodel.view(x=x, show_just_path=True, scale=1.0)"
],
"metadata": {
"id": "x7FWiRMcfBIk"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"We can make it even smaller using a horizontal orientation:"
],
"metadata": {
"id": "rs1sKf6mgnoO"
}
},
{
"cell_type": "code",
"source": [
"viz_rmodel.view(x=x, show_just_path=True, scale=.75, orientation=\"LR\")"
],
"metadata": {
"id": "yNIIDqMA5V96"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Sometimes it's easier just to get an English description of how the model tested our feature values to make a decision:"
],
"metadata": {
"id": "8KoIVLEZgxQa"
}
},
{
"cell_type": "code",
"source": [
"print(viz_rmodel.explain_prediction_path(x=x))"
],
"metadata": {
"id": "qfldMHE6fptn"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### Feature space partitioning\n",
"\n",
"Using `rtree_feature_space()`, we can see how the decision tree partitions a feature space via a collection of splits. For example, here is how the decision tree partitions feature `ShellWeight`:"
],
"metadata": {
"id": "wnldbOEbfgXj"
}
},
{
"cell_type": "code",
"source": [
"viz_rmodel.rtree_feature_space(features=['ShellWeight'],\n",
" show={'splits'})"
],
"metadata": {
"id": "RH7UhZnIBJUN"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"The horizontal orange bars indicate the average `Rings` value within each region. Here's another example using feature `Diameter` (with only one split in the tree):"
],
"metadata": {
"id": "f-XveSErcsl0"
}
},
{
"cell_type": "code",
"source": [
"viz_rmodel.rtree_feature_space(features=['Diameter'], show={'splits'})"
],
"metadata": {
"id": "Nstj5J6uBVvj"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"We can also look at two dimensional feature space, where the `Rings` values vary in color from green (low) to blue (high):"
],
"metadata": {
"id": "3caTuOEolAwy"
}
},
{
"cell_type": "code",
"source": [
"viz_rmodel.rtree_feature_space(features=['ShellWeight','LongestShell'], show={'splits'})"
],
"metadata": {
"id": "-nEsJ-eXjLM2"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"That heat map can be confusing because it's really a 2D projection of a 3D space: two features x target value. Instead, dtreeviz can show you this three-dimensional plot (from a variety of angles and elevations):"
],
"metadata": {
"id": "SPLzkfDDlqwr"
}
},
{
"cell_type": "code",
"source": [
"viz_rmodel.rtree_feature_space3D(features=['ShellWeight','LongestShell'],\n",
" show={'splits'}, elev=30, azim=140, dist=11, figsize=(9,8))"
],
"metadata": {
"id": "PtsRRc8gip99"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"If `ShellWeight` and `LongestShell` were the only features tested by the model, there would be no overlapping vertical \"plates\". Each 2D region of feature space would make a unique prediction. In this tree, there are other features that differentiate between ambiguous vertical prediction regions."
],
"metadata": {
"id": "3sUMxuFOnmui"
}
},
{
"cell_type": "markdown",
"source": [
"At this point, you've learned how to use [dtreeviz](https://github.com/parrt/dtreeviz) to display the structure of decision trees, plot leaf information, trace how a model interprets a specific instance, and how a model partitions future space. You're ready to visualize and interpret trees using your own data sets!\n",
"\n",
"From here, you might also consider checking out these colabs: [Intermediate colab](https://www.tensorflow.org/decision_forests/tutorials/intermediate_colab) or [Making predictions](https://www.tensorflow.org/decision_forests/tutorials/predict_colab)."
],
"metadata": {
"id": "deDL-aQtv8sQ"
}
}
]
}