TensorFlow.js 模型训练 *

与 TensorFlow Serving 和 TensorFlow Lite 不同,TensorFlow.js 不仅支持模型的部署和推断,还支持直接在 TensorFlow.js 中进行模型训练、

在 TensorFlow 基础章节中,我们已经用 Python 实现过,针对某城市在2013年-2017年的房价的任务,通过对该数据进行线性回归,即使用线性模型 y = ax + b 来拟合上述数据,此处 ab 是待求的参数。

下面我们改用 TensorFlow.js 来实现一个 JavaScript 版本。

首先,我们定义数据,进行基本的归一化操作。

  1. import * as tf from '@tensorflow/tfjs'
  2.  
  3. const xsRaw = tf.tensor([2013, 2014, 2015, 2016, 2017])
  4. const ysRaw = tf.tensor([12000, 14000, 15000, 16500, 17500])
  5.  
  6. // 归一化
  7. const xs = xsRaw.sub(xsRaw.min())
  8. .div(xsRaw.max().sub(xsRaw.min()))
  9. const ys = ysRaw.sub(ysRaw.min())
  10. .div(ysRaw.max().sub(ysRaw.min()))

接下来,我们来求线性模型中两个参数 ab 的值。

使用 loss() 计算损失;使用 optimizer.minimize() 自动更新模型参数。

  1. const a = tf.scalar(Math.random()).variable()
  2. const b = tf.scalar(Math.random()).variable()
  3.  
  4. // y = a * x + b.
  5. const f = (x: tf.Tensor) => a.mul(x).add(b)
  6. const loss = (pred: tf.Tensor, label: tf.Tensor) => pred.sub(label).square().mean() as tf.Scalar
  7.  
  8. const learningRate = 1e-3
  9. const optimizer = tf.train.sgd(learningRate)
  10.  
  11. // 训练模型
  12. for (let i = 0; i < 10000; i++) {
  13. optimizer.minimize(() => loss(f(xs), ys))
  14. }
  15.  
  16. // 预测
  17. console.log(`a: ${a.dataSync()}, b: ${b.dataSync()}`)
  18. const preds = f(xs).dataSync() as Float32Array
  19. const trues = ys.arraySync() as number[]
  20. preds.forEach((pred, i) => {
  21. console.log(`x: ${i}, pred: ${pred.toFixed(2)}, true: ${trues[i].toFixed(2)}`)
  22. })

从下面的输出样例中我们可以看到,已经拟合的比较接近了。

  1. a: 0.9339302778244019, b: 0.08108722418546677
  2. x: 0, pred: 0.08, true: 0.00
  3. x: 1, pred: 0.31, true: 0.36
  4. x: 2, pred: 0.55, true: 0.55
  5. x: 3, pred: 0.78, true: 0.82
  6. x: 4, pred: 1.02, true: 1.00

可以直接在浏览器中运行,完整的 HTML 代码如下:

  1. <html>
  2. <head>
  3. <script src="http://unpkg.com/@tensorflow/tfjs/dist/tf.min.js"></script>
  4. <script>
  5. const xsRaw = tf.tensor([2013, 2014, 2015, 2016, 2017])
  6. const ysRaw = tf.tensor([12000, 14000, 15000, 16500, 17500])
  7.  
  8. // 归一化
  9. const xs = xsRaw.sub(xsRaw.min())
  10. .div(xsRaw.max().sub(xsRaw.min()))
  11. const ys = ysRaw.sub(ysRaw.min())
  12. .div(ysRaw.max().sub(ysRaw.min()))
  13. const a = tf.scalar(Math.random()).variable()
  14. const b = tf.scalar(Math.random()).variable()
  15.  
  16. // y = a * x + b.
  17. const f = (x) => a.mul(x).add(b)
  18. const loss = (pred, label) => pred.sub(label).square().mean()
  19.  
  20. const learningRate = 1e-3
  21. const optimizer = tf.train.sgd(learningRate)
  22.  
  23. // 训练模型
  24. for (let i = 0; i < 10000; i++) {
  25. optimizer.minimize(() => loss(f(xs), ys))
  26. }
  27.  
  28. // 预测
  29. console.log(`a: ${a.dataSync()}, b: ${b.dataSync()}`)
  30. const preds = f(xs).dataSync()
  31. const trues = ys.arraySync()
  32. preds.forEach((pred, i) => {
  33. console.log(`x: ${i}, pred: ${pred.toFixed(2)}, true: ${trues[i].toFixed(2)}`)
  34. })
  35. </script>
  36. </head>
  37. </html>