In [1]:
import torch
import numpy as np

import torch.nn.functional as F

from torch import nn, optim

from torchvision import datasets
import torchvision.transforms as transforms

In [2]:
num_workers = 0

batch_size = 20

valid_size = 0.2

In [3]:
from torch.utils.data.sampler import SubsetRandomSampler

In [5]:
transform = transforms.ToTensor()

train_data = datasets.MNIST(root="MNIST_data", train=True, download=True, transform=transform)
test_data = datasets.MNIST(root="MNIST_data", train=False, download=True, transform=transform)

In [6]:
num_train = len(train_data)
indices = list(range(num_train))

np.random.shuffle(indices)

In [7]:
split = int(np.floor(valid_size * num_train))

train_idx, valid_idx = indices[split:], indices[:split]

In [8]:
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)


train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, sampler=train_sampler, num_workers=num_workers)
valid_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, sampler=valid_sampler, num_workers=num_workers)

test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, num_workers=num_workers)

In [ ]: