Skip to content Skip to sidebar Skip to footer

How To Construct A Matrix That Contains All Pairs Of Rows Of A Matrix In Tensorflow

I need to construct a matrix z that would contain combinations of pairs of rows of a matrix x. x = tf.constant([[1, 3], [2, 4], [0, 2],

Solution 1:

Unfortunately, it's a bit more complex than one would like using tensorflow operators only. I would go with creating the indices for all combinations with a while_loop then use tf.gather to collect values:

import tensorflow as tf
x = tf.constant([[1, 3],
                 [2, 4],
                 [3, 2],
                 [0, 1]], dtype=tf.int32)
m = tf.constant([], shape=(0,2), dtype=tf.int32)
_, idxs = tf.while_loop(
  lambda i, m: i < tf.shape(x)[0] - 1,
  lambda i, m: (i + 1, tf.concat([m, tf.stack([tf.tile([i], (tf.shape(x)[0] - 1 - i,)), tf.range(i + 1, tf.shape(x)[0])], axis=1)], axis=0)),
  loop_vars=(0, m),
  shape_invariants=(tf.TensorShape([]), tf.TensorShape([None, 2])))
z = tf.reshape(tf.transpose(tf.gather(x, idxs), (2,0,1)), (-1, 2))

# <tf.Tensor: shape=(12, 2), dtype=int32, numpy=# array([[1, 2],#        [1, 3],#        [1, 0],#        [2, 3],#        [2, 0],#        [3, 0],#        [3, 4],#        [3, 2],#        [3, 1],#        [4, 2],#        [4, 1],#        [2, 1]])>

This should work in both TF1 and TF2.

If the length of x is known in advance, you don't need the while_loop and could simply precompute the indices in python then place them in a constant.

Solution 2:

Here is a way to do that without a loop:

import tensorflow as tf

x = tf.constant([[1, 3],
                 [2, 4],
                 [0, 2],
                 [0, 1]], dtype=tf.int32)
# Number of rows
n = tf.shape(x)[0]
# Grid of indices
ri = tf.range(0, n - 1)
rj = ri + 1
ii, jj = tf.meshgrid(ri, rj, indexing='ij')
# Stack together
grid = tf.stack([ii, jj], axis=-1)
# Get upper triangular part
m = ii < jj
idx = tf.boolean_mask(grid, m)
# Get values
g = tf.gather(x, idx, axis=0)
# Rearrange result
result = tf.transpose(g, [2, 0, 1])
print(result.numpy())
# [[[1 2]#   [1 0]#   [1 0]#   [2 0]#   [2 0]#   [0 0]]# #  [[3 4]#   [3 2]#   [3 1]#   [4 2]#   [4 1]#   [2 1]]]

Post a Comment for "How To Construct A Matrix That Contains All Pairs Of Rows Of A Matrix In Tensorflow"