Imdb

class paddle.text.datasets.Imdb [源代码]

该类是对`IMDB <https://www.imdb.com/interfaces/>`_ 测试数据集的实现。

参数

  • data_file(str) - 保存压缩数据的路径,如果参数:attr:`download`设置为True,

可设置为None。默认为None。 - mode(str) - ‘train’ 或’test’ 模式。默认为’train’。 - cutoff(int) - 构建词典的截止大小。默认为Default 150。 - download(bool) - 如果:attr:`data_file`未设置,是否自动下载数据集。默认为True。

返回值

Dataset, IMDB数据集实例。

代码示例

  1. import paddle
  2. from paddle.text.datasets import Imdb
  3. class SimpleNet(paddle.nn.Layer):
  4. def __init__(self):
  5. super(SimpleNet, self).__init__()
  6. def forward(self, doc, label):
  7. return paddle.sum(doc), label
  8. imdb = Imdb(mode='train')
  9. for i in range(10):
  10. doc, label = imdb[i]
  11. doc = paddle.to_tensor(doc)
  12. label = paddle.to_tensor(label)
  13. model = SimpleNet()
  14. image, label = model(doc, label)
  15. print(doc.numpy().shape, label.numpy().shape)