Skip to content
Snippets Groups Projects
Commit 42571017 authored by Antonio Andriella's avatar Antonio Andriella
Browse files

Working version of the entire framework

parent 404ad9f7
No related branches found
No related tags found
No related merge requests found
...@@ -155,7 +155,7 @@ def main(): ...@@ -155,7 +155,7 @@ def main():
bn_model_user_action = bnlearn.import_DAG('/home/pal/Documents/Framework/bn_generative_model/bn_persona_model/user_action_model.bif') bn_model_user_action = bnlearn.import_DAG('/home/pal/Documents/Framework/bn_generative_model/bn_persona_model/user_action_model.bif')
bn_model_user_react_time = bnlearn.import_DAG('/home/pal/Documents/Framework/bn_generative_model/bn_persona_model/user_react_time_model.bif') bn_model_user_react_time = bnlearn.import_DAG('/home/pal/Documents/Framework/bn_generative_model/bn_persona_model/user_react_time_model.bif')
# initialise memory, attention and reactivity varibles # initialise memory, attention and reactivity variables
persona_memory = 0; persona_memory = 0;
persona_attention = 0; persona_attention = 0;
persona_reactivity = 1; persona_reactivity = 1;
...@@ -180,28 +180,9 @@ def main(): ...@@ -180,28 +180,9 @@ def main():
range(len(user_action))] range(len(user_action))]
initial_state = (1, 1, 0) initial_state = (1, 1, 0)
# attempt = [i for i in range(1, Attempt.counter.value + 1)] #1. RUN THE SIMULATION WITH THE PARAMS SET BY THE CAREGIVER
# # +1a (3,_,_) absorbing state
# game_state = [i for i in range(0, Game_State.counter.value + 1)]
# user_action = [i for i in range(-1, User_Action.counter.value - 1)]
# state_space = (game_state, attempt, user_action)
# states_space_list = list(itertools.product(*state_space))
# agent_assistance_action = [i for i in range(Agent_Assistance.counter.value)]
# agent_feedback_action = [i for i in range(Agent_Feedback.counter.value)]
# action_space = (agent_assistance_action, agent_feedback_action)
# action_space_list = list(itertools.product(*action_space))
##############BEFORE RUNNING THE SIMULATION UPDATE THE BELIEF IF YOU HAVE DATA####################
log_directory = "/home/pal/carf_ws/src/carf/caregiver_in_the_loop/log/1/0"
if os.path.exists(log_directory):
bn_functions.update_episodes_batch(bn_model_user_action, bn_model_user_react_time, bn_model_caregiver_assistance,
bn_model_caregiver_feedback, folder_filename=log_directory,
with_caregiver=True)
else:
assert ("You're not using the user information")
question = input("Are you sure you don't want to load user's belief information?")
game_performance_per_episode, react_time_per_episode, agent_assistance_per_episode, agent_feedback_per_episode, episodes_list = \ game_performance_per_episode, react_time_per_episode, agent_assistance_per_episode, agent_feedback_per_episode, episodes_list = \
Sim.simulation(bn_model_user_action=bn_model_user_action, var_user_action_target_action=['user_action'], Sim.simulation(bn_model_user_action=bn_model_user_action, var_user_action_target_action=['user_action'],
bn_model_user_react_time=bn_model_user_react_time, bn_model_user_react_time=bn_model_user_react_time,
...@@ -218,6 +199,7 @@ def main(): ...@@ -218,6 +199,7 @@ def main():
state_space=states_space_list, action_space=action_space_list, state_space=states_space_list, action_space=action_space_list,
epochs=epochs, task_complexity=5, max_attempt_per_object=4) epochs=epochs, task_complexity=5, max_attempt_per_object=4)
#2. GIVEN THE EPISODES ESTIMATE R(S) and PI(S)
format = "%a%b%d-%H:%M:%S %Y" format = "%a%b%d-%H:%M:%S %Y"
today_id = datetime.datetime.today() today_id = datetime.datetime.today()
...@@ -225,9 +207,9 @@ def main(): ...@@ -225,9 +207,9 @@ def main():
if not os.path.exists(full_path): if not os.path.exists(full_path):
os.mkdir(full_path) os.mkdir(full_path)
plot_game_performance_path = "BEFORE_game_performance_"+"epoch_" + str(epochs) + "_persona_memory_" + str(persona_memory) + "_persona_attention_" + str(persona_attention) + "_persona_reactivity_" + str(persona_reactivity) + ".jpg" plot_game_performance_path = "SIM_game_performance_"+"epoch_" + str(epochs) + "_persona_memory_" + str(persona_memory) + "_persona_attention_" + str(persona_attention) + "_persona_reactivity_" + str(persona_reactivity) + ".jpg"
plot_agent_assistance_path = "BEFORE_agent_assistance_"+"epoch_"+str(epochs)+"_persona_memory_"+str(persona_memory)+"_persona_attention_"+str(persona_attention)+"_persona_reactivity_"+str(persona_reactivity)+".jpg" plot_agent_assistance_path = "SIM_agent_assistance_"+"epoch_"+str(epochs)+"_persona_memory_"+str(persona_memory)+"_persona_attention_"+str(persona_attention)+"_persona_reactivity_"+str(persona_reactivity)+".jpg"
plot_agent_feedback_path = "BEFORE_agent_feedback_"+"epoch_"+str(epochs)+"_persona_memory_"+str(persona_memory)+"_persona_attention_"+str(persona_attention)+"_persona_reactivity_"+str(persona_reactivity)+".jpg" plot_agent_feedback_path = "SIM_agent_feedback_"+"epoch_"+str(epochs)+"_persona_memory_"+str(persona_memory)+"_persona_attention_"+str(persona_attention)+"_persona_reactivity_"+str(persona_reactivity)+".jpg"
utils.plot2D_game_performance(full_path +plot_game_performance_path, epochs, scaling_factor, game_performance_per_episode) utils.plot2D_game_performance(full_path +plot_game_performance_path, epochs, scaling_factor, game_performance_per_episode)
utils.plot2D_assistance(full_path + plot_agent_assistance_path, epochs, scaling_factor, agent_assistance_per_episode) utils.plot2D_assistance(full_path + plot_agent_assistance_path, epochs, scaling_factor, agent_assistance_per_episode)
...@@ -239,41 +221,35 @@ def main(): ...@@ -239,41 +221,35 @@ def main():
state_tuple_indexed = [states_space_list.index(tuple(s)) for s in (states_space_list)] state_tuple_indexed = [states_space_list.index(tuple(s)) for s in (states_space_list)]
#Dirty way to represent the state space in a graphical way
states_space_list_string = [[str(states_space_list[j*12+i]) for i in range(12)] for j in range(3)] states_space_list_string = [[str(states_space_list[j*12+i]) for i in range(12)] for j in range(3)]
build_2dtable(states_space_list_string, 3, 12) build_2dtable(states_space_list_string, 3, 12)
# exp_V, exp_P = vi.value_iteration(world.p_transition, reward, gamma=0.9, error=1e-3, deterministic=True) #R(s) and pi(s) generated from the first sim
# plt.figure(figsize=(12, 4), num="state_space") maxent_R_sim = maxent(world, terminals, episodes_list)
# sns.heatmap(np.reshape(state_tuple_indexed, (4, 12)), cmap="Spectral", annot=True, cbar=False) maxent_V_sim, maxent_P_sim = vi.value_iteration(world.p_transition, maxent_R_sim, gamma=0.9, error=1e-3, deterministic=True)
# plt.savefig(full_path+"state_space.jpg")
# #PLOTS EXPERT
# plt.figure(figsize=(12, 4), num="exp_rew")
# sns.heatmap(np.reshape(reward, (4, 12)), cmap="Spectral", annot=True, cbar=False)
# plt.savefig(full_path+"exp_rew.jpg")
# plt.figure(figsize=(12, 4), num="exp_V")
# sns.heatmap(np.reshape(exp_V, (4, 12)), cmap="Spectral", annot=True, cbar=False)
# plt.savefig(full_path+"exp_V.jpg")
# plt.figure(figsize=(12, 4), num="exp_P")
# sns.heatmap(np.reshape(exp_P, (4, 12)), cmap="Spectral", annot=True, cbar=False)
# plt.savefig(full_path+"exp_P.jpg")
maxent_R = maxent(world, terminals, episodes_list)
maxent_V, maxent_P = vi.value_iteration(world.p_transition, maxent_R, gamma=0.9, error=1e-3, deterministic=True)
plt.figure(figsize=(12, 4), num="maxent_rew") plt.figure(figsize=(12, 4), num="maxent_rew")
sns.heatmap(np.reshape(maxent_R, (4, 12)), cmap="Spectral", annot=True, cbar=False) sns.heatmap(np.reshape(maxent_R_sim, (4, 12)), cmap="Spectral", annot=True, cbar=False)
plt.savefig(full_path + "maxent_rew.jpg") plt.savefig(full_path + "sim_maxent_R.jpg")
plt.figure(figsize=(12, 4), num="maxent_V") plt.figure(figsize=(12, 4), num="maxent_V")
sns.heatmap(np.reshape(maxent_V, (4, 12)), cmap="Spectral", annot=True, cbar=False) sns.heatmap(np.reshape(maxent_V_sim, (4, 12)), cmap="Spectral", annot=True, cbar=False)
plt.savefig(full_path + "maxent_V.jpg") plt.savefig(full_path + "sim_maxent_V.jpg")
plt.figure(figsize=(12, 4), num="maxent_P") plt.figure(figsize=(12, 4), num="maxent_P")
sns.heatmap(np.reshape(maxent_P, (4, 12)), cmap="Spectral", annot=True, cbar=False) sns.heatmap(np.reshape(maxent_P_sim, (4, 12)), cmap="Spectral", annot=True, cbar=False)
plt.savefig(full_path + "maxent_P.jpg") plt.savefig(full_path + "sim_maxent_P.jpg")
#####################################################################################
#3.WE GOT SOME REAL DATA UPDATE THE BELIEF OF THE BN
log_directory = "/home/pal/carf_ws/src/carf/caregiver_in_the_loop/log/1/0"
if os.path.exists(log_directory):
bn_functions.update_episodes_batch(bn_model_user_action, bn_model_user_react_time,
bn_model_caregiver_assistance,
bn_model_caregiver_feedback, folder_filename=log_directory,
with_caregiver=True)
else:
assert ("You're not using the user information")
question = input("Are you sure you don't want to load user's belief information?")
#Compute entropy between two policies
# policies = [exp_P, maxent_P]
# entropy = get_entropy(policies, state_space_index, action_space_index)
game_performance_per_episode, react_time_per_episode, agent_assistance_per_episode, agent_feedback_per_episode, episodes_list = \ game_performance_per_episode, react_time_per_episode, agent_assistance_per_episode, agent_feedback_per_episode, episodes_list = \
Sim.simulation(bn_model_user_action=bn_model_user_action, var_user_action_target_action=['user_action'], Sim.simulation(bn_model_user_action=bn_model_user_action, var_user_action_target_action=['user_action'],
...@@ -288,17 +264,17 @@ def main(): ...@@ -288,17 +264,17 @@ def main():
var_agent_assistance_target_action=["agent_assistance"], var_agent_assistance_target_action=["agent_assistance"],
bn_model_agent_feedback=bn_model_caregiver_feedback, bn_model_agent_feedback=bn_model_caregiver_feedback,
var_agent_feedback_target_action=["agent_feedback"], var_agent_feedback_target_action=["agent_feedback"],
agent_policy=maxent_P, agent_policy=None,
state_space=states_space_list, action_space=action_space_list, state_space=states_space_list, action_space=action_space_list,
epochs=epochs, task_complexity=5, max_attempt_per_object=4) epochs=epochs, task_complexity=5, max_attempt_per_object=4)
plot_game_performance_path = "AFTER_game_performance_" + "epoch_" + str(epochs) + "_persona_memory_" + str( plot_game_performance_path = "REAL_SIM_game_performance_" + "epoch_" + str(epochs) + "_persona_memory_" + str(
persona_memory) + "_persona_attention_" + str(persona_attention) + "_persona_reactivity_" + str( persona_memory) + "persona_attention_" + str(persona_attention) + "_persona_reactivity_" + str(
persona_reactivity) + ".jpg" persona_reactivity) + ".jpg"
plot_agent_assistance_path = "AFTER_agent_assistance_" + "epoch_" + str(epochs) + "_persona_memory_" + str( plot_agent_assistance_path = "REAL_SIM_agent_assistance_" + "epoch_" + str(epochs) + "_persona_memory_" + str(
persona_memory) + "_persona_attention_" + str(persona_attention) + "_persona_reactivity_" + str( persona_memory) + "_persona_attention_" + str(persona_attention) + "_persona_reactivity_" + str(
persona_reactivity) + ".jpg" persona_reactivity) + ".jpg"
plot_agent_feedback_path = "AFTER_agent_feedback_" + "epoch_" + str(epochs) + "_persona_memory_" + str( plot_agent_feedback_path = "REAL_SIM_agent_feedback_" + "epoch_" + str(epochs) + "_persona_memory_" + str(
persona_memory) + "_persona_attention_" + str(persona_attention) + "_persona_reactivity_" + str( persona_memory) + "_persona_attention_" + str(persona_attention) + "_persona_reactivity_" + str(
persona_reactivity) + ".jpg" persona_reactivity) + ".jpg"
...@@ -306,6 +282,25 @@ def main(): ...@@ -306,6 +282,25 @@ def main():
utils.plot2D_assistance(full_path + plot_agent_assistance_path, epochs, scaling_factor, agent_assistance_per_episode) utils.plot2D_assistance(full_path + plot_agent_assistance_path, epochs, scaling_factor, agent_assistance_per_episode)
utils.plot2D_feedback(full_path + plot_agent_feedback_path, epochs, scaling_factor, agent_feedback_per_episode) utils.plot2D_feedback(full_path + plot_agent_feedback_path, epochs, scaling_factor, agent_feedback_per_episode)
# R(s) and pi(s) generated from the first sim
maxent_R_real_sim = maxent(world, terminals, episodes_list)
maxent_V_real_sim, maxent_P_real_sim = vi.value_iteration(world.p_transition, maxent_R_real_sim, gamma=0.9, error=1e-3,
deterministic=True)
plt.figure(figsize=(12, 4), num="maxent_rew")
sns.heatmap(np.reshape(maxent_R_real_sim, (4, 12)), cmap="Spectral", annot=True, cbar=False)
plt.savefig(full_path + "real_sim_maxent_R.jpg")
plt.figure(figsize=(12, 4), num="maxent_V")
sns.heatmap(np.reshape(maxent_V_real_sim, (4, 12)), cmap="Spectral", annot=True, cbar=False)
plt.savefig(full_path + "real_sim_maxent_V.jpg")
plt.figure(figsize=(12, 4), num="maxent_P")
sns.heatmap(np.reshape(maxent_P_real_sim, (4, 12)), cmap="Spectral", annot=True, cbar=False)
plt.savefig(full_path + "real_sim_maxent_P.jpg")
# Compute entropy between two policies
policies = [maxent_P_sim, maxent_P_real_sim]
s_preferences, s_constraints = get_entropy(policies, state_space_index, action_space_index)
print("Preferences:", s_preferences, " Constraints:", s_constraints)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment