Commit 9da59d47 authored by alison-carrera's avatar alison-carrera
Browse files

Added new conditions in exploration predict.

parent 3831ecd4
......@@ -150,7 +150,7 @@ class ONN(nn.Module):
def predict(self, X_data):
pred = self.predict_(X_data)
if self.use_exploration and np.random.uniform() < self.e:
if self.use_exploration and np.random.uniform() < self.e and pred.shape[0] == 1:
removed_arms = self.arms_values.copy()
removed_arms.remove(pred)
return random.choice(removed_arms)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment