{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "ONN_Usage_Example.ipynb", "version": "0.3.2", "provenance": [], "collapsed_sections": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "ugTEiBmBV3Cc", "colab_type": "text" }, "source": [ "#Installing Dependencies" ] }, { "cell_type": "code", "metadata": { "id": "FbEX3YFwKCxy", "colab_type": "code", "outputId": "1e3a42c0-eeec-4c9c-f318-3335e5f228bf", "colab": { "base_uri": "https://localhost:8080/", "height": 68 } }, "source": [ "!pip install --upgrade onn" ], "execution_count": 1, "outputs": [ { "output_type": "stream", "text": [ "Requirement already up-to-date: onn in /usr/local/lib/python3.6/dist-packages (0.1.2)\n", "Requirement already satisfied, skipping upgrade: numpy in /usr/local/lib/python3.6/dist-packages (from onn) (1.16.4)\n", "Requirement already satisfied, skipping upgrade: torch in /usr/local/lib/python3.6/dist-packages (from onn) (1.1.0)\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "XbYfcIpkV-SA", "colab_type": "text" }, "source": [ "##Importing Dependencies" ] }, { "cell_type": "code", "metadata": { "id": "9xT-ZdVsKH6z", "colab_type": "code", "colab": {} }, "source": [ "from onn.OnlineNeuralNetwork import ONN\n", "from onn.OnlineNeuralNetwork import ONN_THS\n", "from sklearn.datasets import make_classification, make_circles\n", "from sklearn.model_selection import train_test_split\n", "import torch\n", "from sklearn.metrics import accuracy_score\n", "import numpy as np" ], "execution_count": 0, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "N9H22wxBWB9f", "colab_type": "text" }, "source": [ "## Initializing Network" ] }, { "cell_type": "code", "metadata": { "id": "1wOHgHL1LieT", "colab_type": "code", "colab": {} }, "source": [ "onn_network = ONN(features_size=10, max_num_hidden_layers=5, qtd_neuron_per_hidden_layer=40, n_classes=10)" ], "execution_count": 0, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "eYqkIyxyWI2h", "colab_type": "text" }, "source": [ "##Creating Fake Classification Dataset" ] }, { "cell_type": "code", "metadata": { "id": "WXgNSF9gL69F", "colab_type": "code", "colab": {} }, "source": [ "X, Y = make_classification(n_samples=50000, n_features=10, n_informative=4, n_redundant=0, n_classes=10,\n", " n_clusters_per_class=1, class_sep=3)\n", "X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.3, random_state=42, shuffle=True)" ], "execution_count": 0, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "THGPFSWJWPEm", "colab_type": "text" }, "source": [ "##Learning and predicting at the same time" ] }, { "cell_type": "code", "metadata": { "id": "70J3ZYtmL-Zm", "colab_type": "code", "outputId": "5a7433d8-7744-488e-8166-c3c310ee93ab", "colab": { "base_uri": "https://localhost:8080/", "height": 2397 } }, "source": [ "for i in range(len(X_train)):\n", " onn_network.partial_fit(np.asarray([X_train[i, :]]), np.asarray([y_train[i]]))\n", " \n", " if i % 1000 == 0:\n", " predictions = onn_network.predict(X_test)\n", " print(\"Online Accuracy: {}\".format(accuracy_score(y_test, predictions)))" ], "execution_count": 5, "outputs": [ { "output_type": "stream", "text": [ "Online Accuracy: 0.14513333333333334\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8394298 0.04014252 0.04014252 0.04014252 0.04014252]\n", "Training Loss: 1.3401858\n", "Online Accuracy: 0.9665333333333334\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83981496 0.04004624 0.04004624 0.04004624 0.04004624]\n", "Training Loss: 0.47074658\n", "Online Accuracy: 0.9748\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83985054 0.04003733 0.04003733 0.04003733 0.04003733]\n", "Training Loss: 0.3416147\n", "Online Accuracy: 0.9767333333333333\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399445 0.04001385 0.04001385 0.04001385 0.04001385]\n", "Training Loss: 0.27886173\n", "Online Accuracy: 0.9772\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.839963 0.04000923 0.04000923 0.04000923 0.04000923]\n", "Training Loss: 0.31251636\n", "Online Accuracy: 0.9772666666666666\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8397189 0.04007026 0.04007026 0.04007026 0.04007026]\n", "Training Loss: 0.22847703\n", "Online Accuracy: 0.9803333333333333\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83356273 0.04162288 0.04159797 0.04160155 0.04161485]\n", "Training Loss: 0.26110825\n", "Online Accuracy: 0.9798\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8396703 0.0400824 0.0400824 0.0400824 0.0400824]\n", "Training Loss: 0.20787847\n", "Online Accuracy: 0.98\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83994615 0.04001344 0.04001344 0.04001344 0.04001344]\n", "Training Loss: 0.2071073\n", "Online Accuracy: 0.98\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8381557 0.04046106 0.04046106 0.04046106 0.04046106]\n", "Training Loss: 0.21403871\n", "Online Accuracy: 0.9800666666666666\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83974886 0.04006276 0.04006276 0.04006276 0.04006276]\n", "Training Loss: 0.17885084\n", "Online Accuracy: 0.9809333333333333\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8398784 0.04003038 0.04003038 0.04003038 0.04003038]\n", "Training Loss: 0.15060541\n", "Online Accuracy: 0.9815333333333334\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399662 0.04000843 0.04000843 0.04000843 0.04000843]\n", "Training Loss: 0.17218196\n", "Online Accuracy: 0.981\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83996993 0.04000749 0.04000749 0.04000749 0.04000749]\n", "Training Loss: 0.18151304\n", "Online Accuracy: 0.9811333333333333\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83979076 0.04005228 0.04005228 0.04005228 0.04005228]\n", "Training Loss: 0.18345681\n", "Online Accuracy: 0.9824\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83995855 0.04001034 0.04001034 0.04001034 0.04001034]\n", "Training Loss: 0.17578402\n", "Online Accuracy: 0.9825333333333334\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83998764 0.04000307 0.04000307 0.04000307 0.04000307]\n", "Training Loss: 0.14615656\n", "Online Accuracy: 0.9819333333333333\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399176 0.04002057 0.04002057 0.04002057 0.04002057]\n", "Training Loss: 0.19433188\n", "Online Accuracy: 0.9819333333333333\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83998096 0.04000473 0.04000473 0.04000473 0.04000473]\n", "Training Loss: 0.121442005\n", "Online Accuracy: 0.982\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83997864 0.04000532 0.04000532 0.04000532 0.04000532]\n", "Training Loss: 0.20980458\n", "Online Accuracy: 0.9822666666666666\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399815 0.04000461 0.04000461 0.04000461 0.04000461]\n", "Training Loss: 0.13572827\n", "Online Accuracy: 0.9821333333333333\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83946246 0.04013437 0.04013437 0.04013437 0.04013437]\n", "Training Loss: 0.16413513\n", "Online Accuracy: 0.9827333333333333\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399841 0.04000396 0.04000396 0.04000396 0.04000396]\n", "Training Loss: 0.17003122\n", "Online Accuracy: 0.9837333333333333\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83970624 0.04007342 0.04007342 0.04007342 0.04007342]\n", "Training Loss: 0.15054016\n", "Online Accuracy: 0.9827333333333333\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.839985 0.04000372 0.04000372 0.04000372 0.04000372]\n", "Training Loss: 0.15771861\n", "Online Accuracy: 0.9835333333333334\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399664 0.04000837 0.04000837 0.04000837 0.04000837]\n", "Training Loss: 0.19387966\n", "Online Accuracy: 0.9825333333333334\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.82004756 0.04498808 0.04498808 0.04498808 0.04498808]\n", "Training Loss: 0.11493891\n", "Online Accuracy: 0.9834666666666667\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399602 0.04000993 0.04000993 0.04000993 0.04000993]\n", "Training Loss: 0.15605311\n", "Online Accuracy: 0.9832666666666666\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399928 0.04000178 0.04000178 0.04000178 0.04000178]\n", "Training Loss: 0.117902525\n", "Online Accuracy: 0.9833333333333333\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8397992 0.04005016 0.04005016 0.04005016 0.04005016]\n", "Training Loss: 0.17483151\n", "Online Accuracy: 0.9842\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83968043 0.04007987 0.04007987 0.04007987 0.04007987]\n", "Training Loss: 0.1470215\n", "Online Accuracy: 0.9840666666666666\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83984363 0.04003906 0.04003906 0.04003906 0.04003906]\n", "Training Loss: 0.13816771\n", "Online Accuracy: 0.9832\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8398159 0.04004599 0.04004599 0.04004599 0.04004599]\n", "Training Loss: 0.14629075\n", "Online Accuracy: 0.9836\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399813 0.04000464 0.04000464 0.04000464 0.04000464]\n", "Training Loss: 0.1586352\n", "Online Accuracy: 0.9838666666666667\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83991635 0.04002089 0.04002089 0.04002089 0.04002089]\n", "Training Loss: 0.11979295\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "_K01fvy7Tj7Z", "colab_type": "text" }, "source": [ "# Learning in batch with CUDA" ] }, { "cell_type": "code", "metadata": { "id": "Jg1UJWvDTjJc", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "outputId": "e966d790-4a05-4eac-9ea3-0f015a5ddd65" }, "source": [ "onn_network = ONN(features_size=10, max_num_hidden_layers=5, qtd_neuron_per_hidden_layer=40, n_classes=10, batch_size=10, use_cuda=True)" ], "execution_count": 6, "outputs": [ { "output_type": "stream", "text": [ "Using CUDA :]\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "Uoxk5Jk_UYKJ", "colab_type": "code", "colab": {} }, "source": [ "from torch.utils.data import Dataset, DataLoader\n", "class Dataset(Dataset):\n", "\n", " def __init__(self, X, Y):\n", " self.X = X\n", " self.Y = Y\n", "\n", " def __len__(self):\n", " return len(self.X)\n", "\n", " def __getitem__(self, idx):\n", " X = self.X[idx],\n", " Y = self.Y[idx]\n", "\n", " return X, Y\n", " \n", "transformed_dataset = Dataset(X_train, y_train)\n", "dataloader = DataLoader(transformed_dataset, batch_size=10,shuffle=True, num_workers=1)" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "8jz4S-SVUAMx", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 221 }, "outputId": "daca2c03-0ef9-4c36-952f-5663f2ffd965" }, "source": [ "for local_X, local_y in dataloader: \n", " onn_network.partial_fit(np.squeeze(torch.stack(local_X).numpy()), local_y.numpy())" ], "execution_count": 9, "outputs": [ { "output_type": "stream", "text": [ "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83882004 0.04029498 0.04029498 0.04029498 0.04029498]\n", "Training Loss: 0.32651797\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.839726 0.0400685 0.0400685 0.0400685 0.0400685]\n", "Training Loss: 0.27955872\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83984923 0.04003768 0.04003768 0.04003768 0.04003768]\n", "Training Loss: 0.26284653\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8398874 0.04002814 0.04002814 0.04002814 0.04002814]\n", "Training Loss: 0.2542592\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "u7XwSO6zVpsz", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "outputId": "e604a7f9-a5b2-4631-e284-74d9e765e6a0" }, "source": [ "predictions = onn_network.predict(X_test)\n", "print(\"Accuracy: {}\".format(accuracy_score(y_test, predictions)))" ], "execution_count": 10, "outputs": [ { "output_type": "stream", "text": [ "Accuracy: 0.9712\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "DSk4Fl4GV9NG", "colab_type": "text" }, "source": [ "#Using contextual bandit - ONN_THS" ] }, { "cell_type": "markdown", "metadata": { "id": "vBweUP_CZaFz", "colab_type": "text" }, "source": [ "In this example the ONN acts like a contextual bandits a reinforcement learning algorithm type. " ] }, { "cell_type": "code", "metadata": { "id": "TeFBn4erUDY3", "colab_type": "code", "colab": {} }, "source": [ "X_linear, Y_linear = make_classification(n_samples=10000, n_features=2, n_informative=2, n_redundant=0, n_classes=2, n_clusters_per_class=1, class_sep=200, shuffle=True)\n", "X_non_linear, Y_non_linear = make_circles(n_samples=10000, noise=0.1, factor=0.3, shuffle=True)\n", "X_linear_2, Y_linear_2 = make_classification(n_samples=10000, n_features=2, n_informative=2, n_redundant=0, n_classes=2, n_clusters_per_class=1, class_sep=200, shuffle=True)\n", "\n", "X_linear_train = X_linear[:5000]\n", "Y_linear_train = Y_linear[:5000]\n", "\n", "X_linear_test = X_linear[5000:]\n", "Y_linear_test = Y_linear[5000:]\n", "\n", "X_non_linear_train = X_non_linear[:5000]\n", "Y_non_linear_train = Y_non_linear[:5000]\n", "\n", "X_non_linear_test = X_non_linear[5000:]\n", "Y_non_linear_test = Y_non_linear[5000:]\n", "\n", "X_linear_train_2 = X_linear_2[:5000]\n", "Y_linear_train_2 = Y_linear_2[:5000]\n", "\n", "X_linear_test_2 = X_linear_2[5000:]\n", "Y_linear_test_2 = Y_linear_2[5000:]" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "lfnN1xG7WtF7", "colab_type": "code", "colab": {} }, "source": [ "gp = ONN_THS(2, 5, 100, 2)" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "7EmKF-y8Wytt", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 1615 }, "outputId": "9a472b1d-d9da-4f2a-f3b1-e2d46e493b0f" }, "source": [ "for epoch in range(5):\n", "\n", " for i in range(len(X_linear_train)):\n", " x = np.asarray([X_linear_train[i, :]])\n", " y = np.asarray([Y_linear_train[i]])\n", "\n", " arm, exp = gp.predict(x)\n", " \n", " if arm == y[0]: \n", " gp.partial_fit(x, y, exp)\n", " \n", " if i % 2000 == 1999:\n", " pred = []\n", " print(\"======================================================\")\n", " for i in range(len(X_linear_test)): \n", " pred.append(gp.predict(np.asarray([X_linear_test[i, :]]))[0])\n", " print(\"Accuracy: \" + str(accuracy_score(Y_linear_test, pred)))\n", " print(\"======================================================\")\n", "\n", "print('Finished Training')" ], "execution_count": 5, "outputs": [ { "output_type": "stream", "text": [ "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.03783625\n", "======================================================\n", "Accuracy: 0.938\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.1704272 0.20562138 0.20954046 0.20842516 0.20598587]\n", "Training Loss: 2.8511312\n", "======================================================\n", "Accuracy: 0.842\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.04529483 0.2372146 0.23887365 0.23777498 0.24084195]\n", "Training Loss: 2.9067016\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.04934153 0.23565038 0.23894446 0.23790586 0.23815775]\n", "Training Loss: 0.89143544\n", "======================================================\n", "Accuracy: 0.7946\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.18905842 0.2011414 0.2039329 0.20293479 0.20293246]\n", "Training Loss: 0.638691\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 1.2998967\n", "======================================================\n", "Accuracy: 0.9372\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.5044\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.20965716 0.15684147 0.3117969 0.15653124 0.1651732 ]\n", "Training Loss: 2.4024343\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.9426\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.04311252 0.2404183 0.23603962 0.24040945 0.24002013]\n", "Training Loss: 1.1431745\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.21438715\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.941\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.3749174 0.1546514 0.16150749 0.15464756 0.15427625]\n", "Training Loss: 1.3335986\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.943\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.9456\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.9434\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "Finished Training\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "G6myrkQjW1kj", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 1462 }, "outputId": "1f54e155-988c-45a4-bf33-998e3698df16" }, "source": [ "for epoch in range(5):\n", "\n", " for i in range(len(X_non_linear_train)):\n", " x = np.asarray([X_non_linear_train[i, :]])\n", " y = np.asarray([Y_non_linear_train[i]])\n", "\n", " arm, exp = gp.predict(x)\n", " \n", " if arm == y[0]: \n", " gp.partial_fit(x, y, exp)\n", " \n", " if i % 2000 == 1999:\n", " pred = []\n", " print(\"======================================================\")\n", " for i in range(len(X_linear_test)): \n", " pred.append(gp.predict(np.asarray([X_non_linear_test[i, :]]))[0])\n", " print(\"Accuracy: \" + str(accuracy_score(Y_non_linear_test, pred)))\n", " print(\"======================================================\")\n", "\n", "print('Finished Training')" ], "execution_count": 6, "outputs": [ { "output_type": "stream", "text": [ "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8394417 0.0401068 0.040231 0.04010678 0.04011366]\n", "Training Loss: 0.057519287\n", "======================================================\n", "Accuracy: 0.4964\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83992 0.04001539 0.04003379 0.04001539 0.04001539]\n", "Training Loss: 0.18476209\n", "======================================================\n", "Accuracy: 0.4928\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8389778 0.04001202 0.04098608 0.04001202 0.04001202]\n", "Training Loss: 0.14300129\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8385863 0.04002043 0.04079157 0.04002043 0.04058118]\n", "Training Loss: 0.14015199\n", "======================================================\n", "Accuracy: 0.4926\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8388008 0.0400123 0.04064631 0.0400123 0.04052823]\n", "Training Loss: 0.1454754\n", "======================================================\n", "Accuracy: 0.49\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83798873 0.04046615 0.04053768 0.04046615 0.04054123]\n", "Training Loss: 0.15644525\n", "======================================================\n", "Accuracy: 0.502\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8398817 0.04001858 0.04005715 0.04001858 0.04002398]\n", "Training Loss: 0.16174349\n", "======================================================\n", "Accuracy: 0.8878\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83951735 0.04012064 0.04012064 0.04012064 0.04012064]\n", "Training Loss: 0.19150087\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8395833 0.04010417 0.04010417 0.04010417 0.04010417]\n", "Training Loss: 0.2572823\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8396205 0.04009486 0.04009486 0.04009486 0.04009486]\n", "Training Loss: 0.21250293\n", "======================================================\n", "Accuracy: 0.9468\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83983326 0.04004166 0.04004166 0.04004166 0.04004166]\n", "Training Loss: 0.17772931\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83942896 0.04014274 0.04014274 0.04014274 0.04014274]\n", "Training Loss: 0.15651754\n", "======================================================\n", "Accuracy: 0.9484\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83984494 0.04003874 0.04003874 0.04003874 0.04003874]\n", "Training Loss: 0.1337454\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.839861 0.04003473 0.04003473 0.04003473 0.04003473]\n", "Training Loss: 0.12318811\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399298 0.04001752 0.04001752 0.04001752 0.04001752]\n", "Training Loss: 0.11616875\n", "======================================================\n", "Accuracy: 0.9522\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399066 0.04002333 0.04002333 0.04002333 0.04002333]\n", "Training Loss: 0.10676191\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83982676 0.04004328 0.04004328 0.04004328 0.04004328]\n", "Training Loss: 0.09901671\n", "======================================================\n", "Accuracy: 0.9498\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399164 0.04002087 0.04002087 0.04002087 0.04002087]\n", "Training Loss: 0.09151248\n", "Finished Training\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "E-IHNagYXgap", "colab_type": "code", "colab": {} }, "source": [ "" ], "execution_count": 0, "outputs": [] } ] }