Content-based recommender using Deep Structured Semantic Model

An example of how to build a Deep Structured Semantic Model (DSSM) for incorporating complex content-based features into a recommender system. See Learning Deep Structured Semantic Models for Web Search using Clickthrough Data. This example does not attempt to provide a datasource or train a model, but merely show how to structure a complex DSSM network.


In [1]:
import warnings

import mxnet as mx
from mxnet import gluon, nd, autograd, sym
import numpy as np
from sklearn.random_projection import johnson_lindenstrauss_min_dim

In [2]:
# Define some constants
max_user = int(1e5)
title_vocab_size = int(3e4)
query_vocab_size = int(3e4)
num_samples = int(1e4)
hidden_units = 128
epsilon_proj = 0.25

ctx = mx.gpu() if len(mx.test_utils.list_gpus()) > 0 else mx.cpu()

Bag of words random projection

A previous version of this example contained a bag of word random projection example, it is kept here for reference but not used in the next example. Random Projection is a dimension reduction technique that guarantees the disruption of the pair-wise distance between your original data point within a certain bound. What is even more interesting is that the dimension to project onto to guarantee that bound does not depend on the original number of dimension but solely on the total number of datapoints. You can see more explanation in this blog post


In [3]:
proj_dim = johnson_lindenstrauss_min_dim(num_samples, epsilon_proj)
print("To keep a distance disruption ~< {}% of our {} samples we need to randomly project to at least {} dimensions".format(epsilon_proj*100, num_samples, proj_dim))


To keep a distance disruption ~< 25.0% of our 10000 samples we need to randomly project to at least 1414 dimensions

In [4]:
class BagOfWordsRandomProjection(gluon.HybridBlock):
    def __init__(self, vocab_size, output_dim, random_seed=54321, pad_index=0):
        """
        :param int vocab_size: number of element in the vocabulary
        :param int output_dim: projection dimension
        :param int ramdon_seed: seed to use to guarantee same projection
        :param int pad_index: index of the vocabulary used for padding sentences
        """
        super(BagOfWordsRandomProjection, self).__init__()
        self._vocab_size = vocab_size
        self._output_dim = output_dim
        proj = self._random_unit_vecs(vocab_size=vocab_size, output_dim=output_dim, random_seed=random_seed)
        # we set the projection of the padding word to 0
        proj[pad_index, :] = 0
        self.proj = self.params.get_constant('proj', value=proj)

    def _random_unit_vecs(self, vocab_size, output_dim, random_seed):
        rs = np.random.RandomState(seed=random_seed)
        W = rs.normal(size=(vocab_size, output_dim))
        Wlen = np.linalg.norm(W, axis=1)
        W_unit = W / Wlen[:,None]
        return W_unit

    def hybrid_forward(self, F, x, proj):
        """
        :param nd or sym F:
        :param nd.NDArray x: index of tokens
        returns the sum of the projected embeddings of each token
        """
        embedded = F.Embedding(x, proj, input_dim=self._vocab_size, output_dim=self._output_dim)
        return embedded.sum(axis=1)

In [5]:
bowrp = BagOfWordsRandomProjection(1000, 20)
bowrp.initialize()

In [6]:
bowrp(mx.nd.array([[10, 50, 100], [5, 10, 0]]))


Out[6]:
[[ 0.35554492  0.0736109  -0.1220893   0.11155054 -0.20963743  0.21141198
   0.12296599  0.12428369 -0.10999548 -0.16867855 -0.09068598  0.14154953
  -0.24029303  0.11956739  0.02830955 -0.14226514 -0.45963028 -0.5456747
  -0.5663947  -0.10585886]
 [-0.31655627 -0.13582113 -0.13815539  0.42596683  0.25674546  0.5024462
  -0.3122709   0.01826438 -0.0277671  -0.14526835  0.44378105  0.09626544
   0.24572927  0.36588538  0.17922089 -0.21583243 -0.30497772  0.19484927
  -0.20705326 -0.13759173]]
<NDArray 2x20 @cpu(0)>

With padding:


In [7]:
bowrp(mx.nd.array([[10, 50, 100, 0], [5, 10, 0, 0]]))


Out[7]:
[[ 0.35554492  0.0736109  -0.1220893   0.11155054 -0.20963743  0.21141198
   0.12296599  0.12428369 -0.10999548 -0.16867855 -0.09068598  0.14154953
  -0.24029303  0.11956739  0.02830955 -0.14226514 -0.45963028 -0.5456747
  -0.5663947  -0.10585886]
 [-0.31655627 -0.13582113 -0.13815539  0.42596683  0.25674546  0.5024462
  -0.3122709   0.01826438 -0.0277671  -0.14526835  0.44378105  0.09626544
   0.24572927  0.36588538  0.17922089 -0.21583243 -0.30497772  0.19484927
  -0.20705326 -0.13759173]]
<NDArray 2x20 @cpu(0)>

Content-based recommender / ranking system using DSSM

For example in the search result ranking problem: You have users, that have performed text-based searches. They were presented with results, and selected one of them. Results are composed of a title and an image.

Your positive examples will be the clicked items in the search results, and the negative examples are sampled from the non-clicked examples.

The network will jointly learn embeddings for users and query text making up the "Query", title and image making the "Item" and learn how similar they are.

After training, you can index the embeddings for your items and do a knn search with your query embeddings using the cosine similarity to return ranked items


In [8]:
proj_dim = 128

In [9]:
class DSSMRecommenderNetwork(gluon.HybridBlock):
    def __init__(self, query_vocab_size, proj_dim, max_user, title_vocab_size, hidden_units, random_seed=54321, p=0.5):
        super(DSSMRecommenderNetwork, self).__init__()
        with self.name_scope():
            
            # User/Query pipeline
            self.user_embedding = gluon.nn.Embedding(max_user, proj_dim)
            self.user_mlp = gluon.nn.Dense(hidden_units, activation="relu")
            
            # Instead of bag of words, we use learned embeddings + stacked biLSTM average
            self.query_text_embedding = gluon.nn.Embedding(query_vocab_size, proj_dim)
            self.query_lstm = gluon.rnn.LSTM(hidden_units, 2, bidirectional=True)
            self.query_text_mlp = gluon.nn.Dense(hidden_units, activation="relu")            
            
            self.query_dropout = gluon.nn.Dropout(p)
            self.query_mlp = gluon.nn.Dense(hidden_units, activation="relu")

            # Item pipeline
            # Instead of bag of words, we use learned embeddings + stacked biLSTM average
            self.title_embedding = gluon.nn.Embedding(title_vocab_size, proj_dim)
            self.title_lstm = gluon.rnn.LSTM(hidden_units, 2, bidirectional=True)
            self.title_mlp = gluon.nn.Dense(hidden_units, activation="relu")
            
            # You could use vgg here for example
            self.image_embedding = gluon.model_zoo.vision.resnet18_v2(pretrained=False).features 
            self.image_mlp = gluon.nn.Dense(hidden_units, activation="relu")
            
            self.item_dropout = gluon.nn.Dropout(p)
            self.item_mlp = gluon.nn.Dense(hidden_units, activation="relu")
    
    def hybrid_forward(self, F, user, query_text, title, image):
        # Query
        user = self.user_embedding(user)
        user = self.user_mlp(user)

        query_text = self.query_text_embedding(query_text)
        query_text = self.query_lstm(query_text.transpose((1,0,2)))
        # average the states
        query_text = query_text.mean(axis=0)
        query_text = self.query_text_mlp(query_text)
        
        query = F.concat(user, query_text)
        query = self.query_dropout(query)
        query = self.query_mlp(query)
        
        # Item
        title_text = self.title_embedding(title)
        title_text = self.title_lstm(title_text.transpose((1,0,2)))
        # average the states
        title_text = title_text.mean(axis=0)
        title_text = self.title_mlp(title_text)
        
        image = self.image_embedding(image)
        image = self.image_mlp(image)
        
        item = F.concat(title_text, image)
        item = self.item_dropout(item)
        item = self.item_mlp(item)
        
        # Cosine Similarity
        query = query.expand_dims(axis=2)
        item = item.expand_dims(axis=2)
        sim = F.batch_dot(query, item, transpose_a=True) / (query.norm(axis=1) * item.norm(axis=1) + 1e-9).expand_dims(axis=2)
        
        return sim.squeeze(axis=2)

In [10]:
network = DSSMRecommenderNetwork(
    query_vocab_size,
    proj_dim,
    max_user,
    title_vocab_size,
    hidden_units
)


network.initialize(mx.init.Xavier(), ctx)

# Load pre-trained vgg16 weights
with network.name_scope():
    network.image_embedding = gluon.model_zoo.vision.resnet18_v2(pretrained=True, ctx=ctx).features

It is quite hard to visualize the network since it is relatively complex but you can see the two-pronged structure, and the resnet18 branch


In [11]:
mx.viz.plot_network(network(
                        mx.sym.var('user'), mx.sym.var('query_text'), mx.sym.var('title'), mx.sym.var('image')),
                    shape={'user': (1,1), 'query_text': (1,30), 'title': (1,30), 'image': (1,3,224,224)},
                    node_attrs={"fixedsize":"False"})


Out[11]:
plot user user dssmrecommendernetwork0_embedding0_fwd dssmrecommendernetwork0_embedding0_fwd dssmrecommendernetwork0_embedding0_fwd->user 1 dssmrecommendernetwork0_dense0_fwd FullyConnected 128 dssmrecommendernetwork0_dense0_fwd->dssmrecommendernetwork0_embedding0_fwd 1x128 dssmrecommendernetwork0_dense0_relu_fwd Activation relu dssmrecommendernetwork0_dense0_relu_fwd->dssmrecommendernetwork0_dense0_fwd 128 query_text query_text dssmrecommendernetwork0_embedding1_fwd dssmrecommendernetwork0_embedding1_fwd dssmrecommendernetwork0_embedding1_fwd->query_text 30 dssmrecommendernetwork0_transpose0 dssmrecommendernetwork0_transpose0 dssmrecommendernetwork0_transpose0->dssmrecommendernetwork0_embedding1_fwd 30x128 dssmrecommendernetwork0_lstm0_reshape0 dssmrecommendernetwork0_lstm0_reshape0 dssmrecommendernetwork0_lstm0_reshape1 dssmrecommendernetwork0_lstm0_reshape1 dssmrecommendernetwork0_lstm0_reshape2 dssmrecommendernetwork0_lstm0_reshape2 dssmrecommendernetwork0_lstm0_reshape3 dssmrecommendernetwork0_lstm0_reshape3 dssmrecommendernetwork0_lstm0_reshape4 dssmrecommendernetwork0_lstm0_reshape4 dssmrecommendernetwork0_lstm0_reshape5 dssmrecommendernetwork0_lstm0_reshape5 dssmrecommendernetwork0_lstm0_reshape6 dssmrecommendernetwork0_lstm0_reshape6 dssmrecommendernetwork0_lstm0_reshape7 dssmrecommendernetwork0_lstm0_reshape7 dssmrecommendernetwork0_lstm0_reshape8 dssmrecommendernetwork0_lstm0_reshape8 dssmrecommendernetwork0_lstm0_reshape9 dssmrecommendernetwork0_lstm0_reshape9 dssmrecommendernetwork0_lstm0_reshape10 dssmrecommendernetwork0_lstm0_reshape10 dssmrecommendernetwork0_lstm0_reshape11 dssmrecommendernetwork0_lstm0_reshape11 dssmrecommendernetwork0_lstm0_reshape12 dssmrecommendernetwork0_lstm0_reshape12 dssmrecommendernetwork0_lstm0_reshape13 dssmrecommendernetwork0_lstm0_reshape13 dssmrecommendernetwork0_lstm0_reshape14 dssmrecommendernetwork0_lstm0_reshape14 dssmrecommendernetwork0_lstm0_reshape15 dssmrecommendernetwork0_lstm0_reshape15 dssmrecommendernetwork0_lstm0__rnn_param_concat0 dssmrecommendernetwork0_lstm0__rnn_param_concat0 dssmrecommendernetwork0_lstm0__rnn_param_concat0->dssmrecommendernetwork0_lstm0_reshape0 dssmrecommendernetwork0_lstm0__rnn_param_concat0->dssmrecommendernetwork0_lstm0_reshape1 dssmrecommendernetwork0_lstm0__rnn_param_concat0->dssmrecommendernetwork0_lstm0_reshape2 dssmrecommendernetwork0_lstm0__rnn_param_concat0->dssmrecommendernetwork0_lstm0_reshape3 dssmrecommendernetwork0_lstm0__rnn_param_concat0->dssmrecommendernetwork0_lstm0_reshape4 dssmrecommendernetwork0_lstm0__rnn_param_concat0->dssmrecommendernetwork0_lstm0_reshape5 dssmrecommendernetwork0_lstm0__rnn_param_concat0->dssmrecommendernetwork0_lstm0_reshape6 dssmrecommendernetwork0_lstm0__rnn_param_concat0->dssmrecommendernetwork0_lstm0_reshape7 dssmrecommendernetwork0_lstm0__rnn_param_concat0->dssmrecommendernetwork0_lstm0_reshape8 dssmrecommendernetwork0_lstm0__rnn_param_concat0->dssmrecommendernetwork0_lstm0_reshape9 dssmrecommendernetwork0_lstm0__rnn_param_concat0->dssmrecommendernetwork0_lstm0_reshape10 dssmrecommendernetwork0_lstm0__rnn_param_concat0->dssmrecommendernetwork0_lstm0_reshape11 dssmrecommendernetwork0_lstm0__rnn_param_concat0->dssmrecommendernetwork0_lstm0_reshape12 dssmrecommendernetwork0_lstm0__rnn_param_concat0->dssmrecommendernetwork0_lstm0_reshape13 dssmrecommendernetwork0_lstm0__rnn_param_concat0->dssmrecommendernetwork0_lstm0_reshape14 dssmrecommendernetwork0_lstm0__rnn_param_concat0->dssmrecommendernetwork0_lstm0_reshape15 dssmrecommendernetwork0_lstm0_dssmrecommendernetwork0_lstm0_h0_0 dssmrecommendernetwork0_lstm0_dssmrecommendernetwork0_lstm0_h0_0 dssmrecommendernetwork0_lstm0_dssmrecommendernetwork0_lstm0_h0_1 dssmrecommendernetwork0_lstm0_dssmrecommendernetwork0_lstm0_h0_1 dssmrecommendernetwork0_lstm0_rnn0 dssmrecommendernetwork0_lstm0_rnn0 dssmrecommendernetwork0_lstm0_rnn0->dssmrecommendernetwork0_transpose0 1x128 dssmrecommendernetwork0_lstm0_rnn0->dssmrecommendernetwork0_lstm0__rnn_param_concat0 dssmrecommendernetwork0_lstm0_rnn0->dssmrecommendernetwork0_lstm0_dssmrecommendernetwork0_lstm0_h0_0 1x128 dssmrecommendernetwork0_lstm0_rnn0->dssmrecommendernetwork0_lstm0_dssmrecommendernetwork0_lstm0_h0_1 1x128 dssmrecommendernetwork0_mean0 dssmrecommendernetwork0_mean0 dssmrecommendernetwork0_mean0->dssmrecommendernetwork0_lstm0_rnn0 1x256 dssmrecommendernetwork0_dense1_fwd FullyConnected 128 dssmrecommendernetwork0_dense1_fwd->dssmrecommendernetwork0_mean0 256 dssmrecommendernetwork0_dense1_relu_fwd Activation relu dssmrecommendernetwork0_dense1_relu_fwd->dssmrecommendernetwork0_dense1_fwd 128 dssmrecommendernetwork0_concat0 dssmrecommendernetwork0_concat0 dssmrecommendernetwork0_concat0->dssmrecommendernetwork0_dense0_relu_fwd 128 dssmrecommendernetwork0_concat0->dssmrecommendernetwork0_dense1_relu_fwd 128 dssmrecommendernetwork0_dropout0_fwd dssmrecommendernetwork0_dropout0_fwd dssmrecommendernetwork0_dropout0_fwd->dssmrecommendernetwork0_concat0 256 dssmrecommendernetwork0_dense2_fwd FullyConnected 128 dssmrecommendernetwork0_dense2_fwd->dssmrecommendernetwork0_dropout0_fwd 256 dssmrecommendernetwork0_dense2_relu_fwd Activation relu dssmrecommendernetwork0_dense2_relu_fwd->dssmrecommendernetwork0_dense2_fwd 128 dssmrecommendernetwork0_expand_dims0 dssmrecommendernetwork0_expand_dims0 dssmrecommendernetwork0_expand_dims0->dssmrecommendernetwork0_dense2_relu_fwd 128 title title dssmrecommendernetwork0_embedding2_fwd dssmrecommendernetwork0_embedding2_fwd dssmrecommendernetwork0_embedding2_fwd->title 30 dssmrecommendernetwork0_transpose1 dssmrecommendernetwork0_transpose1 dssmrecommendernetwork0_transpose1->dssmrecommendernetwork0_embedding2_fwd 30x128 dssmrecommendernetwork0_lstm1_reshape0 dssmrecommendernetwork0_lstm1_reshape0 dssmrecommendernetwork0_lstm1_reshape1 dssmrecommendernetwork0_lstm1_reshape1 dssmrecommendernetwork0_lstm1_reshape2 dssmrecommendernetwork0_lstm1_reshape2 dssmrecommendernetwork0_lstm1_reshape3 dssmrecommendernetwork0_lstm1_reshape3 dssmrecommendernetwork0_lstm1_reshape4 dssmrecommendernetwork0_lstm1_reshape4 dssmrecommendernetwork0_lstm1_reshape5 dssmrecommendernetwork0_lstm1_reshape5 dssmrecommendernetwork0_lstm1_reshape6 dssmrecommendernetwork0_lstm1_reshape6 dssmrecommendernetwork0_lstm1_reshape7 dssmrecommendernetwork0_lstm1_reshape7 dssmrecommendernetwork0_lstm1_reshape8 dssmrecommendernetwork0_lstm1_reshape8 dssmrecommendernetwork0_lstm1_reshape9 dssmrecommendernetwork0_lstm1_reshape9 dssmrecommendernetwork0_lstm1_reshape10 dssmrecommendernetwork0_lstm1_reshape10 dssmrecommendernetwork0_lstm1_reshape11 dssmrecommendernetwork0_lstm1_reshape11 dssmrecommendernetwork0_lstm1_reshape12 dssmrecommendernetwork0_lstm1_reshape12 dssmrecommendernetwork0_lstm1_reshape13 dssmrecommendernetwork0_lstm1_reshape13 dssmrecommendernetwork0_lstm1_reshape14 dssmrecommendernetwork0_lstm1_reshape14 dssmrecommendernetwork0_lstm1_reshape15 dssmrecommendernetwork0_lstm1_reshape15 dssmrecommendernetwork0_lstm1__rnn_param_concat0 dssmrecommendernetwork0_lstm1__rnn_param_concat0 dssmrecommendernetwork0_lstm1__rnn_param_concat0->dssmrecommendernetwork0_lstm1_reshape0 dssmrecommendernetwork0_lstm1__rnn_param_concat0->dssmrecommendernetwork0_lstm1_reshape1 dssmrecommendernetwork0_lstm1__rnn_param_concat0->dssmrecommendernetwork0_lstm1_reshape2 dssmrecommendernetwork0_lstm1__rnn_param_concat0->dssmrecommendernetwork0_lstm1_reshape3 dssmrecommendernetwork0_lstm1__rnn_param_concat0->dssmrecommendernetwork0_lstm1_reshape4 dssmrecommendernetwork0_lstm1__rnn_param_concat0->dssmrecommendernetwork0_lstm1_reshape5 dssmrecommendernetwork0_lstm1__rnn_param_concat0->dssmrecommendernetwork0_lstm1_reshape6 dssmrecommendernetwork0_lstm1__rnn_param_concat0->dssmrecommendernetwork0_lstm1_reshape7 dssmrecommendernetwork0_lstm1__rnn_param_concat0->dssmrecommendernetwork0_lstm1_reshape8 dssmrecommendernetwork0_lstm1__rnn_param_concat0->dssmrecommendernetwork0_lstm1_reshape9 dssmrecommendernetwork0_lstm1__rnn_param_concat0->dssmrecommendernetwork0_lstm1_reshape10 dssmrecommendernetwork0_lstm1__rnn_param_concat0->dssmrecommendernetwork0_lstm1_reshape11 dssmrecommendernetwork0_lstm1__rnn_param_concat0->dssmrecommendernetwork0_lstm1_reshape12 dssmrecommendernetwork0_lstm1__rnn_param_concat0->dssmrecommendernetwork0_lstm1_reshape13 dssmrecommendernetwork0_lstm1__rnn_param_concat0->dssmrecommendernetwork0_lstm1_reshape14 dssmrecommendernetwork0_lstm1__rnn_param_concat0->dssmrecommendernetwork0_lstm1_reshape15 dssmrecommendernetwork0_lstm1_dssmrecommendernetwork0_lstm1_h0_0 dssmrecommendernetwork0_lstm1_dssmrecommendernetwork0_lstm1_h0_0 dssmrecommendernetwork0_lstm1_dssmrecommendernetwork0_lstm1_h0_1 dssmrecommendernetwork0_lstm1_dssmrecommendernetwork0_lstm1_h0_1 dssmrecommendernetwork0_lstm1_rnn0 dssmrecommendernetwork0_lstm1_rnn0 dssmrecommendernetwork0_lstm1_rnn0->dssmrecommendernetwork0_transpose1 1x128 dssmrecommendernetwork0_lstm1_rnn0->dssmrecommendernetwork0_lstm1__rnn_param_concat0 dssmrecommendernetwork0_lstm1_rnn0->dssmrecommendernetwork0_lstm1_dssmrecommendernetwork0_lstm1_h0_0 1x128 dssmrecommendernetwork0_lstm1_rnn0->dssmrecommendernetwork0_lstm1_dssmrecommendernetwork0_lstm1_h0_1 1x128 dssmrecommendernetwork0_mean1 dssmrecommendernetwork0_mean1 dssmrecommendernetwork0_mean1->dssmrecommendernetwork0_lstm1_rnn0 1x256 dssmrecommendernetwork0_dense3_fwd FullyConnected 128 dssmrecommendernetwork0_dense3_fwd->dssmrecommendernetwork0_mean1 256 dssmrecommendernetwork0_dense3_relu_fwd Activation relu dssmrecommendernetwork0_dense3_relu_fwd->dssmrecommendernetwork0_dense3_fwd 128 image image dssmrecommendernetwork0_resnetv21_batchnorm0_fwd dssmrecommendernetwork0_resnetv21_batchnorm0_fwd dssmrecommendernetwork0_resnetv21_batchnorm0_fwd->image 3x224x224 dssmrecommendernetwork0_resnetv21_conv0_fwd Convolution 7x7/2x2, 64 dssmrecommendernetwork0_resnetv21_conv0_fwd->dssmrecommendernetwork0_resnetv21_batchnorm0_fwd 3x224x224 dssmrecommendernetwork0_resnetv21_batchnorm1_fwd dssmrecommendernetwork0_resnetv21_batchnorm1_fwd dssmrecommendernetwork0_resnetv21_batchnorm1_fwd->dssmrecommendernetwork0_resnetv21_conv0_fwd 64x112x112 dssmrecommendernetwork0_resnetv21_relu0_fwd Activation relu dssmrecommendernetwork0_resnetv21_relu0_fwd->dssmrecommendernetwork0_resnetv21_batchnorm1_fwd 64x112x112 dssmrecommendernetwork0_resnetv21_pool0_fwd Pooling max, 3x3/2x2 dssmrecommendernetwork0_resnetv21_pool0_fwd->dssmrecommendernetwork0_resnetv21_relu0_fwd 64x112x112 dssmrecommendernetwork0_resnetv21_stage1_batchnorm0_fwd dssmrecommendernetwork0_resnetv21_stage1_batchnorm0_fwd dssmrecommendernetwork0_resnetv21_stage1_batchnorm0_fwd->dssmrecommendernetwork0_resnetv21_pool0_fwd 64x56x56 dssmrecommendernetwork0_resnetv21_stage1_activation0 Activation relu dssmrecommendernetwork0_resnetv21_stage1_activation0->dssmrecommendernetwork0_resnetv21_stage1_batchnorm0_fwd 64x56x56 dssmrecommendernetwork0_resnetv21_stage1_conv0_fwd Convolution 3x3/1x1, 64 dssmrecommendernetwork0_resnetv21_stage1_conv0_fwd->dssmrecommendernetwork0_resnetv21_stage1_activation0 64x56x56 dssmrecommendernetwork0_resnetv21_stage1_batchnorm1_fwd dssmrecommendernetwork0_resnetv21_stage1_batchnorm1_fwd dssmrecommendernetwork0_resnetv21_stage1_batchnorm1_fwd->dssmrecommendernetwork0_resnetv21_stage1_conv0_fwd 64x56x56 dssmrecommendernetwork0_resnetv21_stage1_activation1 Activation relu dssmrecommendernetwork0_resnetv21_stage1_activation1->dssmrecommendernetwork0_resnetv21_stage1_batchnorm1_fwd 64x56x56 dssmrecommendernetwork0_resnetv21_stage1_conv1_fwd Convolution 3x3/1x1, 64 dssmrecommendernetwork0_resnetv21_stage1_conv1_fwd->dssmrecommendernetwork0_resnetv21_stage1_activation1 64x56x56 dssmrecommendernetwork0_resnetv21_stage1__plus0 dssmrecommendernetwork0_resnetv21_stage1__plus0 dssmrecommendernetwork0_resnetv21_stage1__plus0->dssmrecommendernetwork0_resnetv21_pool0_fwd 64x56x56 dssmrecommendernetwork0_resnetv21_stage1__plus0->dssmrecommendernetwork0_resnetv21_stage1_conv1_fwd 64x56x56 dssmrecommendernetwork0_resnetv21_stage1_batchnorm2_fwd dssmrecommendernetwork0_resnetv21_stage1_batchnorm2_fwd dssmrecommendernetwork0_resnetv21_stage1_batchnorm2_fwd->dssmrecommendernetwork0_resnetv21_stage1__plus0 64x56x56 dssmrecommendernetwork0_resnetv21_stage1_activation2 Activation relu dssmrecommendernetwork0_resnetv21_stage1_activation2->dssmrecommendernetwork0_resnetv21_stage1_batchnorm2_fwd 64x56x56 dssmrecommendernetwork0_resnetv21_stage1_conv2_fwd Convolution 3x3/1x1, 64 dssmrecommendernetwork0_resnetv21_stage1_conv2_fwd->dssmrecommendernetwork0_resnetv21_stage1_activation2 64x56x56 dssmrecommendernetwork0_resnetv21_stage1_batchnorm3_fwd dssmrecommendernetwork0_resnetv21_stage1_batchnorm3_fwd dssmrecommendernetwork0_resnetv21_stage1_batchnorm3_fwd->dssmrecommendernetwork0_resnetv21_stage1_conv2_fwd 64x56x56 dssmrecommendernetwork0_resnetv21_stage1_activation3 Activation relu dssmrecommendernetwork0_resnetv21_stage1_activation3->dssmrecommendernetwork0_resnetv21_stage1_batchnorm3_fwd 64x56x56 dssmrecommendernetwork0_resnetv21_stage1_conv3_fwd Convolution 3x3/1x1, 64 dssmrecommendernetwork0_resnetv21_stage1_conv3_fwd->dssmrecommendernetwork0_resnetv21_stage1_activation3 64x56x56 dssmrecommendernetwork0_resnetv21_stage1__plus1 dssmrecommendernetwork0_resnetv21_stage1__plus1 dssmrecommendernetwork0_resnetv21_stage1__plus1->dssmrecommendernetwork0_resnetv21_stage1__plus0 64x56x56 dssmrecommendernetwork0_resnetv21_stage1__plus1->dssmrecommendernetwork0_resnetv21_stage1_conv3_fwd 64x56x56 dssmrecommendernetwork0_resnetv21_stage2_batchnorm0_fwd dssmrecommendernetwork0_resnetv21_stage2_batchnorm0_fwd dssmrecommendernetwork0_resnetv21_stage2_batchnorm0_fwd->dssmrecommendernetwork0_resnetv21_stage1__plus1 64x56x56 dssmrecommendernetwork0_resnetv21_stage2_activation0 Activation relu dssmrecommendernetwork0_resnetv21_stage2_activation0->dssmrecommendernetwork0_resnetv21_stage2_batchnorm0_fwd 64x56x56 dssmrecommendernetwork0_resnetv21_stage2_conv0_fwd Convolution 3x3/2x2, 128 dssmrecommendernetwork0_resnetv21_stage2_conv0_fwd->dssmrecommendernetwork0_resnetv21_stage2_activation0 64x56x56 dssmrecommendernetwork0_resnetv21_stage2_batchnorm1_fwd dssmrecommendernetwork0_resnetv21_stage2_batchnorm1_fwd dssmrecommendernetwork0_resnetv21_stage2_batchnorm1_fwd->dssmrecommendernetwork0_resnetv21_stage2_conv0_fwd 128x28x28 dssmrecommendernetwork0_resnetv21_stage2_activation1 Activation relu dssmrecommendernetwork0_resnetv21_stage2_activation1->dssmrecommendernetwork0_resnetv21_stage2_batchnorm1_fwd 128x28x28 dssmrecommendernetwork0_resnetv21_stage2_conv1_fwd Convolution 3x3/1x1, 128 dssmrecommendernetwork0_resnetv21_stage2_conv1_fwd->dssmrecommendernetwork0_resnetv21_stage2_activation1 128x28x28 dssmrecommendernetwork0_resnetv21_stage2_conv2_fwd Convolution 1x1/2x2, 128 dssmrecommendernetwork0_resnetv21_stage2_conv2_fwd->dssmrecommendernetwork0_resnetv21_stage2_activation0 64x56x56 dssmrecommendernetwork0_resnetv21_stage2__plus0 dssmrecommendernetwork0_resnetv21_stage2__plus0 dssmrecommendernetwork0_resnetv21_stage2__plus0->dssmrecommendernetwork0_resnetv21_stage2_conv1_fwd 128x28x28 dssmrecommendernetwork0_resnetv21_stage2__plus0->dssmrecommendernetwork0_resnetv21_stage2_conv2_fwd 128x28x28 dssmrecommendernetwork0_resnetv21_stage2_batchnorm2_fwd dssmrecommendernetwork0_resnetv21_stage2_batchnorm2_fwd dssmrecommendernetwork0_resnetv21_stage2_batchnorm2_fwd->dssmrecommendernetwork0_resnetv21_stage2__plus0 128x28x28 dssmrecommendernetwork0_resnetv21_stage2_activation2 Activation relu dssmrecommendernetwork0_resnetv21_stage2_activation2->dssmrecommendernetwork0_resnetv21_stage2_batchnorm2_fwd 128x28x28 dssmrecommendernetwork0_resnetv21_stage2_conv3_fwd Convolution 3x3/1x1, 128 dssmrecommendernetwork0_resnetv21_stage2_conv3_fwd->dssmrecommendernetwork0_resnetv21_stage2_activation2 128x28x28 dssmrecommendernetwork0_resnetv21_stage2_batchnorm3_fwd dssmrecommendernetwork0_resnetv21_stage2_batchnorm3_fwd dssmrecommendernetwork0_resnetv21_stage2_batchnorm3_fwd->dssmrecommendernetwork0_resnetv21_stage2_conv3_fwd 128x28x28 dssmrecommendernetwork0_resnetv21_stage2_activation3 Activation relu dssmrecommendernetwork0_resnetv21_stage2_activation3->dssmrecommendernetwork0_resnetv21_stage2_batchnorm3_fwd 128x28x28 dssmrecommendernetwork0_resnetv21_stage2_conv4_fwd Convolution 3x3/1x1, 128 dssmrecommendernetwork0_resnetv21_stage2_conv4_fwd->dssmrecommendernetwork0_resnetv21_stage2_activation3 128x28x28 dssmrecommendernetwork0_resnetv21_stage2__plus1 dssmrecommendernetwork0_resnetv21_stage2__plus1 dssmrecommendernetwork0_resnetv21_stage2__plus1->dssmrecommendernetwork0_resnetv21_stage2__plus0 128x28x28 dssmrecommendernetwork0_resnetv21_stage2__plus1->dssmrecommendernetwork0_resnetv21_stage2_conv4_fwd 128x28x28 dssmrecommendernetwork0_resnetv21_stage3_batchnorm0_fwd dssmrecommendernetwork0_resnetv21_stage3_batchnorm0_fwd dssmrecommendernetwork0_resnetv21_stage3_batchnorm0_fwd->dssmrecommendernetwork0_resnetv21_stage2__plus1 128x28x28 dssmrecommendernetwork0_resnetv21_stage3_activation0 Activation relu dssmrecommendernetwork0_resnetv21_stage3_activation0->dssmrecommendernetwork0_resnetv21_stage3_batchnorm0_fwd 128x28x28 dssmrecommendernetwork0_resnetv21_stage3_conv0_fwd Convolution 3x3/2x2, 256 dssmrecommendernetwork0_resnetv21_stage3_conv0_fwd->dssmrecommendernetwork0_resnetv21_stage3_activation0 128x28x28 dssmrecommendernetwork0_resnetv21_stage3_batchnorm1_fwd dssmrecommendernetwork0_resnetv21_stage3_batchnorm1_fwd dssmrecommendernetwork0_resnetv21_stage3_batchnorm1_fwd->dssmrecommendernetwork0_resnetv21_stage3_conv0_fwd 256x14x14 dssmrecommendernetwork0_resnetv21_stage3_activation1 Activation relu dssmrecommendernetwork0_resnetv21_stage3_activation1->dssmrecommendernetwork0_resnetv21_stage3_batchnorm1_fwd 256x14x14 dssmrecommendernetwork0_resnetv21_stage3_conv1_fwd Convolution 3x3/1x1, 256 dssmrecommendernetwork0_resnetv21_stage3_conv1_fwd->dssmrecommendernetwork0_resnetv21_stage3_activation1 256x14x14 dssmrecommendernetwork0_resnetv21_stage3_conv2_fwd Convolution 1x1/2x2, 256 dssmrecommendernetwork0_resnetv21_stage3_conv2_fwd->dssmrecommendernetwork0_resnetv21_stage3_activation0 128x28x28 dssmrecommendernetwork0_resnetv21_stage3__plus0 dssmrecommendernetwork0_resnetv21_stage3__plus0 dssmrecommendernetwork0_resnetv21_stage3__plus0->dssmrecommendernetwork0_resnetv21_stage3_conv1_fwd 256x14x14 dssmrecommendernetwork0_resnetv21_stage3__plus0->dssmrecommendernetwork0_resnetv21_stage3_conv2_fwd 256x14x14 dssmrecommendernetwork0_resnetv21_stage3_batchnorm2_fwd dssmrecommendernetwork0_resnetv21_stage3_batchnorm2_fwd dssmrecommendernetwork0_resnetv21_stage3_batchnorm2_fwd->dssmrecommendernetwork0_resnetv21_stage3__plus0 256x14x14 dssmrecommendernetwork0_resnetv21_stage3_activation2 Activation relu dssmrecommendernetwork0_resnetv21_stage3_activation2->dssmrecommendernetwork0_resnetv21_stage3_batchnorm2_fwd 256x14x14 dssmrecommendernetwork0_resnetv21_stage3_conv3_fwd Convolution 3x3/1x1, 256 dssmrecommendernetwork0_resnetv21_stage3_conv3_fwd->dssmrecommendernetwork0_resnetv21_stage3_activation2 256x14x14 dssmrecommendernetwork0_resnetv21_stage3_batchnorm3_fwd dssmrecommendernetwork0_resnetv21_stage3_batchnorm3_fwd dssmrecommendernetwork0_resnetv21_stage3_batchnorm3_fwd->dssmrecommendernetwork0_resnetv21_stage3_conv3_fwd 256x14x14 dssmrecommendernetwork0_resnetv21_stage3_activation3 Activation relu dssmrecommendernetwork0_resnetv21_stage3_activation3->dssmrecommendernetwork0_resnetv21_stage3_batchnorm3_fwd 256x14x14 dssmrecommendernetwork0_resnetv21_stage3_conv4_fwd Convolution 3x3/1x1, 256 dssmrecommendernetwork0_resnetv21_stage3_conv4_fwd->dssmrecommendernetwork0_resnetv21_stage3_activation3 256x14x14 dssmrecommendernetwork0_resnetv21_stage3__plus1 dssmrecommendernetwork0_resnetv21_stage3__plus1 dssmrecommendernetwork0_resnetv21_stage3__plus1->dssmrecommendernetwork0_resnetv21_stage3__plus0 256x14x14 dssmrecommendernetwork0_resnetv21_stage3__plus1->dssmrecommendernetwork0_resnetv21_stage3_conv4_fwd 256x14x14 dssmrecommendernetwork0_resnetv21_stage4_batchnorm0_fwd dssmrecommendernetwork0_resnetv21_stage4_batchnorm0_fwd dssmrecommendernetwork0_resnetv21_stage4_batchnorm0_fwd->dssmrecommendernetwork0_resnetv21_stage3__plus1 256x14x14 dssmrecommendernetwork0_resnetv21_stage4_activation0 Activation relu dssmrecommendernetwork0_resnetv21_stage4_activation0->dssmrecommendernetwork0_resnetv21_stage4_batchnorm0_fwd 256x14x14 dssmrecommendernetwork0_resnetv21_stage4_conv0_fwd Convolution 3x3/2x2, 512 dssmrecommendernetwork0_resnetv21_stage4_conv0_fwd->dssmrecommendernetwork0_resnetv21_stage4_activation0 256x14x14 dssmrecommendernetwork0_resnetv21_stage4_batchnorm1_fwd dssmrecommendernetwork0_resnetv21_stage4_batchnorm1_fwd dssmrecommendernetwork0_resnetv21_stage4_batchnorm1_fwd->dssmrecommendernetwork0_resnetv21_stage4_conv0_fwd 512x7x7 dssmrecommendernetwork0_resnetv21_stage4_activation1 Activation relu dssmrecommendernetwork0_resnetv21_stage4_activation1->dssmrecommendernetwork0_resnetv21_stage4_batchnorm1_fwd 512x7x7 dssmrecommendernetwork0_resnetv21_stage4_conv1_fwd Convolution 3x3/1x1, 512 dssmrecommendernetwork0_resnetv21_stage4_conv1_fwd->dssmrecommendernetwork0_resnetv21_stage4_activation1 512x7x7 dssmrecommendernetwork0_resnetv21_stage4_conv2_fwd Convolution 1x1/2x2, 512 dssmrecommendernetwork0_resnetv21_stage4_conv2_fwd->dssmrecommendernetwork0_resnetv21_stage4_activation0 256x14x14 dssmrecommendernetwork0_resnetv21_stage4__plus0 dssmrecommendernetwork0_resnetv21_stage4__plus0 dssmrecommendernetwork0_resnetv21_stage4__plus0->dssmrecommendernetwork0_resnetv21_stage4_conv1_fwd 512x7x7 dssmrecommendernetwork0_resnetv21_stage4__plus0->dssmrecommendernetwork0_resnetv21_stage4_conv2_fwd 512x7x7 dssmrecommendernetwork0_resnetv21_stage4_batchnorm2_fwd dssmrecommendernetwork0_resnetv21_stage4_batchnorm2_fwd dssmrecommendernetwork0_resnetv21_stage4_batchnorm2_fwd->dssmrecommendernetwork0_resnetv21_stage4__plus0 512x7x7 dssmrecommendernetwork0_resnetv21_stage4_activation2 Activation relu dssmrecommendernetwork0_resnetv21_stage4_activation2->dssmrecommendernetwork0_resnetv21_stage4_batchnorm2_fwd 512x7x7 dssmrecommendernetwork0_resnetv21_stage4_conv3_fwd Convolution 3x3/1x1, 512 dssmrecommendernetwork0_resnetv21_stage4_conv3_fwd->dssmrecommendernetwork0_resnetv21_stage4_activation2 512x7x7 dssmrecommendernetwork0_resnetv21_stage4_batchnorm3_fwd dssmrecommendernetwork0_resnetv21_stage4_batchnorm3_fwd dssmrecommendernetwork0_resnetv21_stage4_batchnorm3_fwd->dssmrecommendernetwork0_resnetv21_stage4_conv3_fwd 512x7x7 dssmrecommendernetwork0_resnetv21_stage4_activation3 Activation relu dssmrecommendernetwork0_resnetv21_stage4_activation3->dssmrecommendernetwork0_resnetv21_stage4_batchnorm3_fwd 512x7x7 dssmrecommendernetwork0_resnetv21_stage4_conv4_fwd Convolution 3x3/1x1, 512 dssmrecommendernetwork0_resnetv21_stage4_conv4_fwd->dssmrecommendernetwork0_resnetv21_stage4_activation3 512x7x7 dssmrecommendernetwork0_resnetv21_stage4__plus1 dssmrecommendernetwork0_resnetv21_stage4__plus1 dssmrecommendernetwork0_resnetv21_stage4__plus1->dssmrecommendernetwork0_resnetv21_stage4__plus0 512x7x7 dssmrecommendernetwork0_resnetv21_stage4__plus1->dssmrecommendernetwork0_resnetv21_stage4_conv4_fwd 512x7x7 dssmrecommendernetwork0_resnetv21_batchnorm2_fwd dssmrecommendernetwork0_resnetv21_batchnorm2_fwd dssmrecommendernetwork0_resnetv21_batchnorm2_fwd->dssmrecommendernetwork0_resnetv21_stage4__plus1 512x7x7 dssmrecommendernetwork0_resnetv21_relu1_fwd Activation relu dssmrecommendernetwork0_resnetv21_relu1_fwd->dssmrecommendernetwork0_resnetv21_batchnorm2_fwd 512x7x7 dssmrecommendernetwork0_resnetv21_pool1_fwd Pooling avg, 1x1/1x1 dssmrecommendernetwork0_resnetv21_pool1_fwd->dssmrecommendernetwork0_resnetv21_relu1_fwd 512x7x7 dssmrecommendernetwork0_resnetv21_flatten0_flatten0 dssmrecommendernetwork0_resnetv21_flatten0_flatten0 dssmrecommendernetwork0_resnetv21_flatten0_flatten0->dssmrecommendernetwork0_resnetv21_pool1_fwd 512x1x1 dssmrecommendernetwork0_dense4_fwd FullyConnected 128 dssmrecommendernetwork0_dense4_fwd->dssmrecommendernetwork0_resnetv21_flatten0_flatten0 512 dssmrecommendernetwork0_dense4_relu_fwd Activation relu dssmrecommendernetwork0_dense4_relu_fwd->dssmrecommendernetwork0_dense4_fwd 128 dssmrecommendernetwork0_concat1 dssmrecommendernetwork0_concat1 dssmrecommendernetwork0_concat1->dssmrecommendernetwork0_dense3_relu_fwd 128 dssmrecommendernetwork0_concat1->dssmrecommendernetwork0_dense4_relu_fwd 128 dssmrecommendernetwork0_dropout1_fwd dssmrecommendernetwork0_dropout1_fwd dssmrecommendernetwork0_dropout1_fwd->dssmrecommendernetwork0_concat1 256 dssmrecommendernetwork0_dense5_fwd FullyConnected 128 dssmrecommendernetwork0_dense5_fwd->dssmrecommendernetwork0_dropout1_fwd 256 dssmrecommendernetwork0_dense5_relu_fwd Activation relu dssmrecommendernetwork0_dense5_relu_fwd->dssmrecommendernetwork0_dense5_fwd 128 dssmrecommendernetwork0_expand_dims1 dssmrecommendernetwork0_expand_dims1 dssmrecommendernetwork0_expand_dims1->dssmrecommendernetwork0_dense5_relu_fwd 128 dssmrecommendernetwork0_batch_dot0 dssmrecommendernetwork0_batch_dot0 dssmrecommendernetwork0_batch_dot0->dssmrecommendernetwork0_expand_dims0 128x1 dssmrecommendernetwork0_batch_dot0->dssmrecommendernetwork0_expand_dims1 128x1 dssmrecommendernetwork0_norm0 dssmrecommendernetwork0_norm0 dssmrecommendernetwork0_norm0->dssmrecommendernetwork0_expand_dims0 128x1 dssmrecommendernetwork0_norm1 dssmrecommendernetwork0_norm1 dssmrecommendernetwork0_norm1->dssmrecommendernetwork0_expand_dims1 128x1 dssmrecommendernetwork0__mul0 dssmrecommendernetwork0__mul0 dssmrecommendernetwork0__mul0->dssmrecommendernetwork0_norm0 1 dssmrecommendernetwork0__mul0->dssmrecommendernetwork0_norm1 1 dssmrecommendernetwork0__plusscalar0 dssmrecommendernetwork0__plusscalar0 dssmrecommendernetwork0__plusscalar0->dssmrecommendernetwork0__mul0 1 dssmrecommendernetwork0_expand_dims2 dssmrecommendernetwork0_expand_dims2 dssmrecommendernetwork0_expand_dims2->dssmrecommendernetwork0__plusscalar0 1 dssmrecommendernetwork0__div0 dssmrecommendernetwork0__div0 dssmrecommendernetwork0__div0->dssmrecommendernetwork0_batch_dot0 1x1 dssmrecommendernetwork0__div0->dssmrecommendernetwork0_expand_dims2 1x1 dssmrecommendernetwork0_squeeze0 dssmrecommendernetwork0_squeeze0 dssmrecommendernetwork0_squeeze0->dssmrecommendernetwork0__div0 1x1

We can print the summary of the network using dummy data. We can see it is already training on 32M parameters!


In [12]:
user  = mx.nd.array([[200], [100]], ctx)
query = mx.nd.array([[10, 20, 0, 0, 0], [40, 50, 0, 0, 0]], ctx) # Example of an encoded text
title = mx.nd.array([[10, 20, 0, 0, 0], [40, 50, 0, 0, 0]], ctx) # Example of an encoded text
image = mx.nd.random.uniform(shape=(2,3, 224,224), ctx=ctx) # Example of an encoded image


network.summary(user, query, title, image)


--------------------------------------------------------------------------------
        Layer (type)                                Output Shape         Param #
================================================================================
               Input    (2, 1), (2, 5), (2, 5), (2, 3, 224, 224)               0
         Embedding-1                                 (2, 1, 128)        12800000
        Activation-2  <Symbol dssmrecommendernetwork0_dense0_relu_fwd>               0
        Activation-3                                    (2, 128)               0
             Dense-4                                    (2, 128)           16512
         Embedding-5                                 (2, 5, 128)         3840000
              LSTM-6                                 (5, 2, 256)          659456
        Activation-7  <Symbol dssmrecommendernetwork0_dense1_relu_fwd>               0
        Activation-8                                    (2, 128)               0
             Dense-9                                    (2, 128)           32896
          Dropout-10                                    (2, 256)               0
       Activation-11  <Symbol dssmrecommendernetwork0_dense2_relu_fwd>               0
       Activation-12                                    (2, 128)               0
            Dense-13                                    (2, 128)           32896
        Embedding-14                                 (2, 5, 128)         3840000
             LSTM-15                                 (5, 2, 256)          659456
       Activation-16  <Symbol dssmrecommendernetwork0_dense3_relu_fwd>               0
       Activation-17                                    (2, 128)               0
            Dense-18                                    (2, 128)           32896
        BatchNorm-19                            (2, 3, 224, 224)              12
           Conv2D-20                           (2, 64, 112, 112)            9408
        BatchNorm-21                           (2, 64, 112, 112)             256
       Activation-22                           (2, 64, 112, 112)               0
        MaxPool2D-23                             (2, 64, 56, 56)               0
        BatchNorm-24                             (2, 64, 56, 56)             256
           Conv2D-25                             (2, 64, 56, 56)           36864
        BatchNorm-26                             (2, 64, 56, 56)             256
           Conv2D-27                             (2, 64, 56, 56)           36864
     BasicBlockV2-28                             (2, 64, 56, 56)               0
        BatchNorm-29                             (2, 64, 56, 56)             256
           Conv2D-30                             (2, 64, 56, 56)           36864
        BatchNorm-31                             (2, 64, 56, 56)             256
           Conv2D-32                             (2, 64, 56, 56)           36864
     BasicBlockV2-33                             (2, 64, 56, 56)               0
        BatchNorm-34                             (2, 64, 56, 56)             256
           Conv2D-35                            (2, 128, 28, 28)            8192
           Conv2D-36                            (2, 128, 28, 28)           73728
        BatchNorm-37                            (2, 128, 28, 28)             512
           Conv2D-38                            (2, 128, 28, 28)          147456
     BasicBlockV2-39                            (2, 128, 28, 28)               0
        BatchNorm-40                            (2, 128, 28, 28)             512
           Conv2D-41                            (2, 128, 28, 28)          147456
        BatchNorm-42                            (2, 128, 28, 28)             512
           Conv2D-43                            (2, 128, 28, 28)          147456
     BasicBlockV2-44                            (2, 128, 28, 28)               0
        BatchNorm-45                            (2, 128, 28, 28)             512
           Conv2D-46                            (2, 256, 14, 14)           32768
           Conv2D-47                            (2, 256, 14, 14)          294912
        BatchNorm-48                            (2, 256, 14, 14)            1024
           Conv2D-49                            (2, 256, 14, 14)          589824
     BasicBlockV2-50                            (2, 256, 14, 14)               0
        BatchNorm-51                            (2, 256, 14, 14)            1024
           Conv2D-52                            (2, 256, 14, 14)          589824
        BatchNorm-53                            (2, 256, 14, 14)            1024
           Conv2D-54                            (2, 256, 14, 14)          589824
     BasicBlockV2-55                            (2, 256, 14, 14)               0
        BatchNorm-56                            (2, 256, 14, 14)            1024
           Conv2D-57                              (2, 512, 7, 7)          131072
           Conv2D-58                              (2, 512, 7, 7)         1179648
        BatchNorm-59                              (2, 512, 7, 7)            2048
           Conv2D-60                              (2, 512, 7, 7)         2359296
     BasicBlockV2-61                              (2, 512, 7, 7)               0
        BatchNorm-62                              (2, 512, 7, 7)            2048
           Conv2D-63                              (2, 512, 7, 7)         2359296
        BatchNorm-64                              (2, 512, 7, 7)            2048
           Conv2D-65                              (2, 512, 7, 7)         2359296
     BasicBlockV2-66                              (2, 512, 7, 7)               0
        BatchNorm-67                              (2, 512, 7, 7)            2048
       Activation-68                              (2, 512, 7, 7)               0
  GlobalAvgPool2D-69                              (2, 512, 1, 1)               0
          Flatten-70                                    (2, 512)               0
       Activation-71  <Symbol dssmrecommendernetwork0_dense4_relu_fwd>               0
       Activation-72                                    (2, 128)               0
            Dense-73                                    (2, 128)           65664
          Dropout-74                                    (2, 256)               0
       Activation-75  <Symbol dssmrecommendernetwork0_dense5_relu_fwd>               0
       Activation-76                                    (2, 128)               0
            Dense-77                                    (2, 128)           32896
DSSMRecommenderNetwork-78                                      (2, 1)               0
================================================================================
Parameters in forward computation graph, duplicate included
   Total params: 33195468
   Trainable params: 33187520
   Non-trainable params: 7948
Shared params in forward computation graph: 0
Unique parameters in model: 33195468
--------------------------------------------------------------------------------

In [13]:
network(user, query, title, image)


Out[13]:
[[0.34404233]
 [0.3254302 ]]
<NDArray 2x1 @gpu(0)>

The output is the similarity, if we wanted to train it on real data, we would need to minimize the Cosine loss, 1 - cosine_similarity.