@NikitaDen

Почему не сходится градиентный спуск?

Добрый день! При реализации линейной регрессии возникла проблема: при градиентном спуске loss увеличивается причем на несколько порядков за раз. Loss - MSE, градиентный спуск - обычный. В качестве датасета выбрал California Housing dataset . При написании кода опирался на статью
Код класса
class LinearRegression():

  w = None
  alpha = None

  def __init__(self, lr, E=20):
    self.lr = lr
    self.w = np.zeros(X.shape[1] + 1)
    self.E = E

  def loss(self, X, y):
    return np.sum((y - X @ self.w) ** 2) / X.shape[0]

  def grad(self, X, y):
    print(self.w)
    grad_basic = np.transpose(X) @ (X @ self.w - y)    
    assert grad_basic.shape == (X.shape[1],) , "Градиенты должны быть столбцом из k_features + 1 элементов"
    return grad_basic / X.shape[0]

  def sgd(self, X, y, E=20):
    X = np.concatenate((np.ones((X.shape[0], 1)), X), axis = 1)
    print(X)
    self.loss_arr = [self.loss(X, y)]
    for _ in tqdm(range(E)):
      if abs(self.loss_arr[-1]) < 0.1:
        break
      self.w -= self.lr * self.grad(X, y)
      self.loss_arr.append(self.loss(X, y))

  def fit(self, X, y):
    self.sgd(X, y, self.E)

  def get_params(self):
    return self.w

  def get_loss(self):
    return self.loss_arr

  def predict(self, X):
    X = np.concatenate((np.ones((X.shape[0], 1)), X), axis = 1)
    return X.dot(self.w)

Вот так ведет себя loss:
5.610483198987253,
71185949512.90901,
1.9667789518677714e+21,
5.433978711941763e+31,
1.5013443485392878e+42...

Остальной код по загрузке датасета находится в ноутбуке
Подскажите, пожалуйста, где я совершил ошибку
  • Вопрос задан
  • 43 просмотра
Пригласить эксперта
Ответы на вопрос 1
@SeptiM
Попробуй градиент нормализовать.
Ответ написан
Комментировать
Ваш ответ на вопрос

Войдите, чтобы написать ответ

Войти через центр авторизации
Похожие вопросы
American Hunters Мурманск
До 200 000 ₽
KazanExpress Москва
от 300 000 до 400 000 ₽
KazanExpress Казань
от 300 000 до 400 000 ₽
25 июл. 2022, в 07:44
3000 руб./за проект
16 авг. 2022, в 08:05
3000 руб./за проект
16 авг. 2022, в 04:20
15000 руб./за проект