#include <cstdio>
#include <iostream>
#include <ctime>
#include "adolc.h"
//#include "taputil.h"
//#include "adouble.h"
//#include "oplate.h"
#include "RevolveDriver.h"



extern "C" {
#include "revolve.h"
}

/* The following are defined in adolc.cpp */
int revolve_fos_reverse(short   tnum,
			int     depen,    
			int     indep,    
			double  *lagrange,
			double  *results); 
extern int trace_flag;
extern locint location_cnt;


//void take_stock_intermediate();
//void trace_on_intermediate(short tnum,int revals);
//locint keep_stock_intermediate();
//void trace_off_intermediate(int flag);





namespace adolc
{
   
//@@@@@@@  functions needed for checkpointing added by EGK derived from the
//checkpointing package (the parts that are still under the original common
//public license). 
   
void take_stock_intermediate()
{
   trace_flag = 1;
}
void trace_on_intermediate(short tnum,int revals)
{
   start_trace(tnum,revals);
   take_stock_intermediate();
}
   
  

locint keep_stock_intermediate()
{
   trace_flag = 0;
   return location_cnt;
}
void trace_off_intermediate(int flag)
{
   int locations;
   locations = keep_stock_intermediate();
   stop_trace(locations,flag);
}
   
     
void
Revolve::
Evaluate(const double* parameters,double* value, double* deriv)
{

   if(parameters == 0)
   {
      std::cout << "Error: parameters is null." << std::endl;
      abort();
   }

   if(value == 0)
   {
      std::cout << "Error: value is null." << std::endl;
      abort();
   }


   if(m_NumberOfIndependents == 0)
   {
      std::cout << "Error: number of independents is zero." 
		<< std::endl;
      abort();
   }


   m_Steps = m_Labels.size();

   // Check to make sure labels have been set
   if(m_Steps == 0)
   {
      std::cout << "Error: Number Of Timesteps is zero" << std::endl;
      abort();
   }


   if(m_Stack.size() == 0)
   {
      std::cout << "Error: Stack size is zero" << std::endl;
      abort();
   }

   //printf(" ENTER:   STEPS, SNAPS, INFO \n");
   //scanf("%i",&steps);
   //scanf("%i",&snaps);
   //scanf("%i",&info);
   //printf("  %d  %d  %d \n",steps,snaps,info);
	 
   m_Capo = 0;
   m_Fine = m_Steps + m_Capo;
   m_Check = -1;                  /* Neccessary for first call */
   m_Keep = 1;
   //m_Tape = 0;


   
   if(m_Info > 0)
   {
      printf("m_Check:   %d\n",m_Check);
      printf("m_Capo:    %d\n",m_Capo);
      printf("m_Fine:    %d\n",m_Fine);
      printf("m_Snaps:   %d\n",m_Snaps);
      printf("m_Steps:   %d\n",m_Steps);
   }
   

   
   m_TopOfStack = 0;


   if(0 == deriv)
   {
      // Sometimes it may be usefull for a function to know if it needs to
      // calculated derivatives, therefore we set this special flag here.
      m_Stack[m_TopOfStack]->m_JustEvaluateFunction=true;
   }
   else
   {
      m_Stack[m_TopOfStack]->m_JustEvaluateFunction=false;
   }

   m_Derivative.resize(m_NumberOfIndependents);
     

   currently_in_initial_timestep = 1;
   if(0 == deriv)
      currently_in_initial_timestep = 0;


   
   m_DoingForwardSweep = true;
   m_Stack[m_TopOfStack]->m_DoingForwardSweep = 1;
   
   for(unsigned int i=0;i<m_NumberOfIndependents;i++)
     m_Stack[m_TopOfStack]->m_Independents[i] = parameters[i];
   

   enum action whatodo;
   m_CurrentTime = 0;

   m_TimeSpentInReverse = 0.0;
   m_TimeSpentRecording = 0.0;
   double timeSpentMovingForwardWithoutRecording = 0.0;
   


   do
   {
      whatodo = revolve(&m_Check, &m_Capo, &m_Fine, m_Snaps, &m_Info);

      if ((whatodo == takeshot))
      {
	 if(m_Info > 1)
	    printf(" takeshot at %6d,   m_Fine: %6d,  m_Check:  %6d  \n",
		   m_Capo,m_Fine,m_Check);

	 this->TakeShot();
      }
      else if ((whatodo == advance))
      {
	 if(m_Info > 1)
	    printf(" advance to %7d,   m_Fine: %6d,  m_Check:  %6d \n",
		   m_Capo,m_Fine,m_Check);

	 const unsigned int start_time = clock();

	 for( ; m_CurrentTime < m_Capo; m_CurrentTime++)
	   this->AdvanceOneStep(m_CurrentTime);

	 const unsigned int end_time = clock();
	 const double elapsed_time = 
	    (end_time - start_time )/(double)CLOCKS_PER_SEC;
	 timeSpentMovingForwardWithoutRecording += elapsed_time;

      }
      else if ((whatodo == firsturn)) 
      {
	 if(m_Info > 0)
	    {printf(" firsturn at %6d,   m_Fine: %6d,  m_Check:  %6d \n",
		    m_Capo,m_Fine,m_Check);fflush(0);}  

	 if(deriv)
	 {
	    this->FirstTurn(m_CurrentTime);
	 }
	 else
	 {
	    // If we're only evaluating the function, then break out after
	    // doing the final step.
	    this->AdvanceOneStep(m_CurrentTime);
	    this->SetDependent();
	    break;
	 }
      }
      else if ((whatodo == youturn))
      {
	 if(m_Info > 1)
	    printf(" youturn at %7d,   m_Fine: %6d,  m_Check:  %6d \n",
		   m_Capo,m_Fine,m_Check);

	 this->YouTurn(m_CurrentTime);
      }
      else if ((whatodo == restore)) 
      {
	 if(m_Info > 1)
	    printf(" restore at %7d,   m_Fine: %6d,  m_Check:  %6d \n",
		   m_Capo,m_Fine,m_Check);

	 this->Restore();
	 m_CurrentTime = m_Capo;
      }
      else if (whatodo == error) 
      {
	 printf(" irregular termination of revolve \n");
	 switch(m_Info)
	 {
	    case 10: 
	       printf(" number of checkpoints stored exceeds checkup, \n");
	       printf(" increase constant 'checkup' and recompile \n");
	       break;
	    case 11: 
	       printf(" number of checkpoints stored = %d exceeds"
		      " m_Snaps = %d, \n", m_Check+1,m_Snaps);
	       printf(" ensure 'm_Snaps' > 0 and increase initial 'm_Fine'\n");
	       break;
	    case 12: 
	       printf(" error occurs in numforw \n");
	       break;
	    case 13: 
	       printf(" enhancement of 'm_Fine', 'm_Snaps' checkpoints"
		      " stored, \n");
	       printf(" increase 'm_Snaps'\n");
	       break;
	    case 14: 
	       printf(" number of m_Snaps exceeds snapsup, ");
	       printf(" increase constant 'snapsup' and recompile \n");
	       break;
	    case 15: 
	       printf(" number of reps exceeds repsup, ");
	       printf(" increase constant 'repsup' and recompile \n"); 
	 }
	 
	 fflush(0);
	 abort();
      }
      
      if(m_Info > 1)
	 fflush(0);
      
   } while((whatodo != terminate) && (whatodo != error));

   if(deriv)
   {
      for(unsigned int i=0;i<m_NumberOfIndependents;i++)
	 deriv[i] = m_Derivative[i];
   

      // Print time spent reversing and recording
      std::cout << "\n\nTime spent reversing was " 
		<< m_TimeSpentInReverse << " seconds.\n";
      std::cout << "Time spent recording was " << m_TimeSpentRecording 
		<< " seconds.\n";
      std::cout << "Time spent moving forward without recording was " 
		<< timeSpentMovingForwardWithoutRecording 
		<< " seconds.\n\n" << std::endl;

//      exit(1);
   }

   // Set back to one
   currently_in_initial_timestep = 1;


   *value = m_FinalObjectiveFunctionValue;

}


   
   
void
Revolve::
SetFunction(const ObjectiveFunctionBase* func)
{


   m_NumberOfIndependents = func->GetNumberOfIndependents();

   if(m_NumberOfIndependents == 0)
   {
      std::cout << "Error: number of independents is zero." 
		<< std::endl;
      abort();
   }

   
   m_Stack.resize(m_Snaps+1);

   
   for(unsigned int i=0;i<m_Stack.size();i++)
   {
      m_Stack[i] = func->Create();

      if(m_Stack[i] == 0)
      {
	 std::cout << "Error: Create failed. (m_Stack[" << i 
		   << "] is NULL.)" << std::endl;
	 abort();
      }


   }

   ObjectiveFunctionBase::AllocateIndependents(m_NumberOfIndependents);


   m_Labels = func->GetLabels();


   std::cout << "location_cnt: " << location_cnt << std::endl;

   currently_in_initial_timestep = 1;
}
   

Revolve::
Revolve(): 
   m_Info(0),
   m_Snaps(20),
   m_Tape(0),
   m_NumberOfIndependents(0)
{
   currently_in_initial_timestep = 1;
}


Revolve::
~Revolve() 
{

   // Set back to one
   currently_in_initial_timestep = 1;

   for(unsigned int i=0;i<m_Stack.size();i++)
   {
      if(m_Stack[i])
	delete m_Stack[i];
   }
}
   

   
void 
Revolve::
TakeShot()
{
   m_Stack[m_TopOfStack]->CopyGlobalStoreToLocalStore();

   m_Stack[m_TopOfStack+1]->CopyData( m_Stack[m_TopOfStack] );
   ++m_TopOfStack;
}
   
void 
Revolve::
Restore()
{
   if(m_Fine - m_Capo == 1)
      --m_TopOfStack;
   else
      m_Stack[m_TopOfStack]->CopyData( m_Stack[m_TopOfStack-1] );

   m_Stack[m_TopOfStack]->CopyLocalStoreToGlobalStore();
}
   
void 
Revolve::
AdvanceOneStep(int t) const
{
   m_Stack[m_TopOfStack]->m_DoingForwardSweep = m_DoingForwardSweep;

   ObjectiveFunctionBase::Label label = m_Labels[t];
   label.Reset();

   m_Stack[m_TopOfStack]->Function(label);
}
   

   
void Revolve::
SetIndependents() const
{
   if(m_TopOfStack != 0)
   {
      std::cout << "Error: m_TopOfStack should be zero at first time step" 
		<< std::endl;
      abort();
   }
   
   for(unsigned int i=0;i<m_Stack[m_TopOfStack]->GetNumberOfIndependents();i++)
   {
     m_Stack[m_TopOfStack]->m_Independents[i] <<= 
	m_Stack[m_TopOfStack]->m_Independents[i].value();
   }
}


void 
Revolve::
SetDependent()
{

   m_Stack[m_TopOfStack]->m_Dependent >>= m_FinalObjectiveFunctionValue;
	
}
   
   
void 
Revolve::
YouTurn(int t)
{

   if(t > 0)
   {
      trace_on_intermediate(m_Tape, m_Keep);
   }
   else
   {
      trace_on(m_Tape, m_Keep);
      this->SetIndependents();
   }
   
   
   {
   const unsigned int start_time = clock();
   this->AdvanceOneStep(t);
   const unsigned int end_time = clock();
   const double elapsed_time = (end_time - start_time )/(double)CLOCKS_PER_SEC;
   m_TimeSpentRecording += elapsed_time;
   }

   trace_off_intermediate(0);

   
   double weight = 1.0;


   {
   const unsigned int start_time = clock();
   revolve_fos_reverse(m_Tape,1,m_NumberOfIndependents,&weight,&m_Derivative[0]);
   const unsigned int end_time = clock();
   const double elapsed_time = (end_time - start_time )/(double)CLOCKS_PER_SEC;
   m_TimeSpentInReverse += elapsed_time;
   }

}

void 
Revolve::
FirstTurn(int t)
{

   if(t > 0)
   {
      trace_on_intermediate(m_Tape, m_Keep);
   }   
   else
   {
      trace_on(m_Tape, m_Keep);
      this->SetIndependents();
   }
   
   
   {
   const unsigned int start_time = clock();
   this->AdvanceOneStep(t);
   const unsigned int end_time = clock();
   const double elapsed_time = (end_time - start_time )/(double)CLOCKS_PER_SEC;
   m_TimeSpentRecording += elapsed_time;
   }

   this->SetDependent();

   trace_off();


   double weight = 1.0;

   {
   const unsigned int start_time = clock();
   revolve_fos_reverse(m_Tape,1,m_NumberOfIndependents,&weight,&m_Derivative[0]);
   const unsigned int end_time = clock();
   const double elapsed_time = (end_time - start_time )/(double)CLOCKS_PER_SEC;
   m_TimeSpentInReverse += elapsed_time;
   }

   m_DoingForwardSweep = false;

}


}
