Commit 957484a8 authored by alison-carrera's avatar alison-carrera
Browse files

Changes in exploration.

parent f0f74d94
......@@ -150,9 +150,12 @@ class ONN(nn.Module):
def predict(self, X_data):
pred = self.predict_(X_data)
if self.use_exploration and np.random.uniform() < self.e and pred.shape[0] == 1:
if self.use_exploration and np.random.uniform() < self.e and self.batch_size == 1:
removed_arms = self.arms_values.copy()
removed_arms.remove(pred)
removed_arms.remove(pred[0])
return random.choice(removed_arms)
if self.batch_size == 1:
return pred[0]
return pred
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment