In [ ]:
from __future__ import print_function
import os
import tfx_utils
def _make_default_sqlite_uri(pipeline_name):
return os.path.join(os.environ['HOME'], 'airflow/tfx/metadata', pipeline_name, 'metadata.db')
def get_metadata_store(pipeline_name):
return tfx_utils.TFXReadonlyMetadataStore.from_sqlite_db(_make_default_sqlite_uri(pipeline_name))
pipeline_name = 'taxi' # or taxi_solution
pipeline_db_path = _make_default_sqlite_uri(pipeline_name)
print('Pipeline DB:\n{}'.format(pipeline_db_path))
store = get_metadata_store(pipeline_name)
Now print out the model artifacts:
In [ ]:
store.get_artifacts_of_type_df(tfx_utils.TFXArtifactTypes.MODEL)
Now analyze the model performance:
In [ ]:
store.display_tfma_analysis(<insert artifact ID here>, slicing_column='trip_start_hour')
Now plot the artifact lineage:
In [ ]:
# Try different IDs here. Click stop in the plot when changing IDs.
%matplotlib notebook
store.plot_artifact_lineage(<insert artifact ID here>)