#define _GNU_SOURCE 1
#include <dlfcn.h>

#include <vector>
#include <iostream>
#include <string>
#include <algorithm>
#include <numeric>
#include <cmath>
#include <cstring>
#include <cstdlib>

#include "hp-timing.h"
#include "perf.h"

static hp_timing_t hp_timing_overhead;

__thread long tval;
extern __thread long tvalie __attribute__((alias ("tval")));

class test_base {
  std::string const name;
  long *const resultvar;
  bool const addr;
public:
  test_base(std::string const&nm, long *rv, bool byref) :
    name (nm), resultvar (rv), addr (byref) {}
  ~test_base() {}
  virtual hp_timing_t operator () () const = 0;
  void check (ntr res) const {
    if (addr)
      {
	if (res == (intptr_t) resultvar)
	  return;
	else
	  {
	    std::cerr << name << " failed check, expected " << resultvar
		      << ", got " << (long *)res << std::endl << std::flush;
	  }
      }
    else
      {
	if (res == *resultvar)
	  return;
	else
	  {
	    std::cerr << name << " failed check, expected " << *resultvar
		      << ", got " << res << std::endl << std::flush;
	  }
      }
    throw this;
  }
  std::string const get_name () const { return name; }
};

class test_self_timing : public test_base {
  self_timing_t const *f;
public:
  test_self_timing (std::string const&nm, long *rv, bool byref,
		    self_timing_t *ft) :
    test_base (nm, rv, byref), f (ft) {}
  hp_timing_t operator () () const {
    hp_timing_t t;
    check (f (&t));
    return t;
  }
};

class test_caller_timing : public test_base {
  caller_timing_t const *f;
public:
  test_caller_timing (std::string const&nm, long *rv, bool byref,
		      caller_timing_t const *ft) :
    test_base (nm, rv, byref), f (ft) {}
  hp_timing_t operator () () const {
    hp_timing_t before, after, diff;
    ntr res;
    HP_TIMING_NOW (before);
    res = f ();
    HP_TIMING_NOW (after);
    check (res);
    HP_TIMING_DIFF (diff, before, after);
    return diff;
  }
};

typedef std::vector<test_base const *> testvec;

test_base const *npt2test_base (std::string const& group, const npt *test,
				long *tval, long *ltval) {
  long *rv;
  bool byref;
  test_base const *tb;

  switch (NTT_ADDRLOAD (test->name))
    {
    case NTT_LOAD: byref = false; break;
    case NTT_ADDR: byref = true; break;
    default: abort ();
    }

  switch (NTT_TVAL (test->name))
    {
    case NTT_TVAL: rv = tval; break;
    case NTT_LTVAL: rv = ltval; break;
    default: abort ();
    }

  switch (NTT_TIMING (test->name))
    {
    case NTT_SELF_TIMING:
      return new test_self_timing (group + "." + test->name, rv, byref,
				   test->func.self_timing);

    case NTT_CALLER_TIMING:
      return new test_caller_timing (group + "." + test->name, rv, byref,
				     test->func.caller_timing);

    default:
      abort ();
    }
}

static void
run_tests_randomized (testvec const& tests, size_t count)
{
  size_t ntests = tests.size ();
  typedef std::vector<size_t> testidxvec;
  testidxvec testidxs (ntests);
  typedef std::vector<hp_timing_t> hptvec;
  typedef std::vector<unsigned long long> ullvec;
  typedef std::vector<long double> ldvec;
  hptvec mins(ntests, (hp_timing_t)-1);
  ullvec sumchk(ntests);
  ldvec sums(ntests);

  for (size_t i = 0; i < ntests; i++)
    testidxs[i] = i;

  for (size_t round = 0; round < count; round++)
    {
      std::random_shuffle (testidxs.begin (), testidxs.end ());

      std::cerr << "round " << round << ":" << std::endl;

      for (size_t i = 0; i < ntests; i++)
	{
	  size_t testidx = testidxs[i];
	  test_base const &test = *tests[testidx];
	  hp_timing_t ttime = test ();

	  if (ttime < mins[testidx])
	    mins[testidx] = ttime;

	  sums[testidx] += ttime;
	  sumchk[testidx] += ttime;

	  std::cerr << test.get_name () << ": " << ttime << std::endl;
	}
    }

  std::cout << "name: min, avg over "
	    << count << " runs (estimated timing overhead is "
	    << hp_timing_overhead << ")" << std::endl;

  for (size_t i = 0; i < ntests; i++)
    {
      std::cout << tests[i]->get_name ()
		<< ": "
		<< mins[i]
		<< ", "
		<< (sums[i] != sumchk[i] ? "(overflowed)" : "")
		<< (sums[i] / count)
		<< std::endl;
    }
}

static void
run_tests_tight_loop (testvec const& tests, size_t count)
{
  std::cout << "name: min in "
	    << count << " runs (estimated timing overhead is "
	    << hp_timing_overhead << ")" << std::endl;

  for (size_t i = 0, ntests = tests.size (); i < ntests; i++)
    {
      test_base const &test = *tests[i];
      hp_timing_t tmin;

      for (size_t round = 0; round < count; round++)
	{
	  hp_timing_t ttime = test ();

	  if (round == 0 || ttime < tmin)
	    tmin = ttime;

	  // std::cerr << test.get_name () << ": " << ttime << std::endl;
	}
      std::cout << tests[i]->get_name () << ": " << tmin << std::endl;
    }
}

int
main(int argc, char *argv[]) {
  int i = 1;
  size_t count;
  bool randomized = false, count_set = false;

  for (; i < argc && argv[i][0] == '-'; i++)
    {
      std::string arg = argv[i];

      if (arg == "-n")
	{
	  ++i;
	  if (i == argc)
	    {
	      std::cerr << "-n needs an argument" << std::endl;
	      std::exit (1);
	    }
	  char *end;
	  long tcount = strtol (argv[i], &end, 0);
	  if (*end != 0 || count <= 0)
	    {
	      std::cerr << "invalid argument given to -n" << std::endl;
	      std::exit (1);
	    }
	  count = tcount;
	  count_set = true;
	}
      else if (arg == "-r")
	randomized = true;
      else if (arg == "-t")
	randomized = false;
      else if (arg == "-h" || arg == "-?")
	{
	  std::cout << "Usage: " << argv[0] << " [-h?] [-n count] [-r | -t]"
		    << std::endl;
	  std::exit (0);
	}
      else
	break;
    }

  if (! count_set)
    count = randomized ? 10000 : 100000000;

  HP_TIMING_DIFF_INIT ();

  testvec tv;

  for (; i < argc; ++i)
    {
      void *dlhandle = dlopen(argv[i], RTLD_LOCAL | RTLD_LAZY);
      npt const *const *tests;
      long *ltvalp;

      if (! dlhandle)
	{
	  std::cerr << argv[i] << ": could not dlopen: "
		    << dlerror () << std::endl;
	  continue;
	}

      tests = static_cast<npt const *const *>(dlsym (dlhandle, "ntls_tests"));
      ltvalp = static_cast<long *>(dlsym (dlhandle, "ltval"));

      if (! tests)
	{
	  std::cerr << argv[i] << ": could not find test list" << std::endl;
	  dlclose (dlhandle);
	  continue;
	}

      for (npt const *const *testit = tests;
	   *testit; ++testit)
	tv.push_back (npt2test_base (argv[i], *testit, &tval, ltvalp));
    }

  if (randomized)
    run_tests_randomized (tv, count);
  else
    run_tests_tight_loop (tv, count);
}
