{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "ONN_Usage_Example.ipynb", "version": "0.3.2", "provenance": [], "collapsed_sections": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" } }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "ugTEiBmBV3Cc", "colab_type": "text" }, "source": [ "#Installing Dependencies" ] }, { "cell_type": "code", "metadata": { "id": "FbEX3YFwKCxy", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 187 }, "outputId": "acb57a9f-52e2-46b4-de1e-f4d2c0fc0d94" }, "source": [ "!pip install onn" ], "execution_count": 1, "outputs": [ { "output_type": "stream", "text": [ "Collecting onn\n", " Downloading https://files.pythonhosted.org/packages/a2/98/4900eee97e9df045febcbce0b330f28145db5c05a83da0384666228466a8/onn-0.0.4.tar.gz\n", "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from onn) (1.16.4)\n", "Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (from onn) (1.1.0)\n", "Building wheels for collected packages: onn\n", " Building wheel for onn (setup.py) ... \u001b[?25l\u001b[?25hdone\n", " Stored in directory: /root/.cache/pip/wheels/71/9f/46/14b8f23801d65ea3edd4eb9a165651526d4f12ff1d6efe8257\n", "Successfully built onn\n", "Installing collected packages: onn\n", "Successfully installed onn-0.0.4\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "XbYfcIpkV-SA", "colab_type": "text" }, "source": [ "##Importing Dependencies" ] }, { "cell_type": "code", "metadata": { "id": "9xT-ZdVsKH6z", "colab_type": "code", "colab": {} }, "source": [ "from onn.OnlineNeuralNetwork import ONN\n", "from sklearn.datasets import make_classification\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.metrics import accuracy_score\n", "import numpy as np" ], "execution_count": 0, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "N9H22wxBWB9f", "colab_type": "text" }, "source": [ "## Initializing Network" ] }, { "cell_type": "code", "metadata": { "id": "1wOHgHL1LieT", "colab_type": "code", "colab": {} }, "source": [ "onn_network = ONN(features_size=10, max_num_hidden_layers=5, qtd_neuron_per_hidden_layer=40, n_classes=10)" ], "execution_count": 0, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "eYqkIyxyWI2h", "colab_type": "text" }, "source": [ "##Creating Fake Classification Dataset" ] }, { "cell_type": "code", "metadata": { "id": "WXgNSF9gL69F", "colab_type": "code", "colab": {} }, "source": [ "X, Y = make_classification(n_samples=50000, n_features=10, n_informative=4, n_redundant=0, n_classes=10,\n", " n_clusters_per_class=1, class_sep=3)\n", "X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.3, random_state=42, shuffle=True)" ], "execution_count": 0, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "THGPFSWJWPEm", "colab_type": "text" }, "source": [ "##Learning and predicting at the same time" ] }, { "cell_type": "code", "metadata": { "id": "70J3ZYtmL-Zm", "colab_type": "code", "colab": { "base_uri": "https://localhost:8080/", "height": 2397 }, "outputId": "55e294ed-61e6-49d6-8e4d-c5f9c4d4b901" }, "source": [ "for i in range(len(X_train)):\n", " onn_network.partial_fit(np.asarray([X_train[i, :]]), np.asarray([y_train[i]]))\n", " \n", " if i % 1000 == 0:\n", " predictions = onn_network.predict(X_test)\n", " print(\"Online Accuracy: {}\".format(accuracy_score(y_test, predictions)))" ], "execution_count": 26, "outputs": [ { "output_type": "stream", "text": [ "Online Accuracy: 0.05466666666666667\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83948636 0.04012838 0.04012838 0.04012838 0.04012838]\n", "Training Loss: 1.3012835\n", "Online Accuracy: 0.9549333333333333\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8393235 0.0401691 0.0401691 0.0401691 0.0401691]\n", "Training Loss: 0.43253973\n", "Online Accuracy: 0.9649333333333333\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8398507 0.0400373 0.0400373 0.0400373 0.0400373]\n", "Training Loss: 0.35385495\n", "Online Accuracy: 0.9657333333333333\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8398936 0.04002658 0.04002658 0.04002658 0.04002658]\n", "Training Loss: 0.26460257\n", "Online Accuracy: 0.9705333333333334\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83983797 0.04004049 0.04004049 0.04004049 0.04004049]\n", "Training Loss: 0.2707248\n", "Online Accuracy: 0.9728666666666667\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399278 0.04001803 0.04001803 0.04001803 0.04001803]\n", "Training Loss: 0.23593011\n", "Online Accuracy: 0.9732666666666666\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83984274 0.04003929 0.04003929 0.04003929 0.04003929]\n", "Training Loss: 0.22243583\n", "Online Accuracy: 0.9753333333333334\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83947986 0.04013 0.04013 0.04013 0.04013 ]\n", "Training Loss: 0.22021466\n", "Online Accuracy: 0.9766666666666667\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399528 0.04001177 0.04001177 0.04001177 0.04001177]\n", "Training Loss: 0.19627546\n", "Online Accuracy: 0.9751333333333333\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83973163 0.04006707 0.04006707 0.04006707 0.04006707]\n", "Training Loss: 0.23148519\n", "Online Accuracy: 0.9762666666666666\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8398566 0.04003582 0.04003582 0.04003582 0.04003582]\n", "Training Loss: 0.19100852\n", "Online Accuracy: 0.9776\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.839655 0.04008622 0.04008622 0.04008622 0.04008622]\n", "Training Loss: 0.19581231\n", "Online Accuracy: 0.9774\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8313776 0.04210529 0.04216187 0.0421773 0.0421779 ]\n", "Training Loss: 0.21734852\n", "Online Accuracy: 0.9776\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399382 0.04001542 0.04001542 0.04001542 0.04001542]\n", "Training Loss: 0.23634802\n", "Online Accuracy: 0.9773333333333334\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83993286 0.04001676 0.04001676 0.04001676 0.04001676]\n", "Training Loss: 0.18388465\n", "Online Accuracy: 0.9785333333333334\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83987343 0.04003161 0.04003161 0.04003161 0.04003161]\n", "Training Loss: 0.1682849\n", "Online Accuracy: 0.9784\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8398625 0.04003434 0.04003434 0.04003434 0.04003434]\n", "Training Loss: 0.21148238\n", "Online Accuracy: 0.9794\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83989906 0.04002522 0.04002522 0.04002522 0.04002522]\n", "Training Loss: 0.14424549\n", "Online Accuracy: 0.9782\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83997196 0.04000699 0.04000699 0.04000699 0.04000699]\n", "Training Loss: 0.17481777\n", "Online Accuracy: 0.9798666666666667\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399623 0.0400094 0.0400094 0.0400094 0.0400094]\n", "Training Loss: 0.14369117\n", "Online Accuracy: 0.9787333333333333\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399842 0.04000393 0.04000393 0.04000393 0.04000393]\n", "Training Loss: 0.21523248\n", "Online Accuracy: 0.9789333333333333\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83994156 0.04001458 0.04001458 0.04001458 0.04001458]\n", "Training Loss: 0.1409239\n", "Online Accuracy: 0.9803333333333333\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83994156 0.04001459 0.04001459 0.04001459 0.04001459]\n", "Training Loss: 0.14340058\n", "Online Accuracy: 0.9812\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399909 0.04000225 0.04000225 0.04000225 0.04000225]\n", "Training Loss: 0.1858233\n", "Online Accuracy: 0.9808666666666667\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399545 0.04001136 0.04001136 0.04001136 0.04001136]\n", "Training Loss: 0.12744027\n", "Online Accuracy: 0.9806\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83992904 0.04001772 0.04001772 0.04001772 0.04001772]\n", "Training Loss: 0.13657537\n", "Online Accuracy: 0.9815333333333334\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399902 0.04000243 0.04000243 0.04000243 0.04000243]\n", "Training Loss: 0.14688537\n", "Online Accuracy: 0.9813333333333333\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399494 0.04001262 0.04001262 0.04001262 0.04001262]\n", "Training Loss: 0.17533518\n", "Online Accuracy: 0.9790666666666666\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8390823 0.0402294 0.0402294 0.0402294 0.0402294]\n", "Training Loss: 0.17110217\n", "Online Accuracy: 0.9812666666666666\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399557 0.04001106 0.04001106 0.04001106 0.04001106]\n", "Training Loss: 0.16559279\n", "Online Accuracy: 0.9814666666666667\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83982074 0.04004479 0.04004479 0.04004479 0.04004479]\n", "Training Loss: 0.13938081\n", "Online Accuracy: 0.9813333333333333\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399806 0.04000482 0.04000482 0.04000482 0.04000482]\n", "Training Loss: 0.18100421\n", "Online Accuracy: 0.9804666666666667\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83995837 0.04001039 0.04001039 0.04001039 0.04001039]\n", "Training Loss: 0.15145776\n", "Online Accuracy: 0.9819333333333333\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.8399076 0.04002308 0.04002308 0.04002308 0.04002308]\n", "Training Loss: 0.14798096\n", "Online Accuracy: 0.9817333333333333\n", "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n", "Alpha:[0.83984584 0.04003852 0.04003852 0.04003852 0.04003852]\n", "Training Loss: 0.13722865\n" ], "name": "stdout" } ] } ] }