//*****************************************************************************
// FILE:        VpTree.cpp
//
//    Copyright (C)  2012 Kristian Damkjer.
//
// DESCRIPTION: This class is an implementation of the vantage point tree
//              data structure described by Peter Yianilos in "Data
//              Structures and Algorithms for Nearest Neighbor Search in
//              General Metric Spaces".
//
// LIMITATIONS: See VpTree.h for full list of limitations.
//
// SOFTWARE HISTORY:
//> 2012-SEP-11  K. Damkjer
//               Initial Coding.
//<
//*****************************************************************************

#ifdef _OPENMP
#include <omp.h>
#endif

#include "VpTree.h"

#include <algorithm>
#include <iostream>

namespace Damkjer
{

//*****************************************************************************
// CLASS: VpTree::Node
//*****************************************************************************
template<typename T,
         double (*DISTANCE)(const T&, const T&)>
class VpTree<T, DISTANCE>::Node
{
public:
   Node();
      //> The default constructor creates a null node.
      //<
    
   Node(VpTree*);
      //> Construct with a database reference.
      //<
    
   virtual ~Node();
      //> The default destructor. Virtual to ensure proper Node
      //  destruction.
      //<
    
   virtual void knn(const T&,
                    const std::size_t&,
                    std::priority_queue<ResultsCandidate>&,
                    double&) const = 0;
      //> The visitor that accumulates k nearest neighbor results.
      //<

   virtual void rnn(const T&,
                    const double,
                    std::priority_queue<ResultsCandidate>&) const = 0;
      //> The visitor that accumulates fixed radius nearest neighbor
      //  results.
      //<

protected:
   VpTree<T, DISTANCE>* theTree;
      //>
      //<
};

//*****************************************************************************
// CLASS: VpTree::Internal
//*****************************************************************************
template<typename T,
         double (*DISTANCE)(const T&, const T&)>
class VpTree<T, DISTANCE>::Internal : public VpTree<T, DISTANCE>::Node
{
public:
   Internal();
      //> The default constructor creates a null internal node.
      //<
    
   Internal(VpTree*);
      //> Construct with a database reference.
      //<
    
   virtual ~Internal();
      //> The default destructor. Virtual to ensure proper Node
      //  destruction.
      //<
    
   std::size_t theIndex;
      //> The index of the item in the internal database.
      //<

   double theInnerLowerBound;
      //> The lower bound distance to elements on inner branch.
      //<

   double theInnerUpperBound;
      //> The uuper bound distance to elements on inner branch.
      //<

   double theOuterLowerBound;
      //> The lower bound distance to elements on outer branch.
      //<

   double theOuterUpperBound;
      //> The upper bound distance to elements on outer branch.
      //<

   Node* theInnerBranch;
      //> The inner branch partition containing elements closer than
      //  theMedianPartition to this element.
      //<

   Node* theOuterBranch;
      //> The outer branch partition containing elements at or beyond
      //  theMedianPartition from this element.
      //<

   virtual void knn(const T&,
                    const std::size_t&,
                    std::priority_queue<ResultsCandidate>&,
                    double&) const;
      //> The visitor that accumulates k nearest neighbor results.
      //<

   virtual void rnn(const T&,
                    const double,
                    std::priority_queue<ResultsCandidate>&) const;
      //> The visitor that accumulates fixed radius nearest neighbor
      //  results.
      //<
};

//*****************************************************************************
// CLASS: VpTree::Leaf
//*****************************************************************************
template<typename T,
         double (*DISTANCE)(const T&, const T&)>
class VpTree<T, DISTANCE>::Leaf : public VpTree<T, DISTANCE>::Node
{
public:
   Leaf();
      //> The default constructor creates a null leaf.
      //<
    
   Leaf(VpTree*, const std::size_t&, const std::size_t&);
      //> Construct a leaf representing the items in the provided index
      //  range.
      //<
    
   virtual ~Leaf(){}
      //> The default destructor. Virtual to ensure proper Node
      //  destruction.
      //<
    
   virtual void knn(const T&,
                    const std::size_t&,
                    std::priority_queue<ResultsCandidate>&,
                    double&) const;
      //> The visitor that accumulates k nearest neighbor results.
      //<

   virtual void rnn(const T&,
                    const double,
                    std::priority_queue<ResultsCandidate>&) const;
      //> The visitor that accumulates fixed radius nearest neighbor
      //  results.
      //<

private:
   std::size_t theHead;
      //> The start of the index range prepresented by this Leaf.
      //<

   std::size_t theTail;
      //> The end of the index range prepresented by this Leaf.
      //<
};

//*****************************************************************************
// CLASS: VpTree::Item
//*****************************************************************************
template<typename T,
         double (*DISTANCE)(const T&, const T&)>
class VpTree<T, DISTANCE>::Item
{
public:
   Item();
      //> The default constructor creates a null node.
      //<
    
   ~Item(){}
      //> The default destructor. Intentionally non-virtual since Node is
      //  a private inner class on VpTree.
      //<
    
   bool operator< (const Item&) const;
      //>
      //<

   std::size_t theIndex;
      //> The index of the item in the input data set.
      //<

   T theElement;
      //> The database object
      //<
   
   double theDistance;
      //> The most recent ancestral pivot history distance for this item.
      //<
};

//*****************************************************************************
// VpTree::ResultsCandidate
//*****************************************************************************
template<typename T,
         double (*DISTANCE)(const T&, const T&)>
class VpTree<T, DISTANCE>::ResultsCandidate
{
public:
   ResultsCandidate(const std::size_t&, const double);
      //> Create a metric comparitor.
      //<
    
   ~ResultsCandidate(){}
      //> The default destructor. Intentionally non-virtual since
      //  ResultsCandidate is a private inner class on VpTree.
      //<

   bool operator< (const ResultsCandidate&) const;
      //> Compare result candidate distances to determine which is closer
      //  to the query.
      //<

   std::size_t theIndex;
      //> The index of the candidate in the internal database.
      //<
    
   double theDistance;
      //> The candidate's distance to the query point.
      //<
};

//*****************************************************************************
// VpTree::VpTree
//*****************************************************************************
template<typename T,
         double (*DISTANCE)(const T&, const T&)>
inline VpTree<T, DISTANCE>::VpTree()
   : theRoot(0)
   , theItems()
   , theLeafCapacity(1)
{
}

//*****************************************************************************
// VpTree::VpTree
//*****************************************************************************
template<typename T,
         double (*DISTANCE)(const T&, const T&)>
template<typename CONTAINER>
inline VpTree<T, DISTANCE>::VpTree(const CONTAINER& elems,
                                   const std::size_t& leafCapacity)
   : theRoot(0)
   , theItems(elems.size())
   , theLeafCapacity((leafCapacity<1)?1:leafCapacity)
{
   #ifdef _OPENMP
   omp_set_dynamic(1);
   omp_set_num_threads(omp_get_num_procs());
   #endif

   #pragma omp parallel for
   for (int i = 0; i < theItems.size(); ++i)
   {
       theItems[i].theIndex=i;
       theItems[i].theElement=elems[i];
   }

   theRoot = makeTree(0, theItems.size());
}
                
//*****************************************************************************
// VpTree::~VpTree
//*****************************************************************************
template<typename T,
         double (*DISTANCE)(const T&, const T&)>
inline VpTree<T, DISTANCE>::~VpTree()
{
   delete theRoot;
   theRoot = 0;
}

//*****************************************************************************
// VpTree::Node::Node
//*****************************************************************************
template<typename T,
         double (*DISTANCE)(const T&, const T&)>
inline VpTree<T, DISTANCE>::Node::Node()
   : theTree(0)
{
}

//*****************************************************************************
// VpTree::Node::Node
//*****************************************************************************
template<typename T,
         double (*DISTANCE)(const T&, const T&)>
inline VpTree<T, DISTANCE>::Node::Node(VpTree* tree)
   : theTree(tree)
{
}

//*****************************************************************************
// VpTree::Node::~Node
//*****************************************************************************
template<typename T,
         double (*DISTANCE)(const T&, const T&)>
inline VpTree<T, DISTANCE>::Node::~Node()
{
   theTree = 0;
}

//*****************************************************************************
// VpTree::Internal::Internal
//*****************************************************************************
template<typename T,
         double (*DISTANCE)(const T&, const T&)>
inline VpTree<T, DISTANCE>::Internal::Internal()
   : Node()
   , theIndex(0)
   , theInnerLowerBound(0)
   , theInnerUpperBound(0)
   , theOuterLowerBound(0)
   , theOuterUpperBound(0)
   , theInnerBranch(0)
   , theOuterBranch(0)
{
}

//*****************************************************************************
// VpTree::Internal::Internal
//*****************************************************************************
template<typename T,
         double (*DISTANCE)(const T&, const T&)>
inline VpTree<T, DISTANCE>::Internal::Internal(VpTree<T, DISTANCE>* tree)
   : Node(tree)
   , theIndex(0)
   , theInnerLowerBound(0)
   , theInnerUpperBound(0)
   , theOuterLowerBound(0)
   , theOuterUpperBound(0)
   , theInnerBranch(0)
   , theOuterBranch(0)
{
}

//*****************************************************************************
// VpTree::Internal::~Internal
//*****************************************************************************
template<typename T,
         double (*DISTANCE)(const T&, const T&)>
inline VpTree<T, DISTANCE>::Internal::~Internal()
{
   delete theInnerBranch;
   delete theOuterBranch;
   theInnerBranch = 0;
   theOuterBranch = 0;
}

//*****************************************************************************
// VpTree::Leaf::Leaf
//*****************************************************************************
template<typename T,
         double (*DISTANCE)(const T&, const T&)>
inline VpTree<T, DISTANCE>::Leaf::Leaf()
   : Node()
   , theHead(0)
   , theTail(0)
{
}

//*****************************************************************************
// VpTree::Leaf::Leaf
//*****************************************************************************
template<typename T,
         double (*DISTANCE)(const T&, const T&)>
inline VpTree<T, DISTANCE>::Leaf::Leaf(VpTree* tree,
                                         const std::size_t& head,
                                         const std::size_t& tail)
   : Node(tree)
   , theHead(head)
   , theTail(tail)
{
}

//*****************************************************************************
// VpTree::Item::Item
//*****************************************************************************
template<typename T,
         double (*DISTANCE)(const T&, const T&)>
inline VpTree<T, DISTANCE>::Item::Item()
   : theIndex(0)
   , theElement(0)
   , theDistance(0)
{
}

//*****************************************************************************
// VpTree::Item::operator<
//*****************************************************************************
template<typename T,
         double (*DISTANCE)(const T&, const T&)>
inline bool VpTree<T, DISTANCE>::Item::operator<(const Item& other) const
{
   return theDistance < other.theDistance;
}

//*****************************************************************************
// VpTree::ResultsCandidate::ResultsCandidate
//*****************************************************************************
template<typename T,
         double (*DISTANCE)(const T&, const T&)>
inline VpTree<T, DISTANCE>::ResultsCandidate::ResultsCandidate(
                                                  const std::size_t& index,
                                                  const double distance)
   : theIndex(index)
   , theDistance(distance)
{
}

//*****************************************************************************
// VpTree::ResultsCandidate::operator<
//*****************************************************************************
template<typename T,
         double (*DISTANCE)(const T&, const T&)>
inline bool VpTree<T, DISTANCE>::ResultsCandidate::operator<(
                                       const ResultsCandidate& other) const
{
   return theDistance < other.theDistance;
}

//*****************************************************************************
// VpTree::randomSample
//*****************************************************************************
template<typename T,
         double (*DISTANCE)(const T&, const T&)>
std::set<std::size_t> VpTree<T, DISTANCE>::randomSample(
                                             const std::size_t& start,
                                             const std::size_t& stop) const
{
//   std::size_t numSamps=(std::size_t)(ceil(sqrt((double)(stop - start))));
   std::size_t numSamps=(stop-start > 5) ? 5 : (stop - start);

   // If the range is smaller than the number of samples, just return the
   // elements in the range.
   if ((stop - start) <= numSamps)
   {
      std::vector<std::size_t> indices(stop-start, 0);

      for (std::size_t i=start; i < stop; ++i)
      {
         indices[i-start]=i;
      }
        
      return std::set<std::size_t>(indices.begin(), indices.end());
   }
    
   std::set<std::size_t> samples;
    
   // If the range is close to the number of samples, select with better
   // worst-case behvior
   if ((stop - start) < numSamps*2)
   {
      std::size_t itemsNeeded=numSamps;
        
      for (std::size_t i = start;
           samples.size() < numSamps && i < stop;
           ++i)
      {
         if ((rand()/(RAND_MAX + 1.0)) < itemsNeeded/(stop-i))
         {
            samples.insert(i);
            --itemsNeeded;
         }
      }
   }
   else
   {
      // Otherwise, if range dominates samples, select expecting to find
      // unique samples
      while (samples.size() < numSamps)
      {
         // Choose an arbitrary point
         std::size_t val=(std::size_t)
                         (rand()/(RAND_MAX + 1.0)*(stop-start)+start);
            
         samples.insert(val);
      }
   }
   
   return samples;
}
                                
//*****************************************************************************
// VpTree::selectVp
//*****************************************************************************
template<typename T,
         double (*DISTANCE)(const T&, const T&)>
std::size_t VpTree<T, DISTANCE>::selectVp(const std::size_t& start,
                                            const std::size_t& stop) const
{
#if 0
    std::set<std::size_t> p=randomSample(start, stop);
    
   std::size_t bestP=*(p.begin());
   double bestSpread=0;
    
   for (std::set<std::size_t>::const_iterator pItr=p.begin();
        pItr != p.end();
        ++pItr)
   {
      const T& pItem = theItems[*pItr].theElement;
      
      std::set<std::size_t> d=randomSample(start, stop);
      
      std::vector<std::pair<T,std::size_t> > dItems(d.size());
      unsigned int i = 0;
      
      for (std::set<std::size_t>::const_iterator dItr=d.begin();
           dItr != d.end();
           ++dItr)
      {
          dItems[i]=theItems[*dItr];
          ++i;
      }

      std::nth_element(
              dItems.begin(),
              dItems.begin() + dItems.size()/2,
              dItems.end(),
              MetricComparison(pItem));
      
      double mu = DISTANCE(pItem, dItems[dItems.size()/2].theElement);
      
      unsigned int k=1;
      double x, oldM, newM;
      x=oldM=newM=DISTANCE(pItem, dItems[0].theElement)-mu;
      double oldS, newS;
      oldS=newS=0;
      
      for (unsigned int i = 1; i < dItems.size(); ++i)
      {
         x=DISTANCE(pItem, dItems[i].theElement)-mu;
         ++k;
         newM=oldM+(x-oldM)/k;
         newS=oldS+(x-oldM)*(x-newM);
         oldM=newM;
         oldS=newS;
      }
      
      double spread=(k>1)?newS/(k-1):0.0;
      
      if (spread > bestSpread)
      {
          bestSpread=spread;
          
          bestP=*pItr;
      }
   }
   
   return bestP;
#endif
   // Choose an arbitrary point
   return (std::size_t)(rand()/(RAND_MAX + 1.0)*(stop-start)+start);
}

//*****************************************************************************
// VpTree::makeTree
//*****************************************************************************
template<typename T,
         double (*DISTANCE)(const T&, const T&)>
typename VpTree<T, DISTANCE>::Node*
VpTree<T, DISTANCE>::makeTree(const std::size_t& start,
                                const std::size_t& stop)
{
   if (stop <= start) return 0;

   std::size_t setSize = stop - start;
   
   if (setSize <= theLeafCapacity)
   {
      return new Leaf(this, start, stop);
   }
   
   Internal* node=new Internal(this);
   node->theIndex=start;

   const std::size_t vp = selectVp(start, stop);
   std::swap(theItems[start], theItems[vp]);
   
   // Identify bound elements
   std::size_t outerLowerBound = (start + stop + 1)/2;
   std::size_t innerLowerBound = start + 1;
   
   // Update Histories
   double d_max=0;
   double d_min=std::numeric_limits<double>::max();
   for (std::size_t elem = stop; elem --> innerLowerBound; )
   {
      theItems[elem].theDistance = DISTANCE(theItems[start].theElement,
                                            theItems[elem].theElement);
      d_max = (theItems[elem].theDistance > d_max)
              ? theItems[elem].theDistance : d_max;
      d_min = (theItems[elem].theDistance < d_min)
              ? theItems[elem].theDistance : d_min;
   }
   
   // Put the median element in place
   std::nth_element(
           theItems.begin() + innerLowerBound,
           theItems.begin() + outerLowerBound,
           theItems.begin() + stop);

   double d_mid=0;
   for (std::size_t elem = outerLowerBound; elem --> innerLowerBound; )
   {
      d_mid = (theItems[elem].theDistance > d_mid)
              ? theItems[elem].theDistance : d_mid;
   }
   
   node->theInnerLowerBound=d_min;
   node->theInnerUpperBound=d_mid;
   node->theOuterLowerBound=theItems[outerLowerBound].theDistance;
   node->theOuterUpperBound=d_max;
   
   #ifdef _OPENMP
   static int threads = 0;
   
   if (threads < (omp_get_num_procs()-1))
   {
      omp_set_dynamic(1);
      omp_set_num_threads(omp_get_num_procs());
      omp_set_nested(1);
      
      #pragma omp atomic
      ++threads;
      
      #pragma omp parallel
      {         
         #pragma omp sections nowait
         {
            node->theInnerBranch = makeTree(innerLowerBound,
                                            outerLowerBound);
              
            #pragma omp section
            node->theOuterBranch = makeTree(outerLowerBound,
                                            stop);
         }
      }
      
      #pragma omp atomic
      --threads;
   }
   else
   #endif
   {
      node->theInnerBranch = makeTree(innerLowerBound,
                                      outerLowerBound);
      node->theOuterBranch = makeTree(outerLowerBound,
                                      stop);
   }
   
   return node;
}

//*****************************************************************************
// VpTree::knn
//*****************************************************************************
template<typename T,
         double (*DISTANCE)(const T&, const T&)>
std::pair<std::deque<std::size_t>, std::deque<double> >
VpTree<T, DISTANCE>::knn(const T& query,
                           const std::size_t& k,
                           const double limit) const
{
   std::priority_queue<ResultsCandidate> candidates;
   double tau = limit;

   theRoot->knn(query, k, candidates, tau);

   std::deque<std::size_t> indices;
   std::deque<double> distances;
   
   while(!candidates.empty())
   {
      indices.push_back(theItems[candidates.top().theIndex].theIndex);
      distances.push_back(candidates.top().theDistance);
      candidates.pop();
   }

   std::reverse(indices.begin(), indices.end());
   std::reverse(distances.begin(), distances.end());

   return std::make_pair(indices, distances);
}

//*****************************************************************************
// VpTree::Internal::knn
//*****************************************************************************
template<typename T,
         double (*DISTANCE)(const T&, const T&)>
void
VpTree<T, DISTANCE>::Internal::knn(
        const T& query,
        const std::size_t& k,
        std::priority_queue<ResultsCandidate>& candidates,
        double& kth_closest) const
{
   if (!(this->theTree)) return;
   
   double distance = DISTANCE(this->theTree->theItems[theIndex].theElement,
                              query);

   if (distance < kth_closest)
   {
      if (candidates.size() == k)
      {
          candidates.pop();
      }
      
      candidates.push(ResultsCandidate(theIndex, distance));
      
      if (candidates.size() == k)
      {
          kth_closest = candidates.top().theDistance;
      }
   }

   if (!(theInnerBranch || theOuterBranch))
   {
       return;
   }

   double middle = 0.5 * (theInnerUpperBound + theOuterLowerBound);

   if (distance < middle)
   {
      if (theInnerBranch &&
          (distance - theInnerUpperBound <= kth_closest) &&
          (theInnerLowerBound - distance <= kth_closest))


      if (theInnerBranch &&
          (distance <= theInnerUpperBound + kth_closest) &&
          (distance >= theInnerLowerBound - kth_closest))
      {
         theInnerBranch->knn(query, k, candidates, kth_closest);
      }
       
      if (theOuterBranch &&
          (distance >= theOuterLowerBound - kth_closest) &&
          (distance <= theOuterUpperBound + kth_closest))
      {
         theOuterBranch->knn(query, k, candidates, kth_closest);
      }
   }
   else
   {
      if (theOuterBranch &&
          (distance >= theOuterLowerBound - kth_closest) &&
          (distance <= theOuterUpperBound + kth_closest))
      {
         theOuterBranch->knn(query, k, candidates, kth_closest);
      }
       
      if (theInnerBranch &&
          (distance <= theInnerUpperBound + kth_closest) &&
          (distance >= theInnerLowerBound - kth_closest))
      {
         theInnerBranch->knn(query, k, candidates, kth_closest);
      }
   }
}

//*****************************************************************************
// VpTree::Leaf::knn
//*****************************************************************************
template<typename T,
         double (*DISTANCE)(const T&, const T&)>
void
VpTree<T, DISTANCE>::Leaf::knn(
        const T& query,
        const std::size_t& k,
        std::priority_queue<ResultsCandidate>& candidates,
        double& kth_closest) const
{
   if (!(this->theTree)) return;
   
   // Scan the leaf
   for (std::size_t item = theHead; item < theTail; ++item)
   {
      // This is the check state
      double distance = DISTANCE(this->theTree->theItems[item].theElement,
                                 query);
      
      if (distance < kth_closest)
      {
         if (candidates.size() == k)
         {
            candidates.pop();
         }
         
         candidates.push(ResultsCandidate(item, distance));
         
         if (candidates.size() == k)
         {
            kth_closest = candidates.top().theDistance;
         }
      }
   }
}

//*****************************************************************************
// VpTree::rnn
//*****************************************************************************
template<typename T,
         double (*DISTANCE)(const T&, const T&)>
std::pair<std::deque<std::size_t>, std::deque<double> >
VpTree<T, DISTANCE>::rnn(const T& query,
                           const double range) const
{
   std::priority_queue<ResultsCandidate> candidates;

   theRoot->rnn(query, range, candidates);

   std::deque<std::size_t> indices;
   std::deque<double> distances;

   while(!candidates.empty())
   {
      indices.push_back(theItems[candidates.top().theIndex].theIndex);
      distances.push_back(candidates.top().theDistance);
      candidates.pop();
   }

   std::reverse(indices.begin(), indices.end());
   std::reverse(distances.begin(), distances.end());

   return std::make_pair(indices, distances);
}

//*****************************************************************************
// VpTree::Internal::rnn
//*****************************************************************************
template<typename T,
         double (*DISTANCE)(const T&, const T&)>
void
VpTree<T, DISTANCE>::Internal::rnn(
        const T& query,
        const double range,
        std::priority_queue<ResultsCandidate>& candidates) const
{
   if (!(this->theTree)) return;

   double distance = DISTANCE(this->theTree->theItems[theIndex].theElement,
                              query);
   
   if (distance <= range)
   {
      candidates.push(ResultsCandidate(theIndex, distance));
   }

   if (!(theInnerBranch || theOuterBranch))
   {
       return;
   }

   double middle = 0.5 * (theInnerUpperBound + theOuterLowerBound);

   if (distance < middle)
   {
      if (theInnerBranch &&
          (distance <= theInnerUpperBound + range) &&
          (distance >= theInnerLowerBound - range))
      {
         theInnerBranch->rnn(query, range, candidates);
      }
       
      if (theOuterBranch &&
          (distance >= theOuterLowerBound - range) &&
          (distance <= theOuterUpperBound + range))
      {
         theOuterBranch->rnn(query, range, candidates);
      }
   }
   else
   {
      if (theOuterBranch &&
          (distance >= theOuterLowerBound - range) &&
          (distance <= theOuterUpperBound + range))
      {
         theOuterBranch->rnn(query, range, candidates);
      }
       
      if (theInnerBranch &&
          (distance <= theInnerUpperBound + range) &&
          (distance >= theInnerLowerBound - range))
      {
         theInnerBranch->rnn(query, range, candidates);
      }
   }
}

//*****************************************************************************
// Leaf::rnn
//*****************************************************************************
template<typename T,
         double (*DISTANCE)(const T&, const T&)>
void
VpTree<T, DISTANCE>::Leaf::rnn(
        const T& query,
        const double range,
        std::priority_queue<ResultsCandidate>& candidates) const
{
   if (!(this->theTree)) return;

   // Scan the leaf
   for (std::size_t item = theHead; item < theTail; ++item)
   {
      double distance = DISTANCE(this->theTree->theItems[item].theElement,
                                 query);
      
      if (distance <= range)
      {
         candidates.push(ResultsCandidate(item, distance));
      }
   }
}

}
