Skip to content
Snippets Groups Projects
Commit a019f432 authored by Antonio Andriella's avatar Antonio Andriella
Browse files

Save simulation data and plt.clf() to plot more than figures at the time

parent cc1419f5
No related branches found
No related tags found
No related merge requests found
......@@ -282,6 +282,7 @@ def main():
learned_policy_filename = args.agent_patient_interaction_folder+"/"+str(user_id)+"/"+str(with_feedback)+"/"+str(session+1)+"/"+args.output_policy_filename
learned_reward_filename = args.agent_patient_interaction_folder+"/"+str(user_id)+"/"+str(with_feedback)+"/"+str(session+1)+"/"+args.output_reward_filename
learned_value_f_filename = args.agent_patient_interaction_folder+"/"+str(user_id)+"/"+str(with_feedback)+"/"+str(session+1)+"/"+args.output_value_filename
therapist_patient_interaction_folder = args.therapist_patient_interaction_folder
agent_patient_interaction_folder = args.agent_patient_interaction_folder
scaling_factor = 1
......@@ -308,7 +309,7 @@ def main():
action_space_list = action_space
terminal_state = [(Game_State.counter.value, i, user_action[j]) for i in range(1, Attempt.counter.value + 1) for j in
range(len(user_action))]
initial_state = (1, 1, 0)
initial_state = (1, 1, 1)
#output folders
......@@ -350,7 +351,6 @@ def main():
bn_model_agent_behaviour_from_data_and_therapist = None
if os.path.exists(output_folder_data_path):
bn_model_user_action_from_data_and_therapist = Sim.build_model_from_data(
csv_filename=output_folder_data_path + "/summary_bn_variables_from_data.csv", dag_filename=bn_user_model_filename,
......@@ -385,6 +385,14 @@ def main():
utils.plot2D_game_performance(plot_game_performance_path, epochs, scaling_factor, game_performance_per_episode)
utils.plot2D_assistance(plot_agent_assistance_path, epochs, scaling_factor, agent_assistance_per_episode)
sim_patient_performance_filename = "sim_patient_performance.pkl"
sim_agent_assistance_filename = "sim_agent_assistance.pkl"
with open(output_folder_data_path+"/"+sim_agent_assistance_filename, 'wb') as f:
pickle.dump(game_performance_per_episode, f, protocol=2)
with open(output_folder_data_path + "/" + sim_patient_performance_filename, 'wb') as f:
pickle.dump(agent_assistance_per_episode, f, protocol=2)
# add episodes from different policies
# for e in range(len(episodes)):
# episodes_from_different_policies.append(Episode(episodes[e]._t))
......@@ -402,7 +410,7 @@ def main():
# R(s) and pi(s) generated from the first sim
maxent_R = maxent(world=cognitive_game_world, terminal=terminals, trajectories=episodes)
maxent_V, maxent_P = vi.value_iteration(cognitive_game_world.p_transition, maxent_R, gamma=0.99, error=1e-2,
maxent_V, maxent_P = vi.value_iteration(cognitive_game_world.p_transition, maxent_R, gamma=0.9, error=1e-4,
deterministic=False)
print(maxent_P)
with open(learned_policy_filename, 'wb') as f:
......@@ -425,18 +433,22 @@ def main():
# else:
# maxent_P_real_sim[state_index][action_index] = 0.02
# maxent_P_real_sim[state_index] = list(map(lambda x:x/sum(maxent_P_real_sim[state_index]), maxent_P_real_sim[state_index]))
plt.clf()
sns.heatmap(np.reshape(maxent_R, (4, 12)), cmap="Spectral", annot=True, cbar=False)
plt.savefig(output_folder_data_path + "/maxent_R.jpg")
plt.show()
plt.clf()
sns.heatmap(np.reshape(maxent_V, (4, 12)), cmap="Spectral", annot=True, cbar=False)
plt.savefig(output_folder_data_path + "/maxent_V.jpg")
plt.show()
plt.clf()
maxent_P_det = list(map(lambda x: np.argmax(x), maxent_P))
sns.heatmap(np.reshape(maxent_P_det, (4, 12)), cmap="Spectral", annot=True, cbar=False)
plt.savefig(output_folder_data_path + "/maxent_P.jpg")
plt.show()
plt.clf()
f = open(output_folder_data_path+"/"+sim_agent_assistance_filename,'rb')
mydic = pickle.load(f)
f.close()
print(mydic)
if __name__ == '__main__':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment