6-3,使用单GPU训练模型

深度学习的训练过程常常非常耗时,一个模型训练几个小时是家常便饭,训练几天也是常有的事情,有时候甚至要训练几十天。

训练过程的耗时主要来自于两个部分,一部分来自数据准备,另一部分来自参数迭代。

当数据准备过程还是模型训练时间的主要瓶颈时,我们可以使用更多进程来准备数据。

当参数迭代过程成为训练时间的主要瓶颈时,我们通常的方法是应用GPU或者Google的TPU来进行加速。

详见《用GPU加速Keras模型——Colab免费GPU使用攻略》

https://zhuanlan.zhihu.com/p/68509398

无论是内置fit方法,还是自定义训练循环,从CPU切换成单GPU训练模型都是非常方便的,无需更改任何代码。当存在可用的GPU时,如果不特意指定device,tensorflow会自动优先选择使用GPU来创建张量和执行张量计算。

但如果是在公司或者学校实验室的服务器环境,存在多个GPU和多个使用者时,为了不让单个同学的任务占用全部GPU资源导致其他同学无法使用(tensorflow默认获取全部GPU的全部内存资源权限,但实际上只使用一个GPU的部分资源),我们通常会在开头增加以下几行代码以控制每个任务使用的GPU编号和显存大小,以便其他同学也能够同时训练模型。

在Colab笔记本中:修改->笔记本设置->硬件加速器 中选择 GPU

注:以下代码只能在Colab 上才能正确执行。

可通过以下colab链接测试效果《tf_单GPU》:

https://colab.research.google.com/drive/1r5dLoeJq5z01sU72BX2M5UiNSkuxsEFe

  1. %tensorflow_version 2.x
  2. import tensorflow as tf
  3. print(tf.__version__)
  1. from tensorflow.keras import *
  2. #打印时间分割线
  3. @tf.function
  4. def printbar():
  5. ts = tf.timestamp()
  6. today_ts = ts%(24*60*60)
  7. hour = tf.cast(today_ts//3600+8,tf.int32)%tf.constant(24)
  8. minite = tf.cast((today_ts%3600)//60,tf.int32)
  9. second = tf.cast(tf.floor(today_ts%60),tf.int32)
  10. def timeformat(m):
  11. if tf.strings.length(tf.strings.format("{}",m))==1:
  12. return(tf.strings.format("0{}",m))
  13. else:
  14. return(tf.strings.format("{}",m))
  15. timestring = tf.strings.join([timeformat(hour),timeformat(minite),
  16. timeformat(second)],separator = ":")
  17. tf.print("=========="*8,end = "")
  18. tf.print(timestring)

一,GPU设置

  1. gpus = tf.config.list_physical_devices("GPU")
  2. if gpus:
  3. gpu0 = gpus[0] #如果有多个GPU,仅使用第0个GPU
  4. tf.config.experimental.set_memory_growth(gpu0, True) #设置GPU显存用量按需使用
  5. # 或者也可以设置GPU显存为固定使用量(例如:4G)
  6. #tf.config.experimental.set_virtual_device_configuration(gpu0,
  7. # [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=4096)])
  8. tf.config.set_visible_devices([gpu0],"GPU")

比较GPU和CPU的计算速度

  1. printbar()
  2. with tf.device("/gpu:0"):
  3. tf.random.set_seed(0)
  4. a = tf.random.uniform((10000,100),minval = 0,maxval = 3.0)
  5. b = tf.random.uniform((100,100000),minval = 0,maxval = 3.0)
  6. c = a@b
  7. tf.print(tf.reduce_sum(tf.reduce_sum(c,axis = 0),axis=0))
  8. printbar()
  1. ================================================================================17:37:01
  2. 2.24953778e+11
  3. ================================================================================17:37:01
  1. printbar()
  2. with tf.device("/cpu:0"):
  3. tf.random.set_seed(0)
  4. a = tf.random.uniform((10000,100),minval = 0,maxval = 3.0)
  5. b = tf.random.uniform((100,100000),minval = 0,maxval = 3.0)
  6. c = a@b
  7. tf.print(tf.reduce_sum(tf.reduce_sum(c,axis = 0),axis=0))
  8. printbar()
  1. ================================================================================17:37:34
  2. 2.24953795e+11
  3. ================================================================================17:37:40

二,准备数据

  1. MAX_LEN = 300
  2. BATCH_SIZE = 32
  3. (x_train,y_train),(x_test,y_test) = datasets.reuters.load_data()
  4. x_train = preprocessing.sequence.pad_sequences(x_train,maxlen=MAX_LEN)
  5. x_test = preprocessing.sequence.pad_sequences(x_test,maxlen=MAX_LEN)
  6. MAX_WORDS = x_train.max()+1
  7. CAT_NUM = y_train.max()+1
  8. ds_train = tf.data.Dataset.from_tensor_slices((x_train,y_train)) \
  9. .shuffle(buffer_size = 1000).batch(BATCH_SIZE) \
  10. .prefetch(tf.data.experimental.AUTOTUNE).cache()
  11. ds_test = tf.data.Dataset.from_tensor_slices((x_test,y_test)) \
  12. .shuffle(buffer_size = 1000).batch(BATCH_SIZE) \
  13. .prefetch(tf.data.experimental.AUTOTUNE).cache()

三,定义模型

  1. tf.keras.backend.clear_session()
  2. def create_model():
  3. model = models.Sequential()
  4. model.add(layers.Embedding(MAX_WORDS,7,input_length=MAX_LEN))
  5. model.add(layers.Conv1D(filters = 64,kernel_size = 5,activation = "relu"))
  6. model.add(layers.MaxPool1D(2))
  7. model.add(layers.Conv1D(filters = 32,kernel_size = 3,activation = "relu"))
  8. model.add(layers.MaxPool1D(2))
  9. model.add(layers.Flatten())
  10. model.add(layers.Dense(CAT_NUM,activation = "softmax"))
  11. return(model)
  12. model = create_model()
  13. model.summary()
  1. Model: "sequential"
  2. _________________________________________________________________
  3. Layer (type) Output Shape Param #
  4. =================================================================
  5. embedding (Embedding) (None, 300, 7) 216874
  6. _________________________________________________________________
  7. conv1d (Conv1D) (None, 296, 64) 2304
  8. _________________________________________________________________
  9. max_pooling1d (MaxPooling1D) (None, 148, 64) 0
  10. _________________________________________________________________
  11. conv1d_1 (Conv1D) (None, 146, 32) 6176
  12. _________________________________________________________________
  13. max_pooling1d_1 (MaxPooling1 (None, 73, 32) 0
  14. _________________________________________________________________
  15. flatten (Flatten) (None, 2336) 0
  16. _________________________________________________________________
  17. dense (Dense) (None, 46) 107502
  18. =================================================================
  19. Total params: 332,856
  20. Trainable params: 332,856
  21. Non-trainable params: 0
  22. _________________________________________________________________

四,训练模型

  1. optimizer = optimizers.Nadam()
  2. loss_func = losses.SparseCategoricalCrossentropy()
  3. train_loss = metrics.Mean(name='train_loss')
  4. train_metric = metrics.SparseCategoricalAccuracy(name='train_accuracy')
  5. valid_loss = metrics.Mean(name='valid_loss')
  6. valid_metric = metrics.SparseCategoricalAccuracy(name='valid_accuracy')
  7. @tf.function
  8. def train_step(model, features, labels):
  9. with tf.GradientTape() as tape:
  10. predictions = model(features,training = True)
  11. loss = loss_func(labels, predictions)
  12. gradients = tape.gradient(loss, model.trainable_variables)
  13. optimizer.apply_gradients(zip(gradients, model.trainable_variables))
  14. train_loss.update_state(loss)
  15. train_metric.update_state(labels, predictions)
  16. @tf.function
  17. def valid_step(model, features, labels):
  18. predictions = model(features)
  19. batch_loss = loss_func(labels, predictions)
  20. valid_loss.update_state(batch_loss)
  21. valid_metric.update_state(labels, predictions)
  22. def train_model(model,ds_train,ds_valid,epochs):
  23. for epoch in tf.range(1,epochs+1):
  24. for features, labels in ds_train:
  25. train_step(model,features,labels)
  26. for features, labels in ds_valid:
  27. valid_step(model,features,labels)
  28. logs = 'Epoch={},Loss:{},Accuracy:{},Valid Loss:{},Valid Accuracy:{}'
  29. if epoch%1 ==0:
  30. printbar()
  31. tf.print(tf.strings.format(logs,
  32. (epoch,train_loss.result(),train_metric.result(),valid_loss.result(),valid_metric.result())))
  33. tf.print("")
  34. train_loss.reset_states()
  35. valid_loss.reset_states()
  36. train_metric.reset_states()
  37. valid_metric.reset_states()
  38. train_model(model,ds_train,ds_test,10)
  1. ================================================================================17:13:26
  2. Epoch=1,Loss:1.96735072,Accuracy:0.489200622,Valid Loss:1.64124215,Valid Accuracy:0.582813919
  3. ================================================================================17:13:28
  4. Epoch=2,Loss:1.4640888,Accuracy:0.624805152,Valid Loss:1.5559175,Valid Accuracy:0.607747078
  5. ================================================================================17:13:30
  6. Epoch=3,Loss:1.20681274,Accuracy:0.68581605,Valid Loss:1.58494771,Valid Accuracy:0.622439921
  7. ================================================================================17:13:31
  8. Epoch=4,Loss:0.937500894,Accuracy:0.75361836,Valid Loss:1.77466083,Valid Accuracy:0.621994674
  9. ================================================================================17:13:33
  10. Epoch=5,Loss:0.693960547,Accuracy:0.822199941,Valid Loss:2.00267363,Valid Accuracy:0.6197685
  11. ================================================================================17:13:35
  12. Epoch=6,Loss:0.519614,Accuracy:0.870296121,Valid Loss:2.23463202,Valid Accuracy:0.613980412
  13. ================================================================================17:13:37
  14. Epoch=7,Loss:0.408562034,Accuracy:0.901246965,Valid Loss:2.46969271,Valid Accuracy:0.612199485
  15. ================================================================================17:13:39
  16. Epoch=8,Loss:0.339028627,Accuracy:0.920062363,Valid Loss:2.68585229,Valid Accuracy:0.615316093
  17. ================================================================================17:13:41
  18. Epoch=9,Loss:0.293798745,Accuracy:0.92930305,Valid Loss:2.88995624,Valid Accuracy:0.613535166
  19. ================================================================================17:13:43
  20. Epoch=10,Loss:0.263130337,Accuracy:0.936651051,Valid Loss:3.09705234,Valid Accuracy:0.612644672

如果对本书内容理解上有需要进一步和作者交流的地方,欢迎在公众号”Python与算法之美”下留言。作者时间和精力有限,会酌情予以回复。

也可以在公众号后台回复关键字:加群,加入读者交流群和大家讨论。

image.png