DKay7
@DKay7
Юный студент

Python matplotlib scatter. Как динамически установить цвет точки?

Пишу визуализацию алгоритма-классификатора knn через sklearn и matplotlb. Последняя функция, которую хотел бы добавить: по нажатию мышки на график отображалась бы точка и предсказывался бы ее класс. Точка отображается, класс предсказывается верно, но цвет точки не меняется (т.е. 0 класс - фиолетовый, 1 - красный и т.д.). Использую цветовую карту. Не могу понять в чем проблема, ведь то же самое сработало раньше, когда я окрашивал фон за точками. Вот функция, которая вызывается при нажатии кнопки мыши:
def onclick(self, event):
        axes = event.inaxes

        if axes is None:
            return

        if (self.dot or self.dot_text) is not None:
            self.dot.remove()
            self.dot_text.remove()

        x = event.xdata
        y = event.ydata

        label, prob = self.knn.get_prediction([[x, y]], prob=True) #предсказание класса и вероятность точности

        self.dot = plt.scatter(x, y, c=label, s=60, #пытаюсь установить цвет точки, как цвет ее класса
                               cmap=self.cmap, edgecolors='black', linewidth=2)
        print(label)

        self.dot_text = plt.text(x, y, 'Класс {0}\n С вероятностью {1}'.format(
            label[0], round(
                            prob[0][int(label[0])], 3
                            )
        ))

        plt.draw()

На картинке ниже точка должна быть зеленого цвета, но она всегда фиолетового (что соответствует 0 классу)5e931a7b5251c440754157.png
UPD:Спасибо за помощь, прикрепляю исправленный код, если кому-то пригодится. Исправления коснулись только функции scatter, в нее нужно было добавить нормализацию, а в параметр цвета вместо списка одного элемента labels передавать сам элемент - labels[0].
self.dot = plt.scatter(x, y,
                               s=60,
                               c=int(label[0]),
                               norm=plt.Normalize(vmin=0, vmax=self.num_classes-1),
                               edgecolors='black',
                               cmap=self.cmap,
                               linewidth=2)
  • Вопрос задан
  • 475 просмотров
Решения вопроса 1
@dmshar
Уже отвечал вам:
plt.scatter(x, y,
c = label,
norm = plt.Normalize(vmin=min(label), vmax=max(label)+1),
cmap = "nipy_spectral")
Ответ написан
Пригласить эксперта
Ваш ответ на вопрос

Войдите, чтобы написать ответ

Войти через центр авторизации
Похожие вопросы