diff --git a/test.py b/test.py index 222f829a..686a8f5d 100644 --- a/test.py +++ b/test.py @@ -48,7 +48,7 @@ def test( dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=4, - pin_memory=False, + pin_memory=True, collate_fn=dataset.collate_fn) seen = 0 diff --git a/train.py b/train.py index 4dca56bc..5bce4adc 100644 --- a/train.py +++ b/train.py @@ -93,7 +93,7 @@ def train( batch_size=batch_size, num_workers=num_workers, shuffle=False, - pin_memory=False, + pin_memory=True, collate_fn=dataset.collate_fn, sampler=sampler)