from keras.datasets import mnist
import numpy as np
(x_train, y_train), (x_test, y_test) = mnist.load_data()
def get_one_digit(digit):
assert digit in range(10)
X = []
for label, value in zip(y_train, x_train):
if label == digit:
X.append(value)
y = np.full(len(X), digit)
return X, y
X_2, y_2 = get_one_digit(2)