Commit 366cf0cd authored by Alison Carrera's avatar Alison Carrera
Browse files

Added ONN THS params to the pytorch global params.

parent a3fb25c9
......@@ -156,10 +156,10 @@ class ONN_THS(ONN):
s=0.2, e=[0.5, 0.35, 0.2, 0.1, 0.05], use_cuda=False):
super().__init__(features_size, max_num_hidden_layers, qtd_neuron_per_hidden_layer, n_classes, b=b, n=n, s=s,
use_cuda=use_cuda)
self.e = e
self.arms_values = np.arange(n_classes).tolist()
self.n_impressions = np.ones(len(e))
self.n_rewards = np.ones(len(e))
self.e = Parameter(torch.tensor(e), requires_grad=False)
self.arms_values = Parameter(torch.arange(n_classes), requires_grad=False)
self.n_impressions = Parameter(torch.ones(len(e)), requires_grad=False)
self.n_rewards = Parameter(torch.ones(len(e)), requires_grad=False)
def partial_fit(self, X_data, Y_data, exp_factor, show_loss=True):
self.partial_fit_(X_data, Y_data, show_loss)
......@@ -170,10 +170,10 @@ class ONN_THS(ONN):
rewards_0 = self.n_impressions - self.n_rewards
theta_value = np.random.beta(self.n_rewards, rewards_0 + 1)
ranked_arms = np.flip(np.argsort(theta_value), axis=0)
chosen_arm = ranked_arms[0]
chosen_arm = ranked_arms[0].item()
self.n_impressions[chosen_arm] += 1
if np.random.uniform() < self.e[chosen_arm]:
removed_arms = self.arms_values.copy()
removed_arms = self.arms_values.clone().numpy().tolist()
removed_arms.remove(pred)
return random.choice(removed_arms), chosen_arm
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment