# Let us explicitly ask for TensorFlow2.
# This installs a lot of stuff - and will take a while.
!pip3 install tensorflow==2.0.1

import base64
import collections
import dataclasses
import hashlib
import itertools
import math
import numpy
import pprint
import scipy.optimize
import sys

import tensorflow as tf

print('TF version is:', tf.__version__)
print('NumPy version is:', numpy.__version__)

class Solution(object):
  potential: float
  stationarity: float
  pos: numpy.ndarray

def np_esum(spec, *arrays, optimize='greedy'):
  """Numpy einsum with default greedy optimization."""
  return numpy.einsum(spec, *arrays, optimize=optimize)

def get_onb_transform(k_ab):
  if not numpy.allclose(k_ab, k_ab.real) or not numpy.allclose(k_ab, k_ab.T):
    raise ValueError('Bad Gramian.')
  eigvals, eigvecsT = numpy.linalg.eigh(k_ab)
  if not all(v * eigvals[0] > 0 for v in eigvals):
    raise ValueError('Non-definite Gramian.')
  onb_transform = numpy.einsum('a,na->an', eigvals**(-.5), eigvecsT)
  g = np_esum('ab,Aa,Bb->AB', k_ab, onb_transform, onb_transform)
  assert numpy.allclose(
    g, numpy.eye(g.shape[0]) * ((-1, 1)[int(eigvals[0] > 0)])
    ), 'Bad ONB-transform.'
  return onb_transform, numpy.linalg.inv(onb_transform)

def numpy_signature(a, digits=3):
  """Produces a signature-fingerprint of a numpy array."""
  # Hack to ensure that -0.0 gets consistently shown as 0.0.
  minus_zero_hack = 1e-100+1e-100j
  return base64.b64encode(
               for x in numpy.round(a + minus_zero_hack, digits).flat))

def tformat(array,
            elem_filter=lambda x: abs(x) > 1e-8,
    """Formats a numpy-array in human readable table form."""
    # Leading row will be replaced if caller asked for a name-row.
    dim_widths = [
        max(1, int(math.ceil(math.log(dim + 1e-100, 10))))
        for dim in array.shape]
    format_str = '%s: %s' % (' '.join('%%%dd' % w for w in dim_widths), fmt)
    rows = []
    for indices in itertools.product(*[range(dim) for dim in array.shape]):
        v = array[indices]
        if elem_filter(v):
            rows.append(format_str % (indices + (v, )))
    num_entries = len(rows)
    if num_entries > max_rows:
      rows = rows[:max_rows]
    if cols is not None:
      width = max(map(len, rows))
      num_cols = max(1, cols // (3 + width))
      num_xrows = int(math.ceil(len(rows) / num_cols))
      padded = [('%%-%ds' % width) % s
                for s in rows + [''] * (num_cols * num_xrows - len(rows))]
      table = numpy.array(padded, dtype=object).reshape(num_cols, num_xrows).T
      xrows = [' | '.join(row) for row in table]
      xrows = rows
    if name is not None:
      return '\n'.join(
          ['=== %s, shape=%r, %d%s / %d non-small entries ===' % (
              name, array.shape,
            '' if num_entries == len(rows) else '(%d shown)' % num_entries,
            array.size)] +
          [r.strip() for r in xrows])
    return '\n'.join(xrows)

def tprint(array, sep=' ', end='\n', file=sys.stdout, **tformat_kwargs):
    """Prints a numpy array in human readable table form."""
    print(tformat(array, **tformat_kwargs), sep=sep, end=end, file=file)

### Lie Algebra definitions for Spin(8), SU(8), E7.

def permutation_sign(p):
  q = [x for x in p]  # Copy to list.
  parity = 1
  for n in range(len(p)):
    while n != q[n]:
      qn = q[n]
      q[n], q[qn] = q[qn], q[n]  # Swap to make q[qn] = qn.
      parity = -parity
  return parity

def asymm2(a, einsum_spec):
  return 0.5 * (a - numpy.einsum(einsum_spec, a))

class Spin8(object):
  """Container class for Spin(8) tensor invariants."""

  def __init__(self):
    r8 = range(8)
    self.gamma_vsc = gamma_vsc = self._get_gamma_vsc()
    # The gamma^{ab}_{alpha beta} tensor that translates between antisymmetric
    # matrices over vectors [ij] and antisymmetric matrices over spinors [sS].
    self.gamma_vvss = asymm2(
        numpy.einsum('isc,jSc->ijsS', gamma_vsc, gamma_vsc), 'ijsS->jisS')
    # The gamma^{ab}_{alpha* beta*} tensor that translates between antisymmetric
    # matrices over vectors [ij] and antisymmetric matrices over cospinors [cC].
    self.gamma_vvcc = asymm2(
        numpy.einsum('isc,jsC->ijcC', gamma_vsc, gamma_vsc), 'ijcC->jicC')
    # The gamma^{ijkl}_{alpha beta} tensor that translates between antisymmetric
    # 4-forms [ijkl] and symmetric traceless matrices over the spinors (sS).
    g_ijsS = numpy.einsum('isc,jSc->ijsS', self.gamma_vsc, self.gamma_vsc)
    g_ijcC = numpy.einsum('isc,jsC->ijcC', self.gamma_vsc, self.gamma_vsc)
    g_ijklsS = numpy.einsum('ijst,kltS->ijklsS', g_ijsS, g_ijsS)
    g_ijklcC = numpy.einsum('ijcd,kldC->ijklcC', g_ijcC, g_ijcC)
    gamma_vvvvss = numpy.zeros([8] * 6)
    gamma_vvvvcc = numpy.zeros([8] * 6)
    for perm in itertools.permutations(range(4)):
      perm_ijkl = ''.join('ijkl'[p] for p in perm)
      sign = permutation_sign(perm)
      gamma_vvvvss += sign * numpy.einsum(perm_ijkl + 'sS->ijklsS', g_ijklsS)
      gamma_vvvvcc += sign * numpy.einsum(perm_ijkl + 'cC->ijklcC', g_ijklcC)
    self.gamma_vvvvss = gamma_vvvvss / 24.0
    self.gamma_vvvvcc = gamma_vvvvcc / 24.0

  def _get_gamma_vsc(self):
    """Computes SO(8) gamma-matrices."""
    # Conventions match Green, Schwarz, Witten's, but with index-counting
    # starting at zero.
    entries = (
        "007+ 016- 025- 034+ 043- 052+ 061+ 070- "
        "101+ 110- 123- 132+ 145+ 154- 167- 176+ "
        "204+ 215- 226+ 237- 240- 251+ 262- 273+ "
        "302+ 313+ 320- 331- 346- 357- 364+ 375+ "
        "403+ 412- 421+ 430- 447+ 456- 465+ 474- "
        "505+ 514+ 527+ 536+ 541- 550- 563- 572- "
        "606+ 617+ 624- 635- 642+ 653+ 660- 671- "
        "700+ 711+ 722+ 733+ 744+ 755+ 766+ 777+")
    ret = numpy.zeros([8, 8, 8])
    for ijkc in entries.split():
     ijk = tuple(map(int, ijkc[:-1]))
     ret[ijk] = +1 if ijkc[-1] == '+' else -1
    return ret

class SU8(object):
  """Container class for su(8) tensor invariants."""

  def __init__(self):
    # Tensor that translates between adjoint indices 'a' and
    # (vector) x (vector) indices 'ij'
    ij_map = [(i, j) for i in range(8) for j in range(8) if i < j]
    # We also need the mapping between 8 x 8 and 35 representations, using
    # common conventions for a basis of the 35-representation, and likewise
    # for 8 x 8 and 28.
    m_35_8_8 = numpy.zeros([35, 8, 8], dtype=numpy.complex128)
    m_28_8_8 = numpy.zeros([28, 8, 8], dtype=numpy.complex128)
    for n in range(7):
      m_35_8_8[n, n, n] = +1.0
      m_35_8_8[n, n + 1, n + 1] = -1.0
    for a, (m, n) in enumerate(ij_map):
      m_35_8_8[a + 7, m, n] = m_35_8_8[a + 7, n, m] = 1.0
      m_28_8_8[a, m, n] = 1.0
      m_28_8_8[a, n, m] = -1.0
    # The su8 'Generator Matrices'.
    t_aij = numpy.zeros([63, 8, 8], dtype=numpy.complex128)
    t_aij[:35, :, :] = 1.0j * m_35_8_8
    for a, (i, j) in enumerate(ij_map):
      t_aij[a + 35, i, j] = -1.0
      t_aij[a + 35, j, i] = 1.0
    self.ij_map = ij_map
    self.m_35_8_8 = m_35_8_8
    self.m_28_8_8 = m_28_8_8
    self.t_aij = t_aij

class E7(object):
  """Container class for e7 tensor invariants."""

  def __init__(self, spin8, su8):
    self._spin8 = spin8
    self._su8 = su8
    ij_map = su8.ij_map
    t_a_ij_kl = numpy.zeros([133, 56, 56], dtype=numpy.complex128)
    t_a_ij_kl[:35, 28:, :28] = (1 / 8.0) * (
                spin8.gamma_vvvvss, su8.m_35_8_8, su8.m_28_8_8, su8.m_28_8_8))
    t_a_ij_kl[:35, :28, 28:] = t_a_ij_kl[:35, 28:, :28]
    t_a_ij_kl[35:70, 28:, :28] = (1.0j / 8.0) * (
                spin8.gamma_vvvvcc, su8.m_35_8_8, su8.m_28_8_8, su8.m_28_8_8))
    t_a_ij_kl[35:70, :28, 28:] = -t_a_ij_kl[35:70, 28:, :28]
    # We need to find the action of the su(8) algebra on the
    # 28-representation.
    su8_28 = 2 * np_esum('aij,mn,Iim,Jjn->aIJ',
                          numpy.eye(8, dtype=numpy.complex128),
                          su8.m_28_8_8, su8.m_28_8_8)
    t_a_ij_kl[70:, :28, :28] = su8_28
    t_a_ij_kl[70:, 28:, 28:] = su8_28.conjugate()
    self.t_a_ij_kl = t_a_ij_kl
    self.k_ab = numpy.einsum('aMN,bNM->ab', t_a_ij_kl, t_a_ij_kl)
    self.v70_as_sc8x8 = numpy.einsum('sc,xab->sxcab',
                                      su8.m_35_8_8).reshape(70, 2, 8, 8)
    # For e7, there actually is a better orthonormal basis:
    # the sd/asd 4-forms. The approach used here however readily generalizes
    # to all other groups.
    self.v70_onb_onbinv = get_onb_transform(self.k_ab[:70, :70])

def get_proj_35_8888(want_selfdual=True):
  """Computes the (35, 8, 8, 8, 8)-projector to the (anti)self-dual 4-forms."""
  # We first need some basis for the 35 self-dual 4-forms.
  # Our convention is that we lexicographically list those 8-choose-4
  # combinations that contain the index 0.
  sign_selfdual = 1 if want_selfdual else -1
  ret = numpy.zeros([35, 8, 8, 8, 8], dtype=numpy.float64)
  def get_selfdual(ijkl):
    mnpq = tuple(n for n in range(8) if n not in ijkl)
    return (sign_selfdual * permutation_sign(ijkl + mnpq),
            ijkl, mnpq)
  selfduals = [get_selfdual(ijkl)
               for ijkl in itertools.combinations(range(8), 4)
               if 0 in ijkl]
  for num_sd, (sign_sd, ijkl, mnpq) in enumerate(selfduals):
    for abcd in itertools.permutations(range(4)):
      sign_abcd = permutation_sign(abcd)
          ijkl[abcd[3]]] = sign_abcd
          mnpq[abcd[3]]] = sign_abcd * sign_sd
  return ret / 24.0

spin8 = Spin8()
su8 = SU8()
e7 = E7(spin8, su8)

assert (numpy_signature(e7.t_a_ij_kl) ==
        'MMExYjC6Qr6gunZIYfRLLgM2PDtwUDYujBNzAIukAVY'), 'Bad E7(7) definitions.'

### SO(p, 8-p) gaugings

def get_so_pq_E(p=8):
  if p == 8 or p == 0:
    return numpy.eye(56, dtype=complex)
  q = 8 - p
  pq_ratio = p / q
  x88 = numpy.diag([-1.0] * p + [1.0 * pq_ratio] * q)
  t = 0.25j * numpy.pi / (1 + pq_ratio)
  k_ab = numpy.einsum('aij,bij->ab', su8.m_35_8_8, su8.m_35_8_8)
  v35 = numpy.einsum('mab,ab,mM->M', su8.m_35_8_8, x88, numpy.linalg.inv(k_ab))
  gen_E = numpy.einsum(
    numpy.pad(v35, [(0, 133 - 35)], 'constant'))
  return scipy.linalg.expm(-t * gen_E)

### Supergravity.

class SUGRATensors(object):
  v70: tf.Tensor
  vielbein: tf.Tensor
  tee_tensor: tf.Tensor
  a1: tf.Tensor
  a2: tf.Tensor
  potential: tf.Tensor

def get_tf_stationarity(fn_potential, **fn_potential_kwargs):
  """Returns a @tf.function that computes |grad potential|^2."""
  def stationarity(pos):
    tape = tf.GradientTape()
    with tape:
      potential = fn_potential(pos, **fn_potential_kwargs)
    grad_potential = tape.gradient(potential, pos)
    return tf.reduce_sum(grad_potential * grad_potential)
  return stationarity

def dwn_stationarity(t_a1, t_a2):
  """Computes the de Wit-Nicolai stationarity-condition tensor."""
  # See:, text after (3.2).
  t_x0 = (
      +4.0 * tf.einsum('mi,mjkl->ijkl', t_a1, t_a2)
      -3.0 * tf.einsum('mnij,nklm->ijkl', t_a2, t_a2))
  t_x0_real = tf.math.real(t_x0)
  t_x0_imag = tf.math.imag(t_x0)
  tc_sd = tf.constant(get_proj_35_8888(True))
  tc_asd = tf.constant(get_proj_35_8888(False))
  t_x_real_sd = tf.einsum('aijkl,ijkl->a', tc_sd, t_x0_real)
  t_x_imag_asd = tf.einsum('aijkl,ijkl->a', tc_asd, t_x0_imag)
  return (tf.einsum('a,a->', t_x_real_sd, t_x_real_sd) +
          tf.einsum('a,a->', t_x_imag_asd, t_x_imag_asd))

def tf_sugra_tensors(t_v70, compute_masses, t_lhs_vielbein, t_rhs_E):
  """Returns key tensors for D=4 supergravity."""
  tc_28_8_8 = tf.constant(su8.m_28_8_8)
  t_e7_generator_v70 = tf.einsum(
      tf.complex(t_v70, tf.constant([0.0] * 70, dtype=tf.float64)),
      tf.constant(e7.t_a_ij_kl[:70, :, :], dtype=tf.complex128))
  t_complex_vielbein0 = tf.linalg.expm(t_e7_generator_v70) @ t_rhs_E
  if compute_masses:
    t_complex_vielbein = t_lhs_vielbein @ t_complex_vielbein0
    t_complex_vielbein = t_complex_vielbein0
  def expand_ijkl(t_ab):
    return 0.5 * tf.einsum(
        tf.einsum('AB,Aij->ijB', t_ab, tc_28_8_8),
  t_u_ijIJ = expand_ijkl(t_complex_vielbein[:28, :28])
  t_u_klKL = expand_ijkl(t_complex_vielbein[28:, 28:])
  t_v_ijKL = expand_ijkl(t_complex_vielbein[:28, 28:])
  t_v_klIJ = expand_ijkl(t_complex_vielbein[28:, :28])
  t_uv = t_u_klKL + t_v_klIJ
  t_uuvv = (tf.einsum('lmJK,kmKI->lkIJ', t_u_ijIJ, t_u_klKL) -
            tf.einsum('lmJK,kmKI->lkIJ', t_v_ijKL, t_v_klIJ))
  t_T = tf.einsum('ijIJ,lkIJ->lkij', t_uv, t_uuvv)
  t_A1 = (-4.0 / 21.0) * tf.linalg.trace(tf.einsum('mijn->ijmn', t_T))
  t_A2 = (-4.0 / (3 * 3)) * (
      # Antisymmetrize in last 3 indices, taking into account antisymmetry
      # in last two indices.
      + tf.einsum('lijk->ljki', t_T)
      + tf.einsum('lijk->lkij', t_T))
  t_A1_real = tf.math.real(t_A1)
  t_A1_imag = tf.math.imag(t_A1)
  t_A2_real = tf.math.real(t_A2)
  t_A2_imag = tf.math.imag(t_A2)
  t_A1_potential = (-3.0 / 4) * (
      tf.einsum('ij,ij->', t_A1_real, t_A1_real) +
      tf.einsum('ij,ij->', t_A1_imag, t_A1_imag))
  t_A2_potential = (1.0 / 24) * (
      tf.einsum('ijkl,ijkl->', t_A2_real, t_A2_real) +
      tf.einsum('ijkl,ijkl->', t_A2_imag, t_A2_imag))
  t_potential = t_A1_potential + t_A2_potential
  return t_v70, t_complex_vielbein, t_T, t_A1, t_A2, t_potential

def so8_sugra_tensors(t_v70, tc_rhs_E):
  t_v70, t_complex_vielbein, t_T, t_A1, t_A2, t_potential = (
     tf_sugra_tensors(t_v70, False, 0.0, tc_rhs_E))
  return SUGRATensors(

def so8_sugra_scalar_masses(v70, so_pq_p):
  # Note: In some situations, small deviations in the input give quite
  # noticeable deviations in the scalar mass-spectrum.
  # Getting reliable numbers here really requires satisfying
  # the stationarity-condition to high accuracy.
  tc_rhs_E = tf.constant(get_so_pq_E(so_pq_p), dtype=tf.complex128)
  tc_e7_onb = tf.constant(e7.v70_onb_onbinv[0], dtype=tf.complex128)
  tc_e7_taMN = tf.constant(e7.t_a_ij_kl[:70, :, :], dtype=tf.complex128)
  t_v70 = tf.constant(v70, dtype=tf.float64)
  def tf_grad_potential_lhs_onb(t_d_v70_onb):
    tape = tf.GradientTape()
    with tape:
      t_d_gen_e7 = tf.einsum(
                    tf.complex(t_d_v70_onb, tf.zeros_like(t_d_v70_onb))),
      t_lhs_vielbein = (tf.eye(56, dtype=tf.complex128) +
                        t_d_gen_e7 + 0.5 * t_d_gen_e7 @ t_d_gen_e7)
      t_potential = (
    return tape.gradient(t_potential, t_d_v70_onb)
  t_d_v70_onb = tf.Variable(numpy.zeros(70), dtype=tf.float64)
  tape = tf.GradientTape(persistent=True)
  with tape:
    grad_potential = tf.unstack(tf_grad_potential_lhs_onb(t_d_v70_onb))

  t_mm = tf.stack([tape.gradient(grad_potential[k], t_d_v70_onb)
                  for k in range(70)], axis=1)
  stensors = so8_sugra_tensors(t_v70, tc_rhs_E)
  return (t_mm * (36.0 / tf.abs(stensors.potential))).numpy()

### Scanning

def scanner(
  """Scans for critical points in the scalar potential.

    use_dwn_stationarity: Whether to use the explicit stationarity condition
      from `dwn_stationarity`.
    so_pq_p: SO(p, 8-p) non-compact form of the gauge group to use.
    seed: Random number generator seed for generating starting points.
    scale: Scale for normal-distributed search starting point coordinates.
    stationarity_threshold: Upper bound on permissible post-optimization
      stationarity for a solution to be considered good.
    relu_coordinate_threshold: Threshold for any coordinate-value at which
      a ReLU-term kicks in, in order to move coordinates back to near zero.
      (This is relevant for noncompact gaugings with flat directions,
      where solutions can move 'very far out'.)
    gtol: `gtol` parameter for scipy.optimize.fmin_bfgs.
    f_squashed: Squashing-function for stationarity.
      Should be approximately linear near zero, monotonic, and not growing
      faster than logarithmic.
    `Solution` numerical solutions.
  # Use a seeded random number generator for better reproducibility
  # (but note that scipy's optimizers may themselves use independent
  # and not-easily-controllable random state).
  rng = numpy.random.RandomState(seed=seed)
  def get_x0():
    return rng.normal(scale=scale, size=70)
  tc_rhs_E = tf.constant(get_so_pq_E(so_pq_p), dtype=tf.complex128)
  def f_potential(scalars):
    return so8_sugra_tensors(tf.constant(scalars), tc_rhs_E).potential.numpy()
  f_grad_pot_sq_stationarity = (
      None if use_dwn_stationarity
      else get_tf_stationarity(
          lambda t_pos: so8_sugra_tensors(t_pos, tc_rhs_E).potential))
  def f_t_stationarity(t_pos):
    if use_dwn_stationarity:
      stensors = so8_sugra_tensors(t_pos, tc_rhs_E)
      stationarity = dwn_stationarity(stensors.a1, stensors.a2)
      stationarity = f_grad_pot_sq_stationarity(t_pos)
    eff_stationarity = stationarity + tf.reduce_sum(
        tf.nn.relu(abs(t_pos) - relu_coordinate_threshold))
    return eff_stationarity
  def f_opt(pos):
    t_pos = tf.constant(pos)
    t_stationarity = f_squashed(f_t_stationarity(t_pos))
    return t_stationarity.numpy()
  def fprime_opt(pos):
    t_pos = tf.constant(pos)
    tape = tf.GradientTape()
    with tape:
      t_stationarity = f_squashed(f_t_stationarity(t_pos))
    t_grad_opt = tape.gradient(t_stationarity, t_pos)
    return t_grad_opt.numpy()
  while True:
    opt = scipy.optimize.fmin_bfgs(
        f_opt, get_x0(), fprime=fprime_opt, gtol=gtol, maxiter=10**4, disp=0)
    opt_pot = f_potential(opt)
    opt_stat = f_opt(opt)
    if numpy.isnan(opt_pot) or not opt_stat < stationarity_threshold:
      continue  # Optimization ran into a bad solution.
    solution = Solution(potential=opt_pot,
    yield solution

### Demo.

def demo(seed=0,
  solutions_iter = scanner(scale=scale, seed=seed,
                           so_pq_p=so_pq_p, f_squashed=f_squashed)
  for num_solution in range(num_solutions):
    sol = next(solutions_iter)
    print('=== Solution ===')
    mm0 = so8_sugra_scalar_masses(sol.pos, so_pq_p)
    print('\nScalar Masses for: V/g^2=%s:' % sol.potential)
        numpy.round(numpy.linalg.eigh(mm0)[0], 3)).items()))
