diff --git a/bn_functions.py b/bn_functions.py index 887def745abd8ad5c6eb3d7fcc9b6db90d87bf41..5137463051def035e41458d3f5f7245c90dfa00c 100644 --- a/bn_functions.py +++ b/bn_functions.py @@ -60,7 +60,7 @@ def average_prob(ref_cpds_table, current_cpds_table): res_cpds_table[elem1][elem2] = (ref_cpds_table[elem1][elem2]+current_cpds_table[elem1][elem2])/2 return res_cpds_table -def update_cpds_tables(bn_model,variables_tables): +def update_cpds_tables(bn_model, variables_tables): ''' This function updates the bn model with the variables_tables provided in input Args: @@ -144,6 +144,7 @@ def get_stochastic_action(actions_distr_prob): index = i return index + actions_distr_prob_scaled = [0]*len(actions_distr_prob) accum = 0 for i in range(len(actions_distr_prob)): @@ -155,27 +156,60 @@ def get_stochastic_action(actions_distr_prob): return action_id -def interpret_user_output(action_id): +def flat_action_probs(action_probs): + flat_array_user_action_prob = None + column = 0 + if len(action_probs.values.shape)==2: + flat_array_user_action_prob = [action_probs.values[j][i] for j in range(action_probs.values.shape[0]) for i in range(action_probs.values.shape[1])] + column = action_probs.values.shape[0] + row = action_probs.values.shape[1] + else: + assert "Did you forget to add the additional target, only one has been detected" + flat_array_user_action_prob = action_probs.values + column = 1 + row = 0 + return flat_array_user_action_prob, column, row + + + +def interpret_action_output(action_id, col, row, targets): + ''' + Given the id of the action selected from the probabilistic inference model and the target variables + return the action (user act + react time) or (robot ass and robot feedback) + Args + action_id 1d array of probs + col: #col of array action id + row: #row of array action id + targets the targets we aim to evaluate + Return: + user_action + user_react_time + ''' + #N.B it assumes that the query is performed passing as first argument robot_assistance + # and as a second robot_feedback + robot_assistance = 0 + robot_feedback = 0 user_action = 0 user_react_time = 0 - if action_id == 0: - user_action = 0; user_react_time = 0 - elif action_id == 1: - user_action = 1; user_react_time = 0 - elif action_id == 2: - user_action = 2; user_react_time = 0 - elif action_id == 3: - user_action = 0; user_react_time = 1 - elif action_id == 4: - user_action = 1; user_react_time = 1 - elif action_id == 5: - user_action = 2; user_react_time = 1 - elif action_id == 6: - user_action = 0; user_react_time = 2 - elif action_id == 7: - user_action = 1; user_react_time = 2 - elif action_id == 8: - user_action = 2; user_react_time = 2 - - return user_action, user_react_time \ No newline at end of file + if targets[0] == 'user_action': + user_action = int(action_id / row) + user_react_time = int(action_id % row) + print("user_action ", user_action, ' user_react ', user_react_time) + return user_action, user_react_time + elif targets[1] == 'user_action': + user_action = int(action_id % row) + user_react_time = int(action_id / row) + print("user_action ", user_action, ' user_react ', user_react_time) + return user_action, user_react_time + elif targets[0] == "robot_assistance": + robot_assistance = int(action_id / row) + robot_feedback = int(action_id % row) + print("robot_ass ", robot_assistance, 'robot_feed ', robot_feedback) + return robot_assistance, robot_feedback + elif targets[1] == "robot_assistance": + robot_assistance = int(action_id % row) + robot_feedback = int(action_id / row) + print("robot_feed ", robot_assistance, 'robot_ass ', robot_feedback) + return robot_assistance, robot_feedback +