Restricted Boltzmann Machine features for digit classification

http://scikit-learn.org/stable/auto_examples/neural_networks/plot_rbm_logistic_classification.html#sphx-glr-auto-examples-neural-networks-plot-rbm-logistic-classification-py

此範例將使用BernoulliRBM特徵選取方法,提升手寫數字識別的精確率,伯努利限制玻爾茲曼機器模型(`BernoulliRBM

`)將可以對數據做有效的非線性
特徵提取的處理。
為了讓此模型訓練出來更為強健,將輸入的圖檔,分別做上左右下,一像素的平移,用以增加更多訓練資料,
訓練網路的參數是使用grid search演算法,但此訓練太耗費時間,因此不再這重現,。
此範例結果將比較,
1.使用原本的像素值做的邏輯回歸
2.使用BernoulliRBM做特徵選取的邏輯回歸
結果將顯示:使用BernoulliRBM將可以提升分類的準確度。

(一)引入函式庫與資料

  1. from __future__ import print_function
  2. print(__doc__)
  3. # Authors: Yann N. Dauphin, Vlad Niculae, Gabriel Synnaeve
  4. # License: BSD
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. from scipy.ndimage import convolve
  8. from sklearn import linear_model, datasets, metrics
  9. from sklearn.model_selection import train_test_split
  10. from sklearn.neural_network import BernoulliRBM
  11. from sklearn.pipeline import Pipeline

(二)資料前處理、讀取資料、選取模型

  1. def nudge_dataset(X, Y):
  2. """
  3. 此副函式是用來將輸入資料的數字圖形,分別做上左右下一像素的平移,目的是製造更多的訓練資料讓模型訓練出來更強健
  4. """
  5. direction_vectors = [
  6. [[0, 1, 0],
  7. [0, 0, 0],
  8. [0, 0, 0]],
  9. [[0, 0, 0],
  10. [1, 0, 0],
  11. [0, 0, 0]],
  12. [[0, 0, 0],
  13. [0, 0, 1],
  14. [0, 0, 0]],
  15. [[0, 0, 0],
  16. [0, 0, 0],
  17. [0, 1, 0]]]
  18. shift = lambda x, w: convolve(x.reshape((8, 8)), mode='constant',
  19. weights=w).ravel()
  20. X = np.concatenate([X] +
  21. [np.apply_along_axis(shift, 1, X, vector)
  22. for vector in direction_vectors])
  23. Y = np.concatenate([Y for _ in range(5)], axis=0)
  24. return X, Y
  25. # Load Data
  26. digits = datasets.load_digits()
  27. X = np.asarray(digits.data, 'float32')
  28. X, Y = nudge_dataset(X, digits.target)
  29. X = (X - np.min(X, 0)) / (np.max(X, 0) + 0.0001) # 將灰階影像降尺度降到[0,1]
  30. # 將資料切割成訓練集與測試集
  31. X_train, X_test, Y_train, Y_test = train_test_split(X, Y,
  32. test_size=0.2,
  33. random_state=0)
  34. # Models we will use
  35. logistic = linear_model.LogisticRegression()
  36. rbm = BernoulliRBM(random_state=0, verbose=True)
  37. classifier = Pipeline(steps=[('rbm', rbm), ('logistic', logistic)])

(三)設定模型參數與訓練模型

  1. # 參數選擇需使用cross-validation去比較
  2. # 此參數是使用GridSearchCV找出來的. Here we are not performing cross-validation to save time.
  3. #GridSratch 就是將參數設定好,跑過全部參數後去找結果最好的一組參數
  4. rbm.learning_rate = 0.06
  5. rbm.n_iter = 20
  6. #.n_components = 100 表示隱藏層單元為100,即表示萃取出100個特徵,特徵萃取的越多準確率會越高,但越耗時間
  7. rbm.n_components = 100
  8. logistic.C = 6000.0
  9. # Training RBM-Logistic Pipeline
  10. classifier.fit(X_train, Y_train)
  11. # Training Logistic regression
  12. logistic_classifier = linear_model.LogisticRegression(C=100.0)
  13. logistic_classifier.fit(X_train, Y_train)

(四)評估模型的分辨準確率

  1. print()
  2. print("Logistic regression using RBM features:\n%s\n" % (
  3. metrics.classification_report(
  4. Y_test,
  5. classifier.predict(X_test))))
  6. print("Logistic regression using raw pixel features:\n%s\n" % (
  7. metrics.classification_report(
  8. Y_test,
  9. logistic_classifier.predict(X_test))))

Ex 2: Restricted Boltzmann Machine features for digit classification - 图1

圖1:使用RBM演算法後準確率為0.95

Ex 2: Restricted Boltzmann Machine features for digit classification - 图2

圖2:不使用任何特徵選取方法做的做的邏輯回歸準確率0.77

(五)畫出100個RBM萃取出的特徵

  1. plt.figure(figsize=(4.2, 4))
  2. for i, comp in enumerate(rbm.components_):
  3. plt.subplot(10, 10, i + 1)
  4. plt.imshow(comp.reshape((8, 8)), cmap=plt.cm.gray_r,
  5. interpolation='nearest')
  6. plt.xticks(())
  7. plt.yticks(())
  8. plt.suptitle('100 components extracted by RBM', fontsize=16)
  9. plt.subplots_adjust(0.08, 0.02, 0.92, 0.85, 0.08, 0.23)
  10. plt.show()

Ex 2: Restricted Boltzmann Machine features for digit classification - 图3

圖3:使用RBM演算法,尋找出來的特徵

(六)完整程式碼

  1. from __future__ import print_function
  2. print(__doc__)
  3. # Authors: Yann N. Dauphin, Vlad Niculae, Gabriel Synnaeve
  4. # License: BSD
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. from scipy.ndimage import convolve
  8. from sklearn import linear_model, datasets, metrics
  9. from sklearn.model_selection import train_test_split
  10. from sklearn.neural_network import BernoulliRBM
  11. from sklearn.pipeline import Pipeline
  12. ###############################################################################
  13. # Setting up
  14. def nudge_dataset(X, Y):
  15. """
  16. This produces a dataset 5 times bigger than the original one,
  17. by moving the 8x8 images in X around by 1px to left, right, down, up
  18. """
  19. direction_vectors = [
  20. [[0, 1, 0],
  21. [0, 0, 0],
  22. [0, 0, 0]],
  23. [[0, 0, 0],
  24. [1, 0, 0],
  25. [0, 0, 0]],
  26. [[0, 0, 0],
  27. [0, 0, 1],
  28. [0, 0, 0]],
  29. [[0, 0, 0],
  30. [0, 0, 0],
  31. [0, 1, 0]]]
  32. shift = lambda x, w: convolve(x.reshape((8, 8)), mode='constant',
  33. weights=w).ravel()
  34. X = np.concatenate([X] +
  35. [np.apply_along_axis(shift, 1, X, vector)
  36. for vector in direction_vectors])
  37. Y = np.concatenate([Y for _ in range(5)], axis=0)
  38. return X, Y
  39. # Load Data
  40. digits = datasets.load_digits()
  41. X = np.asarray(digits.data, 'float32')
  42. X, Y = nudge_dataset(X, digits.target)
  43. X = (X - np.min(X, 0)) / (np.max(X, 0) + 0.0001) # 0-1 scaling
  44. X_train, X_test, Y_train, Y_test = train_test_split(X, Y,
  45. test_size=0.2,
  46. random_state=0)
  47. # Models we will use
  48. logistic = linear_model.LogisticRegression()
  49. rbm = BernoulliRBM(random_state=0, verbose=True)
  50. classifier = Pipeline(steps=[('rbm', rbm), ('logistic', logistic)])
  51. ###############################################################################
  52. # Training
  53. # Hyper-parameters. These were set by cross-validation,
  54. # using a GridSearchCV. Here we are not performing cross-validation to
  55. # save time.
  56. rbm.learning_rate = 0.06
  57. rbm.n_iter = 20
  58. # More components tend to give better prediction performance, but larger
  59. # fitting time
  60. rbm.n_components = 100
  61. logistic.C = 6000.0
  62. # Training RBM-Logistic Pipeline
  63. classifier.fit(X_train, Y_train)
  64. # Training Logistic regression
  65. logistic_classifier = linear_model.LogisticRegression(C=100.0)
  66. logistic_classifier.fit(X_train, Y_train)
  67. ###############################################################################
  68. # Evaluation
  69. print()
  70. print("Logistic regression using RBM features:\n%s\n" % (
  71. metrics.classification_report(
  72. Y_test,
  73. classifier.predict(X_test))))
  74. print("Logistic regression using raw pixel features:\n%s\n" % (
  75. metrics.classification_report(
  76. Y_test,
  77. logistic_classifier.predict(X_test))))
  78. ###############################################################################
  79. # Plotting
  80. plt.figure(figsize=(4.2, 4))
  81. for i, comp in enumerate(rbm.components_):
  82. plt.subplot(10, 10, i + 1)
  83. plt.imshow(comp.reshape((8, 8)), cmap=plt.cm.gray_r,
  84. interpolation='nearest')
  85. plt.xticks(())
  86. plt.yticks(())
  87. plt.suptitle('100 components extracted by RBM', fontsize=16)
  88. plt.subplots_adjust(0.08, 0.02, 0.92, 0.85, 0.08, 0.23)
  89. plt.show()