diff --git a/main.py b/main.py index dcbf08465c08966617aa326ac77981471f76f392..ab68aac2423f96d884620d98db1b10244fca5bc7 100644 --- a/main.py +++ b/main.py @@ -15,7 +15,7 @@ def compute_next_state(user_action, task_progress_counter, attempt_counter, corr ''' The function computes given the current state and action of the user, the next state Args: - user_action: -1 wrong, 0 timeout, 1 correct + user_action: -1a wrong, 0 timeout, 1a correct game_state_counter: beg, mid, end correct_move_counter: attempt_counter: @@ -182,25 +182,26 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use query_agent_assistance_prob = bn_functions.infer_prob_from_state(bn_model_agent_assistance, infer_variable=var_agent_assistance_target_action, evidence_variables=vars_agent_evidence) - - query_agent_feedback_prob = bn_functions.infer_prob_from_state(bn_model_agent_feedback, + if bn_model_agent_feedback != None: + query_agent_feedback_prob = bn_functions.infer_prob_from_state(bn_model_agent_feedback, infer_variable=var_agent_feedback_target_action, evidence_variables=vars_agent_evidence) - + selected_agent_feedback_action = bn_functions.get_stochastic_action(query_agent_feedback_prob.values) + else: + selected_agent_feedback_action = 0 selected_agent_assistance_action = bn_functions.get_stochastic_action(query_agent_assistance_prob.values) - selected_agent_feedback_action = bn_functions.get_stochastic_action(query_agent_feedback_prob.values) + n_feedback_per_episode[e][selected_agent_feedback_action] += 1 #counters for plots n_assistance_lev_per_episode[e][selected_agent_assistance_action] += 1 - n_feedback_per_episode[e][selected_agent_feedback_action] += 1 current_agent_action = (selected_agent_assistance_action, selected_agent_feedback_action) print("agent_assistance {}, attempt {}, game {}, agent_feedback {}".format(selected_agent_assistance_action, attempt_counter, game_state_counter, selected_agent_feedback_action)) ##########################QUERY FOR THE USER ACTION AND REACT TIME##################################### - #compare the real user with the estimated Persona and returns a user action (0, 1, 2) + #compare the real user with the estimated Persona and returns a user action (0, 1a, 2) if bn_model_other_user_action!=None and bn_model_user_react_time!=None: #return the user action in this state based on the user profile vars_other_user_evidence = {other_user_attention_name:other_user_attention_value, @@ -289,11 +290,6 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use print("game_state_counter {}, iter_counter {}, correct_counter {}, wrong_counter {}, " "timeout_counter {}, max_attempt {}".format(game_state_counter, iter_counter, correct_move_counter, wrong_move_counter, timeout_counter, max_attempt_counter)) - # print("agent_assistance_per_action {}".format(agent_assistance_per_action)) - # print("attempt_counter_per_action {}".format(attempt_counter_per_action)) - # print("game_state_counter_per_action {}".format(game_state_counter_per_action)) - # print("agent_feedback_per_action {}".format(agent_feedback_per_action)) - # print("iter {}, correct {}, wrong {}, timeout {}".format(iter_counter, correct_move_counter, wrong_move_counter, timeout_counter)) #save episode episodes.append(episode) @@ -373,7 +369,7 @@ other_user_memory = 2; other_user_attention = 2; other_user_reactivity = 2; #define state space struct for the irl algorithm attempt = [i for i in range(1, Attempt.counter.value+1)] -#+1 (3,_,_) absorbing state +#+1a (3,_,_) absorbing state game_state = [i for i in range(0, Game_State.counter.value+1)] user_action = [i for i in range(-1, User_Action.counter.value-1)] state_space = (game_state, attempt, user_action) @@ -384,12 +380,13 @@ action_space = (agent_assistance_action, agent_feedback_action) action_space_list = list(itertools.product(*action_space)) ##############BEFORE RUNNING THE SIMULATION UPDATE THE BELIEF IF YOU HAVE DATA#################### -bn_belief_user_action_file = "/home/pal/carf_ws/src/carf/caregiver_in_the_loop/log/0/bn_belief_user_action.pkl" -bn_belief_user_react_time_file = "/home/pal/carf_ws/src/carf/caregiver_in_the_loop/log/0/bn_belief_user_react_time.pkl" -bn_belief_caregiver_assistance_file = "/home/pal/carf_ws/src/carf/caregiver_in_the_loop/log/0/bn_belief_caregiver_assistive_action.pkl" -bn_belief_caregiver_feedback_file = "/home/pal/carf_ws/src/carf/caregiver_in_the_loop/log/0/bn_belief_caregiver_feedback_action.pkl" -if bn_belief_user_action_file != None and bn_belief_user_react_time_file!= None and \ - bn_belief_caregiver_assistance_file!=None and bn_belief_caregiver_feedback_file!=None: +log_directory = "" +if os.path.exists(log_directory): + bn_belief_user_action_file = log_directory+"/bn_belief_user_action.pkl" + bn_belief_user_react_time_file = log_directory+"/bn_belief_user_react_time.pkl" + bn_belief_caregiver_assistance_file = log_directory+"/bn_belief_caregiver_assistive_action.pkl" + bn_belief_caregiver_feedback_file = log_directory+"/bn_belief_caregiver_feedback_action.pkl" + bn_belief_user_action = utils.read_user_statistics_from_pickle(bn_belief_user_action_file) bn_belief_user_react_time = utils.read_user_statistics_from_pickle(bn_belief_user_react_time_file) bn_belief_caregiver_assistance = utils.read_user_statistics_from_pickle(bn_belief_caregiver_assistance_file) @@ -399,32 +396,11 @@ if bn_belief_user_action_file != None and bn_belief_user_react_time_file!= None bn_model_caregiver_assistance = bn_functions.update_cpds_tables(bn_model=bn_model_caregiver_assistance, variables_tables=bn_belief_caregiver_assistance) bn_model_caregiver_feedback = bn_functions.update_cpds_tables(bn_model=bn_model_caregiver_feedback, variables_tables=bn_belief_caregiver_feedback) - game_performance_per_episode, react_time_per_episode, agent_assistance_per_episode, agent_feedback_per_episode, generated_episodes = \ - simulation(bn_model_user_action=bn_model_user_action, var_user_action_target_action=['user_action'], - bn_model_user_react_time=bn_model_user_react_time, - var_user_react_time_target_action=['user_react_time'], - user_memory_name="memory", user_memory_value=persona_memory, - user_attention_name="attention", user_attention_value=persona_attention, - user_reactivity_name="reactivity", user_reactivity_value=persona_reactivity, - task_progress_name="game_state", game_attempt_name="attempt", - agent_assistance_name="agent_assistance", agent_feedback_name="agent_feedback", - bn_model_agent_assistance=bn_model_caregiver_assistance, - var_agent_assistance_target_action=["agent_assistance"], - bn_model_agent_feedback=bn_model_caregiver_feedback, var_agent_feedback_target_action=["agent_feedback"], - bn_model_other_user_action=bn_model_other_user_action, - var_other_user_action_target_action=['user_action'], - bn_model_other_user_react_time=bn_model_other_user_react_time, - var_other_user_target_react_time_action=["user_react_time"], other_user_memory_name="memory", - other_user_memory_value=other_user_memory, other_user_attention_name="attention", - other_user_attention_value=other_user_attention, other_user_reactivity_name="reactivity", - other_user_reactivity_value=other_user_reactivity, - state_space=states_space_list, action_space=action_space_list, - epochs=epochs, task_complexity=5, max_attempt_per_object=4) - else: assert("You're not using the user information") question = input("Are you sure you don't want to load user's belief information?") - game_performance_per_episode, react_time_per_episode, agent_assistance_per_episode, agent_feedback_per_episode, generated_episodes = \ + +game_performance_per_episode, react_time_per_episode, agent_assistance_per_episode, agent_feedback_per_episode, generated_episodes = \ simulation(bn_model_user_action=bn_model_user_action, var_user_action_target_action=['user_action'], bn_model_user_react_time=bn_model_user_react_time, var_user_react_time_target_action=['user_react_time'], @@ -435,8 +411,7 @@ else: agent_assistance_name="agent_assistance", agent_feedback_name="agent_feedback", bn_model_agent_assistance=bn_model_caregiver_assistance, var_agent_assistance_target_action=["agent_assistance"], - bn_model_agent_feedback=bn_model_caregiver_feedback, - var_agent_feedback_target_action=["agent_feedback"], + bn_model_agent_feedback=bn_model_caregiver_feedback, var_agent_feedback_target_action=["agent_feedback"], bn_model_other_user_action=bn_model_other_user_action, var_other_user_action_target_action=['user_action'], bn_model_other_user_react_time=bn_model_other_user_react_time,