/* fft.c

   C program for computing the FFT with OpenMP.

   main function tests by making a spike vector, calling FFT, and
   checking the result.  Args to main function are lg of the problem
   size, followed by an optional spike position (default is 1).

   To compile, use the -mp flag.

*/

#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <omp.h>

#include "FFT.h"
#include "timer.h"


/*
 * alex iliev, feb 2003
 *
 * my approach was to write a routine which performs a fixed number of
 * butterflies within a level. then at each level call this routine once per
 * processor, using "parallel for", each processor thus doing an equal part of
 * the level.
 *
 */


/* work out a floor of lg(N), by repeated right shifting by 1 */
int lgN_floor (int N);

/* 2^x */
int raise2 (int x);


/* return the bit reversed version of i */
unsigned int bit_reverse (unsigned int a, size_t lgN);


/* do a bit reversal permutation on a */
void bit_reversal_permute (complex_t *a, size_t N);


/* do 'count' butterflies starting at group g, bfly k */
void n_butterflies (complex_t * A, int l, int m,
		    int g, int k, int count,
		    complex_t omega_m);




/* Compute the FFT of an array a of length N. */
void FFT(complex_t * A, int N)
{

    int l, m;
    complex_t omega_m;
    int n = lgN_floor (N);
    int i;
    
    size_t period;		/* how many butterflies to do at a time  */
    

/*      fprintf (stderr, "Running on %d nodes\n", omp_get_num_procs()); */
    

    /*
     * the real algorithm
     */
    
    bit_reversal_permute (A, N);

    period = (N/2) / omp_get_num_procs(); /* even distribution of butterflies
					   * across processors */
    if (period == 0) period = 1;

    
    for (l = 0; l < n; l++) {	/* l is the level number */

	m = raise2 (l+1);	/* m is the size (in wires) of the groups */

	COMPLEX_ROOT (omega_m, m, 1);

	/*
	  do the butterflies at this level, 'period' at a time, sending each
	  such group to a separate processor
	*/
#pragma omp parallel for
	for (i = 0; i < (N/2)/period; i++) {
	    int start_bfly = i * period;
	    n_butterflies (A, l, m,
			   start_bfly / (m/2), start_bfly % (m/2),
			   period,
			   omega_m);
	}

    }

}



/*
 * do a fixed number of butterflies at a given level.
 */
void n_butterflies (complex_t * A, /* the array */
		    int l,	/* level */
		    int m,	/* 2^(l+1) */
		    int g,	/* number of the group to start in */
		    int k,	/* index of the starting bfly in the group */
		    int count,	/* how many butterflies to do? */
		    complex_t omega_m /* omega_m^1 */
    )
{
    int i;
    int group_size = m/2;
    complex_t omega, t;


    /* redefine g to be the starting wire number of the current group from now
     */
    g = group_size * g * 2;

    /* in the loop we maintain 'omega' = omega_m^k, so get a right start here */
    COMPLEX_ROOT (omega, m, k);
    
    for (i=0; i < count; i++) {
	
	/*
	  do the butterfly
	*/

	COMPLEX_MULT (t, omega, A[g+k+m/2]); /* get the twiddle factor t */
    
	/* do the twiddles. must be in this order! */
	COMPLEX_SUB (A[g+k+m/2], A[g+k], t);
	COMPLEX_ADD (A[g+k],     A[g+k], t);


	COMPLEX_MULT (omega, omega, omega_m); /* maintain omega = omega_m^k */

	k++;

	/*
	  roll to next group if needed
	*/
	if (k >= group_size) {
	    k = 0;
	    g += group_size * 2;
	    omega.real = 1; omega.imag = 0; /* reset 'omega' to 1 = omega_m^0 */
	}

    }

}





/*
  return the bit reversed version of i
*/
unsigned int bit_reverse (unsigned int a, size_t lgN) {

    unsigned int answer = 0;
    unsigned int i;

    for (i = 0; i < lgN; i++) {
	/* set answer[i] to a[lgN-1-i] */
	int j   = lgN-1-i;
	int a_j = (a & (1 << j)) >> j;
	answer |= a_j << i;
    }

    return answer;
}


/*
  do a bit-reversal permutation on array a
*/
void bit_reversal_permute (complex_t *a, size_t N) {

    int i, r;		/* r will be the reversal of i */
    size_t lgN = lgN_floor (N);
    complex_t temp;
    
/* fairly obvious loop parallelization */
#pragma omp parallel for private(r, temp)
    for (i = 0; i < N; i++) {
	r = bit_reverse (i, lgN);
	if (r > i) {
	    /* only swap forward */
	    temp = a[r];
	    a[r] = a[i];
	    a[i] = temp;
	}
    }

}


	     
/* 2 ^ x */
int raise2 (int x) {
    return 1 << x;
}

/* work out a floor of lg(N), by repeated right shifting by 1 */
int lgN_floor (int N) {

    int i;
    for (i = 0; N > 1; N >>= 1, i++)
	;

    return i;
}



    

/* Check that the result of an FFT of a spike vector is correct.  We
   are given the resulting vector a of the FFT, its size N, and the
   spike position.  This could have been written with OpenMP, but it's
   not necessary. */
int checkFFT(complex_t *a, int N, int spike)
{
  complex_t target;		/* what we want to see */
  int i;
  int OK = 1;			/* everything's OK until we see otherwise */
  double tolerance = 1e-9;	/* how close is close enough? */

  for (i = 0; i < N; i++)
    {
      COMPLEX_ROOT(target, N, i * spike);
      if ((fabs(target.real - a[i].real) > tolerance) ||
	  (fabs(target.imag - a[i].imag > tolerance)))
	{
	  OK = 0;
	  printf("Error in position %d: actual = (%f, %f), target = (%f, %f)\n",
		 i, a[i].real, a[i].imag, target.real, target.imag);
	}
    }

  return OK;
}


/* Make vector a, of length N, be a spike vector, with a given spike
   position.  This could have been written with OpenMP, but it's not
   necessary. */
void makeSpike(complex_t *a, int N, int spike)
{
  int i;
  complex_t zero = { 0.0, 0.0 };

/* sasho: initialize in parallel to get good data placement across the nodes */
#pragma omp parallel for
  for (i = 0; i < N; i++)
    a[i] = zero;

  a[spike].real = 1.0;
}


/* Debugging function to print a vector a of length N.  VERY useful. */
void printv(complex_t *a, int N)
{
  int i;

  for (i = 0; i < N; i++)
    printf("%d: (%f, %f)\n", i, a[i].real, a[i].imag);
}


/* Driver to test the FFT function. */
int main(int argc, char **argv)
{
  int N, spike;			/* vector length and spike position */
  complex_t *a;			/* the vector */
  cs88_timer_t timer;		/* a timer */

  /* sasho: set the number of threads right at the start */
  omp_set_num_threads(omp_get_num_procs());
  
  /* Make sure we at least have lg of vector length on command line. */
  if (argc < 2)
    {
      printf("Usage: %s lgN [spike]\n", argv[0]);
      exit(1);
    }

  /* We do.  Grab it, convert it to the vector length, and save it. */
  N = 1 << atoi(argv[1]);

  /* Grab the spike position if there is one.  Otherwise, use the
     default of 1. */
  if (argc >= 3)
    spike = atoi(argv[2]);
  else
    spike = 1;

  /* Now we can allocate the vector... */
  a = malloc(N * sizeof(complex_t));

  /* ...and make the spike...*/
  makeSpike(a, N, spike);

  /* Start the timer. */
  TIMER_RESET(timer);
  TIMER_START(timer);  

  /* ...and FFT it... */
  FFT(a, N);

  /* Stop the timer. */
  TIMER_STOP(timer);

  /* ...and see how we did... */
  if (checkFFT(a, N, spike))
    printf("Answer is OK!\n");
  else
    printf("Bad answer!\n");

  printf("Time = %f\n", TIMER_EVAL(timer));

  /* ...and go have milk and cookies! */
  free(a);

  return 0;
}
