{ "cells": [ { "cell_type": "markdown", "id": "f5c4abc0", "metadata": {}, "source": [ "# Intel® Extension for Scikit-learn KNN for MNIST dataset" ] }, { "cell_type": "code", "execution_count": 1, "id": "23512089", "metadata": {}, "outputs": [], "source": [ "from timeit import default_timer as timer\n", "from IPython.display import HTML\n", "from sklearn import metrics\n", "from sklearn.datasets import fetch_openml\n", "from sklearn.model_selection import train_test_split" ] }, { "cell_type": "markdown", "id": "b6e359f6", "metadata": {}, "source": [ "### Download the data " ] }, { "cell_type": "code", "execution_count": 2, "id": "27b99b44", "metadata": {}, "outputs": [], "source": [ "x, y = fetch_openml(name=\"mnist_784\", return_X_y=True)" ] }, { "cell_type": "markdown", "id": "6259f584", "metadata": {}, "source": [ "Split the data into train and test sets" ] }, { "cell_type": "code", "execution_count": 3, "id": "96e14dd7", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((56000, 784), (14000, 784), (56000,), (14000,))" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=72)\n", "x_train.shape, x_test.shape, y_train.shape, y_test.shape" ] }, { "cell_type": "markdown", "id": "0341cac9", "metadata": {}, "source": [ "### Patch original Scikit-learn with Intel® Extension for Scikit-learn\n", "Intel® Extension for Scikit-learn (previously known as daal4py) contains drop-in replacement functionality for the stock Scikit-learn package. You can take advantage of the performance optimizations of Intel® Extension for Scikit-learn by adding just two lines of code before the usual Scikit-learn imports:" ] }, { "cell_type": "code", "execution_count": 4, "id": "244c5bc9", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Intel(R) Extension for Scikit-learn* enabled (https://github.com/uxlfoundation/scikit-learn-intelex)\n" ] } ], "source": [ "from sklearnex import patch_sklearn\n", "\n", "patch_sklearn()" ] }, { "cell_type": "markdown", "id": "6bb14ac8", "metadata": {}, "source": [ "Intel® Extension for Scikit-learn patching affects performance of specific Scikit-learn functionality. Refer to the [list of supported algorithms and parameters](https://uxlfoundation.github.io/scikit-learn-intelex/latest/algorithms.html) for details. In cases when unsupported parameters are used, the package fallbacks into original Scikit-learn. If the patching does not cover your scenarios, [submit an issue on GitHub](https://github.com/uxlfoundation/scikit-learn-intelex/issues)." ] }, { "cell_type": "markdown", "id": "693b4e26", "metadata": {}, "source": [ "Training and predict KNN algorithm with Intel® Extension for Scikit-learn for MNIST dataset" ] }, { "cell_type": "code", "execution_count": 5, "id": "e9b8f06b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'Intel® extension for Scikit-learn time: 1.45 s'" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.neighbors import KNeighborsClassifier\n", "\n", "params = {\"n_neighbors\": 40, \"weights\": \"distance\", \"n_jobs\": -1}\n", "start = timer()\n", "knn = KNeighborsClassifier(**params).fit(x_train, y_train)\n", "predicted = knn.predict(x_test)\n", "time_opt = timer() - start\n", "f\"Intel® extension for Scikit-learn time: {time_opt:.2f} s\"" ] }, { "cell_type": "code", "execution_count": 6, "id": "8ca549ae", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Classification report for Intel® extension for Scikit-learn KNN:\n", " precision recall f1-score support\n", "\n", " 0 0.97 0.99 0.98 1365\n", " 1 0.93 0.99 0.96 1637\n", " 2 0.99 0.94 0.96 1401\n", " 3 0.96 0.95 0.96 1455\n", " 4 0.98 0.96 0.97 1380\n", " 5 0.95 0.95 0.95 1219\n", " 6 0.96 0.99 0.97 1317\n", " 7 0.94 0.95 0.95 1420\n", " 8 0.99 0.90 0.94 1379\n", " 9 0.92 0.94 0.93 1427\n", "\n", " accuracy 0.96 14000\n", " macro avg 0.96 0.96 0.96 14000\n", "weighted avg 0.96 0.96 0.96 14000\n", "\n", "\n" ] } ], "source": [ "report = metrics.classification_report(y_test, predicted)\n", "print(f\"Classification report for Intel® extension for Scikit-learn KNN:\\n{report}\\n\")" ] }, { "cell_type": "markdown", "id": "bd8e7b0b", "metadata": {}, "source": [ "*The first column of the classification report above is the class labels.* \n", " \n", "### Train the same algorithm with original Scikit-learn\n", "In order to cancel optimizations, we use *unpatch_sklearn* and reimport the class KNeighborsClassifier." ] }, { "cell_type": "code", "execution_count": 7, "id": "5bb884d5", "metadata": {}, "outputs": [], "source": [ "from sklearnex import unpatch_sklearn\n", "\n", "unpatch_sklearn()" ] }, { "cell_type": "markdown", "id": "8cfa0dba", "metadata": {}, "source": [ "Training and predict KNN algorithm with original Scikit-learn library for MNSIT dataset" ] }, { "cell_type": "code", "execution_count": 8, "id": "ae421d8e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'Original Scikit-learn time: 36.15 s'" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.neighbors import KNeighborsClassifier\n", "\n", "\n", "start = timer()\n", "knn = KNeighborsClassifier(**params).fit(x_train, y_train)\n", "predicted = knn.predict(x_test)\n", "time_original = timer() - start\n", "f\"Original Scikit-learn time: {time_original:.2f} s\"" ] }, { "cell_type": "code", "execution_count": 9, "id": "33da9fd1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Classification report for original Scikit-learn KNN:\n", " precision recall f1-score support\n", "\n", " 0 0.97 0.99 0.98 1365\n", " 1 0.93 0.99 0.96 1637\n", " 2 0.99 0.94 0.96 1401\n", " 3 0.96 0.95 0.96 1455\n", " 4 0.98 0.96 0.97 1380\n", " 5 0.95 0.95 0.95 1219\n", " 6 0.96 0.99 0.97 1317\n", " 7 0.94 0.95 0.95 1420\n", " 8 0.99 0.90 0.94 1379\n", " 9 0.92 0.94 0.93 1427\n", "\n", " accuracy 0.96 14000\n", " macro avg 0.96 0.96 0.96 14000\n", "weighted avg 0.96 0.96 0.96 14000\n", "\n", "\n" ] } ], "source": [ "report = metrics.classification_report(y_test, predicted)\n", "print(f\"Classification report for original Scikit-learn KNN:\\n{report}\\n\")" ] }, { "cell_type": "code", "execution_count": 10, "id": "ffd79e96", "metadata": {}, "outputs": [ { "data": { "text/html": [ "