Skip to content
Snippets Groups Projects
Commit a019f432 authored by Antonio Andriella's avatar Antonio Andriella
Browse files

Save simulation data and plt.clf() to plot more than figures at the time

parent cc1419f5
No related branches found
No related tags found
No related merge requests found
...@@ -282,6 +282,7 @@ def main(): ...@@ -282,6 +282,7 @@ def main():
learned_policy_filename = args.agent_patient_interaction_folder+"/"+str(user_id)+"/"+str(with_feedback)+"/"+str(session+1)+"/"+args.output_policy_filename learned_policy_filename = args.agent_patient_interaction_folder+"/"+str(user_id)+"/"+str(with_feedback)+"/"+str(session+1)+"/"+args.output_policy_filename
learned_reward_filename = args.agent_patient_interaction_folder+"/"+str(user_id)+"/"+str(with_feedback)+"/"+str(session+1)+"/"+args.output_reward_filename learned_reward_filename = args.agent_patient_interaction_folder+"/"+str(user_id)+"/"+str(with_feedback)+"/"+str(session+1)+"/"+args.output_reward_filename
learned_value_f_filename = args.agent_patient_interaction_folder+"/"+str(user_id)+"/"+str(with_feedback)+"/"+str(session+1)+"/"+args.output_value_filename learned_value_f_filename = args.agent_patient_interaction_folder+"/"+str(user_id)+"/"+str(with_feedback)+"/"+str(session+1)+"/"+args.output_value_filename
therapist_patient_interaction_folder = args.therapist_patient_interaction_folder therapist_patient_interaction_folder = args.therapist_patient_interaction_folder
agent_patient_interaction_folder = args.agent_patient_interaction_folder agent_patient_interaction_folder = args.agent_patient_interaction_folder
scaling_factor = 1 scaling_factor = 1
...@@ -308,7 +309,7 @@ def main(): ...@@ -308,7 +309,7 @@ def main():
action_space_list = action_space action_space_list = action_space
terminal_state = [(Game_State.counter.value, i, user_action[j]) for i in range(1, Attempt.counter.value + 1) for j in terminal_state = [(Game_State.counter.value, i, user_action[j]) for i in range(1, Attempt.counter.value + 1) for j in
range(len(user_action))] range(len(user_action))]
initial_state = (1, 1, 0) initial_state = (1, 1, 1)
#output folders #output folders
...@@ -350,7 +351,6 @@ def main(): ...@@ -350,7 +351,6 @@ def main():
bn_model_agent_behaviour_from_data_and_therapist = None bn_model_agent_behaviour_from_data_and_therapist = None
if os.path.exists(output_folder_data_path): if os.path.exists(output_folder_data_path):
bn_model_user_action_from_data_and_therapist = Sim.build_model_from_data( bn_model_user_action_from_data_and_therapist = Sim.build_model_from_data(
csv_filename=output_folder_data_path + "/summary_bn_variables_from_data.csv", dag_filename=bn_user_model_filename, csv_filename=output_folder_data_path + "/summary_bn_variables_from_data.csv", dag_filename=bn_user_model_filename,
...@@ -385,6 +385,14 @@ def main(): ...@@ -385,6 +385,14 @@ def main():
utils.plot2D_game_performance(plot_game_performance_path, epochs, scaling_factor, game_performance_per_episode) utils.plot2D_game_performance(plot_game_performance_path, epochs, scaling_factor, game_performance_per_episode)
utils.plot2D_assistance(plot_agent_assistance_path, epochs, scaling_factor, agent_assistance_per_episode) utils.plot2D_assistance(plot_agent_assistance_path, epochs, scaling_factor, agent_assistance_per_episode)
sim_patient_performance_filename = "sim_patient_performance.pkl"
sim_agent_assistance_filename = "sim_agent_assistance.pkl"
with open(output_folder_data_path+"/"+sim_agent_assistance_filename, 'wb') as f:
pickle.dump(game_performance_per_episode, f, protocol=2)
with open(output_folder_data_path + "/" + sim_patient_performance_filename, 'wb') as f:
pickle.dump(agent_assistance_per_episode, f, protocol=2)
# add episodes from different policies # add episodes from different policies
# for e in range(len(episodes)): # for e in range(len(episodes)):
# episodes_from_different_policies.append(Episode(episodes[e]._t)) # episodes_from_different_policies.append(Episode(episodes[e]._t))
...@@ -402,7 +410,7 @@ def main(): ...@@ -402,7 +410,7 @@ def main():
# R(s) and pi(s) generated from the first sim # R(s) and pi(s) generated from the first sim
maxent_R = maxent(world=cognitive_game_world, terminal=terminals, trajectories=episodes) maxent_R = maxent(world=cognitive_game_world, terminal=terminals, trajectories=episodes)
maxent_V, maxent_P = vi.value_iteration(cognitive_game_world.p_transition, maxent_R, gamma=0.99, error=1e-2, maxent_V, maxent_P = vi.value_iteration(cognitive_game_world.p_transition, maxent_R, gamma=0.9, error=1e-4,
deterministic=False) deterministic=False)
print(maxent_P) print(maxent_P)
with open(learned_policy_filename, 'wb') as f: with open(learned_policy_filename, 'wb') as f:
...@@ -425,18 +433,22 @@ def main(): ...@@ -425,18 +433,22 @@ def main():
# else: # else:
# maxent_P_real_sim[state_index][action_index] = 0.02 # maxent_P_real_sim[state_index][action_index] = 0.02
# maxent_P_real_sim[state_index] = list(map(lambda x:x/sum(maxent_P_real_sim[state_index]), maxent_P_real_sim[state_index])) # maxent_P_real_sim[state_index] = list(map(lambda x:x/sum(maxent_P_real_sim[state_index]), maxent_P_real_sim[state_index]))
plt.clf()
sns.heatmap(np.reshape(maxent_R, (4, 12)), cmap="Spectral", annot=True, cbar=False) sns.heatmap(np.reshape(maxent_R, (4, 12)), cmap="Spectral", annot=True, cbar=False)
plt.savefig(output_folder_data_path + "/maxent_R.jpg") plt.savefig(output_folder_data_path + "/maxent_R.jpg")
plt.show() plt.clf()
sns.heatmap(np.reshape(maxent_V, (4, 12)), cmap="Spectral", annot=True, cbar=False) sns.heatmap(np.reshape(maxent_V, (4, 12)), cmap="Spectral", annot=True, cbar=False)
plt.savefig(output_folder_data_path + "/maxent_V.jpg") plt.savefig(output_folder_data_path + "/maxent_V.jpg")
plt.show() plt.clf()
maxent_P_det = list(map(lambda x: np.argmax(x), maxent_P)) maxent_P_det = list(map(lambda x: np.argmax(x), maxent_P))
sns.heatmap(np.reshape(maxent_P_det, (4, 12)), cmap="Spectral", annot=True, cbar=False) sns.heatmap(np.reshape(maxent_P_det, (4, 12)), cmap="Spectral", annot=True, cbar=False)
plt.savefig(output_folder_data_path + "/maxent_P.jpg") plt.savefig(output_folder_data_path + "/maxent_P.jpg")
plt.show() plt.clf()
f = open(output_folder_data_path+"/"+sim_agent_assistance_filename,'rb')
mydic = pickle.load(f)
f.close()
print(mydic)
if __name__ == '__main__': if __name__ == '__main__':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment