Data core

Open In Colab

Core functionality for gathering data

  1. /usr/local/lib/python3.8/dist-packages/torch/cuda/__init__.py:52: UserWarning: CUDA initialization: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx (Triggered internally at /pytorch/c10/cuda/CUDAFunctions.cpp:100.)
  2. return torch._C._cuda_getDeviceCount() > 0

The classes here provide functionality for applying a list of transforms to a set of items (TfmdLists, Datasets) or a DataLoader (TfmdDl) as well as the base class used to gather the data for model training: DataLoaders.

show_batch is a type-dispatched function that is responsible for showing decoded samples. x and y are the input and the target in the batch to be shown, and are passed along to dispatch on their types. There is a different implementation of show_batch if x is a TensorImage or a TensorText for instance (see vision.core or text.data for more details). ctxs can be passed but the function is responsible to create them if necessary. kwargs depend on the specific implementation.

show_results is a type-dispatched function that is responsible for showing decoded samples and their corresponding outs. Like in show_batch, x and y are the input and the target in the batch to be shown, and are passed along to dispatch on their types. ctxs can be passed but the function is responsible to create them if necessary. kwargs depend on the specific implementation.

class TfmdDL[source]

TfmdDL(dataset, bs=64, shuffle=False, num_workers=None, verbose=False, do_setup=True, pin_memory=False, timeout=0, batch_size=None, drop_last=False, indexed=None, n=None, device=None, persistent_workers=False, wif=None, before_iter=None, after_item=None, before_batch=None, after_batch=None, after_iter=None, create_batches=None, create_item=None, create_batch=None, retain=None, get_idxs=None, sample=None, shuffle_fn=None, do_batch=None) :: DataLoader

Transformed DataLoader

A TfmdDL is a DataLoader that creates Pipeline from a list of Transforms for the callbacks after_item, before_batch and after_batch. As a result, it can decode or show a processed batch.

  1. class _Category(int, ShowTitle): pass
  1. class NegTfm(Transform):
  2. def encodes(self, x): return torch.neg(x)
  3. def decodes(self, x): return torch.neg(x)
  4. tdl = TfmdDL([(TensorImage([1]),)] * 4, after_batch=NegTfm(), bs=4, num_workers=4)
  5. b = tdl.one_batch()
  6. test_eq(type(b[0]), TensorImage)
  7. b = (tensor([1.,1.,1.,1.]),)
  8. test_eq(type(tdl.decode_batch(b)[0][0]), TensorImage)
  1. class A(Transform):
  2. def encodes(self, x): return x
  3. def decodes(self, x): return TitledInt(x)
  4. @Transform
  5. def f(x)->None: return fastuple((x,x))
  6. start = torch.arange(50)
  7. test_eq_type(f(2), fastuple((2,2)))
  1. a = A()
  2. tdl = TfmdDL(start, after_item=lambda x: (a(x), f(x)), bs=4)
  3. x,y = tdl.one_batch()
  4. test_eq(type(y), fastuple)
  5. s = tdl.decode_batch((x,y))
  6. test_eq(type(s[0][1]), fastuple)
  1. tdl = TfmdDL(torch.arange(0,50), after_item=A(), after_batch=NegTfm(), bs=4)
  2. test_eq(tdl.dataset[0], start[0])
  3. test_eq(len(tdl), (50-1)//4+1)
  4. test_eq(tdl.bs, 4)
  5. test_stdout(tdl.show_batch, '0n1n2n3')
  6. test_stdout(partial(tdl.show_batch, unique=True), '0n0n0n0')
  1. class B(Transform):
  2. parameters = 'a'
  3. def __init__(self): self.a = torch.tensor(0.)
  4. def encodes(self, x): x
  5. tdl = TfmdDL([(TensorImage([1]),)] * 4, after_batch=B(), bs=4)
  6. test_eq(tdl.after_batch.fs[0].a.device, torch.device('cpu'))
  7. tdl.to(default_device())
  8. test_eq(tdl.after_batch.fs[0].a.device, default_device())

Methods

DataLoader.one_batch[source]

DataLoader.one_batch()

Return one batch from DataLoader.

  1. tfm = NegTfm()
  2. tdl = TfmdDL(start, after_batch=tfm, bs=4)
  1. b = tdl.one_batch()
  2. test_eq(tensor([0,-1,-2,-3]), b)

TfmdDL.decode[source]

TfmdDL.decode(b)

Decode b using tfms

  1. test_eq(tdl.decode(b), tensor(0,1,2,3))

TfmdDL.decode_batch[source]

TfmdDL.decode_batch(b, max_n=9, full=True)

Decode b entirely

  1. test_eq(tdl.decode_batch(b), [0,1,2,3])

TfmdDL.show_batch[source]

TfmdDL.show_batch(b=None, max_n=9, ctxs=None, show=True, unique=False, **kwargs)

Show b (defaults to one_batch), a list of lists of pipeline outputs (i.e. output of a DataLoader)

TfmdDL.to[source]

TfmdDL.to(device)

Put self and its transforms state on device

class DataLoaders[source]

DataLoaders(*loaders, path='.', device=None) :: GetAttr

Basic wrapper around several DataLoaders.

  1. dls = DataLoaders(tdl,tdl)
  2. x = dls.train.one_batch()
  3. x2 = first(tdl)
  4. test_eq(x,x2)
  5. x2 = dls.one_batch()
  6. test_eq(x,x2)

Multiple transforms can by added to multiple dataloaders using Dataloaders.add_tfms. You can specify the dataloaders by list of names dls.add_tfms(...,'valid',...) or by index dls.add_tfms(...,1,....), by default transforms are added to all dataloaders. event is a required argument and determined when the transform will be run, for more information on events please refer to TfmdDL. tfms is a list of Transform, and is a required argument.

  1. class _TestTfm(Transform):
  2. def encodes(self, o): return torch.ones_like(o)
  3. def decodes(self, o): return o
  4. tdl1,tdl2 = TfmdDL(start, bs=4),TfmdDL(start, bs=4)
  5. dls2 = DataLoaders(tdl1,tdl2)
  6. dls2.add_tfms([_TestTfm()],'after_batch',['valid'])
  7. dls2.add_tfms([_TestTfm()],'after_batch',[1])
  8. dls2.train.after_batch,dls2.valid.after_batch,
  1. (Pipeline: , Pipeline: _TestTfm -> _TestTfm)

Methods

DataLoaders.__getitem__[source]

DataLoaders.__getitem__(i)

Retrieve DataLoader at i (0 is training, 1 is validation)

  1. x2 = dls[0].one_batch()
  2. test_eq(x,x2)

DataLoaders.train[source]

Training DataLoader

DataLoaders.valid[source]

Validation DataLoader

DataLoaders.train_ds[source]

Training Dataset

DataLoaders.valid_ds[source]

Validation Dataset

class FilteredBase[source]

FilteredBase(*args, dl_type=None, **kwargs)

Base class for lists with subsets

FilteredBase.dataloaders[source]

FilteredBase.dataloaders(bs=64, shuffle_train=None, shuffle=True, val_shuffle=False, n=None, path='.', dl_type=None, dl_kwargs=None, device=None, drop_last=None, val_bs=None, num_workers=None, verbose=False, do_setup=True, pin_memory=False, timeout=0, batch_size=None, indexed=None, persistent_workers=False, wif=None, before_iter=None, after_item=None, before_batch=None, after_batch=None, after_iter=None, create_batches=None, create_item=None, create_batch=None, retain=None, get_idxs=None, sample=None, shuffle_fn=None, do_batch=None)

Get a DataLoaders

class TfmdLists[source]

TfmdLists(items=None, *rest, use_list=False, match=None) :: FilteredBase

A Pipeline of tfms applied to a collection of items

decode_at[source]

decode_at(o, idx)

Decoded item at idx

  1. def decode_at(o, idx):
  2. "Decoded item at `idx`"
  3. return o.decode(o[idx])

show_at[source]

show_at(o, idx, **kwargs)

  1. def show_at(o, idx, **kwargs):
  2. "Show item at `idx`",
  3. return o.show(o[idx], **kwargs)

A TfmdLists combines a collection of object with a Pipeline. tfms can either be a Pipeline or a list of transforms, in which case, it will wrap them in a Pipeline. use_list is passed along to L with the items and split_idx are passed to each transform of the Pipeline. do_setup indicates if the Pipeline.setup method should be called during initialization.

  1. class _IntFloatTfm(Transform):
  2. def encodes(self, o): return TitledInt(o)
  3. def decodes(self, o): return TitledFloat(o)
  4. int2f_tfm=_IntFloatTfm()
  5. def _neg(o): return -o
  6. neg_tfm = Transform(_neg, _neg)
  1. items = L([1.,2.,3.]); tfms = [neg_tfm, int2f_tfm]
  2. tl = TfmdLists(items, tfms=tfms)
  3. test_eq_type(tl[0], TitledInt(-1))
  4. test_eq_type(tl[1], TitledInt(-2))
  5. test_eq_type(tl.decode(tl[2]), TitledFloat(3.))
  6. test_stdout(lambda: show_at(tl, 2), '-3')
  7. test_eq(tl.types, [float, float, TitledInt])
  8. tl
  1. TfmdLists: [1.0, 2.0, 3.0]
  2. tfms - [_neg:
  3. encodes: (object,object) -> _negdecodes: (object,object) -> _neg, _IntFloatTfm:
  4. encodes: (object,object) -> encodes
  5. decodes: (object,object) -> decodes
  6. ]
  1. splits = [[0,2],[1]]
  2. tl = TfmdLists(items, tfms=tfms, splits=splits)
  3. test_eq(tl.n_subsets, 2)
  4. test_eq(tl.train, tl.subset(0))
  5. test_eq(tl.valid, tl.subset(1))
  6. test_eq(tl.train.items, items[splits[0]])
  7. test_eq(tl.valid.items, items[splits[1]])
  8. test_eq(tl.train.tfms.split_idx, 0)
  9. test_eq(tl.valid.tfms.split_idx, 1)
  10. test_eq(tl.train.new_empty().split_idx, 0)
  11. test_eq(tl.valid.new_empty().split_idx, 1)
  12. test_eq_type(tl.splits, L(splits))
  13. assert not tl.overlapping_splits()
  1. df = pd.DataFrame(dict(a=[1,2,3],b=[2,3,4]))
  2. tl = TfmdLists(df, lambda o: o.a+1, splits=[[0],[1,2]])
  3. test_eq(tl[1,2], [3,4])
  4. tr = tl.subset(0)
  5. test_eq(tr[:], [2])
  6. val = tl.subset(1)
  7. test_eq(val[:], [3,4])
  1. class _B(Transform):
  2. def __init__(self): self.m = 0
  3. def encodes(self, o): return o+self.m
  4. def decodes(self, o): return o-self.m
  5. def setups(self, items):
  6. print(items)
  7. self.m = tensor(items).float().mean().item()
  8. # test for setup, which updates `self.m`
  9. tl = TfmdLists(items, _B())
  10. test_eq(tl.m, 2)
  1. TfmdLists: [1.0, 2.0, 3.0]
  2. tfms - []

Here’s how we can use TfmdLists.setup to implement a simple category list, getting labels from a mock file list:

  1. class _Cat(Transform):
  2. order = 1
  3. def encodes(self, o): return int(self.o2i[o])
  4. def decodes(self, o): return TitledStr(self.vocab[o])
  5. def setups(self, items): self.vocab,self.o2i = uniqueify(L(items), sort=True, bidir=True)
  6. tcat = _Cat()
  7. def _lbl(o): return TitledStr(o.split('_')[0])
  8. # Check that tfms are sorted by `order` & `_lbl` is called first
  9. fns = ['dog_0.jpg','cat_0.jpg','cat_2.jpg','cat_1.jpg','dog_1.jpg']
  10. tl = TfmdLists(fns, [tcat,_lbl])
  11. exp_voc = ['cat','dog']
  12. test_eq(tcat.vocab, exp_voc)
  13. test_eq(tl.tfms.vocab, exp_voc)
  14. test_eq(tl.vocab, exp_voc)
  15. test_eq(tl, (1,0,0,0,1))
  16. test_eq([tl.decode(o) for o in tl], ('dog','cat','cat','cat','dog'))
  1. tl = TfmdLists(fns, [tcat,_lbl], splits=[[0,4], [1,2,3]])
  2. test_eq(tcat.vocab, ['dog'])
  1. tfm = NegTfm(split_idx=1)
  2. tds = TfmdLists(start, A())
  3. tdl = TfmdDL(tds, after_batch=tfm, bs=4)
  4. x = tdl.one_batch()
  5. test_eq(x, torch.arange(4))
  6. tds.split_idx = 1
  7. x = tdl.one_batch()
  8. test_eq(x, -torch.arange(4))
  9. tds.split_idx = 0
  10. x = tdl.one_batch()
  11. test_eq(x, torch.arange(4))
  1. tds = TfmdLists(start, A())
  2. tdl = TfmdDL(tds, after_batch=NegTfm(), bs=4)
  3. test_eq(tdl.dataset[0], start[0])
  4. test_eq(len(tdl), (len(tds)-1)//4+1)
  5. test_eq(tdl.bs, 4)
  6. test_stdout(tdl.show_batch, '0n1n2n3')

TfmdLists.subset[source]

TfmdLists.subset(i)

New TfmdLists with same tfms that only includes items in ith split

TfmdLists.infer_idx[source]

TfmdLists.infer_idx(x)

Finds the index where self.tfms can be applied to x, depending on the type of x

TfmdLists.infer[source]

TfmdLists.infer(x)

Apply self.tfms to x starting at the right tfm depending on the type of x

  1. def mult(x): return x*2
  2. mult.order = 2
  3. fns = ['dog_0.jpg','cat_0.jpg','cat_2.jpg','cat_1.jpg','dog_1.jpg']
  4. tl = TfmdLists(fns, [_lbl,_Cat(),mult])
  5. test_eq(tl.infer_idx('dog_45.jpg'), 0)
  6. test_eq(tl.infer('dog_45.jpg'), 2)
  7. test_eq(tl.infer_idx(4), 2)
  8. test_eq(tl.infer(4), 8)
  9. test_fail(lambda: tl.infer_idx(2.0))
  10. test_fail(lambda: tl.infer(2.0))

class Datasets[source]

Datasets(items=None, tfms=None, tls=None, n_inp=None, dl_type=None, use_list=None, do_setup=True, split_idx=None, train_setup=True, splits=None, types=None, verbose=False) :: FilteredBase

A dataset that creates a tuple from each tfms, passed through item_tfms

A Datasets creates a tuple from items (typically input,target) by applying to them each list of Transform (or Pipeline) in tfms. Note that if tfms contains only one list of tfms, the items given by Datasets will be tuples of one element.

n_inp is the number of elements in the tuples that should be considered part of the input and will default to 1 if tfms consists of one set of transforms, len(tfms)-1 otherwise. In most cases, the number of elements in the tuples spit out by Datasets will be 2 (for input,target) but it can happen that there is 3 (Siamese networks or tabular data) in which case we need to be able to determine when the inputs end and the targets begin.

  1. items = [1,2,3,4]
  2. dsets = Datasets(items, [[neg_tfm,int2f_tfm], [add(1)]])
  3. t = dsets[0]
  4. test_eq(t, (-1,2))
  5. test_eq(dsets[0,1,2], [(-1,2),(-2,3),(-3,4)])
  6. test_eq(dsets.n_inp, 1)
  7. dsets.decode(t)
  1. (1.0, 2)
  1. class Norm(Transform):
  2. def encodes(self, o): return (o-self.m)/self.s
  3. def decodes(self, o): return (o*self.s)+self.m
  4. def setups(self, items):
  5. its = tensor(items).float()
  6. self.m,self.s = its.mean(),its.std()
  1. items = [1,2,3,4]
  2. nrm = Norm()
  3. dsets = Datasets(items, [[neg_tfm,int2f_tfm], [neg_tfm,nrm]])
  4. x,y = zip(*dsets)
  5. test_close(tensor(y).mean(), 0)
  6. test_close(tensor(y).std(), 1)
  7. test_eq(x, (-1,-2,-3,-4,))
  8. test_eq(nrm.m, -2.5)
  9. test_stdout(lambda:show_at(dsets, 1), '-2')
  10. test_eq(dsets.m, nrm.m)
  11. test_eq(dsets.norm.m, nrm.m)
  12. test_eq(dsets.train.norm.m, nrm.m)
  1. test_fns = ['dog_0.jpg','cat_0.jpg','cat_2.jpg','cat_1.jpg','kid_1.jpg']
  2. tcat = _Cat()
  3. dsets = Datasets(test_fns, [[tcat,_lbl]], splits=[[0,1,2], [3,4]])
  4. test_eq(tcat.vocab, ['cat','dog'])
  5. test_eq(dsets.train, [(1,),(0,),(0,)])
  6. test_eq(dsets.valid[0], (0,))
  7. test_stdout(lambda: show_at(dsets.train, 0), "dog")
  1. inp = [0,1,2,3,4]
  2. dsets = Datasets(inp, tfms=[None])
  3. test_eq(*dsets[2], 2) # Retrieve one item (subset 0 is the default)
  4. test_eq(dsets[1,2], [(1,),(2,)]) # Retrieve two items by index
  5. mask = [True,False,False,True,False]
  6. test_eq(dsets[mask], [(0,),(3,)]) # Retrieve two items by mask
  1. inp = pd.DataFrame(dict(a=[5,1,2,3,4]))
  2. dsets = Datasets(inp, tfms=attrgetter('a')).subset(0)
  3. test_eq(*dsets[2], 2) # Retrieve one item (subset 0 is the default)
  4. test_eq(dsets[1,2], [(1,),(2,)]) # Retrieve two items by index
  5. mask = [True,False,False,True,False]
  6. test_eq(dsets[mask], [(5,),(3,)]) # Retrieve two items by mask
  1. inp = [0,1,2,3,4]
  2. dsets = Datasets(inp, tfms=[None])
  3. test_eq(dsets.n_inp, 1)
  4. dsets = Datasets(inp, tfms=[[None],[None],[None]])
  5. test_eq(dsets.n_inp, 2)
  6. dsets = Datasets(inp, tfms=[[None],[None],[None]], n_inp=1)
  7. test_eq(dsets.n_inp, 1)
  1. dsets = Datasets(range(5), tfms=[None], splits=[tensor([0,2]), [1,3,4]])
  2. test_eq(dsets.subset(0), [(0,),(2,)])
  3. test_eq(dsets.train, [(0,),(2,)]) # Subset 0 is aliased to `train`
  4. test_eq(dsets.subset(1), [(1,),(3,),(4,)])
  5. test_eq(dsets.valid, [(1,),(3,),(4,)]) # Subset 1 is aliased to `valid`
  6. test_eq(*dsets.valid[2], 4)
  7. #assert '[(1,),(3,),(4,)]' in str(dsets) and '[(0,),(2,)]' in str(dsets)
  8. dsets
  1. (#5) [(0,),(1,),(2,),(3,),(4,)]
  1. splits = [[False,True,True,False,True], [True,False,False,False,False]]
  2. dsets = Datasets(range(5), tfms=[None], splits=splits)
  3. test_eq(dsets.train, [(1,),(2,),(4,)])
  4. test_eq(dsets.valid, [(0,)])
  1. tfm = [[lambda x: x*2,lambda x: x+1]]
  2. splits = [[1,2],[0,3,4]]
  3. dsets = Datasets(range(5), tfm, splits=splits)
  4. test_eq(dsets.train,[(3,),(5,)])
  5. test_eq(dsets.valid,[(1,),(7,),(9,)])
  6. test_eq(dsets.train[False,True], [(5,)])
  1. class _Tfm(Transform):
  2. split_idx=1
  3. def encodes(self, x): return x*2
  4. def decodes(self, x): return TitledStr(x//2)
  1. dsets = Datasets(range(5), [_Tfm()], splits=[[1,2],[0,3,4]])
  2. test_eq(dsets.train,[(1,),(2,)])
  3. test_eq(dsets.valid,[(0,),(6,),(8,)])
  4. test_eq(dsets.train[False,True], [(2,)])
  5. dsets
  1. (#5) [(0,),(1,),(2,),(3,),(4,)]
  1. ds = dsets.train
  2. with ds.set_split_idx(1):
  3. test_eq(ds,[(2,),(4,)])
  4. test_eq(dsets.train,[(1,),(2,)])
  1. dsets = Datasets(range(5), [_Tfm(),noop], splits=[[1,2],[0,3,4]])
  2. test_eq(dsets.train,[(1,1),(2,2)])
  3. test_eq(dsets.valid,[(0,0),(6,3),(8,4)])
  1. start = torch.arange(0,50)
  2. tds = Datasets(start, [A()])
  3. tdl = TfmdDL(tds, after_item=NegTfm(), bs=4)
  4. b = tdl.one_batch()
  5. test_eq(tdl.decode_batch(b), ((0,),(1,),(2,),(3,)))
  6. test_stdout(tdl.show_batch, "0n1n2n3")
  1. class _Tfm(Transform):
  2. split_idx=1
  3. def encodes(self, x): return x*2
  4. dsets = Datasets(range(8), [None], splits=[[1,2,5,7],[0,3,4,6]])
  1. class _Tfm(Transform):
  2. split_idx=1
  3. def encodes(self, x): return x*2
  4. dsets = Datasets(range(8), [None], splits=[[1,2,5,7],[0,3,4,6]])
  5. dls = dsets.dataloaders(bs=4, after_batch=_Tfm(), shuffle=False, device=torch.device('cpu'))
  6. test_eq(dls.train, [(tensor([1,2,5, 7]),)])
  7. test_eq(dls.valid, [(tensor([0,6,8,12]),)])
  8. test_eq(dls.n_inp, 1)

Methods

  1. items = [1,2,3,4]
  2. dsets = Datasets(items, [[neg_tfm,int2f_tfm]])

Datasets.dataloaders[source]

Datasets.dataloaders(bs=64, shuffle_train=None, shuffle=True, val_shuffle=False, n=None, path='.', dl_type=None, dl_kwargs=None, device=None, drop_last=None, val_bs=None, num_workers=None, verbose=False, do_setup=True, pin_memory=False, timeout=0, batch_size=None, indexed=None, persistent_workers=False, wif=None, before_iter=None, after_item=None, before_batch=None, after_batch=None, after_iter=None, create_batches=None, create_item=None, create_batch=None, retain=None, get_idxs=None, sample=None, shuffle_fn=None, do_batch=None)

Get a DataLoaders

Used to create dataloaders. You may prepend ‘val_’ as in val_shuffle to override functionality for the validation set. dl_kwargs gives finer per dataloader control if you need to work with more than one dataloader.

Datasets.decode[source]

Datasets.decode(o, full=True)

Compose decode of all tuple_tfms then all tfms on i

  1. test_eq(*dsets[0], -1)
  2. test_eq(*dsets.decode((-1,)), 1)

Datasets.show[source]

Datasets.show(o, ctx=None, **kwargs)

Show item o in ctx

  1. test_stdout(lambda:dsets.show(dsets[1]), '-2')

Datasets.new_empty[source]

Datasets.new_empty()

Create a new empty version of the self, keeping only the transforms

  1. items = [1,2,3,4]
  2. nrm = Norm()
  3. dsets = Datasets(items, [[neg_tfm,int2f_tfm], [neg_tfm]])
  4. empty = dsets.new_empty()
  5. test_eq(empty.items, [])

Add test set for inference

  1. class _Tfm1(Transform):
  2. split_idx=0
  3. def encodes(self, x): return x*3
  4. dsets = Datasets(range(8), [[_Tfm(),_Tfm1()]], splits=[[1,2,5,7],[0,3,4,6]])
  5. test_eq(dsets.train, [(3,),(6,),(15,),(21,)])
  6. test_eq(dsets.valid, [(0,),(6,),(8,),(12,)])

test_set[source]

test_set(dsets, test_items, rm_tfms=None, with_labels=False)

Create a test set from test_items using validation transforms of dsets

  1. class _Tfm1(Transform):
  2. split_idx=0
  3. def encodes(self, x): return x*3
  4. dsets = Datasets(range(8), [[_Tfm(),_Tfm1()]], splits=[[1,2,5,7],[0,3,4,6]])
  5. test_eq(dsets.train, [(3,),(6,),(15,),(21,)])
  6. test_eq(dsets.valid, [(0,),(6,),(8,),(12,)])
  7. #Tranform of the validation set are applied
  8. tst = test_set(dsets, [1,2,3])
  9. test_eq(tst, [(2,),(4,),(6,)])

DataLoaders.test_dl[source]

DataLoaders.test_dl(test_items, rm_type_tfms=None, with_labels=False, bs=64, shuffle=False, num_workers=None, verbose=False, do_setup=True, pin_memory=False, timeout=0, batch_size=None, drop_last=False, indexed=None, n=None, device=None, persistent_workers=False, wif=None, before_iter=None, after_item=None, before_batch=None, after_batch=None, after_iter=None, create_batches=None, create_item=None, create_batch=None, retain=None, get_idxs=None, sample=None, shuffle_fn=None, do_batch=None)

Create a test dataloader from test_items using validation transforms of dls

  1. dsets = Datasets(range(8), [[_Tfm(),_Tfm1()]], splits=[[1,2,5,7],[0,3,4,6]])
  2. dls = dsets.dataloaders(bs=4, device=torch.device('cpu'))
  1. dsets = Datasets(range(8), [[_Tfm(),_Tfm1()]], splits=[[1,2,5,7],[0,3,4,6]])
  2. dls = dsets.dataloaders(bs=4, device=torch.device('cpu'))
  3. tst_dl = dls.test_dl([2,3,4,5])
  4. test_eq(tst_dl._n_inp, 1)
  5. test_eq(list(tst_dl), [(tensor([ 4, 6, 8, 10]),)])
  6. #Test you can change transforms
  7. tst_dl = dls.test_dl([2,3,4,5], after_item=add1)
  8. test_eq(list(tst_dl), [(tensor([ 5, 7, 9, 11]),)])

Company logo

©2021 fast.ai. All rights reserved.
Site last generated: Mar 31, 2021