Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
张 建国
onn在线训练
Commits
3831ecd4
Commit
3831ecd4
authored
Jun 17, 2019
by
alison-carrera
Browse files
Added an exploration factor parameter for testing.
parent
98404e54
Changes
3
Show whitespace changes
Inline
Side-by-side
.gitignore
View file @
3831ecd4
...
...
@@ -102,3 +102,5 @@ venv.bak/
# mypy
.mypy_cache/
/.idea
onn/OnlineNeuralNetwork.py
View file @
3831ecd4
import
random
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
...
...
@@ -6,7 +9,7 @@ from torch.nn.parameter import Parameter
class
ONN
(
nn
.
Module
):
def
__init__
(
self
,
features_size
,
max_num_hidden_layers
,
qtd_neuron_per_hidden_layer
,
n_classes
,
batch_size
=
1
,
b
=
0.99
,
n
=
0.01
,
s
=
0.2
,
use_cuda
=
False
):
b
=
0.99
,
n
=
0.01
,
s
=
0.2
,
e
=
0.5
,
use_cuda
=
False
,
use_exploration
=
False
):
super
(
ONN
,
self
).
__init__
()
if
torch
.
cuda
.
is_available
()
and
use_cuda
:
...
...
@@ -22,6 +25,9 @@ class ONN(nn.Module):
self
.
b
=
Parameter
(
torch
.
tensor
(
b
)).
to
(
self
.
device
)
self
.
n
=
Parameter
(
torch
.
tensor
(
n
)).
to
(
self
.
device
)
self
.
s
=
Parameter
(
torch
.
tensor
(
s
)).
to
(
self
.
device
)
self
.
e
=
Parameter
(
torch
.
tensor
(
e
)).
to
(
self
.
device
)
self
.
arms_values
=
np
.
arange
(
n_classes
).
tolist
()
self
.
use_exploration
=
use_exploration
self
.
hidden_layers
=
[]
self
.
output_layers
=
[]
...
...
@@ -143,4 +149,10 @@ class ONN(nn.Module):
self
.
max_num_hidden_layers
,
len
(
X_data
),
1
),
self
.
forward
(
X_data
)),
0
),
dim
=
1
).
cpu
().
numpy
()
def
predict
(
self
,
X_data
):
return
self
.
predict_
(
X_data
)
pred
=
self
.
predict_
(
X_data
)
if
self
.
use_exploration
and
np
.
random
.
uniform
()
<
self
.
e
:
removed_arms
=
self
.
arms_values
.
copy
()
removed_arms
.
remove
(
pred
)
return
random
.
choice
(
removed_arms
)
return
pred
setup.py
View file @
3831ecd4
...
...
@@ -6,7 +6,7 @@ with open(path.join(this_directory, 'README.md'), encoding='utf-8') as f:
long_description
=
f
.
read
()
setup
(
name
=
'onn'
,
version
=
'0.0.
7
'
,
version
=
'0.0.
8
'
,
description
=
'Online Neural Network'
,
url
=
'https://github.com/alison-carrera/onn'
,
author
=
'Alison Carrera'
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment