diff --git a/main.py b/main.py index 0eb7e52ceb83facdcb1940f20c28ba01b16eff85..dcbf08465c08966617aa326ac77981471f76f392 100644 --- a/main.py +++ b/main.py @@ -6,7 +6,7 @@ import numpy as np from bn_variables import Agent_Assistance, Agent_Feedback, User_Action, User_React_time, Game_State, Attempt import bn_functions import utils -import episode as ep +from episode import Episode def compute_next_state(user_action, task_progress_counter, attempt_counter, correct_move_counter, @@ -32,38 +32,43 @@ def compute_next_state(user_action, task_progress_counter, attempt_counter, corr max_attempt_counter ''' + if task_progress_counter >= 0 and task_progress_counter < 2: + game_state_counter = 0 + elif task_progress_counter >= 2 and task_progress_counter < 4: + game_state_counter = 1 + elif task_progress_counter >= 4 and task_progress_counter < 5: + game_state_counter = 2 + else: + game_state_counter = 3 # if then else are necessary to classify the task game state into beg, mid, end - if user_action == 1: - attempt_counter = 1 - correct_move_counter += 1 - task_progress_counter += 1 + if user_action == 1 and game_state_counter<3: + attempt_counter = 1 + correct_move_counter += 1 + task_progress_counter += 1 # if the user made a wrong move and still did not reach the maximum number of attempts - elif user_action == -1 and attempt_counter < max_attempt_per_object: - attempt_counter += 1 - wrong_move_counter += 1 + elif user_action == -1 and attempt_counter < max_attempt_per_object and game_state_counter<3: + attempt_counter += 1 + wrong_move_counter += 1 # if the user did not move any token and still did not reach the maximum number of attempts - elif user_action == 0 and attempt_counter < max_attempt_per_object: - attempt_counter += 1 - timeout_counter += 1 + elif user_action == 0 and attempt_counter < max_attempt_per_object and game_state_counter<3: + attempt_counter += 1 + timeout_counter += 1 # the agent or therapist makes the correct move on the patient's behalf - else: - attempt_counter = 1 - max_attempt_counter += 1 - task_progress_counter +=1 + elif attempt_counter>=max_attempt_per_object and game_state_counter<3: + attempt_counter = 1 + max_attempt_counter += 1 + task_progress_counter +=1 + if game_state_counter==3: + attempt_counter = 1 + task_progress_counter +=1 + print("Reach the end of the episode") # TODO call the function to compute the state of the game (beg, mid, end) - if correct_move_counter >= 0 and correct_move_counter <= 2: - game_state_counter = 0 - elif correct_move_counter > 2 and correct_move_counter <= 4: - game_state_counter = 1 - elif correct_move_counter>4 and correct_move_counter<=5: - game_state_counter = 2 - else: - game_state_counter = 3 + next_state = (game_state_counter, attempt_counter, user_action) @@ -128,7 +133,7 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use #data structure to memorise a sequence of episode episodes = [] - + ep = Episode() for e in range(epochs): '''Simulation framework''' @@ -275,9 +280,9 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use max_attempt_per_object) # store the (state, action, next_state) - episode.append((ep.point_to_index(current_state, state_space), - ep.point_to_index(current_agent_action, action_space), - ep.point_to_index(next_state, state_space))) + episode.append((ep.state_from_point_to_index(state_space, current_state,), + ep.state_from_point_to_index(action_space, current_agent_action), + ep.state_from_point_to_index(state_space, next_state))) print("current_state ", current_state, " next_state ", next_state) ####################################END of EPISODE####################################### @@ -351,7 +356,7 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use ############################################################################# #SIMULATION PARAMS -epochs = 10 +epochs = 100 #initialise the agent bn_model_caregiver_assistance = bnlearn.import_DAG('bn_agent_model/agent_assistive_model.bif') @@ -418,7 +423,29 @@ if bn_belief_user_action_file != None and bn_belief_user_react_time_file!= None else: assert("You're not using the user information") - question = raw_input("Are you sure you don't want to load user's belief information?") + question = input("Are you sure you don't want to load user's belief information?") + game_performance_per_episode, react_time_per_episode, agent_assistance_per_episode, agent_feedback_per_episode, generated_episodes = \ + simulation(bn_model_user_action=bn_model_user_action, var_user_action_target_action=['user_action'], + bn_model_user_react_time=bn_model_user_react_time, + var_user_react_time_target_action=['user_react_time'], + user_memory_name="memory", user_memory_value=persona_memory, + user_attention_name="attention", user_attention_value=persona_attention, + user_reactivity_name="reactivity", user_reactivity_value=persona_reactivity, + task_progress_name="game_state", game_attempt_name="attempt", + agent_assistance_name="agent_assistance", agent_feedback_name="agent_feedback", + bn_model_agent_assistance=bn_model_caregiver_assistance, + var_agent_assistance_target_action=["agent_assistance"], + bn_model_agent_feedback=bn_model_caregiver_feedback, + var_agent_feedback_target_action=["agent_feedback"], + bn_model_other_user_action=bn_model_other_user_action, + var_other_user_action_target_action=['user_action'], + bn_model_other_user_react_time=bn_model_other_user_react_time, + var_other_user_target_react_time_action=["user_react_time"], other_user_memory_name="memory", + other_user_memory_value=other_user_memory, other_user_attention_name="attention", + other_user_attention_value=other_user_attention, other_user_reactivity_name="reactivity", + other_user_reactivity_value=other_user_reactivity, + state_space=states_space_list, action_space=action_space_list, + epochs=epochs, task_complexity=5, max_attempt_per_object=4)