Putting it all together
Now, we are going to perform the same procedure as before, except that we will reset, regroup, and try a new algorithm: K-Nearest Neighbors (KNN).
How to do it...
- Start by importing the model from
sklearn
, followed by a balanced split:
from sklearn.neighbors import KNeighborsClassifier X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state = 0)
Note
The random_state
parameter fixes the random_seed
in the function train_test_split
. In the preceding example, the random_state
is set to zero and can be set to any integer.
- Construct two different KNN models by varying the
n_neighbors
parameter. Observe that the number of folds is now 10. Tenfold cross-validation is common in the machine learning community, particularly in data science competitions:
from sklearn.model_selection import cross_val_score knn_3_clf = KNeighborsClassifier(n_neighbors = 3) knn_5_clf = KNeighborsClassifier(n_neighbors = 5) knn_3_scores = cross_val_score(knn_3_clf, X_train, y_train, cv=10) knn_5_scores = cross_val_score(knn_5_clf, X_train, y_train, cv=10)
- Score and print out the scores for selection:
print "knn_3 mean scores: ", knn_3_scores.mean(), "knn_3 std: ",knn_3_scores.std() print "knn_5 mean scores: ", knn_5_scores.mean(), " knn_5 std: ",knn_5_scores.std() knn_3 mean scores: 0.798333333333 knn_3 std: 0.0908142181722 knn_5 mean scores: 0.806666666667 knn_5 std: 0.0559320575496
Both nearest neighbor types score similarly, yet the KNN with parameter n_neighbors = 5
is a bit more stable. This is an example of hyperparameter optimization which we will examine closely throughout the book.
There's more...
You could have just as easily run a simple loop to score the function more quickly:
all_scores = [] for n_neighbors in range(3,9,1): knn_clf = KNeighborsClassifier(n_neighbors = n_neighbors) all_scores.append((n_neighbors, cross_val_score(knn_clf, X_train, y_train, cv=10).mean())) sorted(all_scores, key = lambda x:x[1], reverse = True)
Its output suggests that n_neighbors = 4
is a good choice:
[(4, 0.85111111111111115), (7, 0.82611111111111113), (6, 0.82333333333333347), (5, 0.80666666666666664), (3, 0.79833333333333334), (8, 0.79833333333333334)]