3.5.2 分类

3.5.2.1 KNN分类器

3.5.2 分类 - 图1

可能最简单的分类器是最接近的邻居: 给定一个观察,使用在N维空间中训练样例中最接近它标签,这里N是每个样例的特征数。

K个最临近的邻居分类器内部使用基于ball tree的算法,用来代表训练的样例。

KNN (K个最临近邻居) 分类的例子:

In [14]:

  1. # 创建并拟合一个最临近邻居分类器
  2. from sklearn import neighbors
  3. knn = neighbors.KNeighborsClassifier()
  4. knn.fit(iris.data, iris.target)

Out[14]:

  1. KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
  2. metric_params=None, n_neighbors=5, p=2, weights='uniform')

In [15]:

  1. knn.predict([[0.1, 0.2, 0.3, 0.4]])

Out[15]:

  1. array([0])

训练集和测试集

当用学习算法进行实验时,重要的一点是不要用拟合预测器的数据来测试预测器的预测力。实际上,我们通常会在测试集上得到准确的预测。

In [16]:

  1. perm = np.random.permutation(iris.target.size)
  2. iris.data = iris.data[perm]
  3. iris.target = iris.target[perm]
  4. knn.fit(iris.data[:100], iris.target[:100])

Out[16]:

  1. KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
  2. metric_params=None, n_neighbors=5, p=2, weights='uniform')

In [17]:

  1. knn.score(iris.data[100:], iris.target[100:])

Out[17]:

  1. 0.95999999999999996

额外的问题: 为什么我们使用随机排列?

3.5.2.2 分类的支持向量机 (SVMs))

3.5.2.2.1 线性支持向量机

3.5.2 分类 - 图2

SVMs试图构建一个最大化两个类的间距的超平面。它选取输入的一个子集,称为支持向量,这个子集中的观察距离分隔超平面最近。

In [18]:

  1. from sklearn import svm
  2. svc = svm.SVC(kernel='linear')
  3. svc.fit(iris.data, iris.target)

Out[18]:

  1. SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, degree=3, gamma=0.0,
  2. kernel='linear', max_iter=-1, probability=False, random_state=None,
  3. shrinking=True, tol=0.001, verbose=False)

在scikit-learn实现了几种支持向量机。最常用的是svm.SVCsvm.NuSVCsvm.LinearSVC; “SVC” 代表支持向量分类器 (也存在用于回归的SVMs, 在scikit-learn被称为“SVR”)。

练习

在digits数据集上训练svm.SVC。留下最后的10%,在这些观察上测试预测的效果。

3.5.2.2.2 使用核 (kernel))

类通常并不是都能用超平面分隔,因此,有一个不仅仅是线性也可能是多项式或者幂的决策函数是明智的 :

线性核 (kernel)3.5.2 分类 - 图3

In [19]:

  1. svc = svm.SVC(kernel='linear')

多项式核 (kernel)3.5.2 分类 - 图4

In [20]:

  1. svc = svm.SVC(kernel='poly', degree=3)
  2. # degree: 多项式的阶

RBF核 (kernel) (径向基核函数)3.5.2 分类 - 图5

In [21]:

  1. svc = svm.SVC(kernel='rbf')
  2. # gamma: 径向基核大小的倒数

练习 以上列出的核哪一个在digits数据集上有较好的预测表现?