BigQuery ML Semi-supervised Self-training Classification with mnist Dataset

Imports and project variables


In [3]:
import os
import shutil
from google.cloud import bigquery
import numpy as np
import pandas as pd
import tensorflow as tf
print(tf.__version__)


1.14.0

In [2]:
# Allow you to easily have Python variables in SQL query.
from IPython.core.magic import register_cell_magic
from IPython import get_ipython


@register_cell_magic("with_globals")
def with_globals(line, cell):
    contents = cell.format(**globals())
    if "print" in line:
        print(contents)
    get_ipython().run_cell(contents)

In [4]:
# change these to try this notebook out
# PROJECT = "cloud-training-demos"
# BUCKET = "cloud-training-demos-ml"
PROJECT = "qwiklabs-gcp-8312a1428d9eb5e2"
BUCKET = "qwiklabs-gcp-8312a1428d9eb5e2-bucket"
REGION = "us-central1"

In [5]:
os.environ["PROJECT"] = PROJECT
os.environ["BUCKET"] = BUCKET
os.environ["REGION"] = REGION

Create data


In [6]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

In [7]:
print("x_train.shape = {}".format(x_train.shape))
print("y_train.shape = {}".format(y_train.shape))
print("x_test.shape = {}".format(x_test.shape))
print("y_test.shape = {}".format(y_test.shape))


x_train.shape = (60000, 28, 28)
y_train.shape = (60000,)
x_test.shape = (10000, 28, 28)
y_test.shape = (10000,)

In [8]:
x_train_flat = x_train.reshape(
  x_train.shape[0], x_train.shape[1] * x_train.shape[2])
x_train_flat.shape


Out[8]:
(60000, 784)

In [9]:
x_test_flat = x_test.reshape(
  x_test.shape[0], x_test.shape[1] * x_test.shape[2])
x_test_flat.shape


Out[9]:
(10000, 784)

In [10]:
train = np.concatenate([x_train_flat, np.expand_dims(y_train, -1),
                        np.random.rand(x_train_flat.shape[0], 1)],
                       axis = 1)
train.shape


Out[10]:
(60000, 786)

In [11]:
test = np.concatenate([x_test_flat,
                       np.expand_dims(y_test, -1)],
                      axis = 1)
test.shape


Out[11]:
(10000, 785)

In [12]:
train_df = pd.DataFrame(
  train,
  columns=["v_" + str(i)
           for i in range(x_train_flat.shape[1])] + ["label", "rand"])
train_df.head()


Out[12]:
v_0 v_1 v_2 v_3 v_4 v_5 v_6 v_7 v_8 v_9 ... v_776 v_777 v_778 v_779 v_780 v_781 v_782 v_783 label rand
0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 5.0 0.287787
1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.284469
2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 4.0 0.916785
3 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.378841
4 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 9.0 0.363079

5 rows × 786 columns


In [13]:
test_df = pd.DataFrame(
  test,
  columns=["v_" + str(i)
           for i in range(x_test_flat.shape[1])] + ["label"])
test_df.head()


Out[13]:
v_0 v_1 v_2 v_3 v_4 v_5 v_6 v_7 v_8 v_9 ... v_775 v_776 v_777 v_778 v_779 v_780 v_781 v_782 v_783 label
0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 7.0
1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 2.0
2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0
3 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
4 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 4.0

5 rows × 785 columns


In [14]:
train_df.describe()


Out[14]:
v_0 v_1 v_2 v_3 v_4 v_5 v_6 v_7 v_8 v_9 ... v_776 v_777 v_778 v_779 v_780 v_781 v_782 v_783 label rand
count 60000.0 60000.0 60000.0 60000.0 60000.0 60000.0 60000.0 60000.0 60000.0 60000.0 ... 60000.000000 60000.000000 60000.000000 60000.000000 60000.0 60000.0 60000.0 60000.0 60000.000000 60000.000000
mean 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.000179 0.000076 0.000059 0.000008 0.0 0.0 0.0 0.0 4.453933 0.499276
std 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.011137 0.006615 0.006582 0.001359 0.0 0.0 0.0 0.0 2.889270 0.288386
min 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.000000 0.000000 0.000000 0.000000 0.0 0.0 0.0 0.0 0.000000 0.000022
25% 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.000000 0.000000 0.000000 0.000000 0.0 0.0 0.0 0.0 2.000000 0.249056
50% 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.000000 0.000000 0.000000 0.000000 0.0 0.0 0.0 0.0 4.000000 0.498355
75% 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.000000 0.000000 0.000000 0.000000 0.0 0.0 0.0 0.0 7.000000 0.749470
max 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.992157 0.992157 0.996078 0.243137 0.0 0.0 0.0 0.0 9.000000 0.999963

8 rows × 786 columns


In [15]:
test_df.describe()


Out[15]:
v_0 v_1 v_2 v_3 v_4 v_5 v_6 v_7 v_8 v_9 ... v_775 v_776 v_777 v_778 v_779 v_780 v_781 v_782 v_783 label
count 10000.0 10000.0 10000.0 10000.0 10000.0 10000.0 10000.0 10000.0 10000.0 10000.0 ... 10000.000000 10000.000000 10000.000000 10000.0 10000.0 10000.0 10000.0 10000.0 10000.0 10000.000000
mean 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.000642 0.000206 0.000002 0.0 0.0 0.0 0.0 0.0 0.0 4.443400
std 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.022494 0.009490 0.000235 0.0 0.0 0.0 0.0 0.0 0.0 2.895865
min 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.000000 0.000000 0.000000 0.0 0.0 0.0 0.0 0.0 0.0 0.000000
25% 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.000000 0.000000 0.000000 0.0 0.0 0.0 0.0 0.0 0.0 2.000000
50% 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.000000 0.000000 0.000000 0.0 0.0 0.0 0.0 0.0 0.0 4.000000
75% 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.000000 0.000000 0.000000 0.0 0.0 0.0 0.0 0.0 0.0 7.000000
max 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.992157 0.611765 0.023529 0.0 0.0 0.0 0.0 0.0 0.0 9.000000

8 rows × 785 columns


In [16]:
train_df.to_csv("mnist_train.csv", index=False)
test_df.to_csv("mnist_test.csv", index=False)

In [17]:
!head -2 mnist_train.csv


v_0,v_1,v_2,v_3,v_4,v_5,v_6,v_7,v_8,v_9,v_10,v_11,v_12,v_13,v_14,v_15,v_16,v_17,v_18,v_19,v_20,v_21,v_22,v_23,v_24,v_25,v_26,v_27,v_28,v_29,v_30,v_31,v_32,v_33,v_34,v_35,v_36,v_37,v_38,v_39,v_40,v_41,v_42,v_43,v_44,v_45,v_46,v_47,v_48,v_49,v_50,v_51,v_52,v_53,v_54,v_55,v_56,v_57,v_58,v_59,v_60,v_61,v_62,v_63,v_64,v_65,v_66,v_67,v_68,v_69,v_70,v_71,v_72,v_73,v_74,v_75,v_76,v_77,v_78,v_79,v_80,v_81,v_82,v_83,v_84,v_85,v_86,v_87,v_88,v_89,v_90,v_91,v_92,v_93,v_94,v_95,v_96,v_97,v_98,v_99,v_100,v_101,v_102,v_103,v_104,v_105,v_106,v_107,v_108,v_109,v_110,v_111,v_112,v_113,v_114,v_115,v_116,v_117,v_118,v_119,v_120,v_121,v_122,v_123,v_124,v_125,v_126,v_127,v_128,v_129,v_130,v_131,v_132,v_133,v_134,v_135,v_136,v_137,v_138,v_139,v_140,v_141,v_142,v_143,v_144,v_145,v_146,v_147,v_148,v_149,v_150,v_151,v_152,v_153,v_154,v_155,v_156,v_157,v_158,v_159,v_160,v_161,v_162,v_163,v_164,v_165,v_166,v_167,v_168,v_169,v_170,v_171,v_172,v_173,v_174,v_175,v_176,v_177,v_178,v_179,v_180,v_181,v_182,v_183,v_184,v_185,v_186,v_187,v_188,v_189,v_190,v_191,v_192,v_193,v_194,v_195,v_196,v_197,v_198,v_199,v_200,v_201,v_202,v_203,v_204,v_205,v_206,v_207,v_208,v_209,v_210,v_211,v_212,v_213,v_214,v_215,v_216,v_217,v_218,v_219,v_220,v_221,v_222,v_223,v_224,v_225,v_226,v_227,v_228,v_229,v_230,v_231,v_232,v_233,v_234,v_235,v_236,v_237,v_238,v_239,v_240,v_241,v_242,v_243,v_244,v_245,v_246,v_247,v_248,v_249,v_250,v_251,v_252,v_253,v_254,v_255,v_256,v_257,v_258,v_259,v_260,v_261,v_262,v_263,v_264,v_265,v_266,v_267,v_268,v_269,v_270,v_271,v_272,v_273,v_274,v_275,v_276,v_277,v_278,v_279,v_280,v_281,v_282,v_283,v_284,v_285,v_286,v_287,v_288,v_289,v_290,v_291,v_292,v_293,v_294,v_295,v_296,v_297,v_298,v_299,v_300,v_301,v_302,v_303,v_304,v_305,v_306,v_307,v_308,v_309,v_310,v_311,v_312,v_313,v_314,v_315,v_316,v_317,v_318,v_319,v_320,v_321,v_322,v_323,v_324,v_325,v_326,v_327,v_328,v_329,v_330,v_331,v_332,v_333,v_334,v_335,v_336,v_337,v_338,v_339,v_340,v_341,v_342,v_343,v_344,v_345,v_346,v_347,v_348,v_349,v_350,v_351,v_352,v_353,v_354,v_355,v_356,v_357,v_358,v_359,v_360,v_361,v_362,v_363,v_364,v_365,v_366,v_367,v_368,v_369,v_370,v_371,v_372,v_373,v_374,v_375,v_376,v_377,v_378,v_379,v_380,v_381,v_382,v_383,v_384,v_385,v_386,v_387,v_388,v_389,v_390,v_391,v_392,v_393,v_394,v_395,v_396,v_397,v_398,v_399,v_400,v_401,v_402,v_403,v_404,v_405,v_406,v_407,v_408,v_409,v_410,v_411,v_412,v_413,v_414,v_415,v_416,v_417,v_418,v_419,v_420,v_421,v_422,v_423,v_424,v_425,v_426,v_427,v_428,v_429,v_430,v_431,v_432,v_433,v_434,v_435,v_436,v_437,v_438,v_439,v_440,v_441,v_442,v_443,v_444,v_445,v_446,v_447,v_448,v_449,v_450,v_451,v_452,v_453,v_454,v_455,v_456,v_457,v_458,v_459,v_460,v_461,v_462,v_463,v_464,v_465,v_466,v_467,v_468,v_469,v_470,v_471,v_472,v_473,v_474,v_475,v_476,v_477,v_478,v_479,v_480,v_481,v_482,v_483,v_484,v_485,v_486,v_487,v_488,v_489,v_490,v_491,v_492,v_493,v_494,v_495,v_496,v_497,v_498,v_499,v_500,v_501,v_502,v_503,v_504,v_505,v_506,v_507,v_508,v_509,v_510,v_511,v_512,v_513,v_514,v_515,v_516,v_517,v_518,v_519,v_520,v_521,v_522,v_523,v_524,v_525,v_526,v_527,v_528,v_529,v_530,v_531,v_532,v_533,v_534,v_535,v_536,v_537,v_538,v_539,v_540,v_541,v_542,v_543,v_544,v_545,v_546,v_547,v_548,v_549,v_550,v_551,v_552,v_553,v_554,v_555,v_556,v_557,v_558,v_559,v_560,v_561,v_562,v_563,v_564,v_565,v_566,v_567,v_568,v_569,v_570,v_571,v_572,v_573,v_574,v_575,v_576,v_577,v_578,v_579,v_580,v_581,v_582,v_583,v_584,v_585,v_586,v_587,v_588,v_589,v_590,v_591,v_592,v_593,v_594,v_595,v_596,v_597,v_598,v_599,v_600,v_601,v_602,v_603,v_604,v_605,v_606,v_607,v_608,v_609,v_610,v_611,v_612,v_613,v_614,v_615,v_616,v_617,v_618,v_619,v_620,v_621,v_622,v_623,v_624,v_625,v_626,v_627,v_628,v_629,v_630,v_631,v_632,v_633,v_634,v_635,v_636,v_637,v_638,v_639,v_640,v_641,v_642,v_643,v_644,v_645,v_646,v_647,v_648,v_649,v_650,v_651,v_652,v_653,v_654,v_655,v_656,v_657,v_658,v_659,v_660,v_661,v_662,v_663,v_664,v_665,v_666,v_667,v_668,v_669,v_670,v_671,v_672,v_673,v_674,v_675,v_676,v_677,v_678,v_679,v_680,v_681,v_682,v_683,v_684,v_685,v_686,v_687,v_688,v_689,v_690,v_691,v_692,v_693,v_694,v_695,v_696,v_697,v_698,v_699,v_700,v_701,v_702,v_703,v_704,v_705,v_706,v_707,v_708,v_709,v_710,v_711,v_712,v_713,v_714,v_715,v_716,v_717,v_718,v_719,v_720,v_721,v_722,v_723,v_724,v_725,v_726,v_727,v_728,v_729,v_730,v_731,v_732,v_733,v_734,v_735,v_736,v_737,v_738,v_739,v_740,v_741,v_742,v_743,v_744,v_745,v_746,v_747,v_748,v_749,v_750,v_751,v_752,v_753,v_754,v_755,v_756,v_757,v_758,v_759,v_760,v_761,v_762,v_763,v_764,v_765,v_766,v_767,v_768,v_769,v_770,v_771,v_772,v_773,v_774,v_775,v_776,v_777,v_778,v_779,v_780,v_781,v_782,v_783,label,rand
0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.011764705882352941,0.07058823529411765,0.07058823529411765,0.07058823529411765,0.49411764705882355,0.5333333333333333,0.6862745098039216,0.10196078431372549,0.6509803921568628,1.0,0.9686274509803922,0.4980392156862745,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.11764705882352941,0.1411764705882353,0.3686274509803922,0.6039215686274509,0.6666666666666666,0.9921568627450981,0.9921568627450981,0.9921568627450981,0.9921568627450981,0.9921568627450981,0.8823529411764706,0.6745098039215687,0.9921568627450981,0.9490196078431372,0.7647058823529411,0.25098039215686274,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.19215686274509805,0.9333333333333333,0.9921568627450981,0.9921568627450981,0.9921568627450981,0.9921568627450981,0.9921568627450981,0.9921568627450981,0.9921568627450981,0.9921568627450981,0.984313725490196,0.36470588235294116,0.3215686274509804,0.3215686274509804,0.2196078431372549,0.15294117647058825,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.07058823529411765,0.8588235294117647,0.9921568627450981,0.9921568627450981,0.9921568627450981,0.9921568627450981,0.9921568627450981,0.7764705882352941,0.7137254901960784,0.9686274509803922,0.9450980392156862,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.3137254901960784,0.611764705882353,0.4196078431372549,0.9921568627450981,0.9921568627450981,0.803921568627451,0.043137254901960784,0.0,0.16862745098039217,0.6039215686274509,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.054901960784313725,0.00392156862745098,0.6039215686274509,0.9921568627450981,0.35294117647058826,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.5450980392156862,0.9921568627450981,0.7450980392156863,0.00784313725490196,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.043137254901960784,0.7450980392156863,0.9921568627450981,0.27450980392156865,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.13725490196078433,0.9450980392156862,0.8823529411764706,0.6274509803921569,0.4235294117647059,0.00392156862745098,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.3176470588235294,0.9411764705882353,0.9921568627450981,0.9921568627450981,0.4666666666666667,0.09803921568627451,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.17647058823529413,0.7294117647058823,0.9921568627450981,0.9921568627450981,0.5882352941176471,0.10588235294117647,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.06274509803921569,0.36470588235294116,0.9882352941176471,0.9921568627450981,0.7333333333333333,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.9764705882352941,0.9921568627450981,0.9764705882352941,0.25098039215686274,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.1803921568627451,0.5098039215686274,0.7176470588235294,0.9921568627450981,0.9921568627450981,0.8117647058823529,0.00784313725490196,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.15294117647058825,0.5803921568627451,0.8980392156862745,0.9921568627450981,0.9921568627450981,0.9921568627450981,0.9803921568627451,0.7137254901960784,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.09411764705882353,0.4470588235294118,0.8666666666666667,0.9921568627450981,0.9921568627450981,0.9921568627450981,0.9921568627450981,0.788235294117647,0.3058823529411765,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.09019607843137255,0.25882352941176473,0.8352941176470589,0.9921568627450981,0.9921568627450981,0.9921568627450981,0.9921568627450981,0.7764705882352941,0.3176470588235294,0.00784313725490196,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.07058823529411765,0.6705882352941176,0.8588235294117647,0.9921568627450981,0.9921568627450981,0.9921568627450981,0.9921568627450981,0.7647058823529411,0.3137254901960784,0.03529411764705882,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.21568627450980393,0.6745098039215687,0.8862745098039215,0.9921568627450981,0.9921568627450981,0.9921568627450981,0.9921568627450981,0.9568627450980393,0.5215686274509804,0.043137254901960784,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.5333333333333333,0.9921568627450981,0.9921568627450981,0.9921568627450981,0.8313725490196079,0.5294117647058824,0.5176470588235295,0.06274509803921569,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,5.0,0.28778708060614955

In [ ]:
%%bash
gsutil -m cp mnist*.csv gs://${BUCKET}

Write data to BigQuery


In [19]:
client = bigquery.Client()
dataset_id = "semi"
dataset_ref = client.dataset(dataset_id)
feature_schema = [bigquery.SchemaField(
  name="v_{}".format(i),
  field_type="FLOAT64",
  mode="NULLABLE",
  description="Feature {}".format(i))
                  for i in range(x_train_flat.shape[-1])]
label_schema = [bigquery.SchemaField(
  name="label",
  field_type="FLOAT64",
  mode="NULLABLE",
  description="Label")]
rand_schema = [bigquery.SchemaField(
  name="rand",
  field_type="FLOAT64",
  mode="NULLABLE",
  description="Random number")]
job_config = bigquery.LoadJobConfig()
job_config.schema = feature_schema + label_schema + rand_schema
job_config.write_disposition = bigquery.WriteDisposition.WRITE_TRUNCATE
job_config.skip_leading_rows = 1
# The source format defaults to CSV, so the line below is optional.
job_config.source_format = bigquery.SourceFormat.CSV

In [20]:
def load_csv_data_to_bigquery(client, dataset_ref, job_config, name):
  uri = "gs://{bucket}/{name}.csv".format(bucket=BUCKET, name=name)

  load_job = client.load_table_from_uri(
      uri, dataset_ref.table(name), job_config=job_config
  )  # API request
  print("Starting job {}".format(load_job.job_id))

  load_job.result()  # Waits for table load to complete.
  print("Job finished.")

  destination_table = client.get_table(dataset_ref.table(name))
  print("Loaded {} rows.".format(destination_table.num_rows))

  return None

Train set


In [21]:
job_config.schema = feature_schema + label_schema + rand_schema
load_csv_data_to_bigquery(client, dataset_ref, job_config, "mnist_train")


Starting job d0624e4a-8806-4f6b-8758-2305f7bd447c
Job finished.
Loaded 60000 rows.

Test set


In [22]:
job_config.schema = feature_schema + label_schema
load_csv_data_to_bigquery(client, dataset_ref, job_config, "mnist_test")


Starting job 3e514686-96b1-4919-bff9-4c67a1c6d323
Job finished.
Loaded 10000 rows.

Create semi-supervised simulated splits


In [23]:
PERCENT_LABELED = 10.0

In [24]:
def create_semi_supervised_simulated_splits_in_bigquery(dataset_id, sql, name):
  job_config = bigquery.QueryJobConfig()
  # Set the destination table
  table_ref = client.dataset(dataset_id).table(name)
  job_config.destination = table_ref
  job_config.write_disposition = bigquery.WriteDisposition.WRITE_TRUNCATE
  # Start the query, passing in the extra configuration.
  query_job = client.query(
      sql,
      # Location must match that of the dataset(s) referenced in the query
      # and of the destination table.
      location="US",
      job_config=job_config)  # API request - starts the query

  query_job.result()  # Waits for the query to finish
  print('Query results loaded to table {}'.format(table_ref.path))

  return None

Labeled


In [25]:
def create_labeled_train_set(project, dataset_id, percent_labeled):
  mnist_train_labeled_sql = """
  SELECT
    * EXCEPT(rand)
  FROM
    `{project}.{dataset}.{table}`
  WHERE rand < {percent}
  """.format(
    project=project,
    dataset=dataset_id,
    table="mnist_train",
    percent=percent_labeled / 100.0)

  create_semi_supervised_simulated_splits_in_bigquery(
    dataset_id, mnist_train_labeled_sql, "mnist_train_labeled")

  return None

In [26]:
create_labeled_train_set(PROJECT, dataset_id, PERCENT_LABELED)


Query results loaded to table /projects/qwiklabs-gcp-8312a1428d9eb5e2/datasets/semi/tables/mnist_train_labeled

Unlabeled


In [27]:
def create_unlabeled_train_set(project, dataset_id, percent_labeled):
  mnist_train_unlabeled_sql = """
  SELECT
    * EXCEPT(rand)
  FROM
    `{project}.{dataset}.{table}`
  WHERE rand >= {percent}
  """.format(
    project=project,
    dataset=dataset_id,
    table="mnist_train",
    percent=percent_labeled / 100.0)

  create_semi_supervised_simulated_splits_in_bigquery(
    dataset_id, mnist_train_unlabeled_sql, "mnist_train_unlabeled")

  return None

In [28]:
create_unlabeled_train_set(PROJECT, dataset_id, PERCENT_LABELED)


Query results loaded to table /projects/qwiklabs-gcp-8312a1428d9eb5e2/datasets/semi/tables/mnist_train_unlabeled

BQML

Train model on labeled train set


In [29]:
def bqml_train_model_on_labeled_dataset():
  query_job = client.query("""
  CREATE OR REPLACE MODEL
    `bqml_ssl.self_training`
  OPTIONS
    ( model_type="logistic_reg",
      auto_class_weights=true,
      input_label_cols = ["label"]) AS
  SELECT
    *
  FROM
    `semi.mnist_train_labeled`
  """)

  try:
    query_job.result()
  finally:
    print("Training complete.")

  return None

In [30]:
bqml_train_model_on_labeled_dataset()


Training complete.

Look at training info


In [31]:
def bqml_training_info():
  query_job = client.query("""
  SELECT
      *
  FROM
      ML.TRAINING_INFO(MODEL `bqml_ssl.self_training`)
  """)

  results = query_job.result()  # Waits for job to complete.

  return results

In [32]:
pd.DataFrame([{key: value for key, value in row.items()} for row in bqml_training_info()])


Out[32]:
duration_ms eval_loss iteration learning_rate loss training_run
0 56049 0.034381 9 1.6 0.025464 0
1 48802 0.034698 8 0.8 0.026480 0
2 45422 0.035927 7 3.2 0.028079 0
3 50042 0.037224 6 1.6 0.030945 0
4 56957 0.039116 5 0.8 0.033941 0
5 48610 0.040741 4 0.4 0.036025 0
6 50531 0.047081 3 1.6 0.041920 0
7 55793 0.051897 2 0.8 0.049430 0
8 55009 0.069936 1 0.4 0.068228 0
9 49424 0.115370 0 0.2 0.114392 0

Evaluate on test set


In [33]:
def bqml_evaluate_on_test_dataset():
  query_job = client.query("""
  SELECT
    *
  FROM
    ML.EVALUATE(MODEL `bqml_ssl.self_training`,
    (SELECT * FROM `semi.mnist_test`))
  """)

  results = query_job.result()  # Waits for job to complete.

  return results

In [34]:
pd.DataFrame([{key: value for key, value in row.items()}
              for row in bqml_evaluate_on_test_dataset()])


Out[34]:
accuracy f1_score log_loss precision recall roc_auc
0 0.8993 0.897872 1.840486 0.898469 0.898071 0.960578

Predict on unlabeled train set


In [35]:
def bqml_predict_unlabeled_dataset():
  query_job = client.query("""
  SELECT
      * EXCEPT(predicted_label_probs, label)
  FROM
      ML.PREDICT(MODEL `bqml_ssl.self_training`,
                 (SELECT * FROM `semi.mnist_train_unlabeled` LIMIT 10)),
    UNNEST(predicted_label_probs) AS unnested_predicted_label_probs
  """)

  results = query_job.result()  # Waits for job to complete.

  return results

In [36]:
pd.DataFrame([{key: value for key, value in row.items()}
              for row in bqml_predict_unlabeled_dataset()])


Out[36]:
predicted_label prob v_0 v_1 v_10 v_100 v_101 v_102 v_103 v_104 ... v_90 v_91 v_92 v_93 v_94 v_95 v_96 v_97 v_98 v_99
0 0.0 0.149127 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
1 0.0 0.133352 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
2 0.0 0.133298 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
3 0.0 0.127621 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
4 0.0 0.107521 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
95 0.0 0.061369 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
96 0.0 0.059155 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
97 0.0 0.058982 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
98 0.0 0.058966 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
99 0.0 0.058668 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0

100 rows × 786 columns

Check confidence


In [37]:
percent_over_random = 80.0
number_of_classes = 10
confidence_percent = (1.0 + percent_over_random / 100.0) / number_of_classes

In [38]:
features_list = ["v_{}".format(i) for i in range(x_train_flat.shape[-1])]
features = ",\n  ".join(features_list)

In [39]:
confidence_query = """
WITH
  CTE_gen_ids AS (
  SELECT
    ROW_NUMBER() OVER () AS row_id,
    *
  FROM
    ML.PREDICT(MODEL `bqml_ssl.self_training`,
      (
      SELECT
        *
      FROM
        `semi.mnist_train_unlabeled`))),
  CTE_max_probs AS (
  SELECT
    row_id,
    MAX(unnested_predicted_label_probs.prob) AS max_prob
  FROM
    CTE_gen_ids,
    UNNEST(predicted_label_probs) AS unnested_predicted_label_probs
  GROUP BY
    row_id),
  CTE_filtered_max_probs AS (
  SELECT
    *
  FROM
    CTE_max_probs
  WHERE
    max_prob {inequality} {confidence_percent})
SELECT
  {features}{label}
FROM
  CTE_filtered_max_probs AS A
INNER JOIN
  CTE_gen_ids AS B
ON
  A.row_id = B.row_id
"""

In [40]:
high_confidence_features_label_query = confidence_query.format(
  inequality=">=",
  confidence_percent=confidence_percent,
  features=features,
  label=", predicted_label AS label")

In [41]:
high_confidence_features_query = confidence_query.format(
  inequality=">=",
  confidence_percent=confidence_percent,
  features=features,
  label="")

In [42]:
low_confidence_features_query = confidence_query.format(
  inequality="<",
  confidence_percent=confidence_percent,
  features=features,
  label="")

In [43]:
%%with_globals
%%bigquery --project $PROJECT
{high_confidence_features_label_query}


Out[43]:
v_0 v_1 v_2 v_3 v_4 v_5 v_6 v_7 v_8 v_9 ... v_775 v_776 v_777 v_778 v_779 v_780 v_781 v_782 v_783 label
0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 8.0
1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 2.0
2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 4.0
3 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 7.0
4 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 7.0
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
457 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 3.0
458 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 2.0
459 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 9.0
460 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 7.0
461 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 8.0

462 rows × 785 columns

Check initial table counts


In [44]:
%%with_globals
%%bigquery --project $PROJECT
SELECT COUNT(*) AS row_count FROM `{PROJECT}.semi.mnist_train_labeled`


Out[44]:
row_count
0 5963

In [45]:
%%with_globals
%%bigquery --project $PROJECT
SELECT COUNT(*) AS row_count FROM `{PROJECT}.semi.mnist_train_unlabeled`


Out[45]:
row_count
0 54037

In [46]:
%%with_globals
%%bigquery --project $PROJECT
SELECT COUNT(*) AS row_count
FROM ({high_confidence_features_query})


Out[46]:
row_count
0 462

In [47]:
%%with_globals
%%bigquery --project $PROJECT
SELECT COUNT(*) AS row_count
FROM ({low_confidence_features_query})


Out[47]:
row_count
0 53575

Adjust tables based on confidence of predictions

Add high confidence examples to labeled dataset with predicted labels


In [48]:
def add_high_confidence_examples_to_labeled(
  dataset_id, high_confidence_features_label_query):
  job_config = bigquery.QueryJobConfig()
  # Set the destination table
  table_ref = client.dataset(dataset_id).table("mnist_train_labeled")
  job_config.destination = table_ref
  job_config.write_disposition = bigquery.WriteDisposition.WRITE_APPEND
  # Start the query, passing in the extra configuration.
  query_job = client.query(
      high_confidence_features_label_query,
      # Location must match that of the dataset(s) referenced in the query
      # and of the destination table.
      location="US",
      job_config=job_config)  # API request - starts the query

  query_job.result()  # Waits for the query to finish
  print('Query results loaded to table {}'.format(table_ref.path))

  return None

In [49]:
add_high_confidence_examples_to_labeled(
  dataset_id, high_confidence_features_label_query)


Query results loaded to table /projects/qwiklabs-gcp-8312a1428d9eb5e2/datasets/semi/tables/mnist_train_labeled

Check updated table counts


In [50]:
%%with_globals
%%bigquery --project $PROJECT
SELECT COUNT(*) AS row_count FROM `{PROJECT}.semi.mnist_train_labeled`


Out[50]:
row_count
0 6425

In [51]:
%%with_globals
%%bigquery --project $PROJECT
SELECT COUNT(*) AS row_count FROM `{PROJECT}.semi.mnist_train_unlabeled`


Out[51]:
row_count
0 54037

Remove high confidence examples from unlabeled dataset


In [52]:
def remove_high_confidence_examples_from_unlabeled(
  dataset_id, low_confidence_features_query):
  job_config = bigquery.QueryJobConfig()
  # Set the destination table
  table_ref = client.dataset(dataset_id).table("mnist_train_unlabeled")
  job_config.destination = table_ref
  job_config.write_disposition = bigquery.WriteDisposition.WRITE_TRUNCATE
  # Start the query, passing in the extra configuration.
  query_job = client.query(
      low_confidence_features_query,
      # Location must match that of the dataset(s) referenced in the query
      # and of the destination table.
      location="US",
      job_config=job_config)  # API request - starts the query

  query_job.result()  # Waits for the query to finish
  print('Query results loaded to table {}'.format(table_ref.path))

  return None

In [53]:
remove_high_confidence_examples_from_unlabeled(
  dataset_id, low_confidence_features_query)


Query results loaded to table /projects/qwiklabs-gcp-8312a1428d9eb5e2/datasets/semi/tables/mnist_train_unlabeled

Check updated table counts


In [54]:
%%with_globals
%%bigquery --project $PROJECT
SELECT COUNT(*) AS row_count FROM `{PROJECT}.semi.mnist_train_labeled`


Out[54]:
row_count
0 6425

In [55]:
%%with_globals
%%bigquery --project $PROJECT
SELECT COUNT(*) AS row_count FROM `{PROJECT}.semi.mnist_train_unlabeled`


Out[55]:
row_count
0 53575

Semi-supervised Self-training Loop

Reset labeled and unlabeled datasets


In [56]:
create_labeled_train_set(PROJECT, dataset_id, PERCENT_LABELED)


Query results loaded to table /projects/qwiklabs-gcp-8312a1428d9eb5e2/datasets/semi/tables/mnist_train_labeled

In [57]:
create_unlabeled_train_set(PROJECT, dataset_id, PERCENT_LABELED)


Query results loaded to table /projects/qwiklabs-gcp-8312a1428d9eb5e2/datasets/semi/tables/mnist_train_unlabeled

Loop until no improvement


In [58]:
old_accuracy = 0.0
max_iterations = 5
iteration = 0
while iteration < max_iterations:
  print("Iteration = {}".format(iteration))

  # Train model on labeled dataset
  print("Starting training.")
  bqml_train_model_on_labeled_dataset()

  # Evaluate model on test set
  print("Starting evaluation.")
  eval_metrics = pd.DataFrame([{key: value for key, value in row.items()}
                               for row in bqml_evaluate_on_test_dataset()])
  print("eval_metrics = {}".format(eval_metrics))

  # Extract accuracy from eval metrics
  accuracy = eval_metrics["accuracy"][0]

  accuracy_improvement = accuracy - old_accuracy
  old_accuracy = accuracy

  if accuracy_improvement > 0.01:
    # Add high confidence examples to labeled from unlabeled
    print("Adding high confidence examples to labeled.")
    add_high_confidence_examples_to_labeled(
      dataset_id, high_confidence_features_label_query)

    # Remove high confidence examples from unlabeled
    print("Removing high confidence examples from unlabeled.")
    remove_high_confidence_examples_from_unlabeled(
      dataset_id, low_confidence_features_query)
    
    iteration += 1
  else:
    print("Not enough improvement, breaking loop!")
    break


Iteration = 0
Starting training.
Training complete.
Starting evaluation.
eval_metrics =    accuracy  f1_score  log_loss  precision    recall   roc_auc
0    0.8993  0.897872  1.840486   0.898469  0.898071  0.960578
Adding high confidence examples to labeled.
Query results loaded to table /projects/qwiklabs-gcp-8312a1428d9eb5e2/datasets/semi/tables/mnist_train_labeled
Removing high confidence examples from unlabeled.
Query results loaded to table /projects/qwiklabs-gcp-8312a1428d9eb5e2/datasets/semi/tables/mnist_train_unlabeled
Iteration = 1
Starting training.
Training complete.
Starting evaluation.
eval_metrics =    accuracy  f1_score  log_loss  precision    recall   roc_auc
0    0.9005  0.898888  1.839274   0.899643  0.899122  0.959887
Not enough improvement, breaking loop!

In [ ]: