In [1]:
import tensorflow as tf
In [5]:
x = tf.placeholder(dtype=tf.float32)
y = tf.placeholder(dtype=tf.float32)
In [6]:
result = tf.cond(x < y, lambda: tf.add(x, y), lambda: tf.square(y))
In [7]:
sess= tf.InteractiveSession()
cond returns tensors returned by the call to either true_fn or false_fn
In [11]:
r1 = sess.run(result, feed_dict={x: 1.0, y: 2.0})
print(r1)
In [12]:
r1 = sess.run(result, feed_dict={x: 3.0, y: 2.0})
print(r1)