Commit 32eee798 authored by Alison Carrera's avatar Alison Carrera
Browse files

Added a way to export and load all params from network in a json way. It is...

Added a way to export and load all params from network in a json way. It is useful to save network state at redis.
parent f6583d7f
import collections
import json
import random import random
import numpy as np import numpy as np
...@@ -150,6 +152,21 @@ class ONN(nn.Module): ...@@ -150,6 +152,21 @@ class ONN(nn.Module):
pred = self.predict_(X_data) pred = self.predict_(X_data)
return pred return pred
def export_params_to_json(self):
state_dict = self.state_dict()
params_gp = {}
for key, tensor in state_dict.items():
params_gp[key] = tensor.cpu().numpy().tolist()
return json.dumps(params_gp)
def load_params_from_json(self, json_data):
params = json.loads(json_data)
o_dict = collections.OrderedDict()
for key, tensor in params.items():
o_dict[key] = torch.tensor(tensor).to(self.device)
self.load_state_dict(o_dict)
class ONN_THS(ONN): class ONN_THS(ONN):
def __init__(self, features_size, max_num_hidden_layers, qtd_neuron_per_hidden_layer, n_classes, b=0.99, n=0.01, def __init__(self, features_size, max_num_hidden_layers, qtd_neuron_per_hidden_layer, n_classes, b=0.99, n=0.01,
......
...@@ -6,7 +6,7 @@ with open(path.join(this_directory, 'README.md'), encoding='utf-8') as f: ...@@ -6,7 +6,7 @@ with open(path.join(this_directory, 'README.md'), encoding='utf-8') as f:
long_description = f.read() long_description = f.read()
setup(name='onn', setup(name='onn',
version='0.1.4', version='0.1.5',
description='Online Neural Network', description='Online Neural Network',
url='https://github.com/alison-carrera/onn', url='https://github.com/alison-carrera/onn',
author='Alison Carrera', author='Alison Carrera',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment