//*****************************************************************************
// FILE:        fastcov.cpp
//
//    Copyright (C)  2012 Kristian Damkjer.
//
// DESCRIPTION:
//>   @todo Describe this file
//<
//
// LIMITATIONS:
//>   Does not work for cell-arrays of complex matrices.
//<
//
// SOFTWARE HISTORY:
//> 2012-SEP-11  K. Damkjer
//               Initial Coding.
//<
//*****************************************************************************

#ifdef _OPENMP
#include <omp.h>
#endif

#include <vector>
#include <sstream>

//***
// Fix "wide char" definition for older versions of MATLAB. This must be placed
// after other includes and before the mex.h include.
//***
#if (defined(MATLAB_MAJOR) && defined(MATLAB_MINOR))
   #if MATLAB_MAJOR <= 7 && MATLAB_MINOR <= 10 && defined(_CHAR16T)
      #define CHAR16_T
   #endif
#endif

#include "mex.h"

//*****************************************************************************
// FUNCTION: mexFunction
//>   The MATLAB Executable Gateway Function.
//
//    @todo Describe this MEX function
//
//    @param nlhs the number of left-hand side parameters.
//    @param plhs the array of left-hand side parameters.
//    @param nrhs the number of right-hand side parameters.
//    @param prhs the array of right-hand side parameters.
//<
//*****************************************************************************
void mexFunction(
        int nlhs, mxArray* plhs[],
        int nrhs, const mxArray* prhs[])
{
   if (nrhs != 1 || !mxIsCell(prhs[0]))
   {
      mexErrMsgIdAndTxt("Damkjer:fastcov:varargin",
                        "Missing or invalid input argument.");
   }

   if (nlhs > 3)
   {
      mexErrMsgIdAndTxt("Damkjer:fastcov:varargout",
                        "Too many output arguments.");
   }
   
   mwSize cells = mxGetNumberOfElements(prhs[0]);
//   mwSize ctrs  = mxGetNumberOfElements(prhs[1]);

//   if (useCenters && cells != ctrs)
//   {
//      mexErrMsgIdAndTxt("Damkjer:fastcov:varargout",
//                        "Size of neighborhoods and centers does not agree.");
//   }
   
   plhs[0] = mxCreateCellMatrix(cells, 1);

   // Better way?
   if (nlhs > 1)
   {
      plhs[1] = mxCreateCellMatrix(cells, 1);
   }
   
   // Better way?
   if (nlhs > 2)
   {
      plhs[2] = mxCreateCellMatrix(cells, 1);
   }
   
   std::vector<const double*> vals(cells,0);
   std::vector<mwSize> Ms(cells,0);
   std::vector<mwSize> Ns(cells,0);

   std::vector<mxArray*> covs(cells,0);
   std::vector<double*> cov_vals(cells,0);

   //HACK: Probably a more efficient way to do this
   std::vector<mxArray*> dist2mean(cells,0);
   std::vector<double*> dist2mean_vals(cells,0);
   
   //HACK: Probably a more efficient way to do this
   std::vector<mxArray*> inty(cells,0);
   std::vector<double*> inty_vals(cells,0);
   
   // Note for future: Ms - points, Ns - dimensions
   for (int cell = 0; cell < cells; ++cell)
   {
       vals[cell]=mxGetPr(mxGetCell(prhs[0], cell));
       Ms[cell]=mxGetM(mxGetCell(prhs[0], cell));
       Ns[cell]=mxGetN(mxGetCell(prhs[0], cell));

       // We will be setting each value, so don't bother to initialize to zero.
       covs[cell] = mxCreateDoubleMatrix(0, 0, mxREAL);
       mxSetM(covs[cell], Ns[cell]);
       mxSetN(covs[cell], Ns[cell]);
       mxSetData(covs[cell], mxMalloc(sizeof(double)*Ns[cell]*Ns[cell]));
       cov_vals[cell] = mxGetPr(covs[cell]);
       
       //HACK: Probably a more efficient way to do this
       dist2mean[cell] = mxCreateDoubleMatrix(1, 1, mxREAL);
       dist2mean_vals[cell] = mxGetPr(dist2mean[cell]);       
       
       //HACK: Probably a more efficient way to do this
       inty[cell] = mxCreateDoubleMatrix(1, 1, mxREAL);
       inty_vals[cell] = mxGetPr(inty[cell]);       
   }

#ifdef _OPENMP
   omp_set_dynamic(1);
   omp_set_num_threads(omp_get_num_procs());
#endif

   #pragma omp parallel for schedule(guided)
   for (int cellp = 0; cellp < cells; ++cellp)
   {
      std::vector<double> mean(Ns[cellp], 0.);
      std::vector<double> skewmean(Ns[cellp], 0.);
      
      // Comment out to remove mean calculation when centering to first point.
      double w1 = 1./Ms[cellp];
      
      for (mwSize n = Ns[cellp]; n --> 0;)
      {
         for (mwSize m = Ms[cellp]; m --> 1;)
         {
            mean[n] += vals[cellp][m + Ms[cellp] * n] * w1;
         }
         
         skewmean[n] = mean[n] / ((Ms[cellp] - 1) * w1);
         
         mean[n] += vals[cellp][Ms[cellp] * n] * w1;
      }
      // End comment.
      
      double w2 = 1./(Ms[cellp]-1);
      
      for (mwSize n1 = Ns[cellp]; n1 --> 0;)
      {
//         for (mwSize n2 = Ns[cellp]; n2 --> 0;)
         for (mwSize n2 = Ns[cellp]; n2 --> n1;)
         {
            cov_vals[cellp][n2 + Ns[cellp] * n1] = 0;

            for (mwSize mc = Ms[cellp]; mc --> 0;)
            {
               // Center to mean.
               cov_vals[cellp][n2 + Ns[cellp] * n1] +=
                        w2
                        * (vals[cellp][mc + Ms[cellp] * n1]-mean[n1])
                        * (vals[cellp][mc + Ms[cellp] * n2]-mean[n2]);

               // Center to first point.
//               cov_vals[cellp][n2 + Ns[cellp] * n1] +=
//                       w2
//                       * (vals[cellp][mc + Ms[cellp] * n1]-vals[cellp][Ms[cellp] * n1])
//                       * (vals[cellp][mc + Ms[cellp] * n2]-vals[cellp][Ms[cellp] * n2]);
               // End comment.
            }

            cov_vals[cellp][n1 + Ns[cellp] * n2] =
                                           cov_vals[cellp][n2 + Ns[cellp] * n1];
         }
      }
      
      if (nlhs > 1)
      {
         // Experiment: Estimate distance from first item to mean
         
         //std::vector<double> diff_loc(Ns[cellp], 0.);
         //std::vector<double> diff_ext(Ns[cellp], 0.);

         double diff_loc = 0;
         double diff_ext = 0;
         double temp;
         
         for (mwSize n = Ns[cellp]; n --> 0;)
         {
            // We can't assume anything about point ordering
            temp = vals[cellp][Ms[cellp] * n] - skewmean[n];
            diff_loc += temp * temp;
            
            temp = vals[cellp][Ms[cellp] * n] -
                   vals[cellp][Ms[cellp] * (n + 1) - 1];
            diff_ext +=  temp * temp;
         }
         
//         dist2mean_vals[cellp][0] = std::sqrt(diff_loc)/std::sqrt(diff_ext);
         dist2mean_vals[cellp][0] = std::sqrt(diff_loc);
         // End comment.

         if (nlhs > 2)
         {
            // Experiment: Estimate intensity, i.e., how "dense" is our mass?
            // Assumes last point is farthest from first.
            inty_vals[cellp][0] = Ms[cellp]/(4./3.*M_PI*diff_ext*sqrt(diff_ext));
            // End comment.
         }
      }
   }

   for (int cell = 0; cell < cells; ++cell)
   {
      mxSetCell(plhs[0], cell, covs[cell]);
   }

   if (nlhs > 1)
   {
      for (int cell = 0; cell < cells; ++cell)
      {
         mxSetCell(plhs[1], cell, dist2mean[cell]);
      }
   }

   if (nlhs > 2)
   {
      for (int cell = 0; cell < cells; ++cell)
      {
         mxSetCell(plhs[2], cell, inty[cell]);
      }
   }
}
