决策树的训练和可视化

为了理解决策树,我们需要先构建一个决策树并亲身体验它到底如何进行预测。

接下来的代码就是在我们熟知的鸢尾花数据集上进行一个决策树分类器的训练。

  1. from sklearn.datasets import load_iris
  2. from sklearn.tree import DecisionTreeClassifier
  3. iris = load_iris()
  4. X = iris.data[:, 2:] # petal length and width y = iris.target
  5. tree_clf = DecisionTreeClassifier(max_depth=2)
  6. tree_clf.fit(X, y)

你可以通过使用export_graphviz()方法,通过生成一个叫做iris_tree.dot的图形定义文件将一个训练好的决策树模型可视化。

  1. from sklearn.tree import export_graphviz
  2. export_graphviz(
  3. tree_clf,
  4. out_file=image_path("iris_tree.dot"),
  5. feature_names=iris.feature_names[2:],
  6. class_names=iris.target_names,
  7. rounded=True,
  8. filled=True
  9. )

译者注:这段代码本人执行不成功,image_path未定义,换其他方法才画出图来。可能是版本原因?

然后,我们可以利用graphviz package [1] 中的dot命令行,将.dot文件转换成 PDF 或 PNG 等多种数据格式。例如,使用命令行将.dot文件转换成.png文件的命令如下:

[1] Graphviz是一款开源图形可视化软件包,http://www.graphviz.org/

  1. $ dot -Tpng iris_tree.dot -o iris_tree.png

我们的第一个决策树如图 6-1。

1528081141956