In [1]:
import torch
import numpy as np
import torch.nn.functional as F
from torch import nn, optim
from torchvision import datasets
import torchvision.transforms as transforms
In [2]:
num_workers = 0
batch_size = 20
valid_size = 0.2
In [3]:
from torch.utils.data.sampler import SubsetRandomSampler
In [5]:
transform = transforms.ToTensor()
train_data = datasets.MNIST(root="MNIST_data", train=True, download=True, transform=transform)
test_data = datasets.MNIST(root="MNIST_data", train=False, download=True, transform=transform)
In [6]:
num_train = len(train_data)
indices = list(range(num_train))
np.random.shuffle(indices)
In [7]:
split = int(np.floor(valid_size * num_train))
train_idx, valid_idx = indices[split:], indices[:split]
In [8]:
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, sampler=train_sampler, num_workers=num_workers)
valid_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, sampler=valid_sampler, num_workers=num_workers)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, num_workers=num_workers)
In [ ]: