//=========================================================================
// FILE:        kannVpTree.cpp
//
//    Copyright (C)  2012 Kristian Damkjer.
//
// DESCRIPTION: This MEX source file provides an implementation of the
//              k all nearest neighbors search algorithm given a set of
//              query points, a number of neighbors, and a VpTree
//              representing the point data base.
//
// LIMITATIONS: This function has the potential to consume a large amount
//              of memory. While I have confirmed that there are no memory
//              leaks in the code to the best of my current ability, I have
//              induced predictable crashes in MATLAB in the related
//              MEX function, frannVpTree, by finding all nearest neighbors
//              for a 4M point data base. The crash did not come from this
//              code, but rather from Handle Graphics when attempting to
//              render a waitbar. As a result, it is recommended that
//              points be streamed or blocked together into more manageable
//              chunks.
//
// SOFTWARE HISTORY:
//> 2012-SEP-11  K. Damkjer
//               Initial Coding.
//  2013-FEB-04  K. Damkjer
//               Restructure parallel sections to see if any performance
//               gain is realized.
//<
//=========================================================================

#include <limits>

#ifdef _OPENMP
#include <omp.h>
#endif

#include "Util/MATLAB/ClassHandle.h"
#include "VpTree.h"

#ifdef _CHAR16T
#define CHAR16_T
#endif

#include "mex.h"

void mexFunction(
        int nlhs, mxArray* plhs[],
        int nrhs, const mxArray* prhs[])
{
   // check the arguments
   if (nrhs != 4 || !mxIsNumeric(prhs[0]))
   {
      mexErrMsgIdAndTxt("Damkjer:kannVpTree:varargin",
                        "Invalid number of arguments");
   }

   // retrieve the tree pointer
   const VpTree<>& tree = matAsObj<VpTree<> >(prhs[0]);

   const mxArray* points=prhs[1];
    
   // Check to make sure that points are real numerics
   if (mxIsSparse(points) ||
       mxGetNumberOfDimensions(points) != 2 ||
       mxIsComplex(points) ||
       !mxIsNumeric(points))
   {
      mexErrMsgIdAndTxt("Damkjer:kannVpTree:prhs",
                        "Point input to kannVpTree must be a full, 2-D matrix representing ND queries.");
   }
    
   const mwSize dims = mxGetM(points);
   const mwSize elems = mxGetN(points);
    
   double* data = mxGetPr(points);
   std::deque<std::vector<double> > pointData(elems, std::vector<double>(dims));
    
   for (mwIndex elem = elems; elem --> 0;)
   {
      for (mwIndex dim = dims; dim --> 0;)
      {
         pointData[elem][dim]=data[elem*dims+dim];
      }
   }

   mwSize k = 1;
   
   const mxArray* kData=prhs[2];

   if (mxIsSparse(kData) ||
       mxGetNumberOfElements(kData) != 1 ||
       mxIsComplex(kData) ||
       !mxIsNumeric(kData))
   {
       mexErrMsgIdAndTxt("Damkjer:kannVpTree:prhs",
                "First arguement to kannVpTree must be a real valued scalar.");
   }
       
   k = (mwSize)(*((double*)mxGetData(kData)));

   double radius = std::numeric_limits<double>::max();
   
   const mxArray* rData=prhs[3];

   if (mxIsSparse(rData) ||
       mxGetNumberOfElements(rData) != 1 ||
       mxIsComplex(rData) ||
       !mxIsNumeric(rData))
   {
       mexErrMsgIdAndTxt("Damkjer:kannVpTree:prhs",
                "Second arguement to kannVpTree must be a real valued scalar.");
   }

   radius = (*(double*)mxGetData(rData));

   plhs[0] = mxCreateCellMatrix(elems, 1);

   if (nlhs==2)
   {
      plhs[1] = mxCreateCellMatrix(elems, 1);
   }
   
#ifdef _OPENMP
   omp_set_dynamic(1);
   omp_set_num_threads(omp_get_num_procs());
#endif

   std::vector<std::pair<std::deque<mwIndex>, std::deque<double> > > results(pointData.size());
   
   #pragma omp parallel for
   for (int p = 0; p < pointData.size(); ++p)
   {
      std::vector<double> q=pointData[p];

//      std::pair<std::deque<mwIndex>, std::deque<double> > results;

      results[p] = tree.knn(q, k, radius);
   }
   
   #pragma omp parallel for
   for (int p = 0; p < pointData.size(); ++p)
   {  
      mwSize neighbors = results[p].first.size();

	   mxArray* neighbor_idxs = 0;
	   mwIndex* idxs = 0;
	   mxArray* neighbor_dists = 0;
	   double* dists = 0;

      #pragma omp critical //(VPSB_KNN_CREATE_IDXS)
      {
         neighbor_idxs = mxCreateNumericMatrix(0, 0, mxINDEX_CLASS, mxREAL);
         mxSetM(neighbor_idxs, neighbors);
         mxSetN(neighbor_idxs, 1);
         mxSetData(neighbor_idxs, mxMalloc(sizeof(mwIndex)*neighbors*1));

         idxs = (mwIndex*)mxGetData(neighbor_idxs);
	   }

      for (mwIndex idx = neighbors; idx --> 0;)
      {
         idxs[idx]=results[p].first[idx]+1;
      }

      #pragma omp critical //(VPSB_KNN_SET_CELL_IDX)
	   {
         mxSetCell(plhs[0], p, neighbor_idxs);
      }

      if (nlhs==2)
      {
         #pragma omp critical //(VPSB_KNN_CREATE_DISTS)
         {
            neighbor_dists = mxCreateDoubleMatrix(0, 0, mxREAL);
            mxSetM(neighbor_dists, neighbors);
            mxSetN(neighbor_dists, 1);
            mxSetData(neighbor_dists, mxMalloc(sizeof(double)*neighbors*1));

            dists = mxGetPr(neighbor_dists);
		   }

         for (mwIndex idx = neighbors; idx --> 0;)
         {
            dists[idx]=results[p].second[idx];
         }

         #pragma omp critical //(VPSB_KNN_SET_CELL_DISTS)
		   {
            mxSetCell(plhs[1], p, neighbor_dists);
         }
      }
   }
}
