diff --git a/main.py b/main.py
index c65e6d3b16f3d2a5ae38e318fb52b98f41e1ee4a..09624e73702b6ccca5c7dac1070bee0aa4de4bd4 100644
--- a/main.py
+++ b/main.py
@@ -23,6 +23,7 @@ import os
 import math
 import operator
 import datetime
+import pickle
 import bnlearn
 
 from cognitive_game_env import CognitiveGame
@@ -124,7 +125,7 @@ def maxent(world, terminal, trajectories):
         estimation of the reward based on the MEIRL
     """
     # set up features: we use one feature vector per state
-    features = world.state_features()#assistive_feature(trajectories)
+    features = world.assistive_feature(trajectories)
 
     # choose our parameter initialization strategy:
     #   initialize parameters with constant
@@ -139,6 +140,23 @@ def maxent(world, terminal, trajectories):
 
     return reward
 
+def build_policy_from_therapist(bn_model_path, action_space, n_game_state, n_attempt):
+    DAG = bnlearn.import_DAG(bn_model_path)
+    therapist_policy = [[[ 0 for k in range(len(action_space))] for j in range(n_attempt)] for i in range(n_game_state)]
+    for g in range(n_game_state):
+        for a in range(n_attempt):
+            q_origin = bnlearn.inference.fit(DAG, variables=['agent_assistance'], evidence={
+                    'game_state': g,
+                    'attempt': a})
+            therapist_policy[g][a] = q_origin.values.tolist()
+    return therapist_policy
+
+def merge_agent_policy(policy_from_data, policy_from_therapist):
+    merged_policy = policy_from_therapist[:]
+    for index in range(len(policy_from_therapist)):
+        merged_policy[index]
+        merged_policy[index] = list(map(lambda x:sum(x), ))
+
 def merge_user_log(folder_pathname, user_id, with_feedback, column_to_remove):
     absolute_path = folder_pathname+"/"+str(+user_id)+"/"+str(with_feedback)
     df = pd.DataFrame()
@@ -177,15 +195,15 @@ def compute_agent_policy(folder_pathname, user_id, with_feedback, state_space, a
         if index == 0 or index in episode_length:
             state_point = (row['game_state'], row['attempt'], 0)
             state_index = ep.state_from_point_to_index(state_space, state_point)
-            action_point = (row['agent_feedback'], row['agent_assistance'])
-            action_index = ep.state_from_point_to_index(action_space, action_point)
+            action_point = (row['agent_assistance'])
+            action_index = action_point
             agent_policy_counter[state_index][action_index] += 1
             row_t_0 = row['user_action']
         else:
             state_point = (row['game_state'], row['attempt'], row_t_0)
             state_index = ep.state_from_point_to_index(state_space, state_point)
-            action_point = (row['agent_feedback'], row['agent_assistance'])
-            action_index = ep.state_from_point_to_index(action_space, action_point)
+            action_point = (row['agent_assistance'])
+            action_index = action_point
             agent_policy_counter[state_index][action_index] += 1
             row_t_0 = row['user_action']
     for s in range(len(state_space)):
@@ -194,16 +212,17 @@ def compute_agent_policy(folder_pathname, user_id, with_feedback, state_space, a
     return agent_policy_prob
 
 def main():
-    df, episode_length = merge_user_log(folder_pathname="/home/pal/Documents/Framework/GenerativeMutualShapingRL/data",
-                   user_id=1, with_feedback=True, column_to_remove=None)
 
     #################GENERATE SIMULATION################################
     # SIMULATION PARAMS
-    epochs = 20
+    epochs = 10
     scaling_factor = 1
     # initialise the agent
     bn_model_user_action_filename = '/home/pal/Documents/Framework/bn_generative_model/bn_persona_model/persona_model_test.bif'
+    bn_model_agent_behaviour_filename = '/home/pal/Documents/Framework/bn_generative_model/bn_agent_model/agent_assistive_model.bif'
+    learned_policy_filename = ""
     bn_model_user_action = bnlearn.import_DAG(bn_model_user_action_filename)
+    bn_model_agent_behaviour = bnlearn.import_DAG(bn_model_agent_behaviour_filename)
 
     #setup by the caregiver
     user_pref_assistance = 2
@@ -222,123 +241,127 @@ def main():
     state_space_index = [episode_instance.state_from_point_to_index(states_space_list, s) for s in states_space_list]
     agent_assistance_action = [i for i in range(Agent_Assistance.counter.value)]
     agent_feedback_action = [i for i in range(Agent_Feedback.counter.value)]
-    action_space = (agent_feedback_action, agent_assistance_action)
-    action_space_list = list(itertools.product(*action_space))
-    action_space_index = [episode_instance.state_from_point_to_index(action_space_list, a) for a in action_space_list]
+    action_space = (agent_assistance_action)
+    action_space_list = action_space#list(itertools.product(*action_space))
+    action_space_index = action_space_list#[episode_instance.state_from_point_to_index(action_space_list, a) for a in action_space_list]
     terminal_state = [(Game_State.counter.value, i, user_action[j]) for i in range(1, Attempt.counter.value + 1) for j in
                       range(len(user_action))]
     initial_state = (1, 1, 0)
     agent_policy = [0 for s in state_space]
 
-    compute_agent_policy(folder_pathname="/home/pal/Documents/Framework/GenerativeMutualShapingRL/data/",
-                         user_id=1, with_feedback=True, state_space=states_space_list,
-                         action_space=action_space_list, episode_length=episode_length)
 
 
-    #1. RUN THE SIMULATION WITH THE PARAMS SET BY THE CAREGIVER
-    game_performance_per_episode, react_time_per_episode, agent_assistance_per_episode, agent_feedback_per_episode, episodes_list = \
-    Sim.simulation(bn_model_user_action=bn_model_user_action,
-                   var_user_action_target_action=['user_action'],
-                   game_state_bn_name="game_state",
-                   attempt_bn_name="attempt",
-                   agent_assistance_bn_name="agent_assistance",
-                   agent_feedback_bn_name="agent_feedback",
-                   user_pref_assistance=user_pref_assistance,
-                   agent_behaviour=agent_behaviour,
-                   agent_policy=[],
-                   state_space=states_space_list,
-                   action_space=action_space_list,
-                   epochs=epochs, task_complexity=5, max_attempt_per_object=4, alpha_learning=0.1)
-
-    #2. GIVEN THE EPISODES ESTIMATE R(S) and PI(S)
-
-    format = "%a%b%d-%H:%M:%S %Y"
-    today_id = datetime.datetime.today()
-    full_path = os.getcwd() + "/results/" + str(today_id) +"/"
-    if not os.path.exists(full_path):
-        os.mkdir(full_path)
-
-    plot_game_performance_path = "SIM_game_performance_"+"epoch_" + str(epochs) + ".jpg"
-    plot_agent_assistance_path = "SIM_agent_assistance_"+"epoch_"+str(epochs)+".jpg"
-    plot_agent_feedback_path = "SIM_agent_feedback_"+"epoch_"+str(epochs)+".jpg"
-
-    utils.plot2D_game_performance(full_path +plot_game_performance_path, epochs, scaling_factor, game_performance_per_episode)
-    utils.plot2D_assistance(full_path + plot_agent_assistance_path, epochs, scaling_factor, agent_assistance_per_episode)
-    utils.plot2D_feedback(full_path + plot_agent_feedback_path, epochs, scaling_factor, agent_feedback_per_episode)
-
-    world, reward, terminals = setup_mdp(initial_state=initial_state, terminal_state=terminal_state, task_length=Game_State.counter.value,
-                                         n_max_attempt=Attempt.counter.value, action_space=action_space_list, state_space=states_space_list,
-                                         user_action=user_action, timeout=15, episode = episodes_list)
-
-    state_tuple_indexed = [states_space_list.index(tuple(s)) for s in (states_space_list)]
-
-    #Dirty way to represent the state space in a graphical way
-    states_space_list_string = [[str(states_space_list[j*12+i]) for i in range(12)] for j in range(3)]
-    build_2dtable(states_space_list_string, 3, 12)
-
-    #R(s) and pi(s) generated from the first sim
-    maxent_R_sim = maxent(world, terminals, episodes_list)
-    maxent_V_sim, maxent_P_sim = vi.value_iteration(world.p_transition, maxent_R_sim, gamma=0.9, error=1e-3, deterministic=False)
-    plt.figure(figsize=(12, 4), num="maxent_rew")
-    sns.heatmap(np.reshape(maxent_R_sim, (4, 12)), cmap="Spectral", annot=True, cbar=False)
-    plt.savefig(full_path + "sim_maxent_R.jpg")
-    plt.figure(figsize=(12, 4), num="maxent_V")
-    sns.heatmap(np.reshape(maxent_V_sim, (4, 12)), cmap="Spectral", annot=True, cbar=False)
-    plt.savefig(full_path + "sim_maxent_V.jpg")
-    plt.figure(figsize=(12, 4), num="maxent_P")
-    sns.heatmap(np.reshape(maxent_P_sim, (4, 12)), cmap="Spectral", annot=True, cbar=False)
-    plt.savefig(full_path + "sim_maxent_P.jpg")
-    #####################################################################################
+    #####################INPUT AND OUTPUT FOLDER ####################################
+    input_folder_data = "/home/pal/Documents/Framework/GenerativeMutualShapingRL/data"
+    user_id = 1
+    with_feedback = True
 
-    #3.WE GOT SOME REAL DATA UPDATE THE BELIEF OF THE BN
-    log_directory = "/home/pal/Documents/Framework/bn_generative_model/bn_persona_model/cognitive_game.csv"
+    output_folder_data = os.getcwd() + "/results/" + str(user_id)
+    if not os.path.exists(output_folder_data):
+        os.mkdir(output_folder_data)
+        if not os.path.exists(output_folder_data+"/"+str(with_feedback)):
+            os.mkdir(output_folder_data+"/"+with_feedback)
 
+    #1. CREATE INITIAL USER COGNITIVE MODEL FROM DATA
+    df_from_data, episode_length = merge_user_log(folder_pathname=input_folder_data,
+                                        user_id=user_id, with_feedback=with_feedback, column_to_remove=None)
+    #2. CREATE POLICY FROM DATA
+    agent_policy_from_data = compute_agent_policy(folder_pathname=input_folder_data,
+                         user_id=user_id, with_feedback=with_feedback, state_space=states_space_list,
+                         action_space=action_space_list, episode_length=episode_length)
+
+    # 3. RUN THE SIMULATION
+    log_directory = input_folder_data+"/"+str(user_id)+"/"+str(with_feedback)
+    bn_model_user_action_from_data_and_therapist = None
+    bn_model_agent_behaviour_from_data_and_therapist = None
     if os.path.exists(log_directory):
-        bn_model_user_action_from_data = Sim.build_model_from_data(csv_filename=log_directory, dag_filename=bn_model_user_action_filename, dag_model=bn_model_user_action)
+        bn_model_user_action_from_data_and_therapist = Sim.build_model_from_data(csv_filename=log_directory+"/summary_bn_variables.csv", dag_filename=bn_model_user_action_filename, dag_model=bn_model_user_action)
+        bn_model_agent_behaviour_from_data_and_therapist = Sim.build_model_from_data(csv_filename=log_directory+"/summary_bn_variables.csv", dag_filename=bn_model_agent_behaviour_filename, dag_model=bn_model_agent_behaviour)
     else:
         assert ("You're not using the user information")
         question = input("Are you sure you don't want to load user's belief information?")
 
-    game_performance_per_episode, react_time_per_episode, agent_assistance_per_episode, agent_feedback_per_episode, episodes_list = \
-        Sim.simulation(bn_model_user_action=bn_model_user_action,
-                       var_user_action_target_action=['user_action'],
-                       game_state_bn_name="game_state",
-                       attempt_bn_name="attempt",
-                       agent_assistance_bn_name="agent_assistance",
-                       agent_feedback_bn_name="agent_feedback",
-                       user_pref_assistance=user_pref_assistance,
-                       agent_behaviour=agent_behaviour,
-                       agent_policy = maxent_P_sim,
-                       state_space=states_space_list,
-                       action_space=action_space_list,
-                       epochs=epochs, task_complexity=5, max_attempt_per_object=4, alpha_learning=0.1)
-
-    plot_game_performance_path = "REAL_SIM_game_performance_" + "epoch_" + str(epochs) + ".jpg"
-    plot_agent_assistance_path = "REAL_SIM_agent_assistance_" + "epoch_" + str(epochs) + ".jpg"
-    plot_agent_feedback_path = "REAL_SIM_agent_feedback_" + "epoch_" + str(epochs) + ".jpg"
-
-    utils.plot2D_game_performance(full_path + plot_game_performance_path, epochs, scaling_factor, game_performance_per_episode)
-    utils.plot2D_assistance(full_path + plot_agent_assistance_path, epochs, scaling_factor, agent_assistance_per_episode)
-    utils.plot2D_feedback(full_path + plot_agent_feedback_path, epochs, scaling_factor, agent_feedback_per_episode)
-
-    # R(s) and pi(s) generated from the first sim
-    maxent_R_real_sim = maxent(world, terminals, episodes_list)
-    maxent_V_real_sim, maxent_P_real_sim = vi.value_iteration(world.p_transition, maxent_R_real_sim, gamma=0.9, error=1e-3,
-                                                    deterministic=True)
+    diff = dict([(s, [0] * len(action_space_index)) for s in state_space_index])
+    entropy = dict([(s, 0) for s in state_space_index])
+    N = 5
+
+    for i in range(N):
+        game_performance_per_episode, react_time_per_episode, agent_assistance_per_episode, agent_feedback_per_episode, episodes_list = \
+            Sim.simulation(bn_model_user_action=bn_model_user_action_from_data_and_therapist,
+                           bn_model_agent_behaviour = bn_model_agent_behaviour_from_data_and_therapist,
+                           var_user_action_target_action=['user_action'],
+                           var_agent_behaviour_target_action=['agent_assistance'],
+                           game_state_bn_name="game_state",
+                           attempt_bn_name="attempt",
+                           agent_assistance_bn_name="agent_assistance",
+                           agent_feedback_bn_name="agent_feedback",
+                           user_pref_assistance=user_pref_assistance,
+                           agent_behaviour=agent_behaviour,
+                           agent_policy = agent_policy_from_data,
+                           state_space=states_space_list,
+                           action_space=action_space_list,
+                           epochs=epochs, task_complexity=5, max_attempt_per_object=4, alpha_learning=0.1)
+
+        plot_game_performance_path = output_folder_data+"/REAL_SIM_game_performance_" + "epoch_" + str(epochs) + ".jpg"
+        plot_agent_assistance_path = output_folder_data+"/REAL_SIM_agent_assistance_" + "epoch_" + str(epochs) + ".jpg"
+        plot_agent_feedback_path = output_folder_data+"/REAL_SIM_agent_feedback_" + "epoch_" + str(epochs) + ".jpg"
+
+        utils.plot2D_game_performance(plot_game_performance_path, epochs, scaling_factor, game_performance_per_episode)
+        utils.plot2D_assistance(plot_agent_assistance_path, epochs, scaling_factor, agent_assistance_per_episode)
+        utils.plot2D_feedback(plot_agent_feedback_path, epochs, scaling_factor, agent_feedback_per_episode)
+
+        cognitive_game_world, reward, terminals = setup_mdp(initial_state=initial_state, terminal_state=terminal_state,
+                                                            task_length=Game_State.counter.value, n_max_attempt=Attempt.counter.value,
+                                                            action_space=action_space_list, state_space=states_space_list,
+                                                            user_action=user_action, timeout=15, episode=episodes_list)
+
+        state_tuple_indexed = [states_space_list.index(tuple(s)) for s in (states_space_list)]
+        states_space_list_string = [[str(states_space_list[j*12+i]) for i in range(12)] for j in range(4)]
+        build_2dtable(states_space_list_string, 4, 12)
+
+        # R(s) and pi(s) generated from the first sim
+        maxent_R_real_sim = maxent(world=cognitive_game_world, terminal=terminals, trajectories=episodes_list)
+        maxent_V_real_sim, maxent_P_real_sim = vi.value_iteration(cognitive_game_world.p_transition, maxent_R_real_sim, gamma=0.9, error=1e-3,
+        deterministic=True)
+
+        learned_policy_filename = output_folder_data + "/" + "learned_policy.pkl"
+        with open(learned_policy_filename, 'wb') as f:
+            pickle.dump(maxent_P_real_sim, f, protocol=2)
+
+
+        for s in state_space_index:
+            index = maxent_P_real_sim[s]
+            diff[s][index] += 1.0 / N
+        sns.heatmap(np.reshape(maxent_P_real_sim, (4, 12)), cmap="Spectral", annot=True, cbar=False)
+        plt.savefig(output_folder_data + "maxent_P_iter_"+str(i)+".jpg")
+
+    for s in state_space_index:
+        E = 0
+        for i in range(len(action_space_index)):
+            if diff[s][i] > 0:
+                E -= diff[s][i] * math.log(diff[s][i])
+        entropy[s] = E
+
+    # highlight high and low entropy states
+    entropy_sort = sorted(entropy.items(), key=operator.itemgetter(1))
+    s_preferences = [episode_instance.state_from_index_to_point(states_space_list, entropy_sort[i][0]) for i in range(-1, -6, -1)]
+    s_constraints = [episode_instance.state_from_index_to_point(states_space_list, entropy_sort[i][0]) for i in range(27)][22:]
+
+    print("S_preferences: ", s_preferences)
+    print("S_constrains: ", s_constraints)
+
+
     plt.figure(figsize=(12, 4), num="maxent_rew")
     sns.heatmap(np.reshape(maxent_R_real_sim, (4, 12)), cmap="Spectral", annot=True, cbar=False)
-    plt.savefig(full_path + "real_sim_maxent_R.jpg")
+    plt.savefig(output_folder_data + "real_sim_maxent_R.jpg")
     plt.figure(figsize=(12, 4), num="maxent_V")
     sns.heatmap(np.reshape(maxent_V_real_sim, (4, 12)), cmap="Spectral", annot=True, cbar=False)
-    plt.savefig(full_path + "real_sim_maxent_V.jpg")
+    plt.savefig(output_folder_data + "real_sim_maxent_V.jpg")
     plt.figure(figsize=(12, 4), num="maxent_P")
     sns.heatmap(np.reshape(maxent_P_real_sim, (4, 12)), cmap="Spectral", annot=True, cbar=False)
-    plt.savefig(full_path + "real_sim_maxent_P.jpg")
+    plt.savefig(output_folder_data + "real_sim_maxent_P.jpg")
+
 
-    # Compute entropy between two policies
-    policies = [maxent_P_sim, maxent_P_real_sim]
-    s_preferences, s_constraints = get_entropy(policies, state_space_index, action_space_index)
-    print("Preferences:", s_preferences, " Constraints:", s_constraints)
 
 
 if __name__ == '__main__':