#!/usr/bin/env python3

import os
import sys
import re

def cstring(x):
  x = x.replace('\\','\\\\').replace('"','\\"').replace('\n','\\n')
  return f'"{x}"'

def sanitize(x):
  x = x.replace('-','').replace('_','')
  return ''.join(c if c in '0123456789abcdefghijklmnopqrstuvwxyz' else '_' for c in x)

ifunc = False
with open('dispatch/ifunc') as f:
  for line in f:
    if line.strip() == '1':
      ifunc = True

typename = {}
with open('types') as f:
  for line in f:
    line = line.strip()
    line = line.split(':')
    if len(line) == 2:
      o = 'sort'
      p = line[0]
      t = line[1]
      typename[p] = t

goal = sys.argv[1]
assert goal in ('auto','manual')
o = sys.argv[2]
p = sys.argv[3]
host = sys.argv[4]
namespace = sys.argv[5]

impls = []
selected = []
usedimpls = set()

for line in sys.stdin:
  line = line.split()
  if line[0] == 'impl':
    impls += [line[1:]]
  if line[0] == 'selected':
    a,i,c = line[1:]
    selected += [(a,i,c)]
    usedimpls.add((i,c))

with open('allarches') as f:
  allarches = f.read().splitlines()

icarch = {}
iccompiler = {}

for i,c in impls:
  with open(f'compilerarch/{c}') as f:
    icarch[i,c] = f.read().strip()
  with open(f'compilerversion/{c}') as f:
    iccompiler[i,c] = f.read().strip()

if goal == 'manual':
  allarches = [a for a in allarches if any(a == icarch[i,c] for i,c in impls)]

print('#include <inttypes.h>')

if goal == 'auto':
  print(f'extern const char *{namespace}_{p}_implementation(void) __attribute__((visibility("default")));')
  print(f'extern const char *{namespace}_{p}_compiler(void) __attribute__((visibility("default")));')
else:
  print(f'extern const char *{namespace}_{p}_implementation(void);')
  print(f'extern const char *{namespace}_{p}_compiler(void);')
  print(f'extern const char *{namespace}_dispatch_{p}_implementation(long long) __attribute__((visibility("default")));')
  print(f'extern const char *{namespace}_dispatch_{p}_compiler(long long) __attribute__((visibility("default")));')
  print(f'extern long long {namespace}_numimpl_{p}(void) __attribute__((visibility("default")));')

for a in allarches:
  if a == 'default': continue
  a_csymbol = sanitize(a)
  print(f'extern int {namespace}_supports_{a_csymbol}(void);')
if len(allarches) > 1: print('')

def printfun_auto(which):
  defaultpointer = '0'

  if which == 'resolver':
    print(f'void *{namespace}_auto_{p}(void)')
  elif which == 'implementation':
    print(f'const char *{namespace}_{p}_implementation(void)')
  elif which == 'compiler':
    print(f'const char *{namespace}_{p}_compiler(void)')
  else:
    raise ValueError(f'unknown printfun {which}')
  print('{')
  for a,i,c in selected:
    cond = ''
    if a != 'default':
      cond = f'if ({namespace}_supports_{sanitize(a)}()) '
    if which == 'resolver':
      thispointer = f'{namespace}_{p}_{sanitize(i)}_{c}'
      if a == 'default':
        defaultpointer = thispointer
      print(f'  {cond}return {thispointer};')
    if which == 'implementation':
      print(f'  {cond}return {cstring(i)};')
    if which == 'compiler':
      print(f'  {cond}return {cstring(iccompiler[i,c])};')
    if a == 'default': break
  if len(selected) == 0:
    print('  return 0; /* no compiled implementations; defer crash to run time */')
  print('}')

  if which == 'resolver':
    if ifunc:
      print('')
      print(f'{rettype} {namespace}_{p}({args}) __attribute__((visibility("default"))) __attribute__((ifunc("{namespace}_auto_{p}")));')
    else:
      print('')
      print(f'static {rettype} (*{namespace}_{p}_pointer)({args}) = {defaultpointer};')
      print('')
      print('__attribute__((constructor(25521)))')
      print(f'static void {namespace}_{p}_pointer_constructor(void)')
      print('{')
      print(f'  {namespace}_{p}_pointer = {namespace}_auto_{p}();')
      print('}')
      print('')
      namedparams = args.split(',')
      namedargs = []
      for i in range(len(namedparams)):
        if namedparams[i][-1] != '*':
          namedparams[i] += ' '
        namedparams[i] += 'arg%d'%i
        namedargs += ['arg%d'%i]
      namedparams = ','.join(namedparams)
      namedargs = ','.join(namedargs)
      print('__attribute__((visibility("default")))')
      print(f'{rettype} {namespace}_{p}({namedparams})')
      print('{')
      if rettype == 'void':
        print(f'  {namespace}_{p}_pointer({namedargs});')
      else:
        print(f'  return {namespace}_{p}_pointer({namedargs});')
      print('}')

for rettype,args in ('void',f'{typename[p]} *,long long'),:
  if goal == 'auto':
    print(f'extern {rettype} {namespace}_{p}({args}) __attribute__((visibility("default")));')
  else:
    print(f'extern {rettype} {namespace}_{p}({args});')
    print(f'extern {rettype} (*{namespace}_dispatch_{p}(long long))({args}) __attribute__((visibility("default")));')
  print('')
  for i,c in impls:
    if goal == 'auto':
      if (i,c) not in usedimpls:
        continue
    print(f'extern {rettype} {namespace}_{p}_{sanitize(i)}_{c}({args}) __attribute__((visibility("default")));')
  print('')
  if goal == 'auto':
    printfun_auto('resolver')
  if goal == 'manual':
    print(f'{rettype} (*{namespace}_dispatch_{p}(long long impl))({args})')
    print('{')
    for a in allarches:
      if a == 'default': continue
      a_csymbol = sanitize(a)
      print(f'  int supports_{a_csymbol} = {namespace}_supports_{a_csymbol}();')
    print('  if (impl >= 0) {')
    for i,c in impls:
      a = icarch[i,c]
      a_csymbol = sanitize(a)
      if a == 'default':
        print(f'    if (!impl--) return {namespace}_{p}_{sanitize(i)}_{c};')
      else:
        print(f'    if (supports_{a_csymbol}) if (!impl--) return {namespace}_{p}_{sanitize(i)}_{c};')
    print('  }')
    print(f'  return {namespace}_{p};')
    print('}')
  print('')

if goal == 'auto':
  printfun_auto('implementation')
  print('')
  printfun_auto('compiler')
else:
  print(f'const char *{namespace}_dispatch_{p}_implementation(long long impl)')
  print('{')
  for a in allarches:
    if a == 'default': continue
    a_csymbol = sanitize(a)
    print(f'  int supports_{a_csymbol} = {namespace}_supports_{a_csymbol}();')
  print('  if (impl >= 0) {')
  for i,c in impls:
    a = icarch[i,c]
    a_csymbol = sanitize(a)
    if a == 'default':
      print(f'    if (!impl--) return {cstring(i)};')
    else:
      print(f'    if (supports_{a_csymbol}) if (!impl--) return {cstring(i)};')
  print('  }')
  print(f'  return {namespace}_{p}_implementation();')
  print('}')
  print('')

  print(f'const char *{namespace}_dispatch_{p}_compiler(long long impl)')
  print('{')
  for a in allarches:
    if a == 'default': continue
    a_csymbol = sanitize(a)
    print(f'  int supports_{a_csymbol} = {namespace}_supports_{a_csymbol}();')
  print('  if (impl >= 0) {')
  for i,c in impls:
    a = icarch[i,c]
    a_csymbol = sanitize(a)
    if a == 'default':
      print(f'    if (!impl--) return {cstring(iccompiler[i,c])};')
    else:
      print(f'    if (supports_{a_csymbol}) if (!impl--) return {cstring(iccompiler[i,c])};')
  print('  }')
  print(f'  return {namespace}_{p}_compiler();')
  print('}')
  print('')

  print(f'long long {namespace}_numimpl_{p}(void)')
  print('{')
  numimpla = sum(1 for (i,c) in impls if icarch[i,c] == 'default')
  numimpl = ['%d' % numimpla]
  for a in allarches:
    if a == 'default': continue
    a_csymbol = sanitize(a)
    print(f'  long long supports_{a_csymbol} = {namespace}_supports_{a_csymbol}();')
    numimpla = sum(1 for (i,c) in impls if icarch[i,c] == a)
    numimpl += [f'supports_{a_csymbol}*{numimpla}']
  numimpl = '+'.join(numimpl)
  print(f'  return {numimpl};')
  print('}')
