При предсказании нейросеть иногда дает маску с ложноположительными значениями пикселей от 0.5-0.75 итд. Среднее число таких пикселей от 100 до 500 штук на одну ложнопредсказанную маску. Хочу в метрике Dice поставить фильтр, которая фильтровала бы только пиксели с значением от 0.75 в предсказанной маске и если количество таких пикселей меньше 800 штук то можно считать, что в этой маске ничего нет, вся маска заполняется 0ми, если больше, то норм. После обучения, в ручном тестировании на картинках все это работает, после написания ifов для фильтрования по значениям каждого пикселя и min area (800 штук минимум), но как это сделать в метрике, что бы все это работало уже во время обучения?
Loss и метрика:
def dice_loss(y_true, y_pred):
smooth=1e-6
y_true_f = K.flatten(y_true)
y_pred_f = K.cast(y_pred, 'float32')
y_pred_f = K.flatten(y_pred)
intersection = y_true_f * y_pred_f
score = K.mean(1. - (2. * K.sum(intersection) + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth))
return score
def dice_metric(preds, trues):
preds = K.cast(preds, 'float32')
return 1 - dice_loss(preds, trues)
Пробовал использовать внутри метрики:
preds_f = K.cast(K.greater(K.flatten(preds), 0.75), 'float32')
Пользы от этого было мало.
Пробовал это:
K.cast(K.greater(K.sum(preds), 800.0), 'float32')
Как сделать в метрике, что бы во время обучения он фильтровал пиксели со значением больше 0.75 и которых минимум 800 штук в одной маске?