diff --git a/bn_functions.py b/bn_functions.py
index 77005f908381d0958aa1e2b8283c80e0f33a8091..a8502a0c8b41845fa424db5ca542c28660e88cea 100644
--- a/bn_functions.py
+++ b/bn_functions.py
@@ -44,12 +44,35 @@ def compute_prob(cpds_table):
     This function checks if any 
     '''
 
-    for val in range(len(cpds_table)):
-            cpds_table[val] = list(map(lambda x: x / (sum(cpds_table[val])+0.00001), cpds_table[val]))
-            cpds_table[val] = check_zero_occurrences(cpds_table[val])
+    cpds_table_array = np.array(cpds_table)
+    cpds_table_array_len = cpds_table_array.shape.__len__()
+
+    if cpds_table_array_len == 4:
+
+        # attempt
+        for elem1 in range(cpds_table_array.shape[0]):
+            # game_state
+            for elem2 in range(cpds_table_array.shape[1]):
+                #assistance
+                for elem3 in range(cpds_table_array.shape[2]):
+                    cpds_table[elem1][elem2][elem3] = list(
+                    map(lambda x: x / (sum(cpds_table[elem1][elem2][elem3]) + 0.00001), cpds_table[elem1][elem2][elem3]))
+                    cpds_table[elem1][elem2][elem3] = check_zero_occurrences(cpds_table[elem1][elem2][elem3])
+
+    elif cpds_table_array_len ==3:
+        #attempt
+        for elem1 in range(cpds_table_array.shape[0]):
+            #game_state
+            for elem2 in range(cpds_table_array.shape[1]):
+                cpds_table[elem1][elem2] = list(map(lambda x: x / (sum(cpds_table[elem1][elem2]) + 0.00001), cpds_table[elem1][elem2]))
+                cpds_table[elem1][elem2] = check_zero_occurrences(cpds_table[elem1][elem2])
+    else:
+        for val in range(len(cpds_table)):
+                cpds_table[val] = list(map(lambda x: x / (sum(cpds_table[val])+0.00001), cpds_table[val]))
+                cpds_table[val] = check_zero_occurrences(cpds_table[val])
     return cpds_table
 
-def average_prob(ref_cpds_table, current_cpds_table):
+def average_prob(ref_cpds_table, current_cpds_table, alpha):
     '''
     Args:
         ref_cpds_table: table from bnlearn
@@ -58,12 +81,13 @@ def average_prob(ref_cpds_table, current_cpds_table):
         avg from both tables
     '''
     res_cpds_table = ref_cpds_table.copy()
+    current_cpds_table_np_array = np.array(current_cpds_table)
     for elem1 in range(len(ref_cpds_table)):
         for elem2 in range(len(ref_cpds_table[0])):
-            res_cpds_table[elem1][elem2] = (ref_cpds_table[elem1][elem2]+current_cpds_table[elem1][elem2])/2
+            res_cpds_table[elem1][elem2] = (ref_cpds_table[elem1][elem2]*(1-alpha))+(current_cpds_table_np_array[elem1][elem2]*alpha)
     return res_cpds_table
 
-def update_cpds_tables(bn_model, variables_tables):
+def update_cpds_tables(bn_model, variables_tables, alpha=0.1):
     '''
     This function updates the bn model with the variables_tables provided in input
     Args:
@@ -80,7 +104,7 @@ def update_cpds_tables(bn_model, variables_tables):
         cpds_table_from_counter = compute_prob(val)
         updated_prob = average_prob(
             np.transpose(cpds_table),
-            cpds_table_from_counter)
+            cpds_table_from_counter, alpha)
         bn_model['model'].cpds[index].values = np.transpose(updated_prob)
 
     return bn_model
diff --git a/bn_persona_model/persona_model_test.bif b/bn_persona_model/persona_model_test.bif
index bbdacc7199fbedeecba0eefb9b3a0ac88720b450..28727a89c3dae174cb90be06ea3cc7b102dd5160 100644
--- a/bn_persona_model/persona_model_test.bif
+++ b/bn_persona_model/persona_model_test.bif
@@ -3,7 +3,7 @@ network persona_model {
 
 %VARIABLES DEFINITION
 
-variable robot_assistance {
+variable agent_assistance {
   type discrete [ 6 ] { lev_0, lev_1, lev_2, lev_3, lev_4, lev_5 };
 }
 variable attempt_t0 {
@@ -13,18 +13,20 @@ variable game_state_t0 {
   type discrete [ 3 ] { beg, mid, end };
 }
 variable attempt_t1 {
-  type discrete [ 4 ] { att_1, att_2, att_3, att_4 };
+  type discrete [ 1 ] { att_1, att_2, att_3, att_4};
 }
 variable game_state_t1 {
   type discrete [ 3 ] { beg, mid, end };
 }
-%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
 variable user_action {
   type discrete [ 3 ] { correct, wrong, timeout };
 }
+variable agent_feedback {
+ type discrete [ 2 ] { no, yes };
+}
 
 %INDIVIDUAL PROBABILITIES DEFINITION
-probability ( robot_assistance ) {
+probability ( agent_assistance ) {
   table 0.17, 0.16, 0.16, 0.17, 0.17, 0.17;
 }
 probability ( game_state_t0 ) {
@@ -39,124 +41,104 @@ probability ( game_state_t1 ) {
 probability ( attempt_t1 ) {
   table 0.25, 0.25, 0.25, 0.25;
 }
-
 probability ( user_action ) {
   table 0.33, 0.33, 0.34;
 }
-
-probability (robot_assistance | game_state_t0, attempt_t0){
- (beg, att_1) 0.0, 0.2, 0.2, 0.2, 0.2, 0.2;
- (beg, att_2) 0.05, 0.15, 0.2, 0.2, 0.2, 0.2;
- (beg, att_3) 0.1, 0.2, 0.2, 0.2, 0.2, 0.1;
- (beg, att_4) 0.15, 0.05, 0.2, 0.2, 0.2, 0.2;
- (mid, att_1) 0.0, 0.2, 0.2, 0.2, 0.2, 0.2;
- (mid, att_2) 0.05, 0.15, 0.2, 0.2, 0.2, 0.2;
- (mid, att_3) 0.1, 0.2, 0.2, 0.2, 0.2, 0.1;
- (mid, att_4) 0.15, 0.05, 0.2, 0.2, 0.2, 0.2;
- (end, att_1) 0.0, 0.2, 0.2, 0.2, 0.2, 0.2;
- (end, att_2) 0.05, 0.15, 0.2, 0.2, 0.2, 0.2;
- (end, att_3) 0.1, 0.2, 0.2, 0.2, 0.2, 0.1;
- (end, att_4) 0.15, 0.05, 0.2, 0.2, 0.2, 0.2;
+probability ( agent_feedback ) {
+  table 0.5, 0.5;
+}
+probability(agent_assistance | agent_feedback) {
+ (yes) 0.4, 0.3, 0.2, 0.1, 0.0, 0.0
+ (no) 0.0, 0.0, 0.1, 0.2, 0.3, 0.4
 }
 
-
-
-probability (user_action | game_state_t0, attempt_t0, robot_assistance){
+probability (user_action | game_state_t0, attempt_t0, agent_assistance){
   (beg, att_1, lev_0) 0.1, 0.9, 0.0;
   (beg, att_2, lev_0) 0.2, 0.8, 0.0;
   (beg, att_3, lev_0) 0.3, 0.7, 0.0;
   (beg, att_4, lev_0) 0.4, 0.6, 0.0;
-  (beg, att_1, lev_1) 0.1, 0.9, 0.0;
-  (beg, att_2, lev_1) 0.2, 0.8, 0.0;
-  (beg, att_3, lev_1) 0.3, 0.7, 0.0;
-  (beg, att_4, lev_1) 0.4, 0.6, 0.0;
-  (beg, att_1, lev_2) 0.1, 0.9, 0.0;
-  (beg, att_2, lev_2) 0.2, 0.8, 0.0;
-  (beg, att_3, lev_2) 0.3, 0.7, 0.0;
-  (beg, att_4, lev_2) 0.4, 0.6, 0.0;
-  (beg, att_1, lev_3) 0.1, 0.9, 0.0;
-  (beg, att_2, lev_3) 0.2, 0.8, 0.0;
-  (beg, att_3, lev_3) 0.3, 0.7, 0.0;
-  (beg, att_4, lev_3) 0.4, 0.6, 0.0;
-  (beg, att_1, lev_4) 0.1, 0.9, 0.0;
-  (beg, att_2, lev_4) 0.2, 0.8, 0.0;
-  (beg, att_3, lev_4) 0.3, 0.7, 0.0;
-  (beg, att_4, lev_4) 0.4, 0.6, 0.0;
-  (beg, att_1, lev_5) 0.1, 0.9, 0.0;
-  (beg, att_2, lev_5) 0.2, 0.8, 0.0;
-  (beg, att_3, lev_5) 0.3, 0.7, 0.0;
-  (beg, att_4, lev_5) 0.4, 0.6, 0.0;
+  (beg, att_1, lev_1) 0.2, 0.8, 0.0;
+  (beg, att_2, lev_1) 0.3, 0.7, 0.0;
+  (beg, att_3, lev_1) 0.4, 0.6, 0.0;
+  (beg, att_4, lev_1) 0.5, 0.5, 0.0;
+  (beg, att_1, lev_2) 0.3, 0.7, 0.0;
+  (beg, att_2, lev_2) 0.4, 0.6, 0.0;
+  (beg, att_3, lev_2) 0.5, 0.5, 0.0;
+  (beg, att_4, lev_2) 0.6, 0.4, 0.0;
+  (beg, att_1, lev_3) 0.4, 0.6, 0.0;
+  (beg, att_2, lev_3) 0.5, 0.5, 0.0;
+  (beg, att_3, lev_3) 0.6, 0.4, 0.0;
+  (beg, att_4, lev_3) 0.7, 0.3, 0.0;
+  (beg, att_1, lev_4) 1.0, 0.0, 0.0;
+  (beg, att_2, lev_4) 1.0, 0.0, 0.0;
+  (beg, att_3, lev_4) 1.0, 0.0, 0.0;
+  (beg, att_4, lev_4) 1.0, 0.0, 0.0;
+  (beg, att_1, lev_5) 1.0, 0.0, 0.0;
+  (beg, att_2, lev_5) 1.0, 0.0, 0.0;
+  (beg, att_3, lev_5) 1.0, 0.0, 0.0;
+  (beg, att_4, lev_5) 1.0, 0.0, 0.0;
 
   (mid, att_1, lev_0) 0.1, 0.9, 0.0;
   (mid, att_2, lev_0) 0.2, 0.8, 0.0;
   (mid, att_3, lev_0) 0.3, 0.7, 0.0;
   (mid, att_4, lev_0) 0.4, 0.6, 0.0;
-  (mid, att_1, lev_1) 0.1, 0.9, 0.0;
-  (mid, att_2, lev_1) 0.2, 0.8, 0.0;
-  (mid, att_3, lev_1) 0.3, 0.7, 0.0;
-  (mid, att_4, lev_1) 0.4, 0.6, 0.0;
-  (mid, att_1, lev_2) 0.1, 0.9, 0.0;
-  (mid, att_2, lev_2) 0.2, 0.8, 0.0;
-  (mid, att_3, lev_2) 0.3, 0.7, 0.0;
-  (mid, att_4, lev_2) 0.4, 0.6, 0.0;
-  (mid, att_1, lev_3) 0.1, 0.9, 0.0;
-  (mid, att_2, lev_3) 0.2, 0.8, 0.0;
-  (mid, att_3, lev_3) 0.3, 0.7, 0.0;
-  (mid, att_4, lev_3) 0.4, 0.6, 0.0;
-  (mid, att_1, lev_4) 0.1, 0.9, 0.0;
-  (mid, att_2, lev_4) 0.2, 0.8, 0.0;
-  (mid, att_3, lev_4) 0.3, 0.7, 0.0;
-  (mid, att_4, lev_4) 0.4, 0.6, 0.0;
-  (mid, att_1, lev_5) 0.1, 0.9, 0.0;
-  (mid, att_2, lev_5) 0.2, 0.8, 0.0;
-  (mid, att_3, lev_5) 0.3, 0.7, 0.0;
-  (mid, att_4, lev_5) 0.4, 0.6, 0.0;
+  (mid, att_1, lev_1) 0.2, 0.8, 0.0;
+  (mid, att_2, lev_1) 0.3, 0.7, 0.0;
+  (mid, att_3, lev_1) 0.4, 0.6, 0.0;
+  (mid, att_4, lev_1) 0.5, 0.5, 0.0;
+  (mid, att_1, lev_2) 0.3, 0.7, 0.0;
+  (mid, att_2, lev_2) 0.4, 0.6, 0.0;
+  (mid, att_3, lev_2) 0.5, 0.5, 0.0;
+  (mid, att_4, lev_2) 0.6, 0.4, 0.0;
+  (mid, att_1, lev_3) 0.4, 0.6, 0.0;
+  (mid, att_2, lev_3) 0.5, 0.5, 0.0;
+  (mid, att_3, lev_3) 0.6, 0.4, 0.0;
+  (mid, att_4, lev_3) 0.7, 0.3, 0.0;
+  (mid, att_1, lev_4) 1.0, 0.0, 0.0;
+  (mid, att_2, lev_4) 1.0, 0.0, 0.0;
+  (mid, att_3, lev_4) 1.0, 0.0, 0.0;
+  (mid, att_4, lev_4) 1.0, 0.0, 0.0;
+  (mid, att_1, lev_5) 1.0, 0.0, 0.0;
+  (mid, att_2, lev_5) 1.0, 0.0, 0.0;
+  (mid, att_3, lev_5) 1.0, 0.0, 0.0;
+  (mid, att_4, lev_5) 1.0, 0.0, 0.0;
 
   (end, att_1, lev_0) 0.1, 0.9, 0.0;
   (end, att_2, lev_0) 0.2, 0.8, 0.0;
-  (end, att_3, lev_0) 0.2, 0.8, 0.0;
+  (end, att_3, lev_0) 0.3, 0.7, 0.0;
   (end, att_4, lev_0) 0.4, 0.6, 0.0;
-  (end, att_1, lev_1) 0.1, 0.9, 0.0;
-  (end, att_2, lev_1) 0.2, 0.8, 0.0;
+  (end, att_1, lev_1) 0.2, 0.8, 0.0;
+  (end, att_2, lev_1) 0.3, 0.7, 0.0;
   (end, att_3, lev_1) 0.4, 0.6, 0.0;
-  (end, att_4, lev_1) 0.4, 0.6, 0.0;
-  (end, att_1, lev_2) 0.1, 0.9, 0.0;
-  (end, att_2, lev_2) 0.2, 0.8, 0.0;
-  (end, att_3, lev_2) 0.4, 0.6, 0.0;
-  (end, att_4, lev_2) 0.4, 0.6, 0.0;
-  (end, att_1, lev_3) 0.1, 0.9, 0.0;
-  (end, att_2, lev_3) 0.2, 0.8, 0.0;
-  (end, att_3, lev_3) 0.5, 0.5, 0.0;
-  (end, att_4, lev_3) 0.4, 0.6, 0.0;
-  (end, att_1, lev_4) 0.1, 0.9, 0.0;
-  (end, att_2, lev_4) 0.2, 0.8, 0.0;
-  (end, att_3, lev_4) 0.7, 0.3, 0.0;
-  (end, att_4, lev_4) 0.4, 0.6, 0.0;
-  (end, att_1, lev_5) 0.1, 0.9, 0.0;
-  (end, att_2, lev_5) 0.2, 0.8, 0.0;
-  (end, att_3, lev_5) 0.3, 0.7, 0.0;
-  (end, att_4, lev_5) 0.4, 0.6, 0.0;
-
+  (end, att_4, lev_1) 0.5, 0.5, 0.0;
+  (end, att_1, lev_2) 0.3, 0.7, 0.0;
+  (end, att_2, lev_2) 0.4, 0.6, 0.0;
+  (end, att_3, lev_2) 0.5, 0.5, 0.0;
+  (end, att_4, lev_2) 0.6, 0.4, 0.0;
+  (end, att_1, lev_3) 0.4, 0.6, 0.0;
+  (end, att_2, lev_3) 0.5, 0.5, 0.0;
+  (end, att_3, lev_3) 0.6, 0.4, 0.0;
+  (end, att_4, lev_3) 0.7, 0.3, 0.0;
+  (end, att_1, lev_4) 1.0, 0.0, 0.0;
+  (end, att_2, lev_4) 1.0, 0.0, 0.0;
+  (end, att_3, lev_4) 1.0, 0.0, 0.0;
+  (end, att_4, lev_4) 1.0, 0.0, 0.0;
+  (end, att_1, lev_5) 1.0, 0.0, 0.0;
+  (end, att_2, lev_5) 1.0, 0.0, 0.0;
+  (end, att_3, lev_5) 1.0, 0.0, 0.0;
+  (end, att_4, lev_5) 1.0, 0.0, 0.0;
 }
 
 
 probability (game_state_t1 | user_action)  {
-   (correct) 0.2, 0.3, 0.5;
-   (wrong) 0.5, 0.3, 0.2;
-   (timeout) 0.33, 0.34, 0.33;
+   (correct) 0.25, 0.3, 0.45;
+   (wrong) 0.33, 0.33, 0.33;
+   (timeout) 0.33, 0.33, 0.33;
 }
 probability (attempt_t1 | user_action)  {
-   (correct) 0.1, 0.2, 0.3, 0.4;
-   (wrong) 0.4, 0.3, 0.2, 0.1;
+   (correct) 0.1, 0.2, 0.25, 0.45;
+   (wrong) 0.25, 0.25, 0.25, 0.25;
    (timeout) 0.25, 0.25, 0.25, 0.25;
 }
 
 
-probability (user_action | robot_assistance){
-  (lev_0) 0.1, 0.6, 0.3;
-  (lev_1) 0.2, 0.5, 0.3;
-  (lev_2) 0.3, 0.5, 0.2;
-  (lev_3) 0.5, 0.3, 0.2;
-  (lev_4) 0.9, 0.1, 0.0;
-  (lev_5) 0.9, 0.1, 0.0;
-}
\ No newline at end of file
diff --git a/simulation.py b/simulation.py
index 37036d3cf2537b8cb526ff7d6d7a4181c5b7dd0f..33aa674f8854ee9d789b9f277d81763909380129 100644
--- a/simulation.py
+++ b/simulation.py
@@ -2,12 +2,102 @@ import itertools
 import os
 import bnlearn
 import numpy as np
+import random
 #import classes and modules
 from bn_variables import Agent_Assistance, Agent_Feedback, User_Action, User_React_time, Game_State, Attempt
 import bn_functions
 import utils
 from episode import Episode
 
+#
+# def choose_next_states(task_progress, game_state_t0, n_attempt_t0, max_attempt_per_object,
+#                        selected_agent_assistance_action,
+#                        bn_model_user_action, var_user_action_target_action):
+#
+#     def get_next_state(task_progress, game_state_t0, n_attempt_t0, max_attempt_per_object):
+#
+#         next_state = []
+#
+#         #correct move on the last state of the bin
+#         if (task_progress == 1 or task_progress == 3 or task_progress == 4) and n_attempt_t0<max_attempt_per_object:
+#             next_state.append((game_state_t0+1, n_attempt_t0+1))
+#         #correct state bu still in the bin
+#         elif task_progress == 0 or task_progress == 2 and n_attempt_t0<max_attempt_per_object:
+#             next_state.append((game_state_t0, n_attempt_t0+1))
+#         elif (task_progress == 1 or task_progress == 3 or task_progress == 4) and n_attempt_t0>=max_attempt_per_object:
+#            assert "you reach the maximum number of attempt the agent will move it for you"
+#         elif task_progress == 0 or task_progress == 2 and n_attempt_t0>=max_attempt_per_object:
+#             assert "you reach the maximum number of attempt the agent will move it for you"
+#
+#         return next_state
+#
+#     next_state = get_next_state(task_progress, game_state_t0, n_attempt_t0, max_attempt_per_object)
+#     query_answer_probs = []
+#     for t in next_state:
+#         vars_user_evidence = {"game_state_t0": game_state_t0,
+#                           "attempt_t0": n_attempt_t0 - 1,
+#                           "robot_assistance": selected_agent_assistance_action,
+#                           "game_state_t1": t[0],
+#                           "attempt_t1": t[1],
+#                           }
+#
+#         query_user_action_prob = bn_functions.infer_prob_from_state(bn_model_user_action,
+#                                                                 infer_variable=var_user_action_target_action,
+#                                                                 evidence_variables=vars_user_evidence)
+#         query_answer_probs.append(query_user_action_prob)
+#
+#
+#     #do the inference here
+#     #1. check given the current_state which are the possible states
+#     #2. for each of the possible states get the probability of user_action
+#     #3. select the state with the most higher action and execute it
+#     #4. return user_action
+#
+
+def generate_agent_assistance(preferred_assistance, agent_behaviour, n_game_state, n_attempt, alpha_action=0.1):
+    agent_policy = [[0 for j in range(n_attempt)] for i in range(n_game_state)]
+    previous_assistance = -1
+    def get_alternative_action(agent_assistance, previous_assistance, agent_behaviour, alpha_action):
+        agent_assistance_res = agent_assistance
+        if previous_assistance == agent_assistance:
+            if agent_behaviour == "challenge":
+                if random.random() > alpha_action:
+                    agent_assistance_res = min(max(0, agent_assistance-1), 5)
+                else:
+                    agent_assistance_res = min(max(0, agent_assistance), 5)
+            else:
+                if random.random() > alpha_action:
+                    agent_assistance_res = min(max(0, agent_assistance + 1), 5)
+                else:
+                    agent_assistance_res = min(max(0, agent_assistance), 5)
+        return agent_assistance_res
+
+
+    for gs in range(n_game_state):
+        for att in range(n_attempt):
+            if att == 0:
+                if random.random()>alpha_action:
+                    agent_policy[gs][att] = preferred_assistance
+                    previous_assistance = agent_policy[gs][att]
+                else:
+                    if random.random()>0.5:
+                        agent_policy[gs][att] = min(max(0, preferred_assistance-1),5)
+                        previous_assistance = agent_policy[gs][att]
+                    else:
+                        agent_policy[gs][att] = min(max(0, preferred_assistance+1), 5)
+                        previous_assistance = agent_policy[gs][att]
+            else:
+                if agent_behaviour == "challenge":
+                    agent_policy[gs][att] = min(max(0, preferred_assistance-1), 5)
+                    agent_policy[gs][att] = get_alternative_action(agent_policy[gs][att], previous_assistance, agent_behaviour, alpha_action)
+                    previous_assistance = agent_policy[gs][att]
+                else:
+                    agent_policy[gs][att] = min(max(0, preferred_assistance+1), 5)
+                    agent_policy[gs][att] = get_alternative_action(agent_policy[gs][att], previous_assistance, agent_behaviour, alpha_action)
+                    previous_assistance = agent_policy[gs][att]
+
+    return agent_policy
+
 
 def compute_next_state(user_action, task_progress_counter, attempt_counter, correct_move_counter,
                        wrong_move_counter, timeout_counter, max_attempt_counter, max_attempt_per_object
@@ -80,11 +170,10 @@ def compute_next_state(user_action, task_progress_counter, attempt_counter, corr
 def simulation(bn_model_user_action, var_user_action_target_action, bn_model_user_react_time, var_user_react_time_target_action,
                user_memory_name, user_memory_value, user_attention_name, user_attention_value,
                user_reactivity_name, user_reactivity_value,
-               task_progress_name, game_attempt_name, agent_assistance_name, agent_feedback_name,
-               bn_model_agent_assistance, var_agent_assistance_target_action, bn_model_agent_feedback,
-               var_agent_feedback_target_action, agent_policy,
+               task_progress_t0_name, task_progress_t1_name, game_attempt_t0_name, game_attempt_t1_name,
+               agent_assistance_name, agent_policy,
                state_space, action_space,
-               epochs=50, task_complexity=5, max_attempt_per_object=4):
+               epochs=50, task_complexity=5, max_attempt_per_object=4, alpha_learning=0):
     '''
     Args:
 
@@ -98,8 +187,6 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use
 
     #metrics we need, in order to compute afterwords the belief
 
-    attempt_counter_per_action = [[0 for i in range(Attempt.counter.value)]  for j in range(User_Action.counter.value)]
-    game_state_counter_per_action = [[0 for i in range(Game_State.counter.value)]  for j in range(User_Action.counter.value)]
     agent_feedback_per_action = [[0 for i in range(Agent_Feedback.counter.value)] for j in range(User_Action.counter.value)]
     agent_assistance_per_action = [[0 for i in range(Agent_Assistance.counter.value)] for j in range(User_Action.counter.value)]
 
@@ -108,12 +195,21 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use
     agent_feedback_per_react_time = [[0 for i in range(Agent_Feedback.counter.value)] for j in  range(User_React_time.counter.value)]
     agent_assistance_per_react_time = [[0 for i in range(Agent_Assistance.counter.value)] for j in   range(User_React_time.counter.value)]
 
-    game_state_counter_per_agent_assistance = [[0 for i in range(Game_State.counter.value)] for j in   range(Agent_Assistance.counter.value)]
-    attempt_counter_per_agent_assistance = [[0 for i in range(Attempt.counter.value)] for j in   range(Agent_Assistance.counter.value)]
-
     game_state_counter_per_agent_feedback = [[0 for i in range(Game_State.counter.value)] for j in   range(Agent_Feedback.counter.value)]
     attempt_counter_per_agent_feedback = [[0 for i in range(Attempt.counter.value)] for j in   range(Agent_Feedback.counter.value)]
-
+    game_state_counter_per_agent_assistance = [[0 for i in range(Game_State.counter.value)] for j in
+                                             range(Agent_Assistance.counter.value)]
+    attempt_counter_per_agent_assistance = [[0 for i in range(Attempt.counter.value)] for j in
+                                          range(Agent_Assistance.counter.value)]
+
+
+    user_action_per_game_state_attempt_counter_agent_assistance = [[[[0 for i in range(User_Action.counter.value)] for l in range(Game_State.counter.value)] for j in
+                                               range(Attempt.counter.value)] for k in range(Agent_Assistance.counter.value)]
+    user_action_per_agent_assistance = [[0 for i in range(User_Action.counter.value)] for j in
+                                            range(Agent_Assistance.counter.value)]
+    attempt_counter_per_user_action = [[0 for i in range(Attempt.counter.value)] for j in range(User_Action.counter.value)]
+    game_state_counter_per_user_action = [[0 for i in range(Game_State.counter.value)] for j in
+                                     range(User_Action.counter.value)]
 
     #output variables:
     n_correct_per_episode = [0]*epochs
@@ -131,6 +227,10 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use
     ep = Episode()
 
     for e in range(epochs):
+        print("##########################################################")
+        print("EPISODE ",e)
+        print("##########################################################")
+
         '''Simulation framework'''
         #counters
         game_state_counter = 0
@@ -142,10 +242,11 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use
         max_attempt_counter = 0
 
         #The following variables are used to update the BN at the end of the episode
-        user_action_dynamic_variables = {'attempt': attempt_counter_per_action,
-                             'game_state': game_state_counter_per_action,
-                             'agent_assistance': agent_assistance_per_action,
-                             'agent_feedback': agent_feedback_per_action}
+        user_action_dynamic_variables = {
+                                        'attempt_t1': attempt_counter_per_user_action,
+                                        'game_state_t1': game_state_counter_per_user_action,
+                                        'user_action': user_action_per_game_state_attempt_counter_agent_assistance
+                                        }
 
         user_react_time_dynamic_variables = {'attempt': attempt_counter_per_react_time,
                              'game_state': game_state_counter_per_react_time,
@@ -167,38 +268,8 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use
 
             current_state = (game_state_counter, attempt_counter, selected_user_action)
 
-            if type(agent_policy) is not np.ndarray:
-                ##################QUERY FOR THE ROBOT ASSISTANCE AND FEEDBACK##################
-                vars_agent_evidence = {
-                                       user_reactivity_name: user_reactivity_value,
-                                       user_memory_name: user_memory_value,
-                                       task_progress_name: game_state_counter,
-                                       game_attempt_name: attempt_counter-1,
-                                       }
-
-                query_agent_assistance_prob = bn_functions.infer_prob_from_state(bn_model_agent_assistance,
-                                                                       infer_variable=var_agent_assistance_target_action,
-                                                                       evidence_variables=vars_agent_evidence)
-                if bn_model_agent_feedback != None:
-                    query_agent_feedback_prob = bn_functions.infer_prob_from_state(bn_model_agent_feedback,
-                                                                          infer_variable=var_agent_feedback_target_action,
-                                                                          evidence_variables=vars_agent_evidence)
-                    selected_agent_feedback_action = bn_functions.get_stochastic_action(query_agent_feedback_prob.values)
-                else:
-                    selected_agent_feedback_action = 0
-
-
-                selected_agent_assistance_action = bn_functions.get_stochastic_action(query_agent_assistance_prob.values)
-            else:
-                idx_state = ep.state_from_point_to_index(state_space, current_state)
-                if agent_policy[idx_state]>=Agent_Assistance.counter.value:
-                    selected_agent_assistance_action = agent_policy[idx_state]-Agent_Assistance.counter.value
-                    selected_agent_feedback_action = 1
-                else:
-                    selected_agent_assistance_action = agent_policy[idx_state]
-                    selected_agent_feedback_action = 0
-
-            n_feedback_per_episode[e][selected_agent_feedback_action] += 1
+            selected_agent_assistance_action = agent_policy[game_state_counter][attempt_counter-1]#random.randint(0,5)
+            selected_agent_feedback_action = 0#random.randint(0,1)
 
             #counters for plots
             n_assistance_lev_per_episode[e][selected_agent_assistance_action] += 1
@@ -211,38 +282,39 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use
             #compare the real user with the estimated Persona and returns a user action (0, 1a, 2)
 
             #return the user action in this state based on the Persona profile
-            vars_user_evidence = {user_attention_name: user_attention_value,
-                                  user_reactivity_name: user_reactivity_value,
-                                  user_memory_name: user_memory_value,
-                                  task_progress_name: game_state_counter,
-                                  game_attempt_name: attempt_counter-1,
-                                  agent_assistance_name: selected_agent_assistance_action,
-                                  agent_feedback_name: selected_agent_feedback_action
-                                  }
+            vars_user_evidence = {    task_progress_t0_name: game_state_counter,
+                                      game_attempt_t0_name: attempt_counter - 1,
+                                      task_progress_t1_name: game_state_counter,
+                                      game_attempt_t1_name: attempt_counter - 1,
+                                      agent_assistance_name: selected_agent_assistance_action,
+                                      }
+
             query_user_action_prob = bn_functions.infer_prob_from_state(bn_model_user_action,
                                                                         infer_variable=var_user_action_target_action,
                                                                         evidence_variables=vars_user_evidence)
-            query_user_react_time_prob = bn_functions.infer_prob_from_state(bn_model_user_react_time,
-                                                                            infer_variable=var_user_react_time_target_action,
-                                                                            evidence_variables=vars_user_evidence)
-
-
+            # query_user_react_time_prob = bn_functions.infer_prob_from_state(bn_model_user_react_time,
+            #                                                                 infer_variable=var_user_react_time_target_action,
+            #                                                                 evidence_variables=vars_user_evidence)
+            #
+            #
 
             selected_user_action = bn_functions.get_stochastic_action(query_user_action_prob.values)
-            selected_user_react_time = bn_functions.get_stochastic_action(query_user_react_time_prob.values)
+            # selected_user_react_time = bn_functions.get_stochastic_action(query_user_react_time_prob.values)
             # counters for plots
-            n_react_time_per_episode[e][selected_user_react_time] += 1
+            # n_react_time_per_episode[e][selected_user_react_time] += 1
 
             #updates counters for user action
-            agent_assistance_per_action[selected_user_action][selected_agent_assistance_action] += 1
-            attempt_counter_per_action[selected_user_action][attempt_counter-1] += 1
-            game_state_counter_per_action[selected_user_action][game_state_counter] += 1
-            agent_feedback_per_action[selected_user_action][selected_agent_feedback_action] += 1
+
+            user_action_per_game_state_attempt_counter_agent_assistance[selected_agent_assistance_action][attempt_counter-1][game_state_counter][selected_user_action] += 1
+            attempt_counter_per_user_action[selected_user_action][attempt_counter-1] += 1
+            game_state_counter_per_user_action[selected_user_action][game_state_counter] += 1
+            user_action_per_agent_assistance[selected_agent_assistance_action][selected_user_action] += 1
+
             #update counter for user react time
-            agent_assistance_per_react_time[selected_user_react_time][selected_agent_assistance_action] += 1
-            attempt_counter_per_react_time[selected_user_react_time][attempt_counter-1] += 1
-            game_state_counter_per_react_time[selected_user_react_time][game_state_counter] += 1
-            agent_feedback_per_react_time[selected_user_react_time][selected_agent_feedback_action] += 1
+            # agent_assistance_per_react_time[selected_user_react_time][selected_agent_assistance_action] += 1
+            # attempt_counter_per_react_time[selected_user_react_time][attempt_counter-1] += 1
+            # game_state_counter_per_react_time[selected_user_react_time][game_state_counter] += 1
+            # agent_feedback_per_react_time[selected_user_react_time][selected_agent_feedback_action] += 1
             #update counter for agent feedback
             game_state_counter_per_agent_feedback[selected_agent_feedback_action][game_state_counter] += 1
             attempt_counter_per_agent_feedback[selected_agent_feedback_action][attempt_counter-1] += 1
@@ -269,8 +341,6 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use
                                                                         timeout_counter, max_attempt_counter,
                                                                         max_attempt_per_object)
 
-
-
             # store the (state, action, next_state)
             episode.append((ep.state_from_point_to_index(state_space, current_state),
                             ep.state_from_point_to_index(action_space, current_agent_action),
@@ -286,22 +356,25 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use
         episodes.append(Episode(episode))
 
         #update user models
-        bn_model_user_action = bn_functions.update_cpds_tables(bn_model_user_action, user_action_dynamic_variables)
+        bn_model_user_action = bn_functions.update_cpds_tables(bn_model_user_action, user_action_dynamic_variables, alpha_learning)
         bn_model_user_react_time = bn_functions.update_cpds_tables(bn_model_user_react_time, user_react_time_dynamic_variables)
         #update agent models
-        bn_model_agent_assistance = bn_functions.update_cpds_tables(bn_model_agent_assistance, agent_assistance_dynamic_variables)
-        if bn_model_agent_feedback !=None:
-            bn_model_agent_feedback = bn_functions.update_cpds_tables(bn_model_agent_feedback, agent_feedback_dynamic_variables)
+
+        print("user_given_game_attempt:", bn_model_user_action['model'].cpds[0].values)
+        print("user_given_robot:", bn_model_user_action['model'].cpds[5].values)
+        print("game_user:", bn_model_user_action['model'].cpds[3].values)
+        print("attempt_user:", bn_model_user_action['model'].cpds[2].values)
 
         #reset counter
-        agent_assistance_per_action = [[0 for i in range(Agent_Assistance.counter.value)] for j in
-                                         range(User_Action.counter.value)]
-        agent_feedback_per_action = [[0 for i in range(Agent_Feedback.counter.value)] for j in
-                                     range(User_Action.counter.value)]
-        game_state_counter_per_action = [[0 for i in range(Game_State.counter.value)] for j in
-                                         range(User_Action.counter.value)]
-        attempt_counter_per_action = [[0 for i in range(Attempt.counter.value)] for j in
-                                      range(User_Action.counter.value)]
+        user_action_per_game_state_attempt_counter_agent_assistance = [[[[0 for i in range(User_Action.counter.value)]
+                                                                         for l in range(Game_State.counter.value)] for j in
+                                                                        range(Attempt.counter.value)] for k in range(Agent_Assistance.counter.value)]
+        user_action_per_agent_assistance = [[0 for i in range(User_Action.counter.value)] for j in
+                                            range(Agent_Assistance.counter.value)]
+        attempt_counter_per_user_action = [[0 for i in range(Attempt.counter.value)] for j in
+                                           range(User_Action.counter.value)]
+        game_state_counter_per_user_action = [[0 for i in range(Game_State.counter.value)] for j in
+                                              range(User_Action.counter.value)]
 
         attempt_counter_per_react_time = [[0 for i in range(Attempt.counter.value)] for j in
                                           range(User_React_time.counter.value)]
@@ -312,15 +385,6 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use
         agent_assistance_per_react_time = [[0 for i in range(Agent_Assistance.counter.value)] for j in
                                            range(User_React_time.counter.value)]
 
-        game_state_counter_per_agent_assistance = [[0 for i in range(Game_State.counter.value)] for j in
-                                                   range(Agent_Assistance.counter.value)]
-        attempt_counter_per_agent_assistance = [[0 for i in range(Attempt.counter.value)] for j in
-                                                range(Agent_Assistance.counter.value)]
-
-        game_state_counter_per_agent_feedback = [[0 for i in range(Game_State.counter.value)] for j in
-                                                 range(Agent_Feedback.counter.value)]
-        attempt_counter_per_agent_feedback = [[0 for i in range(Attempt.counter.value)] for j in
-                                              range(Agent_Feedback.counter.value)]
 
         #for plots
         n_correct_per_episode[e] = correct_move_counter
@@ -343,126 +407,58 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use
 #############################################################################
 #############################################################################
 
-# #SIMULATION PARAMS
-# epochs = 100
-#
-# #initialise the agent
-# bn_model_caregiver_assistance = bnlearn.import_DAG('bn_agent_model/agent_assistive_model.bif')
-# bn_model_caregiver_feedback = bnlearn.import_DAG('bn_agent_model/agent_feedback_model.bif')
-# bn_model_user_action = bnlearn.import_DAG('bn_persona_model/user_action_model.bif')
-# bn_model_user_react_time = bnlearn.import_DAG('bn_persona_model/user_react_time_model.bif')
-# bn_model_other_user_action = None#bnlearn.import_DAG('bn_persona_model/other_user_action_model.bif')
-# bn_model_other_user_react_time = None#bnlearn.import_DAG('bn_persona_model/other_user_react_time_model.bif')
-#
-# #initialise memory, attention and reactivity varibles
-# persona_memory = 0; persona_attention = 0; persona_reactivity = 1;
-# #initialise memory, attention and reactivity varibles
-# other_user_memory = 2; other_user_attention = 2; other_user_reactivity = 2;
-#
-# #define state space struct for the irl algorithm
-# attempt = [i for i in range(1, Attempt.counter.value+1)]
-# #+1a (3,_,_) absorbing state
-# game_state = [i for i in range(0, Game_State.counter.value+1)]
-# user_action = [i for i in range(-1, User_Action.counter.value-1)]
-# state_space = (game_state, attempt, user_action)
-# states_space_list = list(itertools.product(*state_space))
-# agent_assistance_action = [i for i in range(Agent_Assistance.counter.value)]
-# agent_feedback_action = [i for i in range(Agent_Feedback.counter.value)]
-# action_space = (agent_assistance_action, agent_feedback_action)
-# action_space_list = list(itertools.product(*action_space))
-#
-# ##############BEFORE RUNNING THE SIMULATION UPDATE THE BELIEF IF YOU HAVE DATA####################
-# log_directory = "/home/pal/carf_ws/src/carf/caregiver_in_the_loop/log/1/0"
-# if os.path.exists(log_directory):
-#     bn_belief_user_action_file = log_directory+"/bn_belief_user_action.pkl"
-#     bn_belief_user_react_time_file = log_directory+"/bn_belief_user_react_time.pkl"
-#     bn_belief_caregiver_assistance_file = log_directory+"/bn_belief_caregiver_assistive_action.pkl"
-#     bn_belief_caregiver_feedback_file = log_directory+"/bn_belief_caregiver_feedback_action.pkl"
-#
-#     bn_belief_user_action = utils.read_user_statistics_from_pickle(bn_belief_user_action_file)
-#     bn_belief_user_react_time = utils.read_user_statistics_from_pickle(bn_belief_user_react_time_file)
-#     bn_belief_caregiver_assistance = utils.read_user_statistics_from_pickle(bn_belief_caregiver_assistance_file)
-#     bn_belief_caregiver_feedback = utils.read_user_statistics_from_pickle(bn_belief_caregiver_feedback_file)
-#     bn_model_user_action = bn_functions.update_cpds_tables(bn_model=bn_model_user_action, variables_tables=bn_belief_user_action)
-#     bn_model_user_react_time = bn_functions.update_cpds_tables(bn_model=bn_model_user_react_time, variables_tables=bn_belief_user_react_time)
-#     bn_model_caregiver_assistance = bn_functions.update_cpds_tables(bn_model=bn_model_caregiver_assistance, variables_tables=bn_belief_caregiver_assistance)
-#     bn_model_caregiver_feedback = bn_functions.update_cpds_tables(bn_model=bn_model_caregiver_feedback, variables_tables=bn_belief_caregiver_feedback)
-#
-# else:
-#     assert("You're not using the user information")
-#     question = input("Are you sure you don't want to load user's belief information?")
-#
-# game_performance_per_episode, react_time_per_episode, agent_assistance_per_episode, agent_feedback_per_episode, generated_episodes = \
-#         simulation(bn_model_user_action=bn_model_user_action, var_user_action_target_action=['user_action'],
-#                    bn_model_user_react_time=bn_model_user_react_time,
-#                    var_user_react_time_target_action=['user_react_time'],
-#                    user_memory_name="memory", user_memory_value=persona_memory,
-#                    user_attention_name="attention", user_attention_value=persona_attention,
-#                    user_reactivity_name="reactivity", user_reactivity_value=persona_reactivity,
-#                    task_progress_name="game_state", game_attempt_name="attempt",
-#                    agent_assistance_name="agent_assistance", agent_feedback_name="agent_feedback",
-#                    bn_model_agent_assistance=bn_model_caregiver_assistance,
-#                    var_agent_assistance_target_action=["agent_assistance"],
-#                    bn_model_agent_feedback=bn_model_caregiver_feedback, var_agent_feedback_target_action=["agent_feedback"],
-#                    bn_model_other_user_action=bn_model_other_user_action,
-#                    var_other_user_action_target_action=['user_action'],
-#                    bn_model_other_user_react_time=bn_model_other_user_react_time,
-#                    var_other_user_target_react_time_action=["user_react_time"], other_user_memory_name="memory",
-#                    other_user_memory_value=other_user_memory, other_user_attention_name="attention",
-#                    other_user_attention_value=other_user_attention, other_user_reactivity_name="reactivity",
-#                    other_user_reactivity_value=other_user_reactivity,
-#                    state_space=states_space_list, action_space=action_space_list,
-#                    epochs=epochs, task_complexity=5, max_attempt_per_object=4)
-#
-#
-#
-# plot_game_performance_path = ""
-# plot_agent_assistance_path = ""
-# episodes_path = "episodes.npy"
-#
-# if bn_model_other_user_action != None:
-#     plot_game_performance_path = "game_performance_"+"_epoch_"+str(epochs)+"_real_user_memory_"+str(real_user_memory)+"_real_user_attention_"+str(real_user_attention)+"_real_user_reactivity_"+str(real_user_reactivity)+".jpg"
-#     plot_agent_assistance_path = "agent_assistance_"+"epoch_"+str(epochs)+"_real_user_memory_"+str(real_user_memory)+"_real_user_attention_"+str(real_user_attention)+"_real_user_reactivity_"+str(real_user_reactivity)+".jpg"
-#     plot_agent_feedback_path = "agent_feedback_"+"epoch_"+str(epochs)+"_real_user_memory_"+str(real_user_memory)+"_real_user_attention_"+str(real_user_attention)+"_real_user_reactivity_"+str(real_user_reactivity)+".jpg"
-#
-# else:
-#     plot_game_performance_path = "game_performance_"+"epoch_" + str(epochs) + "_persona_memory_" + str(persona_memory) + "_persona_attention_" + str(persona_attention) + "_persona_reactivity_" + str(persona_reactivity) + ".jpg"
-#     plot_agent_assistance_path = "agent_assistance_"+"epoch_"+str(epochs)+"_persona_memory_"+str(persona_memory)+"_persona_attention_"+str(persona_attention)+"_persona_reactivity_"+str(persona_reactivity)+".jpg"
-#     plot_agent_feedback_path = "agent_feedback_"+"epoch_"+str(epochs)+"_persona_memory_"+str(persona_memory)+"_persona_attention_"+str(persona_attention)+"_persona_reactivity_"+str(persona_reactivity)+".jpg"
-#
-# dir_name = input("Please insert the name of the directory:")
-# full_path = os.getcwd()+"/results/"+dir_name+"/"
-# if not os.path.exists(full_path):
-#   os.mkdir(full_path)
-#   print("Directory ", full_path, " created.")
-# else:
-#   dir_name = input("The directory already exist please insert a new name:")
-#   print("Directory ", full_path, " created.")
-#   if os.path.exists(full_path):
-#     assert("Directory already exists ... start again")
-#     exit(0)
-#
-# with open(full_path+episodes_path, "ab") as f:
-#   np.save(full_path+episodes_path, generated_episodes)
-#   f.close()
-#
-#
-# utils.plot2D_game_performance(full_path+plot_game_performance_path, epochs, game_performance_per_episode)
-# utils.plot2D_assistance(full_path+plot_agent_assistance_path, epochs, agent_assistance_per_episode)
-# utils.plot2D_feedback(full_path+plot_agent_feedback_path, epochs, agent_feedback_per_episode)
-
-
-
-'''
-With the current simulator we can generate a list of episodes
-the episodes will be used to generate the trans probabilities and as input to the IRL algo 
-'''
-#TODO
-# - include reaction time as output
-# - average mistakes, average timeout, average assistance, average_react_time
-# - include real time episodes into the simulation:
-#   - counters for agent_assistance, agent_feedback, attempt, game_state, attention and reactivity
-#   - using the function update probability to generate the new user model and use it as input to the simulator
-
-
 
+agent_policy = generate_agent_assistance(preferred_assistance=2, agent_behaviour="help", n_game_state=Game_State.counter.value, n_attempt=Attempt.counter.value, alpha_action=0.5)
+
+# SIMULATION PARAMS
+epochs = 20
+scaling_factor = 1
+# initialise the agent
+bn_model_caregiver_assistance = bnlearn.import_DAG('/home/pal/Documents/Framework/bn_generative_model/bn_agent_model/agent_assistive_model.bif')
+bn_model_caregiver_feedback = None#bnlearn.import_DAG('/home/pal/Documents/Framework/bn_generative_model/bn_agent_model/agent_feedback_model.bif')
+bn_model_user_action = bnlearn.import_DAG('/home/pal/Documents/Framework/bn_generative_model/bn_persona_model/persona_model_test.bif')
+bn_model_user_react_time = bnlearn.import_DAG('/home/pal/Documents/Framework/bn_generative_model/bn_persona_model/user_react_time_model.bif')
+
+# initialise memory, attention and reactivity variables
+persona_memory = 0;
+persona_attention = 0;
+persona_reactivity = 1;
+
+# define state space struct for the irl algorithm
+episode_instance = Episode()
+# DEFINITION OF THE MDP
+# define state space struct for the irl algorithm
+attempt = [i for i in range(1, Attempt.counter.value + 1)]
+# +1 (3,_,_) absorbing state
+game_state = [i for i in range(0, Game_State.counter.value + 1)]
+user_action = [i for i in range(-1, User_Action.counter.value - 1)]
+state_space = (game_state, attempt, user_action)
+states_space_list = list(itertools.product(*state_space))
+state_space_index = [episode_instance.state_from_point_to_index(states_space_list, s) for s in states_space_list]
+agent_assistance_action = [i for i in range(Agent_Assistance.counter.value)]
+agent_feedback_action = [i for i in range(Agent_Feedback.counter.value)]
+action_space = (agent_feedback_action, agent_assistance_action)
+action_space_list = list(itertools.product(*action_space))
+action_space_index = [episode_instance.state_from_point_to_index(action_space_list, a) for a in action_space_list]
+terminal_state = [(Game_State.counter.value, i, user_action[j]) for i in range(1, Attempt.counter.value + 1) for j in
+                  range(len(user_action))]
+initial_state = (1, 1, 0)
+
+#1. RUN THE SIMULATION WITH THE PARAMS SET BY THE CAREGIVER
+
+
+game_performance_per_episode, react_time_per_episode, agent_assistance_per_episode, agent_feedback_per_episode, episodes_list = \
+    simulation(bn_model_user_action=bn_model_user_action, var_user_action_target_action=['user_action'],
+                   bn_model_user_react_time=bn_model_user_react_time,
+                   var_user_react_time_target_action=['user_react_time'],
+                   user_memory_name="memory", user_memory_value=persona_memory,
+                   user_attention_name="attention", user_attention_value=persona_attention,
+                   user_reactivity_name="reactivity", user_reactivity_value=persona_reactivity,
+                   task_progress_t0_name="game_state_t0", task_progress_t1_name="game_state_t1",
+                   game_attempt_t0_name="attempt_t0", game_attempt_t1_name="attempt_t1",
+                   agent_assistance_name="agent_assistance", agent_policy=agent_policy,
+                   state_space=states_space_list, action_space=action_space_list,
+                   epochs=epochs, task_complexity=5, max_attempt_per_object=4, alpha_learning=0.1)
+
+utils.plot2D_game_performance("/home/pal/Documents/Framework/bn_generative_model/results/user_performance.png", epochs, scaling_factor, game_performance_per_episode)
+utils.plot2D_assistance("/home/pal/Documents/Framework/bn_generative_model/results/agent_assistance.png", epochs, scaling_factor, agent_assistance_per_episode)
diff --git a/test.py b/test.py
index 0a232014330323e1afb30dbde192df29af620de1..66bd57cac6a41c2575e609287938cb842d8dd8b0 100644
--- a/test.py
+++ b/test.py
@@ -8,14 +8,45 @@ DAG = bn.import_DAG('bn_persona_model/persona_model_test.bif')
 #DAGnew = bn.parameter_learning.fit(model, df, methodtype="bayes")
 #bn.print_CPD(DAGnew)
 q1 = bn.inference.fit(DAG, variables=['user_action'], evidence={
-                                                                'game_state_t0': 1,
-                                                                'attempt_t0':0,
-                                                                'robot_assistance':5,
-                                                                'game_state_t1': 1,
+                                                                'game_state_t0': 0,
                                                                 'attempt_t0':1,
-
-
+                                                                'game_state_t1': 0,
+                                                                'attempt_t1':2,
+                                                                'agent_assistance':0,
 })
-
-print(q1.variables)
 print(q1.values)
+
+# robot_assistance = [0, 1, 2, 3, 4, 5]
+# attempt_t0 = [0, 1, 2, 3]
+# game_state_t0 = [0, 1, 2]
+# attempt_t1 = [0]
+# game_state_t1 = [0, 1, 2]
+#
+# query_result = [[[0 for j in range(len(attempt_t0))] for i in range(len(robot_assistance))] for k in range(len(game_state_t0))]
+# for k in range(len(game_state_t0)):
+#     for i in range(len(robot_assistance)):
+#         for j in range(len(attempt_t0)-1):
+#             if j == 0:
+#                 query = bn.inference.fit(DAG, variables=['user_action'],
+#                 evidence={'game_state_t0': k,
+#                 'attempt_t0': j,
+#                 'agent_assistance': i,
+#                 'game_state_t1': k,
+#                 'attempt_t1': j})
+#                 query_result[k][i][j] = query.values
+#             else:
+#                 query = bn.inference.fit(DAG, variables=['user_action'],
+#                                              evidence={'game_state_t0': k,
+#                                                        'attempt_t0': j,
+#                                                        'agent_assistance': i,
+#                                                        'game_state_t1': k,
+#                                                        'attempt_t1': j + 1})
+#                 query_result[k][i][j] = query.values
+# for k in range(len(game_state_t0)):
+#     for i in range(len(robot_assistance)):
+#         for j in range(len(attempt_t0)):
+#             if j == 0:
+#                 print("game_state:",k, "attempt_from:", j," attempt_to:",j, " robot_ass:",i, " prob:", query_result[k][i][j])
+#             else:
+#                 print("game_state:", k, "attempt_from:", j, " attempt_to:", j+1, " robot_ass:", i, " prob:",
+#                       query_result[k][i][j])
diff --git a/utils.py b/utils.py
index a1aa1d495a9e7176deb2f51757f02ea076cefd1f..41d788bacb6523191606f13cbaea7ba8515ecca6 100644
--- a/utils.py
+++ b/utils.py
@@ -43,6 +43,7 @@ def plot2D_assistance(save_path, n_episodes, scaling_factor=1, *y):
     lev_2 = list(map(lambda x:x[2], y[0]))[1::scaling_factor]
     lev_3 = list(map(lambda x:x[3], y[0]))[1::scaling_factor]
     lev_4 = list(map(lambda x:x[4], y[0]))[1::scaling_factor]
+    lev_5 = list(map(lambda x:x[5], y[0]))[1::scaling_factor]
 
     # plot bars
     plt.figure(figsize=(10, 7))
@@ -54,7 +55,8 @@ def plot2D_assistance(save_path, n_episodes, scaling_factor=1, *y):
             width=barWidth, label='lev_3')
     plt.bar(r, lev_4, bottom=np.array(lev_0) + np.array(lev_1)+ np.array(lev_2)+ np.array(lev_3), edgecolor='white',
             width=barWidth, label='lev_4')
-
+    plt.bar(r, lev_5, bottom=np.array(lev_0) + np.array(lev_1) + np.array(lev_2) + np.array(lev_3)+ np.array(lev_4), edgecolor='white',
+            width=barWidth, label='lev_5')
 
     plt.legend()
     # Custom X axis