Вместо бинарного классификатора оказолось лучше в данном случае использовать категорийный
import numpy as np
import pandas as pd
from sklearn.metrics import roc_auc_score
from keras.models import Sequential
from keras.layers import Dense, Dropout
from sklearn.cross_validation import train_test_split
SEED = 12345
np.random.seed(SEED)
count = 30
Xtrain = np.random.rand(count*10)
Xtrain=Xtrain.reshape(count,10)
dummy = np.random.rand(count)
dummy_y=np.array(range(count*2)).reshape(count,2)
for i in xrange(count):
dummy[i]= 1 if dummy[i]<0.8 else 0
dummy_y[i,0] = 1*(dummy[i]==0)
dummy_y[i,1] = 1*(dummy[i]==1)
def baseline_model():
# create model
model = Sequential()
model.add(Dense(1024, input_dim=Xtrain.shape[1], init='normal', activation='relu'))
model.add(Dropout(0.6))
model.add(Dense(64, init='normal', activation='relu'))
model.add(Dropout(0.6))
model.add(Dense(2, init='normal', activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adadelta', metrics=['accuracy']) #logloss
return model
model=baseline_model()
X_train, X_val, y_train, y_val = train_test_split(Xtrain, dummy_y, test_size=0.2, random_state=42)
cl_weight = pd.Series(dummy).value_counts()
fit= model.fit(X_train,y_train,
nb_epoch=50,
#validation_data=(X_val, y_val),
verbose=True,
)
# evaluate the model
scores_val = model.predict(X_train)
print('roc_auc_score val {}'.format(roc_auc_score(y_train.flatten(), scores_val.flatten())))
scores_val = model.predict(X_val)
print (scores_val.flatten())
print (y_val.flatten())
print('roc_auc_score val {}'.format(roc_auc_score(y_val.flatten(), scores_val.flatten())))
python ./test.py
Using Theano backend.
/usr/lib64/python2.7/site-packages/sklearn/cross_validation.py:44: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.
"This module will be removed in 0.20.", DeprecationWarning)
Epoch 1/50
24/24 [==============================] - 0s - loss: 0.6895 - acc: 0.4583
Epoch 2/50
24/24 [==============================] - 0s - loss: 0.6849 - acc: 0.5417
Epoch 3/50
24/24 [==============================] - 0s - loss: 0.6436 - acc: 0.8333
Epoch 4/50
24/24 [==============================] - 0s - loss: 0.6684 - acc: 0.8333
Epoch 5/50
24/24 [==============================] - 0s - loss: 0.6346 - acc: 0.8750
Epoch 6/50
24/24 [==============================] - 0s - loss: 0.6492 - acc: 0.7500
Epoch 7/50
24/24 [==============================] - 0s - loss: 0.6372 - acc: 0.8333
Epoch 8/50
24/24 [==============================] - 0s - loss: 0.6169 - acc: 0.8333
Epoch 9/50
24/24 [==============================] - 0s - loss: 0.6219 - acc: 0.8333
Epoch 10/50
24/24 [==============================] - 0s - loss: 0.6079 - acc: 0.8333
Epoch 11/50
24/24 [==============================] - 0s - loss: 0.5828 - acc: 0.8333
Epoch 12/50
24/24 [==============================] - 0s - loss: 0.5869 - acc: 0.8333
Epoch 13/50
24/24 [==============================] - 0s - loss: 0.5633 - acc: 0.8333
Epoch 14/50
24/24 [==============================] - 0s - loss: 0.5732 - acc: 0.8333
Epoch 15/50
24/24 [==============================] - 0s - loss: 0.5874 - acc: 0.8333
Epoch 16/50
24/24 [==============================] - 0s - loss: 0.5641 - acc: 0.8333
Epoch 17/50
24/24 [==============================] - 0s - loss: 0.5293 - acc: 0.8333
Epoch 18/50
24/24 [==============================] - 0s - loss: 0.5501 - acc: 0.8333
Epoch 19/50
24/24 [==============================] - 0s - loss: 0.5411 - acc: 0.8333
Epoch 20/50
24/24 [==============================] - 0s - loss: 0.5245 - acc: 0.8333
Epoch 21/50
24/24 [==============================] - 0s - loss: 0.5410 - acc: 0.8333
Epoch 22/50
24/24 [==============================] - 0s - loss: 0.5301 - acc: 0.8333
Epoch 23/50
24/24 [==============================] - 0s - loss: 0.5027 - acc: 0.8333
Epoch 24/50
24/24 [==============================] - 0s - loss: 0.4847 - acc: 0.8333
Epoch 25/50
24/24 [==============================] - 0s - loss: 0.4581 - acc: 0.8333
Epoch 26/50
24/24 [==============================] - 0s - loss: 0.4954 - acc: 0.8333
Epoch 27/50
24/24 [==============================] - 0s - loss: 0.4515 - acc: 0.8333
Epoch 28/50
24/24 [==============================] - 0s - loss: 0.5454 - acc: 0.8333
Epoch 29/50
24/24 [==============================] - 0s - loss: 0.5059 - acc: 0.8333
Epoch 30/50
24/24 [==============================] - 0s - loss: 0.4639 - acc: 0.8333
Epoch 31/50
24/24 [==============================] - 0s - loss: 0.4487 - acc: 0.8333
Epoch 32/50
24/24 [==============================] - 0s - loss: 0.4825 - acc: 0.8333
Epoch 33/50
24/24 [==============================] - 0s - loss: 0.4811 - acc: 0.8333
Epoch 34/50
24/24 [==============================] - 0s - loss: 0.4678 - acc: 0.8333
Epoch 35/50
24/24 [==============================] - 0s - loss: 0.4447 - acc: 0.8333
Epoch 36/50
24/24 [==============================] - 0s - loss: 0.4182 - acc: 0.8333
Epoch 37/50
24/24 [==============================] - 0s - loss: 0.4401 - acc: 0.8333
Epoch 38/50
24/24 [==============================] - 0s - loss: 0.3984 - acc: 0.8333
Epoch 39/50
24/24 [==============================] - 0s - loss: 0.4578 - acc: 0.8333
Epoch 40/50
24/24 [==============================] - 0s - loss: 0.4232 - acc: 0.8333
Epoch 41/50
24/24 [==============================] - 0s - loss: 0.4517 - acc: 0.8333
Epoch 42/50
24/24 [==============================] - 0s - loss: 0.4737 - acc: 0.8333
Epoch 43/50
24/24 [==============================] - 0s - loss: 0.3670 - acc: 0.8333
Epoch 44/50
24/24 [==============================] - 0s - loss: 0.4164 - acc: 0.8333
Epoch 45/50
24/24 [==============================] - 0s - loss: 0.4500 - acc: 0.8333
Epoch 46/50
24/24 [==============================] - 0s - loss: 0.4243 - acc: 0.8333
Epoch 47/50
24/24 [==============================] - 0s - loss: 0.3858 - acc: 0.8333
Epoch 48/50
24/24 [==============================] - 0s - loss: 0.4698 - acc: 0.8333
Epoch 49/50
24/24 [==============================] - 0s - loss: 0.4444 - acc: 0.8333
Epoch 50/50
24/24 [==============================] - 0s - loss: 0.4683 - acc: 0.8333
roc_auc_score val 0.930555555556
[ 0.16673177 0.83326823 0.1581582 0.84184182 0.29502279 0.70497721
0.1026172 0.8973828 0.20092192 0.79907811 0.10272423 0.89727575]
[0 1 0 1 0 1 0 1 1 0 0 1]
roc_auc_score val 0.916666666667