Введение

Классификация методом поиска ближайших соседей - относительно простой для понимания метод классификации, суть которого подробно рассматриваться в этой статье не будет.

Метод предполагает наличие алгоритма поиска ближайших соседей. Можно использовать разные алгоритмы. Самый простой, но при этом не эффективный по времени алгоритм - полный перебор всех соседей для поиска ближайших. Существуют так же методы поиска, называемые KDtree и BallTree.

Передо мной стояла задача написать классификатор с алгоритмом поиска BallTree, используя заданный интерфейс, примерно такой же, как в sklearn). В рунете я не нашёл инструкцию или статью, сильно помогающие написать код. Поэтому решился написать пост здесь. Заранее спасибо за критику, наверняка в коде будут моменты, требующие серьёзной переработки. Но в общем и целом решение рабочее.

Для того, чтобы понять статью - необходимо знать основы Python и библиотеки numpy.


Теория

Дано:

1) Пространство с точками (точка - представление объекта) и метрика, определяющая каким образом находить расстояние между точками.

2) Множество классов. Классы - не пересекающиеся множества, содержащие точки. С каждой точкой ассоциируется один класс. Говорят, что точка A принадлежит классу 1, а точка B к классу "2", например.

3) Точка/несколько точек класс которой/которых не известен.

Задача:

Предположить, к какому классу наиболее вероятно принадлежит данная точка.

Идея в том, что точки, принадлежащие одному классу, часто располагаются рядом друг с другом. Поэтому определить класс точки можно посмотрев на её соседей. Чем больше соседей класса n располагаются близко к точке - тем вероятнее эта точка принадлежит к классу n.

Появляется задача найти ближайших соседей среди всего множества точек. Для этого можно проверять расстояние до каждой точки в пространстве и запоминать классы самых близких точек. Данный подход требует много времени, ведь проверить нужно каждую точку. Но если заранее структурировать пространство точек, то проверять каждую точку во всех случаях не понадобится, так как структура пространства подскажет, какие точки точно не могут быть ближайшими соседями.

BallTree

В конечном итоге у нас должна образоваться структура, состоящая из вложенных шаров. Каждая точка относится только к одному шару. Все точки разбиты на не пересекающиеся кластеры-шары. Эта структура образует двоичное дерево, в вершинах которого будут лежать центроиды, а кроной которого будут листья, содержащие точки пространства. При этом размер листьев регулируется отдельным параметром LeafSize.

Точки
Наше пространство точек. Точки разделены на 3 класса по цветам.
Наше пространство точек. Точки разделены на 3 класса по цветам.

Структурируем пространство следующим образом:

1) Отсортируем все точки по размерности с наибольшим разбросом значений. Для этого необходимо найти её. Дальше будем работать с этой размерностью.

2) Выберем так называемую точку - центроид. Центроид - это точка, которая находится примерно по середине множества точек, которое мы хотим объединить в один кластер. Для того, чтобы найти центроид, необходимо выбрать точку, у которой координата меньше всего отличается от среднего значения координат по всем точкам рассматриваемого множества вдоль найденной в пункте 1) размерности.

3) Найдём радиус до самой дальней от центроида точки. Это радиус шара-кластера, который покроет всё множество точек (при первой итерации пункта).

Первый шар
Шар, который ассоциирован с вершиной захватывает все точки пространства.
Шар, который ассоциирован с вершиной захватывает все точки пространства.
Граф ballTree пока что выглядит так. Есть только корень.
Граф ballTree пока что выглядит так. Есть только корень.

4) Множество всех точек разделилось на две части - до центроида и после центроида. Возвращаемся к пункту два, применяя его на первой, а далее второй половине множества. До тех пор, пока оставшиеся количество точек не будет меньше, чем LeafSize.

Вложенные шары
Шары, созданные на второй итерации, помечены фиолетовым цветом. Это левый и правый потомки вершины, которые являются листьями, так как охватывают все оставшиеся точки множества.
Шары, созданные на второй итерации, помечены фиолетовым цветом. Это левый и правый потомки вершины, которые являются листьями, так как охватывают все оставшиеся точки множества.
У графа ballTree появились вершины-листья. Если бы точек было бы больше, то ветвление графа происходило бы до тех пор, пока все точки не распределились бы по листьям.
У графа ballTree появились вершины-листья. Если бы точек было бы больше, то ветвление графа происходило бы до тех пор, пока все точки не распределились бы по листьям.

5) Если осталось меньше, чем LeafSize (размер листа) точек, то записываем их в лист и приписываем к шару, покрывающему эти точки.

6) Завершаем работу тогда, когда все точки распределены по листьям.

Теперь, после построения дерева, процесс поиска k ближайших соседей рассматриваемой точки ускорится. Почему?

В процессе поиска нам будет необходимо спускаться по дереву сверху вниз начиная от корня. На каждой вершине нужно принимать решение - рассматривать точки, находящиеся по правую или левую сторону, или не рассматривать? При этом кандидаты в ближайшие соседи будут записываться в множество размером k (количество соседей, которых необходимо найти). Если будет найдена точка, которая расположена ближе, чем самый дальний сосед - то его необходимо заменить на эту точку. Если ещё не набрано необходимое количество ближайших соседей, то записывается любая найденная точка.

При достижении очередной вершины выберем среди всех текущих найденных ближайших соседей точку с максимальным расстоянием от рассматриваемой. Выберем один из двух центроидов, лежащих в нижних вершинах. Если это максимальное расстояние меньше, чем расстояние от рассматриваемой точки до границы шара этого центроида - то нам нет никакой надобности рассматривать точки, входящие в этот шар этого центроида. В любом случае там не будет соседей, которые будут лежать ближе, чем текущий самый далёкий сосед, выгоднее рассматривать точки, находящиеся в другом шаре.

Это и есть момент оптимизации. Заранее структурированное пространство точек даёт возможность знать, насколько далеко будут располагаться ВСЕ точки в кластере (шаре) без вычисления расстояний до каждой точки из этого кластера.

Момент оптимизации

Допустим, что нам дали оранжевую точку для классификации. Допустим, что нам нужно найти одного ближайшего соседа, и изначально кандидатом в него была одна из синих точек. Дальше необходимо сделать выбор - какие точки рассматривать? Может быть, есть точки поближе? Так вот очевидно, что рассматривать точки из шара 3 нет никакого смысла, так как расстояние от нашей точки до границы шара длинее, чем до максимально далёкого на данный момент соседа.


Практика - KnnBalltreeClassifier на Python

Граф будем хранить в виде массива точек, оперируя их индексами таким образом, что индекс левой вершины будет равен 2 * v, а правой 2 * v + 1 (v - индекс текущей вершины). При этом индекс корня = 1. Нам понадобятся классы для вершин графа, класс для точки и реализация вспомогательных функций. Так же необходимо подключить numpy, collections, math и random (для создания множеств точек).

Найденных ближайших соседей будем записывать в set. Понадобится свойство множества (set) не допускать к хранению дублирующиеся элементы (расстояния до точек могут повторяться).

  • Узлы графа BallTree

class Node(object):        
    def __init__(self, pivot = None, points = [], radius = None): 
      #Индекс корня - 1
        self.pivot = pivot #Индекс левой верщины - 2*v
        self.radius = radius #Индекс правой вершины - 2*v + 1
        self.points = points
    def __str__(self):
        self.stringList = ''.join(str(symbol) for symbol in self.points)
        return("pivot = " + str(self.pivot) + " radius = " 
        + str(self.radius) + " points = " 
        + str(self.stringList))

pivot - точка центроид шара, который ассоциирован с этой вершиной графа.
radius - длина радиуса шара, который ассоциирован с этой вершиной графа.
points - массив точек, которые лежат в этом шаре. Важное примечание! Точки будут лежать только в последнем вложенном шаре, который их покрывает. С точки зрения графа - только в листовой вершине (листе).
Дальше идёт метод для красивой печати класса в консоли. Удобно для отладки и просмотра ballTree на текущем этапе работы программы.

  • Опишем класс точки

class Point(object):
    def __init__(self, point, index):
        self.point = point
        self.index = index
    def __str__(self):
        return("Vector " + str(self.point) + " with index " + str(self.index))

point - контейнер, содержащий координаты точки. То есть по сути сама точка как набор её координат.

index - изначальный индекс точки во входном множестве (пространстве точек). Понадобится для того, чтобы узнать классы ближайших соседей (классы хранятся в другом множестве и ассоциированы с точками через индексы).

И снова метод для удобного вывода в консоль при печати.

  • Конструктор

def __init__(self, n_neighbors=1, leaf_size=30):
    if(n_neighbors <= 0):
        print("Ошибка! Соседей не может быть меньше, чем 1")
        return
    self.n_neighbors = n_neighbors
    self.leaf_size = leaf_size
    self.nodes = {}
    self.kNSet = set()
    self.points = []
    self.classes = []
    self.pointsOriginal = []

n_neighbors - количество ближайших соседей, которых требуется найти для классификации новой точки.
leaf_size - размер листа, то есть количество точек в листовых вершинах. Чем больше размер листа - тем меньше листовых вершин, и наоборот.
nodes - словарь для хранения вершин
kNSet - множество, которое будет содержать ближайших соседей. Важно сделать его именно set, чтобы не было повторений среди элементов.
points - изначальный массив точек, будет использоваться для работы.
classes - массив классов (для каждой точки - один класс)
pointsOriginal - изначальный массив точек, который не будет меняться в процессе работы программы (пока не будет загружено новое пространство точек).

  • Вспомогательные функции

def GetnpArrayOfPoints(self, xStruct):
    x = []
    for i in range(len(xStruct)):
        x.append(xStruct[i].point)
    return(np.array(x))
      
def Distance(self, a, b):
    if(len(a) != len(b)):
        print("Размерности объектов не одинаковые!")
        return("Error")
    return np.linalg.norm(a-b)
     
     
def GetMaxDistance(self, point, points):
    points = self.GetnpArrayOfPoints(points)
    point = point.point
    #print(points) #Для отладки
    max = 0
    for i in range(len(points)):
        currentDist = self.Distance(point, points[i])
        if(currentDist > max):
            max = currentDist
    return(max)

GetnpArrayOfPoints упаковывает точки в массив numpy из контейнера xStruct, который содержит элементы типа Point.
Distance - ищет евклидово расстояние между двумя точками a и b.
GetMaxDistance - ищет максимальное расстояние от текущей точки point до какой либо точки из массива points.

  • Функция для нахождения размерности с наибольшим разбросом значений

def GetMaxSpreadDimension(self, x): #Нужно передать подотрезок array[a:b]
    x = self.GetnpArrayOfPoints(x)       
    difference = []
    for i in range(x.shape[1]):
        maximum = x[0][i]
        minimum = x[0][i]
        for j in range(x.shape[0]):
            maximum = max(maximum, x[j][i])
            minimum = min(minimum, x[j][i])
        difference.append((maximum - minimum, i))
    return(max(difference)[1])

Сначала переведём точки в массив нампай, а дальше просто будем перебирать размерности, для того, чтобы найти размерность с максимальным разбросом. Функция понадобится в дальнейшем для построения ballTree.

  • Сортировка точек вдоль указанной размерности

def SortingByDimension(self, x, startPoint, endPoint, dimension): 
    #Нужно передать начало и конец
    left = x[:startPoint]                                  
    rigth = x[endPoint:]
    x = np.array(sorted(x[startPoint:endPoint], key = lambda Point: 
    Point.point[dimension]))
    x = np.concatenate((left, x, rigth), 0)
    return(x)

Упорядочим точки вдоль конкретной размерности. Понадобится в дальнейшем для построения дерева ballTree. Заметим, что сортировать будем только точки, находящиеся на определённом отрезке нашего массива, то есть все точки с индекса startPoint и до индекса endPoint. После сортировки конкретного участка - соединим его с множеством x.

  • Поиск индекса центроида

def GetCentroidIndex(self, x, dimension): 
  #Нужно передать подотрезок array[a:b]
        x = self.GetnpArrayOfPoints(x)
        if(x.shape[0] == 0):
            print("Ошибка! Передан нулевой массив в поиск центроида")
        sum = 0
        for i in range(x.shape[0]): #Размер x не нулевой
            sum = sum + x[i][dimension]
        sum = sum / x.shape[0]
        minimum = abs(sum - x[0][dimension])
        index = 0
        for i in range(x.shape[0]):
            if(abs(x[i][dimension] - sum) < minimum):
                minimum = abs(x[i][dimension] - sum)
                index = i
        return(index)

Центроидом необходимо выбрать точку, которая располагалась бы примерно по середине передаваемого множества x. Найдём среднее значение среди координат всех точек из x. Выбирать будем центроид вдоль размерности dimension. Ищем точку, разница между средним значением по всему x и её координатами минимальна.

  • Построение ballTree

def ConstructTree(self, leafSize, vertexIndex, startPoint, endPoint):
    if(endPoint - startPoint <= 0): 
        return #endPoint не включается в отрезок [ )
        
    dimension = self.GetMaxSpreadDimension(self.points[startPoint:endPoint])
    self.points = self.SortingByDimension(self.points, startPoint, endPoint, 
    dimension)
    centroidIndex = self.GetCentroidIndex(self.points[startPoint:endPoint], 
    dimension) + startPoint 
    radius = self.GetMaxDistance(self.points[centroidIndex], 
    self.points[startPoint:endPoint])
    self.nodes[vertexIndex] = self.Node(self.points[centroidIndex], 
    [], radius)
    
    if(endPoint - startPoint <= leafSize):
        for i in range(startPoint, endPoint):
            self.nodes[vertexIndex].points.append(self.points[i])
        return
    else:
        self.nodes[vertexIndex].points.append(self.points[centroidIndex])
    self.ConstructTree(leafSize, vertexIndex * 2, startPoint, centroidIndex)
    self.ConstructTree(leafSize, vertexIndex * 2 + 1, centroidIndex + 1, 
    endPoint)

Рекурсивная функция. Принимает на вход отрезок, содержащий точки рассматриваемого множества. Начало отрезка - startPoint, конец - endPoint. Сам по себе отрезок обозначает часть множества всех точек, хранящихся в массиве points. vertexIndex - индекс родительской вершины.

Завершаем работу, если дошли от начала до конца.

Находим размерность, вдоль которой будем работать (как вы знаете - это размерность с наибольшим разбросом).

Сортируем точки вдоль найденной размерности.

Дальше находим центроид переданного отрезка с точками. Не забываем про индексацию с нуля (точки индексируются с единицы, а внутри функции поиска центроида - с нуля), поэтому индекс центроида складываем со StartPoint.

Дальше необходимо создать вершину. Для этого находим радиус шара (это будет расстояние до самой отдалённой точки текущего отрезка) и создаём вершину (узел дерева), в который запишем центроид (найденную точку) и пока что пустой массив точек, содержащихся в этом шаре.

Далее - если осталось меньше чем leafSize точек, то добавляем оставшиеся точки внутрь узла (в массив Points), то есть наша вершина становится листовой. Как только заканчиваются точки или размер листа - выходим.

Если точки ещё есть, то условно делим всё множество на две части - от начала до центройда и от цендроида до конца. Одну из частей передаём в качестве отрезка рекурсивно функции для построения левой вершины, вторую - для построения правой вершины.

По итогу имеем, что упорядоченный массив точек поделён на кластеры, к которым можно обратиться через поле points класса Node (вершина дерева). При этом индекс каждой вершины в общем упорядоченном массиве points класса KnnBalltreeClassifier можно вычислить как 2*v для всех левых вершин, 2*v+1 для всех правых вершин и 1 для корня.

  • Поиск ближайших соседей

def searchBallSubtree(self, vertexIndex, newPoint, kN):
    if(len(self.nodes[vertexIndex].points) > 1):
        for point in self.nodes[vertexIndex].points:
            distance = self.Distance(newPoint, point.point)
            if(len(self.kNSet) >= kN and max(self.kNSet)[0] > distance):
                self.kNSet.remove(max(self.kNSet))
            if(len(self.kNSet) < kN):
                self.kNSet.add((distance, point.index))
        return
    leftChild = vertexIndex * 2
    rightChild = vertexIndex * 2 + 1
    distance = self.Distance(newPoint, self.nodes[vertexIndex].pivot.point)

    if(len(self.kNSet) >= kN and max(self.kNSet)[0] > distance):
        self.kNSet.remove(max(self.kNSet))
    if(len(self.kNSet) < kN):
        self.kNSet.add((distance, self.nodes[vertexIndex].pivot.index))
        
    if(len(self.nodes[leftChild].points) != 0):
        distance = self.Distance(newPoint, 
                                 self.nodes[leftChild].pivot.point)
        if(len(self.kNSet) < kN or 
           max(self.kNSet)[0] > distance - self.nodes[leftChild].radius):
            self.searchBallSubtree(leftChild, newPoint, kN)

    if(len(self.nodes[rightChild].points) != 0):
        distance = self.Distance(newPoint, 
                                 self.nodes[rightChild].pivot.point)
        if(len(self.kNSet) < kN or 
           max(self.kNSet)[0] > distance - self.nodes[rightChild].radius):
            self.searchBallSubtree(rightChild, newPoint, kN)

Рекурсивная функция.

Если в узле (шаре, который ассоциирован с вершиной) есть точки, то нужно проверить, нет ли среди них ближайших соседей. Вычисляем расстояния до каждой точки и смотрим, меньше ли оно, чем у максимально далёкой точки в kNSet (то есть среди кандидатов на ближайших соседей).

Если количество необходимых ближайших соседей уже достигнуто, то удаляем самую далёкую точку и добавляем пару (расстояние, индекс) в knSet. По расстоянию мы можем легко найти самую далёкого ближайшего соседа, а по индексу - непосредственно саму точку, которая хранится в массиве pointsOriginal.

Если необходимого количества ближайших соседей ещё нет, то просто добавляем пару (расстояние, индекс) в knSet.

Дальше нам необходимо перейти ниже по ballTree для рассмотрения следующих вершин. Для этого вычислим индексы левой и правой вершины.

Центроид вершины - то же точка, которая может быть ближайшим соседом. Рассмотрим её аналогичным образом (как точки внутри шара на 4 - 8 строках).

Дальше проверим условие, описываемое в последнем абзаце раздела Теория. То есть заранее узнаем по левому центроиду, могут ли в его шаре лежать точки, среди которых есть ближайшие соседи. Для этого вообще проверим, есть ли там какие либо точки и набрали ли мы нужное количество ближайших соседей. Если там нет точек - то не будем рассматривать этот шар. Если мы не набрали нужное количество ближайших соседей - то будем рассматривать этот шар. Если у нас уже есть нужное количество ближайших соседей, и самый дальний ближайший сосед лежит ближе, чем граница этого шара, то все точки шара находятся дальше него.

Аналогичные операции проделываем с правой вершины.

Далее рекурсивно запускаем функцию, где меняем только индекс текущей вершины.

  • Инициализация пространства точек

def fit(self, x, y):
    if (len(x) != len(y)):
        print("Ошибка - не у всех точек определены классы")
        return("fit Error")
    self.points = []
    self.classes = []
    self.pointsOriginal = x
    for i in range(len(x)):
        self.points.append(self.Point(x[i], i))
    self.classes = y

    self.ConstructTree(self.leaf_size, 1, 0, len(self.points))
    return self

Нам необходимо передать классификатору множество точек, среди которых он будет искать ближайших соседей. Эта процедура называется fit, так как мы как бы скармливаем классификатору множество объектов, с которыми он будет работать. В качестве аргументов мы принимаем пространство точек x и массив их классов y.

В fit необходимо сделать все нужные действия по обработке этого множества.

Поэтому проверяем, есть ли классы у всех точек, чистим массивы классов и точек, сохраняем все точки в pointsOriginal (далее по индексу мы сможем найти в нём нужную точку.

Дальше создаём массив из классов Point (которые хранят оригинальный индекс и собственно сам вектор координат) и записываем массив классов в classes.

Теперь мы готовы построить дерево.

  • Предсказание класса

def predict(self, x):
    predictions = []
    for i in range(len(x)):
        self.kNSet.clear()
        self.searchBallSubtree(1, x[i], self.n_neighbors)
        results = []
        for index in self.kNSet: #kNSet хранит пару (дистанция, индекс)
            results.append(self.classes[index[1]])

        resultClass = Counter(results).most_common()[0][0]
        predictions.append(resultClass)
    return np.array(predictions)

Смысл нашего классификатора - предсказывать класс точки, основываясь на информации о ближайших соседях. Поэтому в этой функции мы предсказываем классы всех точек, поданных в аргументе как массив x.

Для этого создаём массив predictions, ищем ближайших соседей для каждой точки из x. Наиболее вероятный класс точки - это самый часто встречающийся класс среди точек ближайших соседей. Узнаём мы его на 10 строчке с помощью метода most_common.

  • Предсказание вероятности класса

def predict_proba(self, X):    
    classesCount = np.unique(self.classes)
    results = []
    for point in X:
        neighborsClasses = self.get_kneighbors_classes(point)
        result = []
        for clas in classesCount:
            count = 0
            for neighborClass in neighborsClasses:
                if(clas == neighborClass):
                    count += 1
            result.append(count/self.n_neighbors * 100)
        results.append(result)
    return(np.array(results))

Если мы не просто хотим знать, какому классу наиболее вероятно принадлежит точка, а хотим знать насколько вероятно что она принадлежит к этому или какому либо другому классу по результату работы нашего алгоритма, то воспользуемся этим методом.

Чтобы найти вероятности, нам необходимо найти количество классов каждого вида среди всех ближайших соседей и поделить на общее количество классов (равное количество ближайших соседей). Выразим ответ в процентах (12 строчка).

  • Просмотр информации о найденных ближайших соседях

def get_kneighbors_classes(self, x):
    self.kNSet.clear()
    self.searchBallSubtree(1, x, self.n_neighbors)
    result = []
    for index in self.kNSet:
        result.append(self.classes[index[1]]) 
    return np.array(result)
        
def kneighbors(self, x, n_neighbors):
    results = []
    for i in range(len(x)):
        self.kNSet.clear()
        self.searchBallSubtree(1, x[i], n_neighbors)
        result = []

        for index in self.kNSet:
            result.append(self.pointsOriginal[index[1]])
        results.append(result)    
    return np.array(np.array(results))

Иногда нам может быть интересно посмотреть информацию о ближайших соседях нашего множества.

get_kneighbors_classes найдёт и выдаст нам классы ближайших соседей точки x.
kneighbors найдёт ближайших соседей для всех точек из массива x.

Код класса KnnBalltreeClassifier
import random
import math
import numpy as np
from collections import Counter

class KnnBalltreeClassifier(object):
    '''Классификатор реализует взвешенное голосование по ближайшим соседям. 
    При подсчете расcтояния используется l2-метрика.
    Поиск ближайшего соседа осуществляется поиском по ball-дереву.
    Параметры
    ----------
    n_neighbors : int, optional
        Число ближайших соседей, учитывающихся в голосовани
        - 'distance' : веса обратно пропорциональны расстоянию д о классифицируемого объекта
        -  функция, которая получает на вход массив расстояний и возвращает массив весов
    leaf_size: int, optional
        Максимально допустимый размер листа дерева
    '''
    class Node(object):        
        def __init__(self, pivot = None, points = [], radius = None): #Индекс корня - 1
            self.pivot = pivot #Индекс левой верщины - 2*v
            self.radius = radius #Индекс правой вершины - 2*v + 1
            self.points = points
        def __str__(self):
            self.stringList = ''.join(str(symbol) for symbol in self.points)
            return("pivot = " + str(self.pivot) + " radius = " + str(self.radius) + " points = " 
                + str(self.stringList))
        
    class Point(object):
            def __init__(self, point, index):
                self.point = point
                self.index = index
            def __str__(self):
                return("Vector " + str(self.point) + " with index " + str(self.index))

    def __init__(self, n_neighbors=1 leaf_size=30):
        if(n_neighbors <= 0):
            print("Ошибка! Соседей не может быть меньше, чем 1")
            return
        self.n_neighbors = n_neighbors
        self.leaf_size = leaf_size
        self.nodes = {}
        self.kNSet = set()
        self.points = []
        self.classes = []
        self.pointsOriginal = []

    def GetnpArrayOfPoints(self, xStruct):
        x = []
        for i in range(len(xStruct)):
            x.append(xStruct[i].point)
        return(np.array(x))
    
    def Distance(self, a, b):
        if(len(a) != len(b)):
            print("Размерности объектов не одинаковые!")
            return("Error")
        return np.linalg.norm(a-b)

    def GetMaxDistance(self, point, points):
        points = self.GetnpArrayOfPoints(points)
        point = point.point
        #print(points)

        max = 0
        for i in range(len(points)):
            currentDist = self.Distance(point, points[i])
            if(currentDist > max):
                max = currentDist
        return(max)

    def GetMaxSpreadDimension(self, x): #Нужно передать подотрезок array[a:b]
        x = self.GetnpArrayOfPoints(x)       
        difference = []
        for i in range(x.shape[1]):
            maximum = x[0][i]
            minimum = x[0][i]
            for j in range(x.shape[0]):
                maximum = max(maximum, x[j][i])
                minimum = min(minimum, x[j][i])
            difference.append((maximum - minimum, i))
        return(max(difference)[1])

    def SortingByDimension(self, x, startPoint, endPoint, dimension): #Нужно передать начало и 
        left = x[:startPoint]                                  
        rigth = x[endPoint:]
        x = np.array(sorted(x[startPoint:endPoint], key = lambda Point: Point.point[dimension]))
        x = np.concatenate((left, x, rigth), 0)
        return(x)

    def GetCentroidIndex(self, x, dimension): #Нужно передать подотрезок array[a:b]
        x = self.GetnpArrayOfPoints(x)
        if(x.shape[0] == 0):
            print("Ошибка! Передан нулевой массив в поиск центроида")
        sum = 0
        for i in range(x.shape[0]): #Размер x не нулевой
            sum = sum + x[i][dimension]
        sum = sum / x.shape[0]
        minimum = abs(sum - x[0][dimension])
        index = 0
        for i in range(x.shape[0]):
            if(abs(x[i][dimension] - sum) < minimum):
                minimum = abs(x[i][dimension] - sum)
                index = i
        return(index)

    def ConstructTree(self, leafSize, vertexIndex, startPoint, endPoint):
        if(endPoint - startPoint <= 0): 
            return #endPoint не включается в отрезок [ )
        

        dimension = self.GetMaxSpreadDimension(self.points[startPoint:endPoint])
        self.points = self.SortingByDimension(self.points, startPoint, endPoint, dimension)
        centroidIndex = self.GetCentroidIndex(self.points[startPoint:endPoint], dimension) + startPoint 

        radius = self.GetMaxDistance(self.points[centroidIndex], self.points[startPoint:endPoint])
        self.nodes[vertexIndex] = self.Node(self.points[centroidIndex], [], radius)
        if(endPoint - startPoint <= leafSize):
            for i in range(startPoint, endPoint):
                self.nodes[vertexIndex].points.append(self.points[i])
            return
        else:
            self.nodes[vertexIndex].points.append(self.points[centroidIndex])
        self.ConstructTree(leafSize, vertexIndex * 2, startPoint, centroidIndex)
        self.ConstructTree(leafSize, vertexIndex * 2 + 1, centroidIndex + 1, endPoint)

    def searchBallSubtree(self, vertexIndex, newPoint, kN):
        if(len(self.nodes[vertexIndex].points) > 1):
            for point in self.nodes[vertexIndex].points:
                distance = self.Distance(newPoint, point.point)
                if(len(self.kNSet) >= kN and max(self.kNSet)[0] > distance):
                    self.kNSet.remove(max(self.kNSet))
                if(len(self.kNSet) < kN):
                    self.kNSet.add((distance, point.index))
            return
        leftChild = vertexIndex * 2
        rightChild = vertexIndex * 2 + 1
        distance = self.Distance(newPoint, self.nodes[vertexIndex].pivot.point)

        if(len(self.kNSet) >= kN and max(self.kNSet)[0] > distance):
            self.kNSet.remove(max(self.kNSet))
        if(len(self.kNSet) < kN):
            self.kNSet.add((distance, self.nodes[vertexIndex].pivot.index))
        
        if(len(self.nodes[leftChild].points) != 0):
            distance = self.Distance(newPoint, self.nodes[leftChild].pivot.point)
            if(len(self.kNSet) < kN or max(self.kNSet)[0] > distance - self.nodes[leftChild].radius):
                self.searchBallSubtree(leftChild, newPoint, kN)

        if(len(self.nodes[rightChild].points) != 0):
            distance = self.Distance(newPoint, self.nodes[rightChild].pivot.point)
            if(len(self.kNSet) < kN or max(self.kNSet)[0] > distance - self.nodes[rightChild].radius):
                self.searchBallSubtree(rightChild, newPoint, kN)
        
     
    def fit(self, x, y):
        if (len(x) != len(y)):
            print("Ошибка - не у всех точек определены классы")
            return("fit Error")
        self.points = []
        self.classes = []
        self.pointsOriginal = x
        for i in range(len(x)):
            self.points.append(self.Point(x[i], i))
        self.classes = y

        self.ConstructTree(self.leaf_size, 1, 0, len(self.points))
        return self
    
    def predict(self, x):
        predictions = []
        for i in range(len(x)):
            self.kNSet.clear()
            self.searchBallSubtree(1, x[i], self.n_neighbors)
            results = []
            for index in self.kNSet: #kNSet хранит пару (дистанция, индекс)
                results.append(self.classes[index[1]])
            resultClass = Counter(results).most_common()[0][0]
            predictions.append(resultClass)
        return np.array(predictions)
        
    def predict_proba(self, X):    
        classesCount = np.unique(self.classes)
        results = []
        for point in X:
            neighborsClasses = self.get_kneighbors_classes(point)
            result = []
            for clas in classesCount:
                count = 0
                for neighborClass in neighborsClasses:
                    if(clas == neighborClass):
                        count += 1
                result.append(count/self.n_neighbors * 100)
            results.append(result)
        return(np.array(results))



    def get_kneighbors_classes(self, x):
        self.kNSet.clear()
        self.searchBallSubtree(1, x, self.n_neighbors)
        result = []
        for index in self.kNSet:
            result.append(self.classes[index[1]]) 
        return np.array(result)
        
    def kneighbors(self, x, n_neighbors):
        results = []
        for i in range(len(x)):
            self.kNSet.clear()
            self.searchBallSubtree(1, x[i], n_neighbors)
            result = []

            for index in self.kNSet:
                result.append(self.pointsOriginal[index[1]])
            results.append(result)    
        return np.array(np.array(results))

Использование кода

Код можно легко проверить на искусственно сгенерированном пространстве точек.

Сгенерируем пространство точек на плоскости. Класс каждой точки ассоциируем с четвертью плоскости, на которой точка располагается.

import pandas as pd
import csv
import matplotlib.pyplot as plt

import random
import math
import pylab as pl
import numpy as np
import matplotlib.pyplot as plt
import pdb
from collections import Counter
import matplotlib.colors as mcolors

def Gen2dDataset(self, points_num):
    points = []
    classes = []
    for i in range(points_num):
        tempX = np.random.uniform(-5,5)
        tempY = np.random.uniform(-5,5)
        points.append((tempX, tempY))
        if tempX <= 0 and tempY <= 0:
            classes.append(1)
        elif tempX <= 0 and tempY >= 0:
            classes.append(2)
        elif tempX >= 0 and tempY <= 0:
            classes.append(3)
        elif tempX >= 0 and tempY >= 0:
            classes.append(4)
    return np.array(points), np.array(classes)

И запустим наш классификатор на этом множестве.

def Visual2dPointClassTest(model):
    np.random.seed(42)
    points, classes = Gen2dDataset(100)

    fig, ax = plt.subplots()

    ax.set_title('Объекты и классы')
    ax.scatter(points[:,0], points[:,1], c = classes, 
            norm = plt.Normalize(0, 5),
            s = 10)


    model.fit(points, classes)
    checkPoints = [[2, 1], [-2, -2], [-3, 4.5], [0, -1]]
    print("Точки для классфикации:")
    print(np.array(checkPoints))
    checkClasses = self.model.predict(checkPoints)
    print("Ближайшие соседи:")
    print(self.model.kneighbors(checkPoints, 3))
    print("Их классы:")
    print(checkClasses)
    print("Вероятности (по классам):")
    print(self.model.predict_proba(checkPoints))
    print("Жирные точки - тестовые точки для классфикации:")
    ax.scatter(np.array(checkPoints)[:,0], np.array(checkPoints)[:,1], 
            c = checkClasses, norm = plt.Normalize(0, 5), 
            s = 40)

model - это экземпляр нашего класса KnnBalltreeClassifier.

Можем так же протестировать работу по построению ballTree (с рисованием графиков).

def ConstructTreeTest(model):
    np.random.seed(42)
    points, classes = Gen2dDataset(100)
    fig, ax = plt.subplots()
    ax.axis('equal')
    model.fit(points, classes)
    print("Узлы(вершины) дерева:")
    for node in model.nodes.values():
        print(node)
    for node in model.nodes.values():
        arrayOfPoints = []
        for point in node.points:
            arrayOfPoints.append(point)
            ax.scatter(points[:,0], points[:,1], c = classes, 
            norm = plt.Normalize(0, 5),
            s = 10)
        ax.scatter(node.pivot.point[0], node.pivot.point[1], c = 'red', 
            norm = plt.Normalize(0, 5),
            s = 25)
        circle1 = plt.Circle(node.pivot.point, node.radius, 
                             color='r', fill = False)
        ax.add_patch(circle1)

Запускать функции советую в jupiter notebook. Всем добра и успехов в учёбе.

Комментарии (5)


  1. densss2
    29.08.2021 19:51

    Так так так. Я не помню, когда я написал эту статью, но плюс статье поставил)))


    1. SherAlex Автор
      30.08.2021 17:39
      +1

      Не думал, что найду человека в интернетах с такой же авой. Спасибо большое! Я рад, что кто то прочёл эту статью.


  1. mvakhmenin
    30.08.2021 17:13

    Круто

    Можно добавить сравнение точности и скорости с другими классификаторами?


    1. SherAlex Автор
      30.08.2021 17:38

      Спасибо :)

      Точность классификатора можно легко сравнить с аналогичным из sklearn или с любым другим. Интерфейсы похожи, написать не сложно. А так, если что, можете посмотреть проект на моём GitHub. Там в файле run.ipynb есть необходимый код, который запускает сравнение. Точность на уровне - точно такая же, как и у sklearn соседей.

      Скорость работы, за ненадобностью, я не сравнивал. Но, конечно, можно сравнить. Теоретически, скорость работы этого алгоритма будет выше, чем у простого перебора. Можно посмотреть научные результаты тут или вот тут (смотрите ссылки на источники). Практически же это зависит от реализации.