QRNN

Open In Colab

Quasi-recurrent neural networks introduced in Bradbury et al.

  1. /usr/local/lib/python3.8/dist-packages/torch/cuda/__init__.py:52: UserWarning: CUDA initialization: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx (Triggered internally at /pytorch/c10/cuda/CUDAFunctions.cpp:100.)
  2. return torch._C._cuda_getDeviceCount() > 0

ForgetMult

  1. __file__ = Path.cwd().parent/'fastai'/'text'/'models'/'qrnn.py'

load_cpp[source]

load_cpp(name, files, path)

dispatch_cuda[source]

dispatch_cuda(cuda_class, cpu_func, x)

Depending on x.device uses cpu_func or cuda_class.apply

The ForgetMult gate is the quasi-recurrent part of the network, computing the following from x and f.

  1. h[i+1] = x[i] * f[i] + h[i] + (1-f[i])

The initial value for h[0] is either a tensor of zeros or the previous hidden state.

forget_mult_CPU[source]

forget_mult_CPU(x, f, first_h=None, batch_first=True, backward=False)

ForgetMult gate applied to x and f on the CPU.

first_h is the tensor used for the value of h[0] (defaults to a tensor of zeros). If batch_first=True, x and f are expected to be of shape batch_size x seq_length x n_hid, otherwise they are expected to be of shape seq_length x batch_size x n_hid. If backwards=True, the elements in x and f on the sequence dimension are read in reverse.

  1. def manual_forget_mult(x, f, h=None, batch_first=True, backward=False):
  2. if batch_first: x,f = x.transpose(0,1),f.transpose(0,1)
  3. out = torch.zeros_like(x)
  4. prev = h if h is not None else torch.zeros_like(out[0])
  5. idx_range = range(x.shape[0]-1,-1,-1) if backward else range(x.shape[0])
  6. for i in idx_range:
  7. out[i] = f[i] * x[i] + (1-f[i]) * prev
  8. prev = out[i]
  9. if batch_first: out = out.transpose(0,1)
  10. return out
  11. x,f = torch.randn(5,3,20).chunk(2, dim=2)
  12. for (bf, bw) in [(True,True), (False,True), (True,False), (False,False)]:
  13. th_out = manual_forget_mult(x, f, batch_first=bf, backward=bw)
  14. out = forget_mult_CPU(x, f, batch_first=bf, backward=bw)
  15. test_close(th_out,out)
  16. h = torch.randn((5 if bf else 3), 10)
  17. th_out = manual_forget_mult(x, f, h=h, batch_first=bf, backward=bw)
  18. out = forget_mult_CPU(x, f, first_h=h, batch_first=bf, backward=bw)
  19. test_close(th_out,out)
  1. x = torch.randn(3,4,5)
  2. x.size() + torch.Size([0,1,0])
  1. torch.Size([3, 4, 5, 0, 1, 0])

class ForgetMultGPU[source]

ForgetMultGPU() :: Function

Wrapper around the CUDA kernels for the ForgetMult gate.

QRNN

class QRNNLayer[source]

QRNNLayer(input_size, hidden_size=None, save_prev_x=False, zoneout=0, window=1, output_gate=True, batch_first=True, backward=False) :: Module

Apply a single layer Quasi-Recurrent Neural Network (QRNN) to an input sequence.

  1. qrnn_fwd = QRNNLayer(10, 20, save_prev_x=True, zoneout=0, window=2, output_gate=True)
  2. qrnn_bwd = QRNNLayer(10, 20, save_prev_x=True, zoneout=0, window=2, output_gate=True, backward=True)
  3. qrnn_bwd.load_state_dict(qrnn_fwd.state_dict())
  4. x_fwd = torch.randn(7,5,10)
  5. x_bwd = x_fwd.clone().flip(1)
  6. y_fwd,h_fwd = qrnn_fwd(x_fwd)
  7. y_bwd,h_bwd = qrnn_bwd(x_bwd)
  8. test_close(y_fwd, y_bwd.flip(1), eps=1e-4)
  9. test_close(h_fwd, h_bwd, eps=1e-4)
  10. y_fwd,h_fwd = qrnn_fwd(x_fwd, h_fwd)
  11. y_bwd,h_bwd = qrnn_bwd(x_bwd, h_bwd)
  12. test_close(y_fwd, y_bwd.flip(1), eps=1e-4)
  13. test_close(h_fwd, h_bwd, eps=1e-4)

class QRNN[source]

QRNN(input_size, hidden_size, n_layers=1, batch_first=True, dropout=0, bidirectional=False, save_prev_x=False, zoneout=0, window=None, output_gate=True) :: Module

Apply a multiple layer Quasi-Recurrent Neural Network (QRNN) to an input sequence.

  1. qrnn = QRNN(10, 20, 2, bidirectional=True, batch_first=True, window=2, output_gate=False)
  2. x = torch.randn(7,5,10)
  3. y,h = qrnn(x)
  4. test_eq(y.size(), [7, 5, 40])
  5. test_eq(h.size(), [4, 7, 20])
  6. #Without an out gate, the last timestamp in the forward output is the second to last hidden
  7. #and the first timestamp of the backward output is the last hidden
  8. test_close(y[:,-1,:20], h[2])
  9. test_close(y[:,0,20:], h[3])

Company logo

©2021 fast.ai. All rights reserved.
Site last generated: Mar 31, 2021