Commit 98404e54 authored by Alison Carrera's avatar Alison Carrera
Browse files

Changes in network params.

parent 2589b988
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.parameter import Parameter
class ONN(): class ONN(nn.Module):
def __init__(self, features_size, max_num_hidden_layers, qtd_neuron_per_hidden_layer, n_classes, batch_size=1, def __init__(self, features_size, max_num_hidden_layers, qtd_neuron_per_hidden_layer, n_classes, batch_size=1,
b=0.99, n=0.01, s=0.2, use_cuda=False): b=0.99, n=0.01, s=0.2, use_cuda=False):
super(ONN, self).__init__() super(ONN, self).__init__()
...@@ -18,9 +19,9 @@ class ONN(): ...@@ -18,9 +19,9 @@ class ONN():
self.qtd_neuron_per_hidden_layer = qtd_neuron_per_hidden_layer self.qtd_neuron_per_hidden_layer = qtd_neuron_per_hidden_layer
self.n_classes = n_classes self.n_classes = n_classes
self.batch_size = batch_size self.batch_size = batch_size
self.b = torch.tensor(b).to(self.device) self.b = Parameter(torch.tensor(b)).to(self.device)
self.n = torch.tensor(n).to(self.device) self.n = Parameter(torch.tensor(n)).to(self.device)
self.s = torch.tensor(s).to(self.device) self.s = Parameter(torch.tensor(s)).to(self.device)
self.hidden_layers = [] self.hidden_layers = []
self.output_layers = [] self.output_layers = []
...@@ -36,7 +37,7 @@ class ONN(): ...@@ -36,7 +37,7 @@ class ONN():
self.hidden_layers = nn.ModuleList(self.hidden_layers).to(self.device) self.hidden_layers = nn.ModuleList(self.hidden_layers).to(self.device)
self.output_layers = nn.ModuleList(self.output_layers).to(self.device) self.output_layers = nn.ModuleList(self.output_layers).to(self.device)
self.alpha = torch.Tensor(self.max_num_hidden_layers).fill_(1 / (self.max_num_hidden_layers + 1)).to( self.alpha = Parameter(torch.Tensor(self.max_num_hidden_layers).fill_(1 / (self.max_num_hidden_layers + 1))).to(
self.device) self.device)
self.loss_array = [] self.loss_array = []
...@@ -82,7 +83,7 @@ class ONN(): ...@@ -82,7 +83,7 @@ class ONN():
z_t = torch.sum(self.alpha) z_t = torch.sum(self.alpha)
self.alpha = self.alpha / z_t self.alpha = Parameter(self.alpha / z_t).to(self.device)
if show_loss: if show_loss:
real_output = torch.sum(torch.mul( real_output = torch.sum(torch.mul(
......
...@@ -6,7 +6,7 @@ with open(path.join(this_directory, 'README.md'), encoding='utf-8') as f: ...@@ -6,7 +6,7 @@ with open(path.join(this_directory, 'README.md'), encoding='utf-8') as f:
long_description = f.read() long_description = f.read()
setup(name='onn', setup(name='onn',
version='0.0.6', version='0.0.7',
description='Online Neural Network', description='Online Neural Network',
url='https://github.com/alison-carrera/onn', url='https://github.com/alison-carrera/onn',
author='Alison Carrera', author='Alison Carrera',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment