{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "ONN_Usage_Example.ipynb", "version": "0.3.2", "provenance": [], "collapsed_sections": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "ugTEiBmBV3Cc", "colab_type": "text" }, "source": [ "#Installing Dependencies" ] }, { "cell_type": "code", "metadata": { "id": "FbEX3YFwKCxy", "colab_type": "code", "outputId": "ce12c48c-690e-49cb-c8a1-b31f66872022", "colab": { "base_uri": "https://localhost:8080/", "height": 85 } }, "source": [ "!pip install --upgrade onn" ], "execution_count": 7, "outputs": [ { "output_type": "stream", "text": [ "Requirement already up-to-date: onn in /usr/local/lib/python3.6/dist-packages (0.1.8)\n", "Requirement already satisfied, skipping upgrade: torch in /usr/local/lib/python3.6/dist-packages (from onn) (1.1.0)\n", "Requirement already satisfied, skipping upgrade: numpy in /usr/local/lib/python3.6/dist-packages (from onn) (1.16.4)\n", "Requirement already satisfied, skipping upgrade: mabalgs in /usr/local/lib/python3.6/dist-packages (from onn) (0.6.4)\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "CCBv8CqCe5Ao", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 102 }, "outputId": "a9c5bc86-af25-4b3a-b827-a5e22c1b2a1d" }, "source": [ "!pip install -U imbalanced-learn" ], "execution_count": 8, "outputs": [ { "output_type": "stream", "text": [ "Requirement already up-to-date: imbalanced-learn in /usr/local/lib/python3.6/dist-packages (0.5.0)\n", "Requirement already satisfied, skipping upgrade: joblib>=0.11 in /usr/local/lib/python3.6/dist-packages (from imbalanced-learn) (0.13.2)\n", "Requirement already satisfied, skipping upgrade: numpy>=1.11 in /usr/local/lib/python3.6/dist-packages (from imbalanced-learn) (1.16.4)\n", "Requirement already satisfied, skipping upgrade: scipy>=0.17 in /usr/local/lib/python3.6/dist-packages (from imbalanced-learn) (1.3.0)\n", "Requirement already satisfied, skipping upgrade: scikit-learn>=0.21 in /usr/local/lib/python3.6/dist-packages (from imbalanced-learn) (0.21.2)\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "XbYfcIpkV-SA", "colab_type": "text" }, "source": [ "##Importing Dependencies" ] }, { "cell_type": "code", "metadata": { "id": "9xT-ZdVsKH6z", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "outputId": "899e5370-ad08-40cc-8241-9a577717d828" }, "source": [ "from onn.OnlineNeuralNetwork import ONN\n", "from onn.OnlineNeuralNetwork import ONN_THS\n", "from sklearn.datasets import make_classification, make_circles\n", "from sklearn.model_selection import train_test_split\n", "import torch\n", "from sklearn.metrics import accuracy_score, balanced_accuracy_score\n", "from imblearn.datasets import make_imbalance\n", "import numpy as np" ], "execution_count": 34, "outputs": [ { "output_type": "stream", "text": [ "Using TensorFlow backend.\n" ], "name": "stderr" } ] }, { "cell_type": "markdown", "metadata": { "id": "N9H22wxBWB9f", "colab_type": "text" }, "source": [ "## Initializing Network" ] }, { "cell_type": "code", "metadata": { "id": "1wOHgHL1LieT", "colab_type": "code", "colab": {} }, "source": [ "onn_network = ONN(features_size=10, max_num_hidden_layers=5, qtd_neuron_per_hidden_layer=40, n_classes=10)" ], "execution_count": 0, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "eYqkIyxyWI2h", "colab_type": "text" }, "source": [ "##Creating Fake Classification Dataset" ] }, { "cell_type": "code", "metadata": { "id": "WXgNSF9gL69F", "colab_type": "code", "colab": {} }, "source": [ "X, Y = make_classification(n_samples=50000, n_features=10, n_informative=4, n_redundant=0, n_classes=10,\n", " n_clusters_per_class=1, class_sep=3)\n", "X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.3, random_state=42, shuffle=True)" ], "execution_count": 0, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "THGPFSWJWPEm", "colab_type": "text" }, "source": [ "##Learning and predicting at the same time" ] }, { "cell_type": "code", "metadata": { "id": "70J3ZYtmL-Zm", "colab_type": "code", "outputId": "15babeef-1128-44b7-8dcf-92d4ef3d295a", "colab": { "base_uri": "https://localhost:8080/", "height": 1000 } }, "source": [ "for i in range(len(X_train)):\n", " onn_network.partial_fit(np.asarray([X_train[i, :]]), np.asarray([y_train[i]]))\n", " \n", " if i % 1000 == 0:\n", " predictions = onn_network.predict(X_test)\n", " print(\"Online Accuracy: {}\".format(balanced_accuracy_score(y_test, predictions)))" ], "execution_count": 25, "outputs": [ { "output_type": "stream", "text": [ "Online Accuracy: 0.14337461746314914\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83890086 0.04027476 0.04027476 0.04027476 0.04027476]\n", "Training Loss: 1.296051\n", "Online Accuracy: 0.9606844639729234\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.839583 0.04010423 0.04010423 0.04010423 0.04010423]\n", "Training Loss: 0.42200255\n", "Online Accuracy: 0.9587807527275185\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8393374 0.04016563 0.04016563 0.04016563 0.04016563]\n", "Training Loss: 0.3702089\n", "Online Accuracy: 0.967292853621438\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83981675 0.0400458 0.0400458 0.0400458 0.0400458 ]\n", "Training Loss: 0.26317188\n", "Online Accuracy: 0.9654591078178111\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83988035 0.04002989 0.04002989 0.04002989 0.04002989]\n", "Training Loss: 0.25282928\n", "Online Accuracy: 0.9665110941847939\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8398668 0.04003327 0.04003327 0.04003327 0.04003327]\n", "Training Loss: 0.21191452\n", "Online Accuracy: 0.9710897029043102\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83997285 0.04000676 0.04000676 0.04000676 0.04000676]\n", "Training Loss: 0.21445769\n", "Online Accuracy: 0.9719590088952164\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8397224 0.04006939 0.04006939 0.04006939 0.04006939]\n", "Training Loss: 0.19888349\n", "Online Accuracy: 0.9704734746855797\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399309 0.04001725 0.04001725 0.04001725 0.04001725]\n", "Training Loss: 0.19780262\n", "Online Accuracy: 0.9682571684488185\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8398869 0.04002826 0.04002826 0.04002826 0.04002826]\n", "Training Loss: 0.21107836\n", "Online Accuracy: 0.9717444211680417\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399914 0.04000212 0.04000212 0.04000212 0.04000212]\n", "Training Loss: 0.18996303\n", "Online Accuracy: 0.9731342027566192\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.839978 0.04000548 0.04000548 0.04000548 0.04000548]\n", "Training Loss: 0.19749445\n", "Online Accuracy: 0.9732475060493468\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.839696 0.04007598 0.04007598 0.04007598 0.04007598]\n", "Training Loss: 0.19404268\n", "Online Accuracy: 0.9723553377085323\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8398478 0.04003802 0.04003802 0.04003802 0.04003802]\n", "Training Loss: 0.18190624\n", "Online Accuracy: 0.9743039644477742\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83997566 0.04000606 0.04000606 0.04000606 0.04000606]\n", "Training Loss: 0.1828396\n", "Online Accuracy: 0.9738457498877151\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8392171 0.0401957 0.0401957 0.0401957 0.0401957]\n", "Training Loss: 0.19953582\n", "Online Accuracy: 0.9732231868948086\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399872 0.04000318 0.04000318 0.04000318 0.04000318]\n", "Training Loss: 0.16184519\n", "Online Accuracy: 0.9735012003132812\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8394533 0.04013667 0.04013667 0.04013667 0.04013667]\n", "Training Loss: 0.1797122\n", "Online Accuracy: 0.9746659982580329\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399609 0.04000976 0.04000976 0.04000976 0.04000976]\n", "Training Loss: 0.19841474\n", "Online Accuracy: 0.9751102633737601\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8398771 0.0400307 0.0400307 0.0400307 0.0400307]\n", "Training Loss: 0.1930369\n", "Online Accuracy: 0.9740513941554421\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399562 0.04001092 0.04001092 0.04001092 0.04001092]\n", "Training Loss: 0.21371904\n", "Online Accuracy: 0.9755335557702551\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399458 0.04001352 0.04001352 0.04001352 0.04001352]\n", "Training Loss: 0.19402196\n", "Online Accuracy: 0.9758116920399523\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.839987 0.04000323 0.04000323 0.04000323 0.04000323]\n", "Training Loss: 0.17364302\n", "Online Accuracy: 0.9754555741342109\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83992887 0.04001775 0.04001775 0.04001775 0.04001775]\n", "Training Loss: 0.14595963\n", "Online Accuracy: 0.9754666701388206\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399415 0.04001459 0.04001459 0.04001459 0.04001459]\n", "Training Loss: 0.16576274\n", "Online Accuracy: 0.9737672492447718\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399788 0.04000528 0.04000528 0.04000528 0.04000528]\n", "Training Loss: 0.17223129\n", "Online Accuracy: 0.9755318323234476\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8398563 0.0400359 0.0400359 0.0400359 0.0400359]\n", "Training Loss: 0.1655618\n", "Online Accuracy: 0.9763829709632855\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83996785 0.04000802 0.04000802 0.04000802 0.04000802]\n", "Training Loss: 0.15412873\n", "Online Accuracy: 0.976408073580625\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83986807 0.04003297 0.04003297 0.04003297 0.04003297]\n", "Training Loss: 0.19262518\n", "Online Accuracy: 0.9757346501029929\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83998775 0.04000303 0.04000303 0.04000303 0.04000303]\n", "Training Loss: 0.16019145\n", "Online Accuracy: 0.9764235355129427\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399023 0.04002441 0.04002441 0.04002441 0.04002441]\n", "Training Loss: 0.15920378\n", "Online Accuracy: 0.9768755609369231\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83994466 0.04001381 0.04001381 0.04001381 0.04001381]\n", "Training Loss: 0.12754413\n", "Online Accuracy: 0.9754246065697441\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83994985 0.04001251 0.04001251 0.04001251 0.04001251]\n", "Training Loss: 0.1729848\n", "Online Accuracy: 0.9768689984007384\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83995646 0.04001086 0.04001086 0.04001086 0.04001086]\n", "Training Loss: 0.21214448\n", "Online Accuracy: 0.9769247559047265\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83999026 0.04000241 0.04000241 0.04000241 0.04000241]\n", "Training Loss: 0.133252\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "_K01fvy7Tj7Z", "colab_type": "text" }, "source": [ "# Learning in batch with CUDA" ] }, { "cell_type": "code", "metadata": { "id": "Jg1UJWvDTjJc", "colab_type": "code", "outputId": "b917847a-45a3-445f-c527-3d449e1de377", "colab": { "base_uri": "https://localhost:8080/", "height": 34 } }, "source": [ "onn_network = ONN(features_size=10, max_num_hidden_layers=5, qtd_neuron_per_hidden_layer=40, n_classes=10, batch_size=10, use_cuda=True)" ], "execution_count": 26, "outputs": [ { "output_type": "stream", "text": [ "Using CUDA :]\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "Uoxk5Jk_UYKJ", "colab_type": "code", "colab": {} }, "source": [ "from torch.utils.data import Dataset, DataLoader\n", "class Dataset(Dataset):\n", "\n", " def __init__(self, X, Y):\n", " self.X = X\n", " self.Y = Y\n", "\n", " def __len__(self):\n", " return len(self.X)\n", "\n", " def __getitem__(self, idx):\n", " X = self.X[idx],\n", " Y = self.Y[idx]\n", "\n", " return X, Y\n", " \n", "transformed_dataset = Dataset(X_train, y_train)\n", "dataloader = DataLoader(transformed_dataset, batch_size=10,shuffle=True, num_workers=1)" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "8jz4S-SVUAMx", "colab_type": "code", "outputId": "eb25d974-e8af-4e51-e9f0-5f2501b02b56", "colab": { "base_uri": "https://localhost:8080/", "height": 170 } }, "source": [ "for local_X, local_y in dataloader: \n", " onn_network.partial_fit(np.squeeze(torch.stack(local_X).numpy()), local_y.numpy())" ], "execution_count": 28, "outputs": [ { "output_type": "stream", "text": [ "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83902884 0.04024278 0.04024278 0.04024278 0.04024278]\n", "Training Loss: 1.577592\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8394673 0.04013318 0.04013318 0.04013318 0.04013318]\n", "Training Loss: 0.5495532\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8397412 0.0400647 0.0400647 0.0400647 0.0400647]\n", "Training Loss: 0.40501082\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "u7XwSO6zVpsz", "colab_type": "code", "outputId": "6f19825b-4435-442c-c710-4003ef7b555c", "colab": { "base_uri": "https://localhost:8080/", "height": 34 } }, "source": [ "predictions = onn_network.predict(X_test)\n", "print(\"Accuracy: {}\".format(balanced_accuracy_score(y_test, predictions)))" ], "execution_count": 29, "outputs": [ { "output_type": "stream", "text": [ "Accuracy: 0.9517950352535276\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "DSk4Fl4GV9NG", "colab_type": "text" }, "source": [ "#Using contextual bandit - ONN_THS" ] }, { "cell_type": "markdown", "metadata": { "id": "vBweUP_CZaFz", "colab_type": "text" }, "source": [ "In this example the ONN acts like a contextual bandits a reinforcement learning algorithm type. " ] }, { "cell_type": "code", "metadata": { "id": "TeFBn4erUDY3", "colab_type": "code", "colab": {} }, "source": [ "X_linear, Y_linear = make_classification(n_samples=10000, n_features=2, n_informative=2, n_redundant=0, n_classes=2, n_clusters_per_class=1, class_sep=200, shuffle=True)\n", "X_non_linear, Y_non_linear = make_circles(n_samples=10000, noise=0.1, factor=0.3, shuffle=True)\n", "X_linear_2, Y_linear_2 = make_classification(n_samples=10000, n_features=2, n_informative=2, n_redundant=0, n_classes=2, n_clusters_per_class=1, class_sep=200, shuffle=True)\n", "\n", "X_linear_train = X_linear[:5000]\n", "Y_linear_train = Y_linear[:5000]\n", "\n", "X_linear_test = X_linear[5000:]\n", "Y_linear_test = Y_linear[5000:]\n", "\n", "X_non_linear_train = X_non_linear[:5000]\n", "Y_non_linear_train = Y_non_linear[:5000]\n", "\n", "X_non_linear_test = X_non_linear[5000:]\n", "Y_non_linear_test = Y_non_linear[5000:]\n", "\n", "X_linear_train_2 = X_linear_2[:5000]\n", "Y_linear_train_2 = Y_linear_2[:5000]\n", "\n", "X_linear_test_2 = X_linear_2[5000:]\n", "Y_linear_test_2 = Y_linear_2[5000:]" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "lfnN1xG7WtF7", "colab_type": "code", "colab": {} }, "source": [ "gp = ONN_THS(2, 5, 100, 2)" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "7EmKF-y8Wytt", "colab_type": "code", "outputId": "b6b7b943-d466-4274-cd11-e1325290412b", "colab": { "base_uri": "https://localhost:8080/", "height": 1000 } }, "source": [ "for epoch in range(5):\n", "\n", " for i in range(len(X_linear_train)):\n", " x = np.asarray([X_linear_train[i, :]])\n", " y = np.asarray([Y_linear_train[i]])\n", "\n", " arm, exp = gp.predict(x)\n", " \n", " if arm == y[0]: \n", " gp.partial_fit(x, y, exp)\n", " \n", " if i % 2000 == 1999:\n", " pred = []\n", " print(\"======================================================\")\n", " for i in range(len(X_linear_test)): \n", " pred.append(gp.predict(np.asarray([X_linear_test[i, :]]))[0])\n", " print(\"Accuracy: \" + str(balanced_accuracy_score(Y_linear_test, pred)))\n", " print(\"======================================================\")\n", "\n", "print('Finished Training')" ], "execution_count": 32, "outputs": [ { "output_type": "stream", "text": [ "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 7.708073e-07\n", "======================================================\n", "Accuracy: 0.9320135479448636\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.66480255 0.08414558 0.08314086 0.08373952 0.0841715 ]\n", "Training Loss: 1.4113159\n", "======================================================\n", "Accuracy: 0.7803689523703463\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.4597675 0.1319707 0.13887632 0.13379806 0.13558747]\n", "Training Loss: 2.1171865\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.8254292\n", "======================================================\n", "Accuracy: 0.8239034376118035\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 1.9921635\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.09616726 0.22576231 0.22466676 0.22605874 0.2273449 ]\n", "Training Loss: 1.2512604\n", "======================================================\n", "Accuracy: 0.9194272332380344\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.7483906 0.06154997 0.0657564 0.0615586 0.06274442]\n", "Training Loss: 1.2462758\n", "======================================================\n", "Accuracy: 0.9089153835864292\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 1.1584326\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 1.0460584\n", "======================================================\n", "Accuracy: 0.9091254701747447\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.04581921 0.22431232 0.26757908 0.22434013 0.23794925]\n", "Training Loss: 1.4130515\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.00093426596\n", "======================================================\n", "Accuracy: 0.9230967883152815\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 1.7778293\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.9200642132137263\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.1015042 0.22482882 0.22428374 0.22482881 0.22455448]\n", "Training Loss: 1.1210397\n", "======================================================\n", "Accuracy: 0.9140752253990685\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.00011777431\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.13615824 0.21326509 0.22357734 0.21326567 0.21373361]\n", "Training Loss: 1.4623655\n", "======================================================\n", "Accuracy: 0.9162808945688536\n", "======================================================\n", "Finished Training\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "G6myrkQjW1kj", "colab_type": "code", "outputId": "873af04f-482d-4ae5-89fe-8df04f21cc38", "colab": { "base_uri": "https://localhost:8080/", "height": 1000 } }, "source": [ "for epoch in range(5):\n", "\n", " for i in range(len(X_non_linear_train)):\n", " x = np.asarray([X_non_linear_train[i, :]])\n", " y = np.asarray([Y_non_linear_train[i]])\n", "\n", " arm, exp = gp.predict(x)\n", " \n", " if arm == y[0]: \n", " gp.partial_fit(x, y, exp)\n", " \n", " if i % 2000 == 1999:\n", " pred = []\n", " print(\"======================================================\")\n", " for i in range(len(X_linear_test)): \n", " pred.append(gp.predict(np.asarray([X_non_linear_test[i, :]]))[0])\n", " print(\"Accuracy: \" + str(balanced_accuracy_score(Y_non_linear_test, pred)))\n", " print(\"======================================================\")\n", "\n", "print('Finished Training')" ], "execution_count": 33, "outputs": [ { "output_type": "stream", "text": [ "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.7324713 0.06278172 0.07632934 0.0627818 0.06563586]\n", "Training Loss: 0.565434\n", "======================================================\n", "Accuracy: 0.4978415196546431\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83837765 0.04006734 0.04134858 0.04006734 0.04013911]\n", "Training Loss: 0.32934776\n", "======================================================\n", "Accuracy: 0.6431012228961956\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83944225 0.04013941 0.04013941 0.04013941 0.04013941]\n", "Training Loss: 0.27433354\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83982235 0.04004439 0.04004439 0.04004439 0.04004439]\n", "Training Loss: 0.2834035\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83984315 0.04003919 0.04003919 0.04003919 0.04003919]\n", "Training Loss: 0.2099059\n", "======================================================\n", "Accuracy: 0.7659370825499332\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8395842 0.04010393 0.04010393 0.04010393 0.04010393]\n", "Training Loss: 0.21292916\n", "======================================================\n", "Accuracy: 0.9151972664315626\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83984536 0.04003864 0.04003864 0.04003864 0.04003864]\n", "Training Loss: 0.19620447\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8398821 0.04002946 0.04002946 0.04002946 0.04002946]\n", "Training Loss: 0.16505778\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83968943 0.04007762 0.04007762 0.04007762 0.04007762]\n", "Training Loss: 0.14461489\n", "======================================================\n", "Accuracy: 0.9197904671664747\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399311 0.04001722 0.04001722 0.04001722 0.04001722]\n", "Training Loss: 0.124461606\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8396594 0.04008514 0.04008514 0.04008514 0.04008514]\n", "Training Loss: 0.115024015\n", "======================================================\n", "Accuracy: 0.9301949488311918\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8398307 0.04004231 0.04004231 0.04004231 0.04004231]\n", "Training Loss: 0.10779811\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8398185 0.04004535 0.04004535 0.04004535 0.04004535]\n", "Training Loss: 0.099823594\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399364 0.04001588 0.04001588 0.04001588 0.04001588]\n", "Training Loss: 0.092471905\n", "======================================================\n", "Accuracy: 0.8999838239974118\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.839942 0.04001448 0.04001448 0.04001448 0.04001448]\n", "Training Loss: 0.08162724\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83992773 0.04001804 0.04001804 0.04001804 0.04001804]\n", "Training Loss: 0.0805736\n", "======================================================\n", "Accuracy: 0.9387956702073073\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83965665 0.04008581 0.04008581 0.04008581 0.04008581]\n", "Training Loss: 0.081078894\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8397386 0.04006533 0.04006533 0.04006533 0.04006533]\n", "Training Loss: 0.072354406\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399574 0.04001063 0.04001063 0.04001063 0.04001063]\n", "Training Loss: 0.06984673\n", "======================================================\n", "Accuracy: 0.9489987918398066\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399927 0.0400018 0.0400018 0.0400018 0.0400018]\n", "Training Loss: 0.06360293\n", "======================================================\n", "Accuracy: 0.9428014308482289\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83996755 0.0400081 0.0400081 0.0400081 0.0400081 ]\n", "Training Loss: 0.06630126\n", "Finished Training\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "G27Vh6mI6LTi", "colab_type": "text" }, "source": [ "# Imbalanced Dataset" ] }, { "cell_type": "code", "metadata": { "id": "E-IHNagYXgap", "colab_type": "code", "colab": {} }, "source": [ "X, Y = make_classification(n_samples=110000, n_features=20, n_classes=10, n_informative=8, n_redundant=0, n_clusters_per_class=1, class_sep=900)" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "O5Mh3jR-6TWR", "colab_type": "code", "colab": {} }, "source": [ "X_t, Y_t = make_imbalance(X, Y, sampling_strategy={0: 800, 1: 5000, 2: 10000, 3: 10000, 4: 1000, 5: 1000, 6: 500, 7: 10000, 8: 5000, 9:5000})" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "ryU12keU6WOp", "colab_type": "code", "colab": {} }, "source": [ "X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.2)\n", "\n", "X_train_t, X_test_t, y_train_t, y_test_t = train_test_split(X_t, Y_t, test_size=0.2)" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "Si9uRxRu7Ecw", "colab_type": "code", "colab": {} }, "source": [ "gp = ONN_THS(20, 5, 100, 10)" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "a21-OBBC6YrK", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "outputId": "9285bd62-10a5-4f71-b8ec-22097ae3a808" }, "source": [ "for i in range(len(X_train)):\n", " x = np.asarray([X_train[i, :]])\n", " y = np.asarray([y_train[i]])\n", "\n", " arm, exp = gp.predict(x)\n", "\n", " if arm == y[0]: \n", " gp.partial_fit(x, y, exp)\n", "\n", " if i % 2000 == 1999:\n", " pred = []\n", " print(\"======================================================\")\n", " for i in range(len(X_test)): \n", " pred.append(gp.predict(np.asarray([X_test[i, :]]))[0])\n", " print(\"Accuracy: \" + str(balanced_accuracy_score(y_test, pred)))\n", " print(\"======================================================\")\n", "\n", "print('Finished Training')" ], "execution_count": 49, "outputs": [ { "output_type": "stream", "text": [ "======================================================\n", "Accuracy: 0.4353221256829899\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.06869241 0.16809322 0.3725306 0.17724589 0.21343793]\n", "Training Loss: 9.325638\n", "======================================================\n", "Accuracy: 0.4386085948153132\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.11643752 0.20428126 0.23810413 0.20844749 0.23272966]\n", "Training Loss: 11.25701\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 2.6811903\n", "======================================================\n", "Accuracy: 0.6928950497293518\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 10.146265\n", "======================================================\n", "Accuracy: 0.6838257488564606\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 8.005989\n", "======================================================\n", "Accuracy: 0.6115088279471166\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.79145914 0.05235159 0.05036381 0.05197832 0.05384716]\n", "Training Loss: 3.7819355\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 8.093227\n", "======================================================\n", "Accuracy: 0.7768358673631839\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.7728323692342458\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.7750311850166554\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.23982884 0.19055171 0.18932813 0.19059701 0.18969429]\n", "Training Loss: 13.883859\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.8685216303769551\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.8026281358448346\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 17.250746\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 2.5108016\n", "======================================================\n", "Accuracy: 0.8674449681421013\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.8713122081601006\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.86805856622274\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.8673653690806932\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.8709340126479518\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 26.645742\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.5200308 0.12016081 0.11971524 0.12033709 0.11975605]\n", "Training Loss: 5.1447864\n", "======================================================\n", "Accuracy: 0.7919937075157115\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 16.11202\n", "======================================================\n", "Accuracy: 0.8733754651900446\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.8717504131399775\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.873564110770424\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 15.066921\n", "======================================================\n", "Accuracy: 0.7295902241565401\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 2.8158276\n", "======================================================\n", "Accuracy: 0.8713796789076305\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 22.253164\n", "======================================================\n", "Accuracy: 0.8771001088342661\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.53331 0.11533871 0.11965422 0.11555222 0.1161448 ]\n", "Training Loss: 13.497225\n", "======================================================\n", "Accuracy: 0.8746053463042536\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 4.758543\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.8741468052285819\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.8795648644911169\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 16.302721\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 3.1496372\n", "======================================================\n", "Accuracy: 0.8837012263713927\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.902540151229202\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.9021177354976991\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 8.878667\n", "======================================================\n", "Accuracy: 0.8079567207927605\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.8039658939000326\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 6.192695\n", "======================================================\n", "Accuracy: 0.9009930184879561\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 10.917292\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.900008182367175\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.7559461 0.06102359 0.0612076 0.0608501 0.06097254]\n", "Training Loss: 10.377898\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.9055045133877494\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.9021639784287858\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 10.628485\n", "======================================================\n", "Accuracy: 0.9016796783336091\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 8.878321\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.8994004273359604\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.9050839928229234\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.9044701066116776\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.9047566712793573\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.9063798424927105\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.9008264580869225\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.9001214892044537\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.902945101154504\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.9009983483663652\n", "======================================================\n", "Finished Training\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "eqBRyZxh-6_h", "colab_type": "code", "colab": {} }, "source": [ "gp = ONN_THS(20, 5, 100, 10)" ], "execution_count": 0, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "mFFz2xWN6o_4", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "outputId": "ee7d349a-2efb-44b3-d44c-7a1976c903a3" }, "source": [ "for i in range(len(X_train_t)):\n", " x = np.asarray([X_train_t[i, :]])\n", " y = np.asarray([y_train_t[i]])\n", "\n", " arm, exp = gp.predict(x)\n", "\n", " if arm == y[0]: \n", " gp.partial_fit(x, y, exp)\n", "\n", " if i % 2000 == 1999:\n", " pred = []\n", " print(\"======================================================\")\n", " for i in range(len(X_test_t)): \n", " pred.append(gp.predict(np.asarray([X_test_t[i, :]]))[0])\n", " print(\"Accuracy: \" + str(balanced_accuracy_score(y_test_t, pred)))\n", " print(\"======================================================\")\n", "\n", "print('Finished Training')" ], "execution_count": 51, "outputs": [ { "output_type": "stream", "text": [ "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.31641227 0.12936862 0.17681654 0.18663488 0.19076766]\n", "Training Loss: 4.9937906\n", "======================================================\n", "Accuracy: 0.5121877357778365\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.11614554 0.1702754 0.2646138 0.1885041 0.2604611 ]\n", "Training Loss: 9.42681\n", "======================================================\n", "Accuracy: 0.60000532007458\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 3.760187\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.6370007505514537\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.6198465258532846\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 1.579348\n", "======================================================\n", "Accuracy: 0.5121875529066693\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 9.70304\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.6330198059531864\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 10.2148285\n", "======================================================\n", "Accuracy: 0.7084962849287858\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.7265182796040035\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.7239405864519876\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.12717529\n", "======================================================\n", "Accuracy: 0.810706833334715\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 9.516822\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 2.7376916\n", "======================================================\n", "Accuracy: 0.8972895582247269\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 8.428832\n", "======================================================\n", "Accuracy: 0.90217251261295\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 11.3074665\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.9124533480316206\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.9035376710667149\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.9045845614725639\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.6930886 0.07199327 0.08099582 0.07299026 0.08093201]\n", "Training Loss: 12.486347\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.8146383732762196\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.8232097488414242\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 6.5098047\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.9154437246746809\n", "======================================================\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399999 0.04 0.04 0.04 0.04 ]\n", "Training Loss: 0.0\n", "======================================================\n", "Accuracy: 0.9127977126517685\n", "======================================================\n", "Finished Training\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "X3IBIV0WAqKI", "colab_type": "code", "colab": {} }, "source": [ "" ], "execution_count": 0, "outputs": [] } ] }