6-5,使用TPU训练模型

如果想尝试使用Google Colab上的TPU来训练模型,也是非常方便,仅需添加6行代码。

在Colab笔记本中:修改->笔记本设置->硬件加速器 中选择 TPU

注:以下代码只能在Colab 上才能正确执行。

可通过以下colab链接测试效果《tf_TPU》:

https://colab.research.google.com/drive/1XCIhATyE1R7lq6uwFlYlRsUr5d9_-r1s

  1. %tensorflow_version 2.x
  2. import tensorflow as tf
  3. print(tf.__version__)
  4. from tensorflow.keras import *

一,准备数据

  1. MAX_LEN = 300
  2. BATCH_SIZE = 32
  3. (x_train,y_train),(x_test,y_test) = datasets.reuters.load_data()
  4. x_train = preprocessing.sequence.pad_sequences(x_train,maxlen=MAX_LEN)
  5. x_test = preprocessing.sequence.pad_sequences(x_test,maxlen=MAX_LEN)
  6. MAX_WORDS = x_train.max()+1
  7. CAT_NUM = y_train.max()+1
  8. ds_train = tf.data.Dataset.from_tensor_slices((x_train,y_train)) \
  9. .shuffle(buffer_size = 1000).batch(BATCH_SIZE) \
  10. .prefetch(tf.data.experimental.AUTOTUNE).cache()
  11. ds_test = tf.data.Dataset.from_tensor_slices((x_test,y_test)) \
  12. .shuffle(buffer_size = 1000).batch(BATCH_SIZE) \
  13. .prefetch(tf.data.experimental.AUTOTUNE).cache()

二,定义模型

  1. tf.keras.backend.clear_session()
  2. def create_model():
  3. model = models.Sequential()
  4. model.add(layers.Embedding(MAX_WORDS,7,input_length=MAX_LEN))
  5. model.add(layers.Conv1D(filters = 64,kernel_size = 5,activation = "relu"))
  6. model.add(layers.MaxPool1D(2))
  7. model.add(layers.Conv1D(filters = 32,kernel_size = 3,activation = "relu"))
  8. model.add(layers.MaxPool1D(2))
  9. model.add(layers.Flatten())
  10. model.add(layers.Dense(CAT_NUM,activation = "softmax"))
  11. return(model)
  12. def compile_model(model):
  13. model.compile(optimizer=optimizers.Nadam(),
  14. loss=losses.SparseCategoricalCrossentropy(from_logits=True),
  15. metrics=[metrics.SparseCategoricalAccuracy(),metrics.SparseTopKCategoricalAccuracy(5)])
  16. return(model)

三,训练模型

  1. #增加以下6行代码
  2. import os
  3. resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='grpc://' + os.environ['COLAB_TPU_ADDR'])
  4. tf.config.experimental_connect_to_cluster(resolver)
  5. tf.tpu.experimental.initialize_tpu_system(resolver)
  6. strategy = tf.distribute.experimental.TPUStrategy(resolver)
  7. with strategy.scope():
  8. model = create_model()
  9. model.summary()
  10. model = compile_model(model)
  1. WARNING:tensorflow:TPU system 10.26.134.242:8470 has already been initialized. Reinitializing the TPU can cause previously created variables on TPU to be lost.
  2. WARNING:tensorflow:TPU system 10.26.134.242:8470 has already been initialized. Reinitializing the TPU can cause previously created variables on TPU to be lost.
  3. INFO:tensorflow:Initializing the TPU system: 10.26.134.242:8470
  4. INFO:tensorflow:Initializing the TPU system: 10.26.134.242:8470
  5. INFO:tensorflow:Clearing out eager caches
  6. INFO:tensorflow:Clearing out eager caches
  7. INFO:tensorflow:Finished initializing TPU system.
  8. INFO:tensorflow:Finished initializing TPU system.
  9. INFO:tensorflow:Found TPU system:
  10. INFO:tensorflow:Found TPU system:
  11. INFO:tensorflow:*** Num TPU Cores: 8
  12. INFO:tensorflow:*** Num TPU Cores: 8
  13. INFO:tensorflow:*** Num TPU Workers: 1
  14. INFO:tensorflow:*** Num TPU Workers: 1
  15. INFO:tensorflow:*** Num TPU Cores Per Worker: 8
  16. INFO:tensorflow:*** Num TPU Cores Per Worker: 8
  17. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)
  18. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)
  19. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
  20. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
  21. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)
  22. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)
  23. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)
  24. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)
  25. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)
  26. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)
  27. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)
  28. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)
  29. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)
  30. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)
  31. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)
  32. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)
  33. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)
  34. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)
  35. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)
  36. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)
  37. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)
  38. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)
  39. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)
  40. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)
  41. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
  42. INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)
  43. Model: "sequential"
  44. _________________________________________________________________
  45. Layer (type) Output Shape Param #
  46. =================================================================
  47. embedding (Embedding) (None, 300, 7) 216874
  48. _________________________________________________________________
  49. conv1d (Conv1D) (None, 296, 64) 2304
  50. _________________________________________________________________
  51. max_pooling1d (MaxPooling1D) (None, 148, 64) 0
  52. _________________________________________________________________
  53. conv1d_1 (Conv1D) (None, 146, 32) 6176
  54. _________________________________________________________________
  55. max_pooling1d_1 (MaxPooling1 (None, 73, 32) 0
  56. _________________________________________________________________
  57. flatten (Flatten) (None, 2336) 0
  58. _________________________________________________________________
  59. dense (Dense) (None, 46) 107502
  60. =================================================================
  61. Total params: 332,856
  62. Trainable params: 332,856
  63. Non-trainable params: 0
  64. _________________________________________________________________
  1. history = model.fit(ds_train,validation_data = ds_test,epochs = 10)
  1. Train for 281 steps, validate for 71 steps
  2. Epoch 1/10
  3. 281/281 [==============================] - 12s 43ms/step - loss: 3.4466 - sparse_categorical_accuracy: 0.4332 - sparse_top_k_categorical_accuracy: 0.7180 - val_loss: 3.3179 - val_sparse_categorical_accuracy: 0.5352 - val_sparse_top_k_categorical_accuracy: 0.7195
  4. Epoch 2/10
  5. 281/281 [==============================] - 6s 20ms/step - loss: 3.3251 - sparse_categorical_accuracy: 0.5405 - sparse_top_k_categorical_accuracy: 0.7302 - val_loss: 3.3082 - val_sparse_categorical_accuracy: 0.5463 - val_sparse_top_k_categorical_accuracy: 0.7235
  6. Epoch 3/10
  7. 281/281 [==============================] - 6s 20ms/step - loss: 3.2961 - sparse_categorical_accuracy: 0.5729 - sparse_top_k_categorical_accuracy: 0.7280 - val_loss: 3.3026 - val_sparse_categorical_accuracy: 0.5499 - val_sparse_top_k_categorical_accuracy: 0.7217
  8. Epoch 4/10
  9. 281/281 [==============================] - 5s 19ms/step - loss: 3.2751 - sparse_categorical_accuracy: 0.5924 - sparse_top_k_categorical_accuracy: 0.7276 - val_loss: 3.2957 - val_sparse_categorical_accuracy: 0.5543 - val_sparse_top_k_categorical_accuracy: 0.7217
  10. Epoch 5/10
  11. 281/281 [==============================] - 5s 19ms/step - loss: 3.2655 - sparse_categorical_accuracy: 0.6008 - sparse_top_k_categorical_accuracy: 0.7290 - val_loss: 3.3022 - val_sparse_categorical_accuracy: 0.5490 - val_sparse_top_k_categorical_accuracy: 0.7231
  12. Epoch 6/10
  13. 281/281 [==============================] - 5s 19ms/step - loss: 3.2616 - sparse_categorical_accuracy: 0.6041 - sparse_top_k_categorical_accuracy: 0.7295 - val_loss: 3.3015 - val_sparse_categorical_accuracy: 0.5503 - val_sparse_top_k_categorical_accuracy: 0.7235
  14. Epoch 7/10
  15. 281/281 [==============================] - 6s 21ms/step - loss: 3.2595 - sparse_categorical_accuracy: 0.6059 - sparse_top_k_categorical_accuracy: 0.7322 - val_loss: 3.3064 - val_sparse_categorical_accuracy: 0.5454 - val_sparse_top_k_categorical_accuracy: 0.7266
  16. Epoch 8/10
  17. 281/281 [==============================] - 6s 21ms/step - loss: 3.2591 - sparse_categorical_accuracy: 0.6063 - sparse_top_k_categorical_accuracy: 0.7327 - val_loss: 3.3025 - val_sparse_categorical_accuracy: 0.5481 - val_sparse_top_k_categorical_accuracy: 0.7231
  18. Epoch 9/10
  19. 281/281 [==============================] - 5s 19ms/step - loss: 3.2588 - sparse_categorical_accuracy: 0.6062 - sparse_top_k_categorical_accuracy: 0.7332 - val_loss: 3.2992 - val_sparse_categorical_accuracy: 0.5521 - val_sparse_top_k_categorical_accuracy: 0.7257
  20. Epoch 10/10
  21. 281/281 [==============================] - 5s 18ms/step - loss: 3.2577 - sparse_categorical_accuracy: 0.6073 - sparse_top_k_categorical_accuracy: 0.7363 - val_loss: 3.2981 - val_sparse_categorical_accuracy: 0.5516 - val_sparse_top_k_categorical_accuracy: 0.7306
  22. CPU times: user 18.9 s, sys: 3.86 s, total: 22.7 s
  23. Wall time: 1min 1s

如果对本书内容理解上有需要进一步和作者交流的地方,欢迎在公众号”Python与算法之美”下留言。作者时间和精力有限,会酌情予以回复。

也可以在公众号后台回复关键字:加群,加入读者交流群和大家讨论。

image.png