Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
张 建国
onn在线训练
Commits
98404e54
Commit
98404e54
authored
Jun 17, 2019
by
Alison Carrera
Browse files
Changes in network params.
parent
2589b988
Changes
2
Show whitespace changes
Inline
Side-by-side
onn/OnlineNeuralNetwork.py
View file @
98404e54
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.nn.parameter
import
Parameter
class
ONN
():
class
ONN
(
nn
.
Module
):
def
__init__
(
self
,
features_size
,
max_num_hidden_layers
,
qtd_neuron_per_hidden_layer
,
n_classes
,
batch_size
=
1
,
b
=
0.99
,
n
=
0.01
,
s
=
0.2
,
use_cuda
=
False
):
super
(
ONN
,
self
).
__init__
()
...
...
@@ -18,9 +19,9 @@ class ONN():
self
.
qtd_neuron_per_hidden_layer
=
qtd_neuron_per_hidden_layer
self
.
n_classes
=
n_classes
self
.
batch_size
=
batch_size
self
.
b
=
torch
.
tensor
(
b
).
to
(
self
.
device
)
self
.
n
=
torch
.
tensor
(
n
).
to
(
self
.
device
)
self
.
s
=
torch
.
tensor
(
s
).
to
(
self
.
device
)
self
.
b
=
Parameter
(
torch
.
tensor
(
b
)
)
.
to
(
self
.
device
)
self
.
n
=
Parameter
(
torch
.
tensor
(
n
)
)
.
to
(
self
.
device
)
self
.
s
=
Parameter
(
torch
.
tensor
(
s
)
)
.
to
(
self
.
device
)
self
.
hidden_layers
=
[]
self
.
output_layers
=
[]
...
...
@@ -36,7 +37,7 @@ class ONN():
self
.
hidden_layers
=
nn
.
ModuleList
(
self
.
hidden_layers
).
to
(
self
.
device
)
self
.
output_layers
=
nn
.
ModuleList
(
self
.
output_layers
).
to
(
self
.
device
)
self
.
alpha
=
torch
.
Tensor
(
self
.
max_num_hidden_layers
).
fill_
(
1
/
(
self
.
max_num_hidden_layers
+
1
)).
to
(
self
.
alpha
=
Parameter
(
torch
.
Tensor
(
self
.
max_num_hidden_layers
).
fill_
(
1
/
(
self
.
max_num_hidden_layers
+
1
))
)
.
to
(
self
.
device
)
self
.
loss_array
=
[]
...
...
@@ -82,7 +83,7 @@ class ONN():
z_t
=
torch
.
sum
(
self
.
alpha
)
self
.
alpha
=
self
.
alpha
/
z_t
self
.
alpha
=
Parameter
(
self
.
alpha
/
z_t
).
to
(
self
.
device
)
if
show_loss
:
real_output
=
torch
.
sum
(
torch
.
mul
(
...
...
setup.py
View file @
98404e54
...
...
@@ -6,7 +6,7 @@ with open(path.join(this_directory, 'README.md'), encoding='utf-8') as f:
long_description
=
f
.
read
()
setup
(
name
=
'onn'
,
version
=
'0.0.
6
'
,
version
=
'0.0.
7
'
,
description
=
'Online Neural Network'
,
url
=
'https://github.com/alison-carrera/onn'
,
author
=
'Alison Carrera'
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment