diff --git a/bn_functions.py b/bn_functions.py index 77005f908381d0958aa1e2b8283c80e0f33a8091..a8502a0c8b41845fa424db5ca542c28660e88cea 100644 --- a/bn_functions.py +++ b/bn_functions.py @@ -44,12 +44,35 @@ def compute_prob(cpds_table): This function checks if any ''' - for val in range(len(cpds_table)): - cpds_table[val] = list(map(lambda x: x / (sum(cpds_table[val])+0.00001), cpds_table[val])) - cpds_table[val] = check_zero_occurrences(cpds_table[val]) + cpds_table_array = np.array(cpds_table) + cpds_table_array_len = cpds_table_array.shape.__len__() + + if cpds_table_array_len == 4: + + # attempt + for elem1 in range(cpds_table_array.shape[0]): + # game_state + for elem2 in range(cpds_table_array.shape[1]): + #assistance + for elem3 in range(cpds_table_array.shape[2]): + cpds_table[elem1][elem2][elem3] = list( + map(lambda x: x / (sum(cpds_table[elem1][elem2][elem3]) + 0.00001), cpds_table[elem1][elem2][elem3])) + cpds_table[elem1][elem2][elem3] = check_zero_occurrences(cpds_table[elem1][elem2][elem3]) + + elif cpds_table_array_len ==3: + #attempt + for elem1 in range(cpds_table_array.shape[0]): + #game_state + for elem2 in range(cpds_table_array.shape[1]): + cpds_table[elem1][elem2] = list(map(lambda x: x / (sum(cpds_table[elem1][elem2]) + 0.00001), cpds_table[elem1][elem2])) + cpds_table[elem1][elem2] = check_zero_occurrences(cpds_table[elem1][elem2]) + else: + for val in range(len(cpds_table)): + cpds_table[val] = list(map(lambda x: x / (sum(cpds_table[val])+0.00001), cpds_table[val])) + cpds_table[val] = check_zero_occurrences(cpds_table[val]) return cpds_table -def average_prob(ref_cpds_table, current_cpds_table): +def average_prob(ref_cpds_table, current_cpds_table, alpha): ''' Args: ref_cpds_table: table from bnlearn @@ -58,12 +81,13 @@ def average_prob(ref_cpds_table, current_cpds_table): avg from both tables ''' res_cpds_table = ref_cpds_table.copy() + current_cpds_table_np_array = np.array(current_cpds_table) for elem1 in range(len(ref_cpds_table)): for elem2 in range(len(ref_cpds_table[0])): - res_cpds_table[elem1][elem2] = (ref_cpds_table[elem1][elem2]+current_cpds_table[elem1][elem2])/2 + res_cpds_table[elem1][elem2] = (ref_cpds_table[elem1][elem2]*(1-alpha))+(current_cpds_table_np_array[elem1][elem2]*alpha) return res_cpds_table -def update_cpds_tables(bn_model, variables_tables): +def update_cpds_tables(bn_model, variables_tables, alpha=0.1): ''' This function updates the bn model with the variables_tables provided in input Args: @@ -80,7 +104,7 @@ def update_cpds_tables(bn_model, variables_tables): cpds_table_from_counter = compute_prob(val) updated_prob = average_prob( np.transpose(cpds_table), - cpds_table_from_counter) + cpds_table_from_counter, alpha) bn_model['model'].cpds[index].values = np.transpose(updated_prob) return bn_model diff --git a/bn_persona_model/persona_model_test.bif b/bn_persona_model/persona_model_test.bif index bbdacc7199fbedeecba0eefb9b3a0ac88720b450..28727a89c3dae174cb90be06ea3cc7b102dd5160 100644 --- a/bn_persona_model/persona_model_test.bif +++ b/bn_persona_model/persona_model_test.bif @@ -3,7 +3,7 @@ network persona_model { %VARIABLES DEFINITION -variable robot_assistance { +variable agent_assistance { type discrete [ 6 ] { lev_0, lev_1, lev_2, lev_3, lev_4, lev_5 }; } variable attempt_t0 { @@ -13,18 +13,20 @@ variable game_state_t0 { type discrete [ 3 ] { beg, mid, end }; } variable attempt_t1 { - type discrete [ 4 ] { att_1, att_2, att_3, att_4 }; + type discrete [ 1 ] { att_1, att_2, att_3, att_4}; } variable game_state_t1 { type discrete [ 3 ] { beg, mid, end }; } -%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% variable user_action { type discrete [ 3 ] { correct, wrong, timeout }; } +variable agent_feedback { + type discrete [ 2 ] { no, yes }; +} %INDIVIDUAL PROBABILITIES DEFINITION -probability ( robot_assistance ) { +probability ( agent_assistance ) { table 0.17, 0.16, 0.16, 0.17, 0.17, 0.17; } probability ( game_state_t0 ) { @@ -39,124 +41,104 @@ probability ( game_state_t1 ) { probability ( attempt_t1 ) { table 0.25, 0.25, 0.25, 0.25; } - probability ( user_action ) { table 0.33, 0.33, 0.34; } - -probability (robot_assistance | game_state_t0, attempt_t0){ - (beg, att_1) 0.0, 0.2, 0.2, 0.2, 0.2, 0.2; - (beg, att_2) 0.05, 0.15, 0.2, 0.2, 0.2, 0.2; - (beg, att_3) 0.1, 0.2, 0.2, 0.2, 0.2, 0.1; - (beg, att_4) 0.15, 0.05, 0.2, 0.2, 0.2, 0.2; - (mid, att_1) 0.0, 0.2, 0.2, 0.2, 0.2, 0.2; - (mid, att_2) 0.05, 0.15, 0.2, 0.2, 0.2, 0.2; - (mid, att_3) 0.1, 0.2, 0.2, 0.2, 0.2, 0.1; - (mid, att_4) 0.15, 0.05, 0.2, 0.2, 0.2, 0.2; - (end, att_1) 0.0, 0.2, 0.2, 0.2, 0.2, 0.2; - (end, att_2) 0.05, 0.15, 0.2, 0.2, 0.2, 0.2; - (end, att_3) 0.1, 0.2, 0.2, 0.2, 0.2, 0.1; - (end, att_4) 0.15, 0.05, 0.2, 0.2, 0.2, 0.2; +probability ( agent_feedback ) { + table 0.5, 0.5; +} +probability(agent_assistance | agent_feedback) { + (yes) 0.4, 0.3, 0.2, 0.1, 0.0, 0.0 + (no) 0.0, 0.0, 0.1, 0.2, 0.3, 0.4 } - - -probability (user_action | game_state_t0, attempt_t0, robot_assistance){ +probability (user_action | game_state_t0, attempt_t0, agent_assistance){ (beg, att_1, lev_0) 0.1, 0.9, 0.0; (beg, att_2, lev_0) 0.2, 0.8, 0.0; (beg, att_3, lev_0) 0.3, 0.7, 0.0; (beg, att_4, lev_0) 0.4, 0.6, 0.0; - (beg, att_1, lev_1) 0.1, 0.9, 0.0; - (beg, att_2, lev_1) 0.2, 0.8, 0.0; - (beg, att_3, lev_1) 0.3, 0.7, 0.0; - (beg, att_4, lev_1) 0.4, 0.6, 0.0; - (beg, att_1, lev_2) 0.1, 0.9, 0.0; - (beg, att_2, lev_2) 0.2, 0.8, 0.0; - (beg, att_3, lev_2) 0.3, 0.7, 0.0; - (beg, att_4, lev_2) 0.4, 0.6, 0.0; - (beg, att_1, lev_3) 0.1, 0.9, 0.0; - (beg, att_2, lev_3) 0.2, 0.8, 0.0; - (beg, att_3, lev_3) 0.3, 0.7, 0.0; - (beg, att_4, lev_3) 0.4, 0.6, 0.0; - (beg, att_1, lev_4) 0.1, 0.9, 0.0; - (beg, att_2, lev_4) 0.2, 0.8, 0.0; - (beg, att_3, lev_4) 0.3, 0.7, 0.0; - (beg, att_4, lev_4) 0.4, 0.6, 0.0; - (beg, att_1, lev_5) 0.1, 0.9, 0.0; - (beg, att_2, lev_5) 0.2, 0.8, 0.0; - (beg, att_3, lev_5) 0.3, 0.7, 0.0; - (beg, att_4, lev_5) 0.4, 0.6, 0.0; + (beg, att_1, lev_1) 0.2, 0.8, 0.0; + (beg, att_2, lev_1) 0.3, 0.7, 0.0; + (beg, att_3, lev_1) 0.4, 0.6, 0.0; + (beg, att_4, lev_1) 0.5, 0.5, 0.0; + (beg, att_1, lev_2) 0.3, 0.7, 0.0; + (beg, att_2, lev_2) 0.4, 0.6, 0.0; + (beg, att_3, lev_2) 0.5, 0.5, 0.0; + (beg, att_4, lev_2) 0.6, 0.4, 0.0; + (beg, att_1, lev_3) 0.4, 0.6, 0.0; + (beg, att_2, lev_3) 0.5, 0.5, 0.0; + (beg, att_3, lev_3) 0.6, 0.4, 0.0; + (beg, att_4, lev_3) 0.7, 0.3, 0.0; + (beg, att_1, lev_4) 1.0, 0.0, 0.0; + (beg, att_2, lev_4) 1.0, 0.0, 0.0; + (beg, att_3, lev_4) 1.0, 0.0, 0.0; + (beg, att_4, lev_4) 1.0, 0.0, 0.0; + (beg, att_1, lev_5) 1.0, 0.0, 0.0; + (beg, att_2, lev_5) 1.0, 0.0, 0.0; + (beg, att_3, lev_5) 1.0, 0.0, 0.0; + (beg, att_4, lev_5) 1.0, 0.0, 0.0; (mid, att_1, lev_0) 0.1, 0.9, 0.0; (mid, att_2, lev_0) 0.2, 0.8, 0.0; (mid, att_3, lev_0) 0.3, 0.7, 0.0; (mid, att_4, lev_0) 0.4, 0.6, 0.0; - (mid, att_1, lev_1) 0.1, 0.9, 0.0; - (mid, att_2, lev_1) 0.2, 0.8, 0.0; - (mid, att_3, lev_1) 0.3, 0.7, 0.0; - (mid, att_4, lev_1) 0.4, 0.6, 0.0; - (mid, att_1, lev_2) 0.1, 0.9, 0.0; - (mid, att_2, lev_2) 0.2, 0.8, 0.0; - (mid, att_3, lev_2) 0.3, 0.7, 0.0; - (mid, att_4, lev_2) 0.4, 0.6, 0.0; - (mid, att_1, lev_3) 0.1, 0.9, 0.0; - (mid, att_2, lev_3) 0.2, 0.8, 0.0; - (mid, att_3, lev_3) 0.3, 0.7, 0.0; - (mid, att_4, lev_3) 0.4, 0.6, 0.0; - (mid, att_1, lev_4) 0.1, 0.9, 0.0; - (mid, att_2, lev_4) 0.2, 0.8, 0.0; - (mid, att_3, lev_4) 0.3, 0.7, 0.0; - (mid, att_4, lev_4) 0.4, 0.6, 0.0; - (mid, att_1, lev_5) 0.1, 0.9, 0.0; - (mid, att_2, lev_5) 0.2, 0.8, 0.0; - (mid, att_3, lev_5) 0.3, 0.7, 0.0; - (mid, att_4, lev_5) 0.4, 0.6, 0.0; + (mid, att_1, lev_1) 0.2, 0.8, 0.0; + (mid, att_2, lev_1) 0.3, 0.7, 0.0; + (mid, att_3, lev_1) 0.4, 0.6, 0.0; + (mid, att_4, lev_1) 0.5, 0.5, 0.0; + (mid, att_1, lev_2) 0.3, 0.7, 0.0; + (mid, att_2, lev_2) 0.4, 0.6, 0.0; + (mid, att_3, lev_2) 0.5, 0.5, 0.0; + (mid, att_4, lev_2) 0.6, 0.4, 0.0; + (mid, att_1, lev_3) 0.4, 0.6, 0.0; + (mid, att_2, lev_3) 0.5, 0.5, 0.0; + (mid, att_3, lev_3) 0.6, 0.4, 0.0; + (mid, att_4, lev_3) 0.7, 0.3, 0.0; + (mid, att_1, lev_4) 1.0, 0.0, 0.0; + (mid, att_2, lev_4) 1.0, 0.0, 0.0; + (mid, att_3, lev_4) 1.0, 0.0, 0.0; + (mid, att_4, lev_4) 1.0, 0.0, 0.0; + (mid, att_1, lev_5) 1.0, 0.0, 0.0; + (mid, att_2, lev_5) 1.0, 0.0, 0.0; + (mid, att_3, lev_5) 1.0, 0.0, 0.0; + (mid, att_4, lev_5) 1.0, 0.0, 0.0; (end, att_1, lev_0) 0.1, 0.9, 0.0; (end, att_2, lev_0) 0.2, 0.8, 0.0; - (end, att_3, lev_0) 0.2, 0.8, 0.0; + (end, att_3, lev_0) 0.3, 0.7, 0.0; (end, att_4, lev_0) 0.4, 0.6, 0.0; - (end, att_1, lev_1) 0.1, 0.9, 0.0; - (end, att_2, lev_1) 0.2, 0.8, 0.0; + (end, att_1, lev_1) 0.2, 0.8, 0.0; + (end, att_2, lev_1) 0.3, 0.7, 0.0; (end, att_3, lev_1) 0.4, 0.6, 0.0; - (end, att_4, lev_1) 0.4, 0.6, 0.0; - (end, att_1, lev_2) 0.1, 0.9, 0.0; - (end, att_2, lev_2) 0.2, 0.8, 0.0; - (end, att_3, lev_2) 0.4, 0.6, 0.0; - (end, att_4, lev_2) 0.4, 0.6, 0.0; - (end, att_1, lev_3) 0.1, 0.9, 0.0; - (end, att_2, lev_3) 0.2, 0.8, 0.0; - (end, att_3, lev_3) 0.5, 0.5, 0.0; - (end, att_4, lev_3) 0.4, 0.6, 0.0; - (end, att_1, lev_4) 0.1, 0.9, 0.0; - (end, att_2, lev_4) 0.2, 0.8, 0.0; - (end, att_3, lev_4) 0.7, 0.3, 0.0; - (end, att_4, lev_4) 0.4, 0.6, 0.0; - (end, att_1, lev_5) 0.1, 0.9, 0.0; - (end, att_2, lev_5) 0.2, 0.8, 0.0; - (end, att_3, lev_5) 0.3, 0.7, 0.0; - (end, att_4, lev_5) 0.4, 0.6, 0.0; - + (end, att_4, lev_1) 0.5, 0.5, 0.0; + (end, att_1, lev_2) 0.3, 0.7, 0.0; + (end, att_2, lev_2) 0.4, 0.6, 0.0; + (end, att_3, lev_2) 0.5, 0.5, 0.0; + (end, att_4, lev_2) 0.6, 0.4, 0.0; + (end, att_1, lev_3) 0.4, 0.6, 0.0; + (end, att_2, lev_3) 0.5, 0.5, 0.0; + (end, att_3, lev_3) 0.6, 0.4, 0.0; + (end, att_4, lev_3) 0.7, 0.3, 0.0; + (end, att_1, lev_4) 1.0, 0.0, 0.0; + (end, att_2, lev_4) 1.0, 0.0, 0.0; + (end, att_3, lev_4) 1.0, 0.0, 0.0; + (end, att_4, lev_4) 1.0, 0.0, 0.0; + (end, att_1, lev_5) 1.0, 0.0, 0.0; + (end, att_2, lev_5) 1.0, 0.0, 0.0; + (end, att_3, lev_5) 1.0, 0.0, 0.0; + (end, att_4, lev_5) 1.0, 0.0, 0.0; } probability (game_state_t1 | user_action) { - (correct) 0.2, 0.3, 0.5; - (wrong) 0.5, 0.3, 0.2; - (timeout) 0.33, 0.34, 0.33; + (correct) 0.25, 0.3, 0.45; + (wrong) 0.33, 0.33, 0.33; + (timeout) 0.33, 0.33, 0.33; } probability (attempt_t1 | user_action) { - (correct) 0.1, 0.2, 0.3, 0.4; - (wrong) 0.4, 0.3, 0.2, 0.1; + (correct) 0.1, 0.2, 0.25, 0.45; + (wrong) 0.25, 0.25, 0.25, 0.25; (timeout) 0.25, 0.25, 0.25, 0.25; } -probability (user_action | robot_assistance){ - (lev_0) 0.1, 0.6, 0.3; - (lev_1) 0.2, 0.5, 0.3; - (lev_2) 0.3, 0.5, 0.2; - (lev_3) 0.5, 0.3, 0.2; - (lev_4) 0.9, 0.1, 0.0; - (lev_5) 0.9, 0.1, 0.0; -} \ No newline at end of file diff --git a/simulation.py b/simulation.py index 37036d3cf2537b8cb526ff7d6d7a4181c5b7dd0f..33aa674f8854ee9d789b9f277d81763909380129 100644 --- a/simulation.py +++ b/simulation.py @@ -2,12 +2,102 @@ import itertools import os import bnlearn import numpy as np +import random #import classes and modules from bn_variables import Agent_Assistance, Agent_Feedback, User_Action, User_React_time, Game_State, Attempt import bn_functions import utils from episode import Episode +# +# def choose_next_states(task_progress, game_state_t0, n_attempt_t0, max_attempt_per_object, +# selected_agent_assistance_action, +# bn_model_user_action, var_user_action_target_action): +# +# def get_next_state(task_progress, game_state_t0, n_attempt_t0, max_attempt_per_object): +# +# next_state = [] +# +# #correct move on the last state of the bin +# if (task_progress == 1 or task_progress == 3 or task_progress == 4) and n_attempt_t0<max_attempt_per_object: +# next_state.append((game_state_t0+1, n_attempt_t0+1)) +# #correct state bu still in the bin +# elif task_progress == 0 or task_progress == 2 and n_attempt_t0<max_attempt_per_object: +# next_state.append((game_state_t0, n_attempt_t0+1)) +# elif (task_progress == 1 or task_progress == 3 or task_progress == 4) and n_attempt_t0>=max_attempt_per_object: +# assert "you reach the maximum number of attempt the agent will move it for you" +# elif task_progress == 0 or task_progress == 2 and n_attempt_t0>=max_attempt_per_object: +# assert "you reach the maximum number of attempt the agent will move it for you" +# +# return next_state +# +# next_state = get_next_state(task_progress, game_state_t0, n_attempt_t0, max_attempt_per_object) +# query_answer_probs = [] +# for t in next_state: +# vars_user_evidence = {"game_state_t0": game_state_t0, +# "attempt_t0": n_attempt_t0 - 1, +# "robot_assistance": selected_agent_assistance_action, +# "game_state_t1": t[0], +# "attempt_t1": t[1], +# } +# +# query_user_action_prob = bn_functions.infer_prob_from_state(bn_model_user_action, +# infer_variable=var_user_action_target_action, +# evidence_variables=vars_user_evidence) +# query_answer_probs.append(query_user_action_prob) +# +# +# #do the inference here +# #1. check given the current_state which are the possible states +# #2. for each of the possible states get the probability of user_action +# #3. select the state with the most higher action and execute it +# #4. return user_action +# + +def generate_agent_assistance(preferred_assistance, agent_behaviour, n_game_state, n_attempt, alpha_action=0.1): + agent_policy = [[0 for j in range(n_attempt)] for i in range(n_game_state)] + previous_assistance = -1 + def get_alternative_action(agent_assistance, previous_assistance, agent_behaviour, alpha_action): + agent_assistance_res = agent_assistance + if previous_assistance == agent_assistance: + if agent_behaviour == "challenge": + if random.random() > alpha_action: + agent_assistance_res = min(max(0, agent_assistance-1), 5) + else: + agent_assistance_res = min(max(0, agent_assistance), 5) + else: + if random.random() > alpha_action: + agent_assistance_res = min(max(0, agent_assistance + 1), 5) + else: + agent_assistance_res = min(max(0, agent_assistance), 5) + return agent_assistance_res + + + for gs in range(n_game_state): + for att in range(n_attempt): + if att == 0: + if random.random()>alpha_action: + agent_policy[gs][att] = preferred_assistance + previous_assistance = agent_policy[gs][att] + else: + if random.random()>0.5: + agent_policy[gs][att] = min(max(0, preferred_assistance-1),5) + previous_assistance = agent_policy[gs][att] + else: + agent_policy[gs][att] = min(max(0, preferred_assistance+1), 5) + previous_assistance = agent_policy[gs][att] + else: + if agent_behaviour == "challenge": + agent_policy[gs][att] = min(max(0, preferred_assistance-1), 5) + agent_policy[gs][att] = get_alternative_action(agent_policy[gs][att], previous_assistance, agent_behaviour, alpha_action) + previous_assistance = agent_policy[gs][att] + else: + agent_policy[gs][att] = min(max(0, preferred_assistance+1), 5) + agent_policy[gs][att] = get_alternative_action(agent_policy[gs][att], previous_assistance, agent_behaviour, alpha_action) + previous_assistance = agent_policy[gs][att] + + return agent_policy + def compute_next_state(user_action, task_progress_counter, attempt_counter, correct_move_counter, wrong_move_counter, timeout_counter, max_attempt_counter, max_attempt_per_object @@ -80,11 +170,10 @@ def compute_next_state(user_action, task_progress_counter, attempt_counter, corr def simulation(bn_model_user_action, var_user_action_target_action, bn_model_user_react_time, var_user_react_time_target_action, user_memory_name, user_memory_value, user_attention_name, user_attention_value, user_reactivity_name, user_reactivity_value, - task_progress_name, game_attempt_name, agent_assistance_name, agent_feedback_name, - bn_model_agent_assistance, var_agent_assistance_target_action, bn_model_agent_feedback, - var_agent_feedback_target_action, agent_policy, + task_progress_t0_name, task_progress_t1_name, game_attempt_t0_name, game_attempt_t1_name, + agent_assistance_name, agent_policy, state_space, action_space, - epochs=50, task_complexity=5, max_attempt_per_object=4): + epochs=50, task_complexity=5, max_attempt_per_object=4, alpha_learning=0): ''' Args: @@ -98,8 +187,6 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use #metrics we need, in order to compute afterwords the belief - attempt_counter_per_action = [[0 for i in range(Attempt.counter.value)] for j in range(User_Action.counter.value)] - game_state_counter_per_action = [[0 for i in range(Game_State.counter.value)] for j in range(User_Action.counter.value)] agent_feedback_per_action = [[0 for i in range(Agent_Feedback.counter.value)] for j in range(User_Action.counter.value)] agent_assistance_per_action = [[0 for i in range(Agent_Assistance.counter.value)] for j in range(User_Action.counter.value)] @@ -108,12 +195,21 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use agent_feedback_per_react_time = [[0 for i in range(Agent_Feedback.counter.value)] for j in range(User_React_time.counter.value)] agent_assistance_per_react_time = [[0 for i in range(Agent_Assistance.counter.value)] for j in range(User_React_time.counter.value)] - game_state_counter_per_agent_assistance = [[0 for i in range(Game_State.counter.value)] for j in range(Agent_Assistance.counter.value)] - attempt_counter_per_agent_assistance = [[0 for i in range(Attempt.counter.value)] for j in range(Agent_Assistance.counter.value)] - game_state_counter_per_agent_feedback = [[0 for i in range(Game_State.counter.value)] for j in range(Agent_Feedback.counter.value)] attempt_counter_per_agent_feedback = [[0 for i in range(Attempt.counter.value)] for j in range(Agent_Feedback.counter.value)] - + game_state_counter_per_agent_assistance = [[0 for i in range(Game_State.counter.value)] for j in + range(Agent_Assistance.counter.value)] + attempt_counter_per_agent_assistance = [[0 for i in range(Attempt.counter.value)] for j in + range(Agent_Assistance.counter.value)] + + + user_action_per_game_state_attempt_counter_agent_assistance = [[[[0 for i in range(User_Action.counter.value)] for l in range(Game_State.counter.value)] for j in + range(Attempt.counter.value)] for k in range(Agent_Assistance.counter.value)] + user_action_per_agent_assistance = [[0 for i in range(User_Action.counter.value)] for j in + range(Agent_Assistance.counter.value)] + attempt_counter_per_user_action = [[0 for i in range(Attempt.counter.value)] for j in range(User_Action.counter.value)] + game_state_counter_per_user_action = [[0 for i in range(Game_State.counter.value)] for j in + range(User_Action.counter.value)] #output variables: n_correct_per_episode = [0]*epochs @@ -131,6 +227,10 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use ep = Episode() for e in range(epochs): + print("##########################################################") + print("EPISODE ",e) + print("##########################################################") + '''Simulation framework''' #counters game_state_counter = 0 @@ -142,10 +242,11 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use max_attempt_counter = 0 #The following variables are used to update the BN at the end of the episode - user_action_dynamic_variables = {'attempt': attempt_counter_per_action, - 'game_state': game_state_counter_per_action, - 'agent_assistance': agent_assistance_per_action, - 'agent_feedback': agent_feedback_per_action} + user_action_dynamic_variables = { + 'attempt_t1': attempt_counter_per_user_action, + 'game_state_t1': game_state_counter_per_user_action, + 'user_action': user_action_per_game_state_attempt_counter_agent_assistance + } user_react_time_dynamic_variables = {'attempt': attempt_counter_per_react_time, 'game_state': game_state_counter_per_react_time, @@ -167,38 +268,8 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use current_state = (game_state_counter, attempt_counter, selected_user_action) - if type(agent_policy) is not np.ndarray: - ##################QUERY FOR THE ROBOT ASSISTANCE AND FEEDBACK################## - vars_agent_evidence = { - user_reactivity_name: user_reactivity_value, - user_memory_name: user_memory_value, - task_progress_name: game_state_counter, - game_attempt_name: attempt_counter-1, - } - - query_agent_assistance_prob = bn_functions.infer_prob_from_state(bn_model_agent_assistance, - infer_variable=var_agent_assistance_target_action, - evidence_variables=vars_agent_evidence) - if bn_model_agent_feedback != None: - query_agent_feedback_prob = bn_functions.infer_prob_from_state(bn_model_agent_feedback, - infer_variable=var_agent_feedback_target_action, - evidence_variables=vars_agent_evidence) - selected_agent_feedback_action = bn_functions.get_stochastic_action(query_agent_feedback_prob.values) - else: - selected_agent_feedback_action = 0 - - - selected_agent_assistance_action = bn_functions.get_stochastic_action(query_agent_assistance_prob.values) - else: - idx_state = ep.state_from_point_to_index(state_space, current_state) - if agent_policy[idx_state]>=Agent_Assistance.counter.value: - selected_agent_assistance_action = agent_policy[idx_state]-Agent_Assistance.counter.value - selected_agent_feedback_action = 1 - else: - selected_agent_assistance_action = agent_policy[idx_state] - selected_agent_feedback_action = 0 - - n_feedback_per_episode[e][selected_agent_feedback_action] += 1 + selected_agent_assistance_action = agent_policy[game_state_counter][attempt_counter-1]#random.randint(0,5) + selected_agent_feedback_action = 0#random.randint(0,1) #counters for plots n_assistance_lev_per_episode[e][selected_agent_assistance_action] += 1 @@ -211,38 +282,39 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use #compare the real user with the estimated Persona and returns a user action (0, 1a, 2) #return the user action in this state based on the Persona profile - vars_user_evidence = {user_attention_name: user_attention_value, - user_reactivity_name: user_reactivity_value, - user_memory_name: user_memory_value, - task_progress_name: game_state_counter, - game_attempt_name: attempt_counter-1, - agent_assistance_name: selected_agent_assistance_action, - agent_feedback_name: selected_agent_feedback_action - } + vars_user_evidence = { task_progress_t0_name: game_state_counter, + game_attempt_t0_name: attempt_counter - 1, + task_progress_t1_name: game_state_counter, + game_attempt_t1_name: attempt_counter - 1, + agent_assistance_name: selected_agent_assistance_action, + } + query_user_action_prob = bn_functions.infer_prob_from_state(bn_model_user_action, infer_variable=var_user_action_target_action, evidence_variables=vars_user_evidence) - query_user_react_time_prob = bn_functions.infer_prob_from_state(bn_model_user_react_time, - infer_variable=var_user_react_time_target_action, - evidence_variables=vars_user_evidence) - - + # query_user_react_time_prob = bn_functions.infer_prob_from_state(bn_model_user_react_time, + # infer_variable=var_user_react_time_target_action, + # evidence_variables=vars_user_evidence) + # + # selected_user_action = bn_functions.get_stochastic_action(query_user_action_prob.values) - selected_user_react_time = bn_functions.get_stochastic_action(query_user_react_time_prob.values) + # selected_user_react_time = bn_functions.get_stochastic_action(query_user_react_time_prob.values) # counters for plots - n_react_time_per_episode[e][selected_user_react_time] += 1 + # n_react_time_per_episode[e][selected_user_react_time] += 1 #updates counters for user action - agent_assistance_per_action[selected_user_action][selected_agent_assistance_action] += 1 - attempt_counter_per_action[selected_user_action][attempt_counter-1] += 1 - game_state_counter_per_action[selected_user_action][game_state_counter] += 1 - agent_feedback_per_action[selected_user_action][selected_agent_feedback_action] += 1 + + user_action_per_game_state_attempt_counter_agent_assistance[selected_agent_assistance_action][attempt_counter-1][game_state_counter][selected_user_action] += 1 + attempt_counter_per_user_action[selected_user_action][attempt_counter-1] += 1 + game_state_counter_per_user_action[selected_user_action][game_state_counter] += 1 + user_action_per_agent_assistance[selected_agent_assistance_action][selected_user_action] += 1 + #update counter for user react time - agent_assistance_per_react_time[selected_user_react_time][selected_agent_assistance_action] += 1 - attempt_counter_per_react_time[selected_user_react_time][attempt_counter-1] += 1 - game_state_counter_per_react_time[selected_user_react_time][game_state_counter] += 1 - agent_feedback_per_react_time[selected_user_react_time][selected_agent_feedback_action] += 1 + # agent_assistance_per_react_time[selected_user_react_time][selected_agent_assistance_action] += 1 + # attempt_counter_per_react_time[selected_user_react_time][attempt_counter-1] += 1 + # game_state_counter_per_react_time[selected_user_react_time][game_state_counter] += 1 + # agent_feedback_per_react_time[selected_user_react_time][selected_agent_feedback_action] += 1 #update counter for agent feedback game_state_counter_per_agent_feedback[selected_agent_feedback_action][game_state_counter] += 1 attempt_counter_per_agent_feedback[selected_agent_feedback_action][attempt_counter-1] += 1 @@ -269,8 +341,6 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use timeout_counter, max_attempt_counter, max_attempt_per_object) - - # store the (state, action, next_state) episode.append((ep.state_from_point_to_index(state_space, current_state), ep.state_from_point_to_index(action_space, current_agent_action), @@ -286,22 +356,25 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use episodes.append(Episode(episode)) #update user models - bn_model_user_action = bn_functions.update_cpds_tables(bn_model_user_action, user_action_dynamic_variables) + bn_model_user_action = bn_functions.update_cpds_tables(bn_model_user_action, user_action_dynamic_variables, alpha_learning) bn_model_user_react_time = bn_functions.update_cpds_tables(bn_model_user_react_time, user_react_time_dynamic_variables) #update agent models - bn_model_agent_assistance = bn_functions.update_cpds_tables(bn_model_agent_assistance, agent_assistance_dynamic_variables) - if bn_model_agent_feedback !=None: - bn_model_agent_feedback = bn_functions.update_cpds_tables(bn_model_agent_feedback, agent_feedback_dynamic_variables) + + print("user_given_game_attempt:", bn_model_user_action['model'].cpds[0].values) + print("user_given_robot:", bn_model_user_action['model'].cpds[5].values) + print("game_user:", bn_model_user_action['model'].cpds[3].values) + print("attempt_user:", bn_model_user_action['model'].cpds[2].values) #reset counter - agent_assistance_per_action = [[0 for i in range(Agent_Assistance.counter.value)] for j in - range(User_Action.counter.value)] - agent_feedback_per_action = [[0 for i in range(Agent_Feedback.counter.value)] for j in - range(User_Action.counter.value)] - game_state_counter_per_action = [[0 for i in range(Game_State.counter.value)] for j in - range(User_Action.counter.value)] - attempt_counter_per_action = [[0 for i in range(Attempt.counter.value)] for j in - range(User_Action.counter.value)] + user_action_per_game_state_attempt_counter_agent_assistance = [[[[0 for i in range(User_Action.counter.value)] + for l in range(Game_State.counter.value)] for j in + range(Attempt.counter.value)] for k in range(Agent_Assistance.counter.value)] + user_action_per_agent_assistance = [[0 for i in range(User_Action.counter.value)] for j in + range(Agent_Assistance.counter.value)] + attempt_counter_per_user_action = [[0 for i in range(Attempt.counter.value)] for j in + range(User_Action.counter.value)] + game_state_counter_per_user_action = [[0 for i in range(Game_State.counter.value)] for j in + range(User_Action.counter.value)] attempt_counter_per_react_time = [[0 for i in range(Attempt.counter.value)] for j in range(User_React_time.counter.value)] @@ -312,15 +385,6 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use agent_assistance_per_react_time = [[0 for i in range(Agent_Assistance.counter.value)] for j in range(User_React_time.counter.value)] - game_state_counter_per_agent_assistance = [[0 for i in range(Game_State.counter.value)] for j in - range(Agent_Assistance.counter.value)] - attempt_counter_per_agent_assistance = [[0 for i in range(Attempt.counter.value)] for j in - range(Agent_Assistance.counter.value)] - - game_state_counter_per_agent_feedback = [[0 for i in range(Game_State.counter.value)] for j in - range(Agent_Feedback.counter.value)] - attempt_counter_per_agent_feedback = [[0 for i in range(Attempt.counter.value)] for j in - range(Agent_Feedback.counter.value)] #for plots n_correct_per_episode[e] = correct_move_counter @@ -343,126 +407,58 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use ############################################################################# ############################################################################# -# #SIMULATION PARAMS -# epochs = 100 -# -# #initialise the agent -# bn_model_caregiver_assistance = bnlearn.import_DAG('bn_agent_model/agent_assistive_model.bif') -# bn_model_caregiver_feedback = bnlearn.import_DAG('bn_agent_model/agent_feedback_model.bif') -# bn_model_user_action = bnlearn.import_DAG('bn_persona_model/user_action_model.bif') -# bn_model_user_react_time = bnlearn.import_DAG('bn_persona_model/user_react_time_model.bif') -# bn_model_other_user_action = None#bnlearn.import_DAG('bn_persona_model/other_user_action_model.bif') -# bn_model_other_user_react_time = None#bnlearn.import_DAG('bn_persona_model/other_user_react_time_model.bif') -# -# #initialise memory, attention and reactivity varibles -# persona_memory = 0; persona_attention = 0; persona_reactivity = 1; -# #initialise memory, attention and reactivity varibles -# other_user_memory = 2; other_user_attention = 2; other_user_reactivity = 2; -# -# #define state space struct for the irl algorithm -# attempt = [i for i in range(1, Attempt.counter.value+1)] -# #+1a (3,_,_) absorbing state -# game_state = [i for i in range(0, Game_State.counter.value+1)] -# user_action = [i for i in range(-1, User_Action.counter.value-1)] -# state_space = (game_state, attempt, user_action) -# states_space_list = list(itertools.product(*state_space)) -# agent_assistance_action = [i for i in range(Agent_Assistance.counter.value)] -# agent_feedback_action = [i for i in range(Agent_Feedback.counter.value)] -# action_space = (agent_assistance_action, agent_feedback_action) -# action_space_list = list(itertools.product(*action_space)) -# -# ##############BEFORE RUNNING THE SIMULATION UPDATE THE BELIEF IF YOU HAVE DATA#################### -# log_directory = "/home/pal/carf_ws/src/carf/caregiver_in_the_loop/log/1/0" -# if os.path.exists(log_directory): -# bn_belief_user_action_file = log_directory+"/bn_belief_user_action.pkl" -# bn_belief_user_react_time_file = log_directory+"/bn_belief_user_react_time.pkl" -# bn_belief_caregiver_assistance_file = log_directory+"/bn_belief_caregiver_assistive_action.pkl" -# bn_belief_caregiver_feedback_file = log_directory+"/bn_belief_caregiver_feedback_action.pkl" -# -# bn_belief_user_action = utils.read_user_statistics_from_pickle(bn_belief_user_action_file) -# bn_belief_user_react_time = utils.read_user_statistics_from_pickle(bn_belief_user_react_time_file) -# bn_belief_caregiver_assistance = utils.read_user_statistics_from_pickle(bn_belief_caregiver_assistance_file) -# bn_belief_caregiver_feedback = utils.read_user_statistics_from_pickle(bn_belief_caregiver_feedback_file) -# bn_model_user_action = bn_functions.update_cpds_tables(bn_model=bn_model_user_action, variables_tables=bn_belief_user_action) -# bn_model_user_react_time = bn_functions.update_cpds_tables(bn_model=bn_model_user_react_time, variables_tables=bn_belief_user_react_time) -# bn_model_caregiver_assistance = bn_functions.update_cpds_tables(bn_model=bn_model_caregiver_assistance, variables_tables=bn_belief_caregiver_assistance) -# bn_model_caregiver_feedback = bn_functions.update_cpds_tables(bn_model=bn_model_caregiver_feedback, variables_tables=bn_belief_caregiver_feedback) -# -# else: -# assert("You're not using the user information") -# question = input("Are you sure you don't want to load user's belief information?") -# -# game_performance_per_episode, react_time_per_episode, agent_assistance_per_episode, agent_feedback_per_episode, generated_episodes = \ -# simulation(bn_model_user_action=bn_model_user_action, var_user_action_target_action=['user_action'], -# bn_model_user_react_time=bn_model_user_react_time, -# var_user_react_time_target_action=['user_react_time'], -# user_memory_name="memory", user_memory_value=persona_memory, -# user_attention_name="attention", user_attention_value=persona_attention, -# user_reactivity_name="reactivity", user_reactivity_value=persona_reactivity, -# task_progress_name="game_state", game_attempt_name="attempt", -# agent_assistance_name="agent_assistance", agent_feedback_name="agent_feedback", -# bn_model_agent_assistance=bn_model_caregiver_assistance, -# var_agent_assistance_target_action=["agent_assistance"], -# bn_model_agent_feedback=bn_model_caregiver_feedback, var_agent_feedback_target_action=["agent_feedback"], -# bn_model_other_user_action=bn_model_other_user_action, -# var_other_user_action_target_action=['user_action'], -# bn_model_other_user_react_time=bn_model_other_user_react_time, -# var_other_user_target_react_time_action=["user_react_time"], other_user_memory_name="memory", -# other_user_memory_value=other_user_memory, other_user_attention_name="attention", -# other_user_attention_value=other_user_attention, other_user_reactivity_name="reactivity", -# other_user_reactivity_value=other_user_reactivity, -# state_space=states_space_list, action_space=action_space_list, -# epochs=epochs, task_complexity=5, max_attempt_per_object=4) -# -# -# -# plot_game_performance_path = "" -# plot_agent_assistance_path = "" -# episodes_path = "episodes.npy" -# -# if bn_model_other_user_action != None: -# plot_game_performance_path = "game_performance_"+"_epoch_"+str(epochs)+"_real_user_memory_"+str(real_user_memory)+"_real_user_attention_"+str(real_user_attention)+"_real_user_reactivity_"+str(real_user_reactivity)+".jpg" -# plot_agent_assistance_path = "agent_assistance_"+"epoch_"+str(epochs)+"_real_user_memory_"+str(real_user_memory)+"_real_user_attention_"+str(real_user_attention)+"_real_user_reactivity_"+str(real_user_reactivity)+".jpg" -# plot_agent_feedback_path = "agent_feedback_"+"epoch_"+str(epochs)+"_real_user_memory_"+str(real_user_memory)+"_real_user_attention_"+str(real_user_attention)+"_real_user_reactivity_"+str(real_user_reactivity)+".jpg" -# -# else: -# plot_game_performance_path = "game_performance_"+"epoch_" + str(epochs) + "_persona_memory_" + str(persona_memory) + "_persona_attention_" + str(persona_attention) + "_persona_reactivity_" + str(persona_reactivity) + ".jpg" -# plot_agent_assistance_path = "agent_assistance_"+"epoch_"+str(epochs)+"_persona_memory_"+str(persona_memory)+"_persona_attention_"+str(persona_attention)+"_persona_reactivity_"+str(persona_reactivity)+".jpg" -# plot_agent_feedback_path = "agent_feedback_"+"epoch_"+str(epochs)+"_persona_memory_"+str(persona_memory)+"_persona_attention_"+str(persona_attention)+"_persona_reactivity_"+str(persona_reactivity)+".jpg" -# -# dir_name = input("Please insert the name of the directory:") -# full_path = os.getcwd()+"/results/"+dir_name+"/" -# if not os.path.exists(full_path): -# os.mkdir(full_path) -# print("Directory ", full_path, " created.") -# else: -# dir_name = input("The directory already exist please insert a new name:") -# print("Directory ", full_path, " created.") -# if os.path.exists(full_path): -# assert("Directory already exists ... start again") -# exit(0) -# -# with open(full_path+episodes_path, "ab") as f: -# np.save(full_path+episodes_path, generated_episodes) -# f.close() -# -# -# utils.plot2D_game_performance(full_path+plot_game_performance_path, epochs, game_performance_per_episode) -# utils.plot2D_assistance(full_path+plot_agent_assistance_path, epochs, agent_assistance_per_episode) -# utils.plot2D_feedback(full_path+plot_agent_feedback_path, epochs, agent_feedback_per_episode) - - - -''' -With the current simulator we can generate a list of episodes -the episodes will be used to generate the trans probabilities and as input to the IRL algo -''' -#TODO -# - include reaction time as output -# - average mistakes, average timeout, average assistance, average_react_time -# - include real time episodes into the simulation: -# - counters for agent_assistance, agent_feedback, attempt, game_state, attention and reactivity -# - using the function update probability to generate the new user model and use it as input to the simulator - - +agent_policy = generate_agent_assistance(preferred_assistance=2, agent_behaviour="help", n_game_state=Game_State.counter.value, n_attempt=Attempt.counter.value, alpha_action=0.5) + +# SIMULATION PARAMS +epochs = 20 +scaling_factor = 1 +# initialise the agent +bn_model_caregiver_assistance = bnlearn.import_DAG('/home/pal/Documents/Framework/bn_generative_model/bn_agent_model/agent_assistive_model.bif') +bn_model_caregiver_feedback = None#bnlearn.import_DAG('/home/pal/Documents/Framework/bn_generative_model/bn_agent_model/agent_feedback_model.bif') +bn_model_user_action = bnlearn.import_DAG('/home/pal/Documents/Framework/bn_generative_model/bn_persona_model/persona_model_test.bif') +bn_model_user_react_time = bnlearn.import_DAG('/home/pal/Documents/Framework/bn_generative_model/bn_persona_model/user_react_time_model.bif') + +# initialise memory, attention and reactivity variables +persona_memory = 0; +persona_attention = 0; +persona_reactivity = 1; + +# define state space struct for the irl algorithm +episode_instance = Episode() +# DEFINITION OF THE MDP +# define state space struct for the irl algorithm +attempt = [i for i in range(1, Attempt.counter.value + 1)] +# +1 (3,_,_) absorbing state +game_state = [i for i in range(0, Game_State.counter.value + 1)] +user_action = [i for i in range(-1, User_Action.counter.value - 1)] +state_space = (game_state, attempt, user_action) +states_space_list = list(itertools.product(*state_space)) +state_space_index = [episode_instance.state_from_point_to_index(states_space_list, s) for s in states_space_list] +agent_assistance_action = [i for i in range(Agent_Assistance.counter.value)] +agent_feedback_action = [i for i in range(Agent_Feedback.counter.value)] +action_space = (agent_feedback_action, agent_assistance_action) +action_space_list = list(itertools.product(*action_space)) +action_space_index = [episode_instance.state_from_point_to_index(action_space_list, a) for a in action_space_list] +terminal_state = [(Game_State.counter.value, i, user_action[j]) for i in range(1, Attempt.counter.value + 1) for j in + range(len(user_action))] +initial_state = (1, 1, 0) + +#1. RUN THE SIMULATION WITH THE PARAMS SET BY THE CAREGIVER + + +game_performance_per_episode, react_time_per_episode, agent_assistance_per_episode, agent_feedback_per_episode, episodes_list = \ + simulation(bn_model_user_action=bn_model_user_action, var_user_action_target_action=['user_action'], + bn_model_user_react_time=bn_model_user_react_time, + var_user_react_time_target_action=['user_react_time'], + user_memory_name="memory", user_memory_value=persona_memory, + user_attention_name="attention", user_attention_value=persona_attention, + user_reactivity_name="reactivity", user_reactivity_value=persona_reactivity, + task_progress_t0_name="game_state_t0", task_progress_t1_name="game_state_t1", + game_attempt_t0_name="attempt_t0", game_attempt_t1_name="attempt_t1", + agent_assistance_name="agent_assistance", agent_policy=agent_policy, + state_space=states_space_list, action_space=action_space_list, + epochs=epochs, task_complexity=5, max_attempt_per_object=4, alpha_learning=0.1) + +utils.plot2D_game_performance("/home/pal/Documents/Framework/bn_generative_model/results/user_performance.png", epochs, scaling_factor, game_performance_per_episode) +utils.plot2D_assistance("/home/pal/Documents/Framework/bn_generative_model/results/agent_assistance.png", epochs, scaling_factor, agent_assistance_per_episode) diff --git a/test.py b/test.py index 0a232014330323e1afb30dbde192df29af620de1..66bd57cac6a41c2575e609287938cb842d8dd8b0 100644 --- a/test.py +++ b/test.py @@ -8,14 +8,45 @@ DAG = bn.import_DAG('bn_persona_model/persona_model_test.bif') #DAGnew = bn.parameter_learning.fit(model, df, methodtype="bayes") #bn.print_CPD(DAGnew) q1 = bn.inference.fit(DAG, variables=['user_action'], evidence={ - 'game_state_t0': 1, - 'attempt_t0':0, - 'robot_assistance':5, - 'game_state_t1': 1, + 'game_state_t0': 0, 'attempt_t0':1, - - + 'game_state_t1': 0, + 'attempt_t1':2, + 'agent_assistance':0, }) - -print(q1.variables) print(q1.values) + +# robot_assistance = [0, 1, 2, 3, 4, 5] +# attempt_t0 = [0, 1, 2, 3] +# game_state_t0 = [0, 1, 2] +# attempt_t1 = [0] +# game_state_t1 = [0, 1, 2] +# +# query_result = [[[0 for j in range(len(attempt_t0))] for i in range(len(robot_assistance))] for k in range(len(game_state_t0))] +# for k in range(len(game_state_t0)): +# for i in range(len(robot_assistance)): +# for j in range(len(attempt_t0)-1): +# if j == 0: +# query = bn.inference.fit(DAG, variables=['user_action'], +# evidence={'game_state_t0': k, +# 'attempt_t0': j, +# 'agent_assistance': i, +# 'game_state_t1': k, +# 'attempt_t1': j}) +# query_result[k][i][j] = query.values +# else: +# query = bn.inference.fit(DAG, variables=['user_action'], +# evidence={'game_state_t0': k, +# 'attempt_t0': j, +# 'agent_assistance': i, +# 'game_state_t1': k, +# 'attempt_t1': j + 1}) +# query_result[k][i][j] = query.values +# for k in range(len(game_state_t0)): +# for i in range(len(robot_assistance)): +# for j in range(len(attempt_t0)): +# if j == 0: +# print("game_state:",k, "attempt_from:", j," attempt_to:",j, " robot_ass:",i, " prob:", query_result[k][i][j]) +# else: +# print("game_state:", k, "attempt_from:", j, " attempt_to:", j+1, " robot_ass:", i, " prob:", +# query_result[k][i][j]) diff --git a/utils.py b/utils.py index a1aa1d495a9e7176deb2f51757f02ea076cefd1f..41d788bacb6523191606f13cbaea7ba8515ecca6 100644 --- a/utils.py +++ b/utils.py @@ -43,6 +43,7 @@ def plot2D_assistance(save_path, n_episodes, scaling_factor=1, *y): lev_2 = list(map(lambda x:x[2], y[0]))[1::scaling_factor] lev_3 = list(map(lambda x:x[3], y[0]))[1::scaling_factor] lev_4 = list(map(lambda x:x[4], y[0]))[1::scaling_factor] + lev_5 = list(map(lambda x:x[5], y[0]))[1::scaling_factor] # plot bars plt.figure(figsize=(10, 7)) @@ -54,7 +55,8 @@ def plot2D_assistance(save_path, n_episodes, scaling_factor=1, *y): width=barWidth, label='lev_3') plt.bar(r, lev_4, bottom=np.array(lev_0) + np.array(lev_1)+ np.array(lev_2)+ np.array(lev_3), edgecolor='white', width=barWidth, label='lev_4') - + plt.bar(r, lev_5, bottom=np.array(lev_0) + np.array(lev_1) + np.array(lev_2) + np.array(lev_3)+ np.array(lev_4), edgecolor='white', + width=barWidth, label='lev_5') plt.legend() # Custom X axis