//*****************************************************************************
// FILE:        VpTree.hpp
//
//    Copyright (C)  2012 Kristian Damkjer.
//
// DESCRIPTION:
//>   The template implementation for vantage-point trees.
//<
//
// LIMITATIONS:
//>   This class template file is a section of the Damkjer::VpTree interface
//    definition and should not be directly included.
//<
//
// SOFTWARE HISTORY:
//> 2012-SEP-11  K. Damkjer
//               Initial Coding.
//  2013-JUL-23  K. Damkjer
//               Set typedefs to make code more readable and to allow for
//               varying container and metric types. This is useful when the
//               default double-precision is overkill. It is now easy to set
//               types to float.
//<
//*****************************************************************************

#if _OPENMP
#include <omp.h>
#endif

#include <algorithm>
#include <iostream>

#include "VpTree.h"

namespace Damkjer
{

//*****************************************************************************
// VpTree::knn(const PointT&, const IndexT&, const DistT&)
//>   Perform a k nearest neighbor search on the tree returning the indices of
//    and distances to the k nearest neighbors.
//
//    @tparam MetricT The metric search space.
//    @param  query   The query point to focus the search about.
//    @param  k       The number of neighbors to identify.
//    @param  limit   The radial reach used to bound the search.
//    @return         The search results ordered by distance to the query.
//<
//*****************************************************************************
template<typename MetricT>
typename VpTree<MetricT>::SearchResultsT
VpTree<MetricT>::knn(const PointT& query,
                     const IndexT& k,
                     const DistT& limit)
const
{
   ResultsSetT candidates;
   DistT tau = limit;

   theRoot->knn(query, k, candidates, tau);

   std::deque<IndexT> indices;
   std::deque<DistT> distances;

   while(!candidates.empty())
   {
      indices.push_back(theItems[candidates.top().theIndex].theIndex);
      distances.push_back(candidates.top().theDistance);
      candidates.pop();
   }

   std::reverse(indices.begin(), indices.end());
   std::reverse(distances.begin(), distances.end());

   return std::make_pair(indices, distances);
}

//*****************************************************************************
// VpTree::rnn(const PointT&, const DistT&)
//>   Perform a fixed radius nearest neighbor search on the tree returning the
//    indices of and distances to the neighbors in the fixed radius.
//
//    @tparam MetricT The metric search space.
//    @param  query   The query point to focus the search about.
//    @param  range   The radial distance used to bound the search.
//    @return         The search results ordered by distance to the query.
//<
//*****************************************************************************
template<typename MetricT>
typename VpTree<MetricT>::SearchResultsT
VpTree<MetricT>::rnn(const PointT& query,
                     const DistT& range)
const
{
   ResultsSetT candidates;

   theRoot->rnn(query, range, candidates);

   std::deque<IndexT> indices;
   std::deque<DistT> distances;

   while(!candidates.empty())
   {
      indices.push_back(theItems[candidates.top().theIndex].theIndex);
      distances.push_back(candidates.top().theDistance);
      candidates.pop();
   }

   std::reverse(indices.begin(), indices.end());
   std::reverse(distances.begin(), distances.end());

   return std::make_pair(indices, distances);
}

//*****************************************************************************
// VpTree::makeTree(const IndexT&, const IndexT&)
//>   Set this tree's root to be the root of the tree created from the argument
//    set of metric-space elements.
//
//    @tparam MetricT The metric search space.
//    @param  start   The beginning of the element range to transform into a
//                    binary tree.
//    @param  stop    The end of the element range to transform into a binary
//                    tree.
//    @return         The root of the constructed tree.
//<
//*****************************************************************************
template<typename MetricT>
typename VpTree<MetricT>::Node*
VpTree<MetricT>::makeTree(const IndexT& start,
                          const IndexT& stop)
{
   if (stop <= start) return 0;

   IndexT setSize = stop - start;
   
   if (setSize <= theLeafCapacity)
   {
      return new Leaf(this, start, stop);
   }
 
   Branch* node=new Branch(this);
   node->theIndex=start;

   const IndexT vp = selectVp(start, stop);
   std::swap(theItems[start], theItems[vp]);
   
   // Identify bound elements
   IndexT outerLowerBound = (start + stop + 1)/2;
   IndexT innerLowerBound = start + 1;
   
   // Update histories
   DistT d_max=0;
   DistT d_min=std::numeric_limits<DistT>::max();

   for (IndexT elem = stop; elem --> innerLowerBound; )
   {
      theItems[elem].theDistance = theMetric(theItems[start].theElement,
                                             theItems[elem].theElement);
      d_max = (theItems[elem].theDistance > d_max)
              ? theItems[elem].theDistance : d_max;
      d_min = (theItems[elem].theDistance < d_min)
              ? theItems[elem].theDistance : d_min;
   }
   
   // Put the median element in place
   std::nth_element(theItems.begin() + static_cast<DiffT>(innerLowerBound),
                    theItems.begin() + static_cast<DiffT>(outerLowerBound),
                    theItems.begin() + static_cast<DiffT>(stop));

   DistT d_mid=0;

   for (IndexT elem = outerLowerBound; elem --> innerLowerBound; )
   {
      d_mid = (theItems[elem].theDistance > d_mid)
              ? theItems[elem].theDistance : d_mid;
   }
   
   node->theInnerLowerBound=d_min;
   node->theInnerUpperBound=d_mid;
   node->theOuterLowerBound=theItems[outerLowerBound].theDistance;
   node->theOuterUpperBound=d_max;
   
   #if _OPENMP
   static int threads = 0;
   
   if (threads < (omp_get_num_procs()-1))
   {
      if (omp_get_num_threads() == 1)
      {
         omp_set_dynamic(1);
         omp_set_num_threads(omp_get_num_procs());
      }

      if (!omp_get_nested())
      {
         omp_set_nested(1);
      }
      
      #pragma omp atomic
      ++threads;
      
      #pragma omp parallel num_threads(2)
      {         
         #pragma omp sections nowait
         {
            node->theInnerBranch = makeTree(innerLowerBound,
                                            outerLowerBound);
              
            #pragma omp section
            node->theOuterBranch = makeTree(outerLowerBound, stop);
         }
      }
      
      #pragma omp atomic
      --threads;
   }
   else
   #endif
   {
      node->theInnerBranch = makeTree(innerLowerBound, outerLowerBound);
      node->theOuterBranch = makeTree(outerLowerBound, stop);
   }
   
   return node;
}

//*****************************************************************************
// VpTree::randomSample(const IndexT&, const IndexT&)
//>   Select a random sample in the range between the provided indices.
//
//    @tparam MetricT The metric search space.
//    @param  start   The beginning of the element range to transform into a
//                    binary tree.
//    @param  stop    The end of the element range to transform into a binary
//                    tree.
//    @return         A set of randomly sampled elements from the provided
//                    range.
//<
//*****************************************************************************
template<typename MetricT>
std::set<typename VpTree<MetricT>::IndexT>
VpTree<MetricT>::randomSample(const IndexT& start,
                              const IndexT& stop)
const
{
   //***
   // Sampling the sqrt of inputs, while thorough, is completely unnecessary.
   // Leaving the note here for future reference.
   //***
   // IndexT numSamps=(IndexT)(ceil(sqrt((double)(stop - start))));

   // A very small sample set of the population is sufficient
   IndexT numSamps=(stop-start > 5) ? 5 : (stop - start);

   //***
   // If the range is smaller than the number of samples, just return the
   // elements in the range.
   //***
   if ((stop - start) <= numSamps)
   {
      std::vector<IndexT> indices(stop-start, 0);

      for (std::size_t i=start; i < stop; ++i)
      {
         indices[i-start]=i;
      }
        
      return std::set<IndexT>(indices.begin(), indices.end());
   }
    
   std::set<IndexT> samples;
    
   //***
   // If the range is close to the number of samples, select with better
   // worst-case behvior
   //***
   if ((stop - start) < numSamps*2)
   {
      IndexT itemsNeeded=numSamps;
        
      for (IndexT i = start; samples.size() < numSamps && i < stop; ++i)
      {
         if ((rand()/(RAND_MAX + 1.0)) < itemsNeeded/(stop-i))
         {
            samples.insert(i);
            --itemsNeeded;
         }
      }
   }
   else
   {
      //***
      // Otherwise, if range dominates samples, select expecting to find
      // unique samples
      //***
      while (samples.size() < numSamps)
      {
         // Choose an arbitrary point
         IndexT val=static_cast<IndexT>(rand() / (RAND_MAX + 1.0) *
                                        (stop - start) + start);

         samples.insert(val);
      }
   }
   
   return samples;
}
                                
//*****************************************************************************
// VpTree::selectVp(const IndexT&, const IndexT&)
//>   Select a vantage point in the range between the provided indices.
//
//    @tparam MetricT The metric search space.
//    @param  start   The beginning of the element range to transform into a
//                    binary tree.
//    @param  stop    The end of the element range to transform into a binary
//                    tree.
//    @return         Select a vantage point from the range of elements.
//
//    @todo Implement a median of 5's to impose stability into the selection.
//<
//*****************************************************************************
template<typename MetricT>
typename VpTree<MetricT>::IndexT
VpTree<MetricT>::selectVp(const IndexT& start,
                          const IndexT& stop)
const
{
   //***
   // Choosing a vantage point that maximizes the balance of the sub-trees is
   // theoretically advantageous. This involves selecting a vantage point in a 
   // "corner" of the population. However, in practice, the trade for increased
   // selection time dwarfs the benefit realized through search times. The
   // incredibly simple approach of selecting a random member of the population
   // is much simpler and yields almost identical search times in the data sets
   // tested.
   //
   // The "intellegent" selection mode is included in the source code for
   // reference, but dropped by the pre-processor.
   //***

#if USE_OBSOLETED_CODE
   // Choose a point from a small sample set that maximizes spread 
   std::set<IndexT> p=randomSample(start, stop);
    
   IndexT bestP=*(p.begin());
   DistT bestSpread=0;
    
   for (std::set<IndexT>::const_iterator pItr=p.begin();
        pItr != p.end();
        ++pItr)
   {
      const VpTree<PointT, MetricT>::Item& pItem = theItems[*pItr];

      std::set<IndexT> d=randomSample(start, stop);
      
      std::vector<VpTree<PointT, MetricT>::Item> dItems(d.size());
      
      IndexT i = 0;
      
      for (std::set<IndexT>::const_iterator dItr=d.begin();
           dItr != d.end();
           ++dItr)
      {
          dItems[i]=theItems[*dItr];
          ++i;
      }

      std::nth_element(dItems.begin(),
                       dItems.begin() + dItems.size()/2,
                       dItems.end());
      
      DistT mu = theMetric(pItem.theElement,
                           dItems[dItems.size()/2].theElement);

      IndexT k=1;
      DistT x, oldM, newM;

      x=oldM=newM=theMetric(pItem.theElement,
                            dItems[0].theElement)-mu;
      
      DistT oldS, newS;
      oldS=newS=0;
      
      for (IndexT i = 1; i < dItems.size(); ++i)
      {
         x=theMetric(pItem.theElement,
                     dItems[i].theElement)-mu;

         ++k;
         newM=oldM+(x-oldM)/k;
         newS=oldS+(x-oldM)*(x-newM);
         oldM=newM;
         oldS=newS;
      }
      
      DistT spread=static_cast<DistT>((k>1)?newS/(k-1):0.0);
      
      if (spread > bestSpread)
      {
          bestSpread=spread;

          bestP=*pItr;
      }
   }
   
   return bestP;
#else
   // Simplest working case: just choose an arbitrary point
   return static_cast<IndexT>(rand()/(RAND_MAX + 1.0)*(stop-start)+start);
#endif
}

}
