WNixalo –– 2018/6/1-2


In [20]:
from pathlib import Path
from shutil import rmtree, copyfile
import os
import sys
import numpy as np

In [3]:
def count_files(path, fullpath=Path(), count=0):
    """Counts all files in a directory recursively."""
    fullpath /= path
    # check root exists
    if not os.path.exists(fullpath):
        print('Directory does not exist.')
        return
    dirs = os.listdir(fullpath)
    for direc in dirs:
        if (fullpath/direc).is_dir():
            count += count_files(direc, fullpath)
        else:
            count += 1
    return count

In [9]:
path = Path('data/cifar/')

In [12]:
Path(str(path) + '_tmp')


Out[12]:
PosixPath('data/cifar_tmp')

In [105]:
def create_cifar_subset(path, fullpath=Path(), copypath='', p=0.1, copydirs=['train','valid','test']):
    if not copypath:
        copypath = Path(str(path) + '_tmp')
        if os.path.exists(copypath): rmtree(copypath)
    else:
        copypath/=path
    fullpath /= path
    copies = []
    dirs = os.listdir(fullpath)
    for f in dirs:
        if (fullpath/f).is_dir() and (copydirs==[] or f in copydirs):
            os.makedirs(copypath/f)
            create_cifar_subset(f, fullpath, copypath, copydirs=[])
        else:
            copies.append(f)
    if copies:
        copies = np.random.choice(copies, max(1, int(len(copies)*p)), replace=False)
        for copy in copies:
            copyfile(fullpath/copy, copypath/copy)
    
    return copypath

In [106]:
PATH = Path('data/cifar10/')
alt_PATH = create_cifar_subset(PATH)

In [107]:
count_files(PATH), count_files(alt_PATH)


Out[107]:
(60003, 6001)