泛型算法_k近邻_KD-Tree(kd树)

一、数据集和算法:


数据:

T={(2, 3), (5, 4), (9, 6), (4, 7), (8, 1), (7, 2)}


创建KD树的算法比较容易看懂,参考这篇:点我

看了网上很多查询的算法,大多都是给的伪代码,很多都是互相抄的,也不一定正确。这里我自己写了一个可以运行的代码,测试了几次,没什么问题。


我贴一个比较靠谱的最近邻算法(后面的代码给的是一个k近邻查询,但是原理都差不多):

(1)将查询数据Q从根结点开始,按照Q与各个结点的比较结果向下访问Kd-Tree,直至达到叶子结点。
其中Q与结点的比较指的是将Q对应于结点中的k维度上的值与m进行比较,若Q(k) < m,则访问左子树,否则访问右子树。达到叶子结点时,计算Q与叶子结点上保存的数据之间的距离,记录下最小距离对应的数据点,记为当前“最近邻点”Pcur和最小距离Dcur。
(2)进行回溯(Backtracking)操作,该操作是为了找到离Q更近的“最近邻点”。即判断未被访问过的分支里是否还有离Q更近的点,它们之间的距离小于Dcur。
如果Q与其父结点下的未被访问过的分支之间的距离小于Dcur,则认为该分支中存在离P更近的数据,进入该结点,进行(1)步骤一样的查找过程,如果找到更近的数据点,则更新为当前的“最近邻点”Pcur,并更新Dcur。
如果Q与其父结点下的未被访问过的分支之间的距离大于Dcur,则说明该分支内不存在与Q更近的点。
回溯的判断过程是从下往上进行的,直到回溯到根结点时已经不存在与P更近的分支为止。


下面是运行结果:




给一个可以直接编译运行的VS2008代码:点我下载


一、构造一些类模板和函数模板以便后面计算方便


关于函数对象这一块的内容,请参考《C++标准程序库》、《STL源码解析》

这里大家只要了解bind3rd的用法,由于bind3rd标准库并未提供,需要自己动手写。不懂用法,可以百度搜索下bind2nd函数,以便参考。

1. 头文件 myfunctional.hpp


#include <functional>


template<class _Arg1,
class _Arg2,
class _Arg3,
class _Result>
struct tenary_function
{	
	typedef _Arg1 first_argument_type;
	typedef _Arg2 second_argument_type;
	typedef _Arg3 third_argument_type;
	typedef _Result result_type;
};

template<class _Operation>
class binder3rd : public std::binary_function<typename _Operation::first_argument_type, 
	typename _Operation::second_argument_type, typename _Operation::result_type>
{
protected:
	_Operation op;
	typename _Operation::third_argument_type value;
public:
	binder3rd(const _Operation &_Func, const typename _Operation::third_argument_type &_Third) : op(_Func), value(_Third){}

	typename _Operation::result_type operator()(const typename _Operation::first_argument_type &__x, 
		const typename _Operation::second_argument_type &__y) const
	{
		return op(__x, __y, value);
	}
};


//将三元函数对象适配成二元函数对象
template<class _Operation, class _Ty> inline
binder3rd<_Operation> bind3rd(const _Operation& _Func, const _Ty& _Third)
{
	typename _Operation::third_argument_type _Val(_Third);
	return (binder3rd<_Operation>(_Func, _Val));
}


2.heap.hpp

最大堆类模板,用于存放前k近邻,这部分代码是我用的@江南烟雨 ,部分代码我做了修改为了方便KdTree使用

关于最大堆的概念不清楚可以参考他的博客:点我进入

#pragma once

//STL堆算法实现(大顶堆)

//包含容器vector的头文件:Heap用vector来存储元素
#include <vector>
#include <functional>

#define MAX_VALUE 999999 //某个很大的值,存放在vector的第一个位置(最大堆)

const int StartIndex = 1;//容器中堆元素起始索引

using namespace std;

//堆类定义
//默认比较规则less
template <class ElemType,class Compare = less<ElemType> >
class MyHeap{
private:
	vector<ElemType> heapDataVec;//存放元素的容器
	int numCounts;//堆中元素个数
	Compare comp;//比较规则

public:
	MyHeap();

	vector<ElemType> getVec();

	bool empty();
	int size();
	void initHeap(ElemType *data,const int n);//初始化操作
	void makeHeap();//建堆
	void pushHeap(ElemType elem);//向堆中插入元素
	void popHeap();//删除堆顶的元素
	void clear();
	vector<ElemType> sortHeap();
	ElemType getTop();//获取堆顶元素

private:
	void adjustHeap(int childTree,ElemType adjustValue);//调整子树
	void percolateUp(int holeIndex,ElemType adjustValue);//上溯操作
};

template <class ElemType,class Compare>
vector<ElemType> MyHeap<ElemType, Compare>::sortHeap()
{
	std::vector<ElemType> result(numCounts);
	for (int i = numCounts - 1; i >=0 ; --i)
	{
		ElemType topElem = getTop();
		popHeap();
		result[i] = topElem;
	}
	return result;
}

template <class ElemType,class Compare>
void MyHeap<ElemType, Compare>::clear()
{
	heapDataVec.clear();
	ElemType e;
	heapDataVec.push_back(e);
	numCounts = 0;
}

template <class ElemType,class Compare>
int MyHeap<ElemType, Compare>::size()
{
	return numCounts;
}

template <class ElemType,class Compare>
bool MyHeap<ElemType, Compare>::empty()
{
	return numCounts == 0 ? true : false;
}

template <class ElemType,class Compare>
ElemType MyHeap<ElemType, Compare>::getTop()
{
	return heapDataVec[1];
}


template <class ElemType,class Compare>
MyHeap<ElemType,Compare>::MyHeap()
:numCounts(0)
{
	ElemType e;
	heapDataVec.push_back(e);
}

template <class ElemType,class Compare>
vector<ElemType> MyHeap<ElemType,Compare>::getVec()
{
	return heapDataVec;
}

template <class ElemType,class Compare>
void MyHeap<ElemType,Compare>::initHeap(ElemType *data,const int n)
{
	//拷贝元素数据到vector中
	for (int i = 0;i < n;++i)
	{
		heapDataVec.push_back(*(data + i));
		++numCounts;
	}
}

template <class ElemType,class Compare>
void MyHeap<ElemType,Compare>::makeHeap()
{
	//建堆的过程就是一个不断调整堆的过程,循环调用函数adjustHeap依次调整子树
	if (numCounts < 2)
		return;
	//第一个需要调整的子树的根节点多音
	int parent = numCounts / 2;
	while(1)
	{
		adjustHeap(parent,heapDataVec[parent]);
		if (StartIndex == parent)//到达根节点
			return;

		--parent;
	}
}



template <class ElemType,class Compare>
void MyHeap<ElemType,Compare>::pushHeap(ElemType elem)
{
	//将新元素添加到vector中
	heapDataVec.push_back(elem);
	++numCounts;

	//执行一次上溯操作,调整堆,以使其满足最大堆的性质
	percolateUp(numCounts,heapDataVec[numCounts]);
}

template <class ElemType,class Compare>
void MyHeap<ElemType,Compare>::popHeap()
{
	//将堆顶的元素放在容器的最尾部,然后将尾部的原元素作为调整值,重新生成堆
	ElemType adjustValue = heapDataVec[numCounts];
	//堆顶元素为容器的首元素
	heapDataVec[numCounts] = heapDataVec[StartIndex];
	//堆中元素数目减一
	--numCounts;

	adjustHeap(StartIndex,adjustValue);

	//直接删除
	heapDataVec.pop_back();
}

//调整以childTree为根的子树为堆
template <class ElemType,class Compare>
void MyHeap<ElemType,Compare>::adjustHeap(int childTree,ElemType adjustValue)
{
	//洞节点索引
	int holeIndex = childTree;
	int secondChid = 2 * holeIndex + 1;//洞节点的右子节点(注意:起始索引从1开始)
	while(secondChid <= numCounts)
	{
		if (comp(heapDataVec[secondChid],heapDataVec[secondChid - 1]))
		{
			--secondChid;//表示两个子节点中值较大的那个
		}

		//上溯
		heapDataVec[holeIndex] = heapDataVec[secondChid];//令较大值为洞值
		holeIndex = secondChid;//洞节点索引下移
		secondChid = 2 * secondChid + 1;//重新计算洞节点右子节点
	}
	//如果洞节点只有左子节点
	if (secondChid == numCounts + 1)
	{
		//令左子节点值为洞值
		heapDataVec[holeIndex] = heapDataVec[secondChid - 1];
		holeIndex = secondChid - 1;
	}
	//将调整值赋予洞节点
	heapDataVec[holeIndex] = adjustValue;

	//此时可能尚未满足堆的特性,需要再执行一次上溯操作
	percolateUp(holeIndex,adjustValue);
}

//上溯操作
template <class ElemType,class Compare>
void MyHeap<ElemType,Compare>::percolateUp(int holeIndex,ElemType adjustValue)
{
	//将新节点与其父节点进行比较,如果键值比其父节点大,就父子交换位置。
	//如此,知道不需要对换或直到根节点为止
	int parentIndex = holeIndex / 2;
	while(holeIndex > StartIndex && comp(heapDataVec[parentIndex],adjustValue))
	{
		heapDataVec[holeIndex] = heapDataVec[parentIndex];
		holeIndex = parentIndex;
		parentIndex /= 2;
	}
	heapDataVec[holeIndex] = adjustValue;//将新值放置在正确的位置
}


3. KdTree.hpp

这个模板类接受任意数据类型,客户端需要自己继承该类并重写虚方法

#pragma once
#include <vector>
#include <stack>
#include <algorithm>
#include <cmath>
#include "myfunctional.hpp"
#include "heap.hpp"
#define INFINITE 0xFFFFFFFF
template<class DataType, unsigned N>
class KdTree;

template<class DataType, unsigned N>
class KdNode
{
	friend KdTree<DataType, N>;
public:
	~KdNode()
	{
		if (_left != NULL)
		{
			delete _left;
			_left = NULL;
		}
		if (_right != NULL)
		{
			delete _right;
			_right = NULL;
		}
	}
private:
	std::vector<DataType> _data;
	int _split;
	KdNode<DataType, N>* _left;
	KdNode<DataType, N>* _right;
};


//
template<class DataType, unsigned N>
class KdTree
{
public:
	KdTree();
	virtual ~KdTree();
	//数据必须能够度量距离
	virtual double getDist(const std::vector<DataType> &first, const std::vector<DataType> &second) = 0;
	virtual double getDist(const DataType &first, const DataType &second) = 0;

	//任一维度之间可比较大小
	virtual bool less(const DataType &first, const DataType &second) const = 0;
	void createKdTree(const std::vector<DataType> *dataset, int size);
	std::vector<std::pair<double, std::vector<DataType>>> query(const std::vector<DataType> &queryData, int k);
	//寻找split维度上的中位数
	std::vector<DataType> getMedium(std::vector<DataType> *first, std::vector<DataType> *last, int split);
private:
	KdTree(const KdTree<DataType, N>&);
	KdTree<DataType, N>& operator=(KdTree<DataType, N>&);
	KdNode<DataType, N>* createKdTree(std::vector<DataType> *first, std::vector<DataType> *last, int split);
private:
	KdNode<DataType, N> *_head;
	std::vector<DataType> *_copydata;
	std::stack<KdNode<DataType, N>*> _search_path;
	MyHeap<std::pair<double, std::vector<DataType>>> _heap;
};



//按维度排序准则
template<class DataType, unsigned N>
struct _less : public tenary_function<std::vector<DataType>, std::vector<DataType>, int, bool>
{
	const KdTree<DataType, N> *_kdTree;
	bool operator()(const std::vector<DataType> &__x, const std::vector<DataType> &__y, const int __z) const 
	{
		return _kdTree->less(__x[__z], __y[__z]);
	}
	_less(const KdTree<DataType, N> *kdTree) : _kdTree(kdTree)
	{

	}
};
template<class DataType, unsigned N>
std::vector<std::pair<double, std::vector<DataType>>> KdTree<DataType, N>::query(const std::vector<DataType> &queryData, int k)
{
	_heap.clear();
	KdNode<DataType, N> *p = _head;
	KdNode<DataType, N> *curNearest = NULL;
	double minDist = INFINITE;

	//查询至叶节点
	while(p != NULL)
	{
		_search_path.push(p);
		if (queryData[p->_split] < p->_data[p->_split])
		{
			p = p->_left;
		}
		else
		{
			p = p->_right;
		}

	}

	if (!_search_path.empty())
	{
		curNearest = _search_path.top();
		_search_path.pop();
		minDist = getDist(curNearest->_data, queryData);
		_heap.pushHeap(std::make_pair(minDist, curNearest->_data));
	}
	KdNode<DataType, N>* backPoint = NULL;
	while(!_search_path.empty())
	{
		backPoint = _search_path.top(); 
		_search_path.pop();

		double temp = getDist(backPoint->_data, queryData);

		//如果堆小于k, 直接添加到堆
		if (_heap.size() < k)
		{
			_heap.pushHeap(std::make_pair(temp, backPoint->_data));
		}
		else
		{
			// 如果距离小于堆顶元素,则删除堆顶元素,添加此元素
			std::pair<double, std::vector<DataType>> topElement = _heap.getTop();
			minDist = topElement.first;
			if (temp < minDist)
			{
				_heap.popHeap();
				_heap.pushHeap(std::make_pair(temp, backPoint->_data));
			}
		}
		
		std::pair<double, std::vector<DataType>> topElement = _heap.getTop();
		minDist = topElement.first;


		//更新最小超球

		
		if (temp < minDist)
		{
			minDist = temp;
			curNearest = backPoint;
		}

		//查看backPoint所在维度的超平面是否和当前最小超球相交,若相交则进入另一半空间查找
		if (getDist(backPoint->_data[backPoint->_split], queryData[backPoint->_split]) <= minDist)
		{
			//当前节点是否在左子空间,如果在则进入右子空间继续搜索直至叶结点,如果不在则进入左子空间搜索直至叶结点
			if (queryData[backPoint->_split] < backPoint->_data[backPoint->_split])
			{
				p = backPoint->_right;
			}
			else
			{
				p = backPoint->_left;
			}

			//搜索至叶节点
			while(p != NULL)
			{
				_search_path.push(p);
				if (queryData[p->_split] < p->_data[p->_split])
				{
					p = p->_left;
				}
				else
				{
					p = p->_right;
				}
			}
		}
		


	}

	std::vector<std::pair<double, std::vector<DataType>>> result = _heap.sortHeap();
	return result;
}

template<class DataType, unsigned N>
KdTree<DataType, N>::~KdTree()
{
	if(_head != NULL)
	{
		delete _head;
	}
	if (_copydata != NULL)
	{
		delete[] _copydata;
	}
}




template<class DataType, unsigned N>
KdTree<DataType, N>::KdTree() : _head(NULL), _copydata(NULL)
{
	
}

template<class DataType, unsigned N>
std::vector<DataType> KdTree<DataType, N>::getMedium( std::vector<DataType> *first, std::vector<DataType> *last, int split )
{
	std::size_t size = last - first;
	std::sort(first, last, bind3rd(_less<DataType, N>(this), split));
	return *(first+size/2);
}


template<class DataType, unsigned N>
KdNode<DataType, N>* KdTree<DataType, N>::createKdTree(std::vector<DataType> *first, std::vector<DataType> *last, int split)
{
	if (first == last)
	{
		return NULL;
	}
	std::size_t size = last  - first;
	KdNode<DataType, N>* newNode = new KdNode<DataType, N>;
	std::vector<DataType> data = getMedium(first, last, split);
	newNode->_split = split;
	newNode->_data = data;

	newNode->_left = createKdTree(first, first + size/2, (split+1)%N);
	newNode->_right = createKdTree(first + size/2 + 1, last, (split+1)%N);
	return newNode;
}

template<class DataType, unsigned N>
void KdTree<DataType, N>::createKdTree(const std::vector<DataType> *dataset, int size)
{
	_copydata = new std::vector<DataType>[size];
	std::copy(dataset, dataset + size, _copydata);
	_head = createKdTree(_copydata, _copydata + size, 0);
}


二、客户端实现

// kd_tree.cpp : 定义控制台应用程序的入口点。
//

#include "stdafx.h"
#include <iostream>
#include <string>
#include "KdTree.hpp"


class MyKdTree : public KdTree<double, 2>
{
public:
	virtual double getDist(const std::vector<double> &first, const std::vector<double> &second)
	{
		double sum = 0;
		for (std::size_t i = 0; i < first.size(); ++ i)
		{
			sum += std::pow(first[i]-second[i], 2);
		}
		return std::sqrt(sum);
	}
	virtual double getDist(const double &first, const double &second)
	{
		return fabs(first - second);
	}

	virtual bool less(const double &first, const double &second) const
	{
		return first < second;
	}
};



void DoubleKdTree();

int _tmain(int argc, _TCHAR* argv[])
{

	DoubleKdTree();


	
	return 0;
}

void DoubleKdTree()
{
	//创建数据
	double dataset[6][2] = {
		2.0, 3.0,
		5.0, 4.0,
		9.0, 6.0,
		4.0, 7.0,
		8.0, 1.0,
		7.0, 2.0,
	};
	MyKdTree myKdTree;
	std::vector<double> vDataSet[6];
	for (int i = 0; i < 6; ++i)
	{
		double *p = (double*)(&dataset[i]);
		std::vector<double> temp(p, p + 2);
		vDataSet[i] = temp;
	}

	//构建KD树
	myKdTree.createKdTree(vDataSet, 6);
	double data[2] = {0};
	int k = 1;
	while (data[0] != -1 && data[1] != -1)
	{
		std::cout << "输入一个二维数据:";
		std::cin >> data[0] >> data[1];
		std::cout << "\n输入第k近邻的k值:";
		std::cin >> k;
		std::vector<double> test;
		test.push_back(data[0]);
		test.push_back(data[1]);
		std::vector<std::pair<double, std::vector<double>>> result = myKdTree.query(test, k);

		for (std::size_t i = 0; i < result.size(); ++i)
		{
			std::cout << "距离: " << result[i].first << "\t";
			std::cout << "[" << result[i].second[0] << ", " << result[i].second[1] << "]" << std::endl;
		}
		std::cout << std::endl;
	}
}




相关推荐
©️2020 CSDN 皮肤主题: 游动-白 设计师:白松林 返回首页