If machine learning is rocket science then data is your fuel! So before doing anything we will have a close look at the data available and spend some time bringing it into the "right" form (i.e. tf.train.Example).
That's why we start by spending quite a lot of time on this notebook, downloading the data, understanding it, and transforming it into the right format for Tensorflow.
The data used in this workshop is taken from Google's quickdraw (click on the images to see loads of examples):
https://quickdraw.withgoogle.com/data
We will download the data below.
First, we'll choose where our data should be stored.
If you choose a path under "/content/gdrive/My Drive" then data will be stored in your Google drive and persisted across VM starts (preferable).
In [0]:
data_path = '/content/gdrive/My Drive/amld_data'
# Alternatively, you can also store the data in a local directory. This method
# will also work when running the notebook in Jupyter instead of Colab.
# data_path = './amld_data
In [2]:
if data_path.startswith('/content/gdrive/'):
from google.colab import drive
assert data_path.startswith('/content/gdrive/My Drive/'), 'Google Drive paths must start with "/content/gdrive/My Drive/"!'
drive.mount('/content/gdrive')
if data_path.startswith('gs://'):
from google.colab import auth
auth.authenticate_user()
In [3]:
# In Jupyter, you would need to install TF 2 via !pip.
%tensorflow_version 2.x
In [4]:
# Always make sure you are using running the expected version.
# There are considerable differences between versions.
# This Colab was tested with 2.1.0.
import tensorflow as tf
tf.__version__
Out[4]:
In [0]:
import base64, collections, io, itertools, functools, json, os, random, re, textwrap, time, urllib, xml
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from PIL import Image, ImageDraw
from IPython import display
In [6]:
# Retrieve list of categories.
def list_bucket(bucket, regexp='.*'):
"""Returns a filtered list of Keys in specified GCS bucket."""
keys = []
fh = urllib.request.urlopen('https://storage.googleapis.com/%s' % bucket)
content = xml.dom.minidom.parseString(fh.read())
for e in content.getElementsByTagName('Contents'):
key = e.getElementsByTagName('Key')[0].firstChild.data
if re.match(regexp, key):
keys.append(key)
return keys
all_ndjsons = list_bucket('quickdraw_dataset', '.*ndjson$')
print('available: (%d)' % len(all_ndjsons))
print('\n'.join(textwrap.wrap(
'|'.join([key.split('/')[-1].split('.')[0] for key in all_ndjsons]),
width=100)))
In [0]:
# Mini group of two animals.
pets = ['cat', 'dog']
# Somewhat larger group of zoo animals.
zoo = ['camel', 'crocodile', 'dolphin', 'elephant', 'flamingo', 'giraffe',
'kangaroo', 'lion', 'monkey', 'penguin', 'rhinoceros']
# Even larger group of all animals.
animals = ['ant', 'bat', 'bear', 'bee', 'bird', 'butterfly', 'camel', 'cat',
'cow', 'crab', 'crocodile', 'dog', 'dolphin', 'dragon', 'duck',
'elephant', 'fish', 'flamingo', 'frog', 'giraffe', 'hedgehog',
'horse', 'kangaroo', 'lion', 'lobster', 'monkey', 'mosquito',
'mouse', 'octopus', 'owl', 'panda', 'parrot', 'penguin', 'pig',
'rabbit', 'raccoon', 'rhinoceros', 'scorpion', 'sea turtle', 'shark',
'sheep', 'snail', 'snake', 'spider', 'squirrel', 'swan']
# You could do something like:
# my_objects = ['shoe', 'shorts', 't-shirt']
Create your own group -- the more categories you include the more challenging the classification task will be...
In [0]:
# YOUR ACTION REQUIRED:
# Choose one of above groups for remainder of workshop.
# Note: This will result in ~100MB of download per class.
# `dataset_name` will be used to construct directories containing the data.
labels, dataset_name = zoo, 'zoo'
# Or use another dataset defined above:
# labels, dataset_name = pets, 'pets'
# labels, dataset_name = animals, 'animals'
In [9]:
# Download above chosen group.
def valid_ndjson(filename):
"""Checks presence + completeness of .ndjson file."""
try:
json.loads(tf.io.gfile.GFile(filename).readlines()[-1])
return True
except (ValueError, IOError):
return False
def retrieve(bucket, key, filename):
"""Returns a file specified by its Key from a GCS bucket."""
url = 'https://storage.googleapis.com/%s/%s' % (
bucket, urllib.parse.quote(key))
print('\n' + url)
if not tf.io.gfile.exists(filename):
with tf.io.gfile.GFile(filename, 'w') as f:
f.write(urllib.request.urlopen(url).read())
while not valid_ndjson(filename):
print('*** Corrupted download (%.2f MB), retrying...' % (
os.path.getsize(filename) / 2.**20))
with tf.io.gfile.GFile(filename, 'w') as f:
f.write(urllib.request.urlopen(url).read())
tf.io.gfile.makedirs(data_path)
print('\n%d labels:' % len(labels))
for name in labels:
print(name, end=' ')
dst = '%s/%s.ndjson' % (data_path, name)
retrieve('quickdraw_dataset', 'full/simplified/%s.ndjson' % name, dst)
print('%.2f MB' % (tf.io.gfile.stat(dst).length / 2.**20))
print('\nDONE :)')
In [10]:
print('\n'.join([
'%6.1fM : %s' % (tf.io.gfile.stat(path).length/1024**2, path)
for path in tf.io.gfile.glob('{}/*.ndjson'.format(data_path))
]))
Let's further explore what the NDJSON file format is.
In [11]:
path = sorted(tf.io.gfile.glob(os.path.join(data_path, '*.ndjson')))[0]
print(path)
print(tf.io.gfile.GFile(path).read()[:1000] + '...')
As we can see, it's a format that contains one JSON dictionary per line.
Let's parse one single line.
In [12]:
data_json = json.loads(tf.io.gfile.GFile(path).readline())
data_json.keys()
Out[12]:
In [13]:
# So we have some meta information.
for k, v in data_json.items():
if k != 'drawing':
print('%20s -> %s' % (k, v))
In [14]:
# Extract the actual drawing.
drawing = data_json['drawing']
# The drawing consists of a series of strokes:
print('Shapes:', [np.array(stroke).shape for stroke in drawing])
print('Example stroke:', drawing[0])
In [15]:
# Draw the image -- the strokes all have have shape (2, n)
# so the first index seems to be x/y coordinate:
for stroke in drawing:
# Each array has X coordinates at [0, :] and Y coordinates at [1, :].
plt.plot(np.array(stroke[0]), -np.array(stroke[1]))
# Would YOU recognize this drawing successfully?
In [16]:
# Some more code to load many sketches at once.
# Let's ignore the difficult `unrecognized` sketches for now
# (i.e. unrecognized by the official quickdraw classifier).
def convert(line):
"""Converts single JSON line and converts 'drawing' to list of np.array."""
d = json.loads(line)
d['drawing'] = [np.array(stroke) for stroke in d['drawing']]
return d
def loaditer(name, unrecognized=False):
"""Returns iterable of drawings in specified file.
Args:
name: Name of the downloaded object (e.g. "elephant").
unrecognized: Whether to include drawings that were not recognized
by Google AI (i.e. the hard ones).
"""
for line in tf.io.gfile.GFile('%s/%s.ndjson' % (data_path, name)):
d = convert(line)
if d['recognized'] or unrecognized:
yield d
def loadn(name, n, unrecognized=False):
"""Returns list of drawings.
Args:
name: Name of the downloaded object (e.g. "elephant").
n: Number of drawings to load.
unrecognized: Whether to include drawings that were not recognized
by Google AI (i.e. the hard ones).
"""
it = loaditer(name, unrecognized=unrecognized)
return list(itertools.islice(it, 0, n))
n = 100
print('Loading {} instances of "{}"...'.format(n, labels[0]), end='')
sample = loadn(labels[0], 100)
print('done.')
In [17]:
# Some more drawings.
rows, cols = 3, 3
plt.figure(figsize=(3*cols, 3*rows))
for y in range(rows):
for x in range(cols):
i = y * cols + x
plt.subplot(rows, cols, i + 1)
for stroke in sample[i]['drawing']:
plt.plot(np.array(stroke[0]), -np.array(stroke[1]))
Idea: After converting the raw drawing data into rasterized images, we can use MNIST-like image processing to classify the drawings.
In [0]:
def dict_to_img(drawing, img_sz=64, lw=3, maximize=True):
"""Converts QuickDraw data to quadratic rasterized image.
Args:
drawing: Dictionary instance of QuickDraw dataset.
img_sz: Size output image (in pixels).
lw: Line width (in pixels).
maximize: Whether to maximize drawing within image pixels.
Returns:
A PIL.Image with the rasterized drawing.
"""
img = Image.new('L', (img_sz, img_sz))
draw = ImageDraw.Draw(img)
lines = np.array([
stroke[0:2, i:i+2]
for stroke in drawing['drawing']
for i in range(stroke.shape[1] - 1)
], dtype=np.float32)
if maximize:
for i in range(2):
min_, max_ = lines[:,i,:].min() * 0.95, lines[:,i,:].max() * 1.05
lines[:,i,:] = (lines[:,i,:] - min_) / max(max_ - min_, 1)
else:
lines /= 1024
for line in lines:
draw.line(tuple(line.T.reshape((-1,)) * img_sz), fill='white', width=lw)
return img
In [19]:
# Show some examples.
def showimg(img):
"""Shows an image with an inline HTML <img> tag.
Args:
img: Can be a PIL.Image or a numpy.ndarray.
"""
if isinstance(img, np.ndarray):
img = Image.fromarray(img, 'L')
b = io.BytesIO()
img.convert('RGB').save(b, format='png')
enc = base64.b64encode(b.getvalue()).decode('utf-8')
display.display(display.HTML(
'<img src="data:image/png;base64,%s">' % enc))
# Fetch some images + shuffle order.
rows, cols = len(labels), 10
n_per_class = rows * cols // len(labels) + 1
drawings_list = [drawing for name in labels
for drawing in loadn(name, cols)]
# Create mosaic of rendered images.
lw = 4
img_sz = 64
tableau = np.zeros((img_sz * rows, img_sz * cols), dtype=np.uint8)
for y in range(rows):
for x in range(cols):
i = y * cols + x
img = dict_to_img(drawings_list[i], img_sz=img_sz, lw=lw, maximize=True)
tableau[y*img_sz:(y+1)*img_sz,
x*img_sz:(x+1)*img_sz] = np.asarray(img)
showimg(tableau)
print('{} samples of : {}'.format(cols, ' '.join(labels)))
Tensorflow's "native" format for data storage is the tf.train.Example
protocol buffer.
In this section we briefly explore the API needed to access the data
inside the tf.train.Example protocol buffer. It's not necessary to read
through the
Protocol Buffer Basics: Python - documentation.
In [20]:
# Create a new (empty) instance.
example = tf.train.Example()
# An empty example will not print anything.
print(example)
# An example contains a map from feature name to "Feature".
# Every "Feature" contains a list of elements of the same
# type, which is one of:
# - bytes_list (similar to Python's "str")
# - float_list (float number)
# - int64_list (integer number)
# These values can be accessed as follows (no need to understand
# details):
# Add float value "3.1416" to feature "magic_numbers"
example.features.feature['magic_numbers'].float_list.value.append(3.1416)
# Add some more values to the float list "magic_numbers".
example.features.feature['magic_numbers'].float_list.value.extend([2.7183, 1.4142, 1.6180])
### YOUR ACTION REQUIRED:
# Create a second feature named "adversaries" and add the elements
# b'Alice' and b'Bob'.
example.features.feature['adversaries'].bytes_list.value.extend([b'Alice', b'Bob']) #example.features.feature['adversaries'].
# This will now print a serialized representation of our protocol buffer
# with features "magic_numbers" and "adversaries" set...
print(example)
# .. et voila : that's all you need to know about protocol buffers for this
# workshop.
Now let's create a "dataset" of tf.train.Example
protocol buffers ("protos").
A single example will contain all the information we want to use for training for a drawing (i.e. rasterized image, label, and maybe other information).
In [21]:
# Let's first check how many [recognized=True] examples we have in each class.
for name in labels:
num_all_samples = len(list(tf.io.gfile.GFile('%s/%s.ndjson' % (data_path, name))))
num_recognized_samples = len(list(loaditer(name)))
print(name, num_all_samples, 'recognized', num_recognized_samples)
Sharding
A dataset consists of non-overlapping sets of examples that will be used for training and evaluation of the classifier (the "test" set will be used for the final evaluation). As these files can quickly become very large, we split them into smaller files referred to as shards. For example, we could split a single dataset into a number of shards, like
This way we have smaller individual files, and we can also easily access for example only 20% of all data, or have 5 threads which read through all the data simultaneously.
Generally, with large datasets, a recommendation is to split data into individual shards with a size of ~100 MB each. This workshop might use smaller sharding sizes for simplicity reasons.
In [0]:
#@title `make_sharded_files()` code
#@markdown Helper code to create sharded recordio files.
#@markdown Simply **click "execute"** and continue to the next cell.
#@markdown No need to read through this code to understand the remainder of the Colab.
#@markdown
#@markdown If you want to have a look anyways, you can double-click this cell or click on the three dots
#@markdown and then select "Form" and then "Show Code" (shortcut `<Ctrl-M> <F>`).
# Helper code to create sharded recordio files.
# (No need to read through this.)
# The code in this cell simply takes a list of iterators and then
# randomly distributes the values returned by these iterators into sharded
# datasets (e.g. a train/eval/test split).
def rand_key(counts):
"""Returns a random key from "counts", using values as distribution."""
r = random.randint(0, sum(counts.values()))
for key, count in counts.items():
if r > count or count == 0:
r -= count
else:
counts[key] -= 1
return key
def get_split(i, splits):
"""Returns key from "splits" for iteration "i"."""
i %= sum(splits.values())
for split in sorted(splits):
if i < splits[split]:
return split
i -= splits[split]
def make_counts(labels, total):
"""Generates counts for "labels" totaling "total"."""
counts = {}
for i, name in enumerate(labels):
counts[name] = total // (len(labels) - i)
total -= counts[name]
return counts
def example_to_dict(example):
"""Converts a tf.train.Example to a dictionary."""
example_dict = {}
for name, value in example.features.feature.items():
if value.HasField('bytes_list'):
value = value.bytes_list.value
elif value.HasField('int64_list'):
value = value.int64_list.value
elif value.HasField('float_list'):
value = value.float_list.value
else:
raise 'Unknown *_list type!'
if len(value) == 1:
example_dict[name] = value[0]
else:
example_dict[name] = np.array(value)
return example_dict
def make_sharded_files(make_example, path, labels, iters, counts, splits,
shards=10, overwrite=False, report_dt=10, make_df=False):
"""Create sharded dataset from "iters".
Args:
make_example: Converts object returned by elements of "iters"
to tf.train.Example() proto.
path: Directory that will contain recordio files.
labels: Names of labels, will be written to "labels.txt".
iters: List of iterables returning drawing objects.
counts: Dictionary mapping class to number of examples.
splits: Dictionary mapping filename to multiple examples. For example,
splits=dict(a=2, b=1) will result in two examples being written to "a"
for every example being written to "b".
shards: Number of files to be created per split.
overwrite: Whether a pre-existing directory should be overwritten.
report_dt: Number of seconds between status updates (0=no updates).
make_df: Also write data as pandas.DataFrame - do NOT use this with very
large datasets that don't fit in memory!
Returns:
Total number of examples written to disk per split.
"""
assert len(iters) == len(labels)
# Prepare output.
if not tf.io.gfile.exists(path):
tf.io.gfile.makedirs(path)
paths = {
split: ['%s/%s-%05d-of-%05d' % (path, split, i, shards)
for i in range(shards)]
for split in splits
}
assert overwrite or not tf.io.gfile.exists(paths.values()[0][0])
writers = {
split: [tf.io.TFRecordWriter(ps[i]) for i in range(shards)]
for split, ps in paths.items()
}
t0 = time.time()
examples_per_split = collections.defaultdict(int)
i, n = 0, sum(counts.values())
counts = dict(**counts)
rows = []
# Create examples.
while sum(counts.values()):
name = rand_key(counts)
split = get_split(i, splits)
writer = writers[split][examples_per_split[split] % shards]
label = labels.index(name)
example = make_example(label, next(iters[label]))
writer.write(example.SerializeToString())
if make_df:
example.features.feature['split'].bytes_list.value.append(split.encode('utf8'))
rows.append(example_to_dict(example))
examples_per_split[split] += 1
i += 1
if report_dt > 0 and time.time() - t0 > report_dt:
print('processed %d/%d (%.2f%%)' % (i, n, 100. * i / n))
t0 = time.time()
# Store results.
for split in splits:
for writer in writers[split]:
writer.close()
with tf.io.gfile.GFile('%s/labels.txt' % path, 'w') as f:
f.write('\n'.join(labels))
with tf.io.gfile.GFile('%s/counts.json' % path, 'w') as f:
json.dump(examples_per_split, f)
if make_df:
df_path = '%s/dataframe.pkl' % path
print('Writing %s...' % df_path)
pd.DataFrame(rows).to_pickle(df_path)
return dict(**examples_per_split)
In [0]:
# Uses `dict_to_img()` from previous cell to create raster image.
def make_example_img(label, drawing):
"""Converts QuickDraw dictionary to example with rasterized data.
Args:
label: Numerical representation of the label (e.g. '0' for labels[0]).
drawing: Dictionary with QuickDraw data.
Returns:
A tf.train.Example protocol buffer (with 'label', 'img_64', and additional
metadata features).
"""
example = tf.train.Example()
example.features.feature['label'].int64_list.value.append(label)
img_64 = np.asarray(dict_to_img(
drawing, img_sz=64, lw=4, maximize=True)).reshape(-1)
example.features.feature['img_64'].int64_list.value.extend(img_64)
example.features.feature['countrycode'].bytes_list.value.append(
drawing['countrycode'].encode())
example.features.feature['recognized'].int64_list.value.append(
drawing['recognized'])
example.features.feature['word'].bytes_list.value.append(
drawing['word'].encode())
ts = drawing['timestamp']
ts = time.mktime(time.strptime(ts[:ts.index('.')], '%Y-%m-%d %H:%M:%S'))
example.features.feature['timestamp'].int64_list.value.append(int(ts))
example.features.feature['key_id'].int64_list.value.append(
int(drawing['key_id']))
return example
We will now create a dataset with 80k samples consisting of:
The generation below will take about ~5 minutes.
Note: Larger datasets take longer to generate and to train on, but also lead to better classification results.
In [24]:
# Create the (rasterized) dataset.
path = '%s/%s_img' % (data_path, dataset_name)
t0 = time.time()
examples_per_split = make_sharded_files(
make_example=make_example_img,
path=path,
labels=labels,
iters=[loaditer(name) for name in labels],
# Creating 50k train, 20k eval and 10k test examples.
counts=make_counts(labels, 80000),
splits=dict(train=5, eval=2, test=1),
overwrite=True,
# Note: Set this to False when generating large datasets.
make_df=True,
)
# If you don't see the final output below, it's probably because your VM
# has run out of memory and crashed!
# This can happen when make_df=True.
print('stored data to "%s"' % path)
print('generated %s examples in %d seconds' % (
examples_per_split, time.time() - t0))
In [0]:
# Convert stroke coordinates into normalized relative coordinates,
# one single list, and add a "third dimension" that indicates when
# a new stroke starts.
def dict_to_stroke(d):
norm = lambda x: (x - x.min()) / max(1, (x.max() - x.min()))
xy = np.concatenate([np.array(s, dtype=np.float32) for
s in d['drawing']], axis=1)
z = np.zeros(xy.shape[1])
if len(d['drawing']) > 1:
z[np.cumsum(np.array(list(map(lambda x: x.shape[1],
d['drawing'][:-1]))))] = 1
dxy = np.diff(norm(xy))
return np.concatenate([dxy, z.reshape((1, -1))[:, 1:]])
In [26]:
# Visualize and control output of `dict_to_stroke()`.
stroke = dict_to_stroke(sample[0])
# The first 2 dimensions are normalized dx/dy coordinates, and
# the third dimension indicates a new stroke.
xy = stroke[:2, :].cumsum(axis=1)
plt.plot(xy[0,:], -xy[1,:])
pxy = xy[:, stroke[2] != 0]
# Indicate the new stroke with a red circle.
plt.plot(pxy[0], -pxy[1], 'ro');
In [0]:
# Uses `dict_to_stroke()` from previous cell to create raster image.
def make_example_stroke(label, drawing):
"""Converts QuickDraw dictionary to example with stroke data.
Args:
label: Numerical representation of the label (e.g. '0' for labels[0]).
drawing: Dictionary with QuickDraw data.
Returns:
A tf.train.Example protocol buffer (with 'label', 'stroke_x', 'stroke_y',
'stroke_z', and additional metadata features).
"""
example = tf.train.Example()
example.features.feature['label'].int64_list.value.append(label)
stroke = dict_to_stroke(drawing)
example.features.feature['stroke_x'].float_list.value.extend(stroke[0, :])
example.features.feature['stroke_y'].float_list.value.extend(stroke[1, :])
example.features.feature['stroke_z'].float_list.value.extend(stroke[2, :])
example.features.feature['stroke_len'].int64_list.value.append(
stroke.shape[1])
example.features.feature['countrycode'].bytes_list.value.append(
drawing['countrycode'].encode())
example.features.feature['recognized'].int64_list.value.append(
drawing['recognized'])
example.features.feature['word'].bytes_list.value.append(
drawing['word'].encode())
ts = drawing['timestamp']
ts = time.mktime(time.strptime(ts[:ts.index('.')], '%Y-%m-%d %H:%M:%S'))
example.features.feature['timestamp'].int64_list.value.append(int(ts))
example.features.feature['key_id'].int64_list.value.append(
int(drawing['key_id']))
return example
In [28]:
path = '%s/%s_stroke' % (data_path, dataset_name)
t0 = time.time()
examples_per_split = make_sharded_files(
make_example=make_example_stroke,
path=path,
labels=labels,
iters=[loaditer(name) for name in labels],
# Creating 50k train, 20k eval, 10k test examples. Takes ~2min
counts=make_counts(labels, 80000),
splits=dict(train=5, eval=2, test=1),
overwrite=True,
# Note: Set this to False when generating large datasets...
make_df=True,
)
print('stored data to "%s"' % path)
print('generated %s examples in %d seconds' % (examples_per_split, time.time() - t0))
In [29]:
# YOUR ACTION REQUIRED:
# Check out the files generated in $data_path
# Note that you can also inspect the files in http://drive.google.com if you
# used Drive as the destination.
#--snip
!ls -lh "$data_path"/"$dataset_name"*
In [30]:
# Let's look at a single file of the sharded dataset.
tf_record_path = '{}/{}_img/eval-00000-of-00010'.format(data_path, dataset_name)
# YOUR ACTION REQUIRED:
# Use `tf.data.TFRecordDataset()` to read a single record from the file and
# assign it to the variable `record`. What data type has this record?
# Hint: dataset is a Python "iterable".
#dataset = ...
#record
#--snip
for record in tf.data.TFRecordDataset(tf_record_path):
break
# The record is a string Tensor that encodes the serialized protocol buffer.
record
Out[30]:
In [31]:
# Check out the features. They should correspond to what we generated in
# `make_example_img()` above.
example = tf.train.Example()
# Note: `.numpy()` returns the underlying string from the Tensor.
example.ParseFromString(record.numpy())
print(list(example.features.feature.keys()))
In [0]:
# YOUR ACTION REQUIRED:
# Extract the label and the image data from the example protobuf.
# Use above section "tf.train.Example" for reference.
label_int = example.features.feature['label'].int64_list.value[0] #label_int =
img_64 = example.features.feature['img_64'].int64_list.value #img_64 =
In [33]:
# Visualize the image:
print(labels[label_int])
plt.matshow(np.array(img_64).reshape((64, 64)))
Out[33]:
In [34]:
# YOUR ACTION REQUIRED:
# Check that we have an equal distribution of labels in the training files.
#--snip
sample = []
ds = tf.data.TFRecordDataset(tf_record_path)
for i, record in enumerate(itertools.islice(ds, 10000)):
example = tf.train.Example()
example.ParseFromString(record.numpy())
sample.append(example.features.feature['label'].int64_list.value[0])
if i == 10000: break
plt.hist(sample, bins=len(labels))
Out[34]:
In [35]:
# If we want to create our own protocol buffers, we first need to install
# some programs.
!apt-get -y install protobuf-compiler python-pil python-lxml
In [0]:
# Step 1: Write a proto file that describes our data format.
# YOUR ACTION REQUIRED: Complete the definition of the "Person" message (you
# can use the slide for inspiration).
with open('person.proto', 'w') as f:
f.write('''syntax = "proto3";''')
#--snip
f.write('''
message Person {
string name = 1;
string email = 2;
repeated int32 lucky_numbers = 3;
}''')
In [37]:
# Step 2: Compile proto definition to a Python file.
!protoc --python_out=. person.proto
!ls -lh
In [0]:
# Step 3: Import code from generated Python file.
from person_pb2 import Person
# Note: If you change the person_pb2 module, you'll need to restart the kernel
# to see the changes because Python will still remember the previous import.
In [39]:
person = Person()
person.name = 'John Doe'
person.email = 'john.doe@gmail.com'
person.lucky_numbers.extend([13, 99])
person.SerializeToString()
Out[39]:
In [40]:
# YOUR ACTION REQUIRED:
# Compare the size of the serialized person structure in proto format
# vs. JSON encoded (you can use Python's json.dumps() and list members
# manually, or import google.protobuf.json_format).
# Which format is more efficient? Why?
# Which format is easier to use?
# Which format is more versatile?
#--snip
import json
print(len(person.SerializeToString()))
from google.protobuf.json_format import MessageToJson
len(MessageToJson(person))
Out[40]: