Skip to content
Snippets Groups Projects
Commit bef7c58c authored by Antonio Andriella's avatar Antonio Andriella
Browse files

Working version of new model

parent 4aaa0934
No related branches found
No related tags found
No related merge requests found
...@@ -44,12 +44,35 @@ def compute_prob(cpds_table): ...@@ -44,12 +44,35 @@ def compute_prob(cpds_table):
This function checks if any This function checks if any
''' '''
for val in range(len(cpds_table)): cpds_table_array = np.array(cpds_table)
cpds_table[val] = list(map(lambda x: x / (sum(cpds_table[val])+0.00001), cpds_table[val])) cpds_table_array_len = cpds_table_array.shape.__len__()
cpds_table[val] = check_zero_occurrences(cpds_table[val])
if cpds_table_array_len == 4:
# attempt
for elem1 in range(cpds_table_array.shape[0]):
# game_state
for elem2 in range(cpds_table_array.shape[1]):
#assistance
for elem3 in range(cpds_table_array.shape[2]):
cpds_table[elem1][elem2][elem3] = list(
map(lambda x: x / (sum(cpds_table[elem1][elem2][elem3]) + 0.00001), cpds_table[elem1][elem2][elem3]))
cpds_table[elem1][elem2][elem3] = check_zero_occurrences(cpds_table[elem1][elem2][elem3])
elif cpds_table_array_len ==3:
#attempt
for elem1 in range(cpds_table_array.shape[0]):
#game_state
for elem2 in range(cpds_table_array.shape[1]):
cpds_table[elem1][elem2] = list(map(lambda x: x / (sum(cpds_table[elem1][elem2]) + 0.00001), cpds_table[elem1][elem2]))
cpds_table[elem1][elem2] = check_zero_occurrences(cpds_table[elem1][elem2])
else:
for val in range(len(cpds_table)):
cpds_table[val] = list(map(lambda x: x / (sum(cpds_table[val])+0.00001), cpds_table[val]))
cpds_table[val] = check_zero_occurrences(cpds_table[val])
return cpds_table return cpds_table
def average_prob(ref_cpds_table, current_cpds_table): def average_prob(ref_cpds_table, current_cpds_table, alpha):
''' '''
Args: Args:
ref_cpds_table: table from bnlearn ref_cpds_table: table from bnlearn
...@@ -58,12 +81,13 @@ def average_prob(ref_cpds_table, current_cpds_table): ...@@ -58,12 +81,13 @@ def average_prob(ref_cpds_table, current_cpds_table):
avg from both tables avg from both tables
''' '''
res_cpds_table = ref_cpds_table.copy() res_cpds_table = ref_cpds_table.copy()
current_cpds_table_np_array = np.array(current_cpds_table)
for elem1 in range(len(ref_cpds_table)): for elem1 in range(len(ref_cpds_table)):
for elem2 in range(len(ref_cpds_table[0])): for elem2 in range(len(ref_cpds_table[0])):
res_cpds_table[elem1][elem2] = (ref_cpds_table[elem1][elem2]+current_cpds_table[elem1][elem2])/2 res_cpds_table[elem1][elem2] = (ref_cpds_table[elem1][elem2]*(1-alpha))+(current_cpds_table_np_array[elem1][elem2]*alpha)
return res_cpds_table return res_cpds_table
def update_cpds_tables(bn_model, variables_tables): def update_cpds_tables(bn_model, variables_tables, alpha=0.1):
''' '''
This function updates the bn model with the variables_tables provided in input This function updates the bn model with the variables_tables provided in input
Args: Args:
...@@ -80,7 +104,7 @@ def update_cpds_tables(bn_model, variables_tables): ...@@ -80,7 +104,7 @@ def update_cpds_tables(bn_model, variables_tables):
cpds_table_from_counter = compute_prob(val) cpds_table_from_counter = compute_prob(val)
updated_prob = average_prob( updated_prob = average_prob(
np.transpose(cpds_table), np.transpose(cpds_table),
cpds_table_from_counter) cpds_table_from_counter, alpha)
bn_model['model'].cpds[index].values = np.transpose(updated_prob) bn_model['model'].cpds[index].values = np.transpose(updated_prob)
return bn_model return bn_model
......
...@@ -3,7 +3,7 @@ network persona_model { ...@@ -3,7 +3,7 @@ network persona_model {
%VARIABLES DEFINITION %VARIABLES DEFINITION
variable robot_assistance { variable agent_assistance {
type discrete [ 6 ] { lev_0, lev_1, lev_2, lev_3, lev_4, lev_5 }; type discrete [ 6 ] { lev_0, lev_1, lev_2, lev_3, lev_4, lev_5 };
} }
variable attempt_t0 { variable attempt_t0 {
...@@ -13,18 +13,20 @@ variable game_state_t0 { ...@@ -13,18 +13,20 @@ variable game_state_t0 {
type discrete [ 3 ] { beg, mid, end }; type discrete [ 3 ] { beg, mid, end };
} }
variable attempt_t1 { variable attempt_t1 {
type discrete [ 4 ] { att_1, att_2, att_3, att_4 }; type discrete [ 1 ] { att_1, att_2, att_3, att_4};
} }
variable game_state_t1 { variable game_state_t1 {
type discrete [ 3 ] { beg, mid, end }; type discrete [ 3 ] { beg, mid, end };
} }
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
variable user_action { variable user_action {
type discrete [ 3 ] { correct, wrong, timeout }; type discrete [ 3 ] { correct, wrong, timeout };
} }
variable agent_feedback {
type discrete [ 2 ] { no, yes };
}
%INDIVIDUAL PROBABILITIES DEFINITION %INDIVIDUAL PROBABILITIES DEFINITION
probability ( robot_assistance ) { probability ( agent_assistance ) {
table 0.17, 0.16, 0.16, 0.17, 0.17, 0.17; table 0.17, 0.16, 0.16, 0.17, 0.17, 0.17;
} }
probability ( game_state_t0 ) { probability ( game_state_t0 ) {
...@@ -39,124 +41,104 @@ probability ( game_state_t1 ) { ...@@ -39,124 +41,104 @@ probability ( game_state_t1 ) {
probability ( attempt_t1 ) { probability ( attempt_t1 ) {
table 0.25, 0.25, 0.25, 0.25; table 0.25, 0.25, 0.25, 0.25;
} }
probability ( user_action ) { probability ( user_action ) {
table 0.33, 0.33, 0.34; table 0.33, 0.33, 0.34;
} }
probability ( agent_feedback ) {
probability (robot_assistance | game_state_t0, attempt_t0){ table 0.5, 0.5;
(beg, att_1) 0.0, 0.2, 0.2, 0.2, 0.2, 0.2; }
(beg, att_2) 0.05, 0.15, 0.2, 0.2, 0.2, 0.2; probability(agent_assistance | agent_feedback) {
(beg, att_3) 0.1, 0.2, 0.2, 0.2, 0.2, 0.1; (yes) 0.4, 0.3, 0.2, 0.1, 0.0, 0.0
(beg, att_4) 0.15, 0.05, 0.2, 0.2, 0.2, 0.2; (no) 0.0, 0.0, 0.1, 0.2, 0.3, 0.4
(mid, att_1) 0.0, 0.2, 0.2, 0.2, 0.2, 0.2;
(mid, att_2) 0.05, 0.15, 0.2, 0.2, 0.2, 0.2;
(mid, att_3) 0.1, 0.2, 0.2, 0.2, 0.2, 0.1;
(mid, att_4) 0.15, 0.05, 0.2, 0.2, 0.2, 0.2;
(end, att_1) 0.0, 0.2, 0.2, 0.2, 0.2, 0.2;
(end, att_2) 0.05, 0.15, 0.2, 0.2, 0.2, 0.2;
(end, att_3) 0.1, 0.2, 0.2, 0.2, 0.2, 0.1;
(end, att_4) 0.15, 0.05, 0.2, 0.2, 0.2, 0.2;
} }
probability (user_action | game_state_t0, attempt_t0, agent_assistance){
probability (user_action | game_state_t0, attempt_t0, robot_assistance){
(beg, att_1, lev_0) 0.1, 0.9, 0.0; (beg, att_1, lev_0) 0.1, 0.9, 0.0;
(beg, att_2, lev_0) 0.2, 0.8, 0.0; (beg, att_2, lev_0) 0.2, 0.8, 0.0;
(beg, att_3, lev_0) 0.3, 0.7, 0.0; (beg, att_3, lev_0) 0.3, 0.7, 0.0;
(beg, att_4, lev_0) 0.4, 0.6, 0.0; (beg, att_4, lev_0) 0.4, 0.6, 0.0;
(beg, att_1, lev_1) 0.1, 0.9, 0.0; (beg, att_1, lev_1) 0.2, 0.8, 0.0;
(beg, att_2, lev_1) 0.2, 0.8, 0.0; (beg, att_2, lev_1) 0.3, 0.7, 0.0;
(beg, att_3, lev_1) 0.3, 0.7, 0.0; (beg, att_3, lev_1) 0.4, 0.6, 0.0;
(beg, att_4, lev_1) 0.4, 0.6, 0.0; (beg, att_4, lev_1) 0.5, 0.5, 0.0;
(beg, att_1, lev_2) 0.1, 0.9, 0.0; (beg, att_1, lev_2) 0.3, 0.7, 0.0;
(beg, att_2, lev_2) 0.2, 0.8, 0.0; (beg, att_2, lev_2) 0.4, 0.6, 0.0;
(beg, att_3, lev_2) 0.3, 0.7, 0.0; (beg, att_3, lev_2) 0.5, 0.5, 0.0;
(beg, att_4, lev_2) 0.4, 0.6, 0.0; (beg, att_4, lev_2) 0.6, 0.4, 0.0;
(beg, att_1, lev_3) 0.1, 0.9, 0.0; (beg, att_1, lev_3) 0.4, 0.6, 0.0;
(beg, att_2, lev_3) 0.2, 0.8, 0.0; (beg, att_2, lev_3) 0.5, 0.5, 0.0;
(beg, att_3, lev_3) 0.3, 0.7, 0.0; (beg, att_3, lev_3) 0.6, 0.4, 0.0;
(beg, att_4, lev_3) 0.4, 0.6, 0.0; (beg, att_4, lev_3) 0.7, 0.3, 0.0;
(beg, att_1, lev_4) 0.1, 0.9, 0.0; (beg, att_1, lev_4) 1.0, 0.0, 0.0;
(beg, att_2, lev_4) 0.2, 0.8, 0.0; (beg, att_2, lev_4) 1.0, 0.0, 0.0;
(beg, att_3, lev_4) 0.3, 0.7, 0.0; (beg, att_3, lev_4) 1.0, 0.0, 0.0;
(beg, att_4, lev_4) 0.4, 0.6, 0.0; (beg, att_4, lev_4) 1.0, 0.0, 0.0;
(beg, att_1, lev_5) 0.1, 0.9, 0.0; (beg, att_1, lev_5) 1.0, 0.0, 0.0;
(beg, att_2, lev_5) 0.2, 0.8, 0.0; (beg, att_2, lev_5) 1.0, 0.0, 0.0;
(beg, att_3, lev_5) 0.3, 0.7, 0.0; (beg, att_3, lev_5) 1.0, 0.0, 0.0;
(beg, att_4, lev_5) 0.4, 0.6, 0.0; (beg, att_4, lev_5) 1.0, 0.0, 0.0;
(mid, att_1, lev_0) 0.1, 0.9, 0.0; (mid, att_1, lev_0) 0.1, 0.9, 0.0;
(mid, att_2, lev_0) 0.2, 0.8, 0.0; (mid, att_2, lev_0) 0.2, 0.8, 0.0;
(mid, att_3, lev_0) 0.3, 0.7, 0.0; (mid, att_3, lev_0) 0.3, 0.7, 0.0;
(mid, att_4, lev_0) 0.4, 0.6, 0.0; (mid, att_4, lev_0) 0.4, 0.6, 0.0;
(mid, att_1, lev_1) 0.1, 0.9, 0.0; (mid, att_1, lev_1) 0.2, 0.8, 0.0;
(mid, att_2, lev_1) 0.2, 0.8, 0.0; (mid, att_2, lev_1) 0.3, 0.7, 0.0;
(mid, att_3, lev_1) 0.3, 0.7, 0.0; (mid, att_3, lev_1) 0.4, 0.6, 0.0;
(mid, att_4, lev_1) 0.4, 0.6, 0.0; (mid, att_4, lev_1) 0.5, 0.5, 0.0;
(mid, att_1, lev_2) 0.1, 0.9, 0.0; (mid, att_1, lev_2) 0.3, 0.7, 0.0;
(mid, att_2, lev_2) 0.2, 0.8, 0.0; (mid, att_2, lev_2) 0.4, 0.6, 0.0;
(mid, att_3, lev_2) 0.3, 0.7, 0.0; (mid, att_3, lev_2) 0.5, 0.5, 0.0;
(mid, att_4, lev_2) 0.4, 0.6, 0.0; (mid, att_4, lev_2) 0.6, 0.4, 0.0;
(mid, att_1, lev_3) 0.1, 0.9, 0.0; (mid, att_1, lev_3) 0.4, 0.6, 0.0;
(mid, att_2, lev_3) 0.2, 0.8, 0.0; (mid, att_2, lev_3) 0.5, 0.5, 0.0;
(mid, att_3, lev_3) 0.3, 0.7, 0.0; (mid, att_3, lev_3) 0.6, 0.4, 0.0;
(mid, att_4, lev_3) 0.4, 0.6, 0.0; (mid, att_4, lev_3) 0.7, 0.3, 0.0;
(mid, att_1, lev_4) 0.1, 0.9, 0.0; (mid, att_1, lev_4) 1.0, 0.0, 0.0;
(mid, att_2, lev_4) 0.2, 0.8, 0.0; (mid, att_2, lev_4) 1.0, 0.0, 0.0;
(mid, att_3, lev_4) 0.3, 0.7, 0.0; (mid, att_3, lev_4) 1.0, 0.0, 0.0;
(mid, att_4, lev_4) 0.4, 0.6, 0.0; (mid, att_4, lev_4) 1.0, 0.0, 0.0;
(mid, att_1, lev_5) 0.1, 0.9, 0.0; (mid, att_1, lev_5) 1.0, 0.0, 0.0;
(mid, att_2, lev_5) 0.2, 0.8, 0.0; (mid, att_2, lev_5) 1.0, 0.0, 0.0;
(mid, att_3, lev_5) 0.3, 0.7, 0.0; (mid, att_3, lev_5) 1.0, 0.0, 0.0;
(mid, att_4, lev_5) 0.4, 0.6, 0.0; (mid, att_4, lev_5) 1.0, 0.0, 0.0;
(end, att_1, lev_0) 0.1, 0.9, 0.0; (end, att_1, lev_0) 0.1, 0.9, 0.0;
(end, att_2, lev_0) 0.2, 0.8, 0.0; (end, att_2, lev_0) 0.2, 0.8, 0.0;
(end, att_3, lev_0) 0.2, 0.8, 0.0; (end, att_3, lev_0) 0.3, 0.7, 0.0;
(end, att_4, lev_0) 0.4, 0.6, 0.0; (end, att_4, lev_0) 0.4, 0.6, 0.0;
(end, att_1, lev_1) 0.1, 0.9, 0.0; (end, att_1, lev_1) 0.2, 0.8, 0.0;
(end, att_2, lev_1) 0.2, 0.8, 0.0; (end, att_2, lev_1) 0.3, 0.7, 0.0;
(end, att_3, lev_1) 0.4, 0.6, 0.0; (end, att_3, lev_1) 0.4, 0.6, 0.0;
(end, att_4, lev_1) 0.4, 0.6, 0.0; (end, att_4, lev_1) 0.5, 0.5, 0.0;
(end, att_1, lev_2) 0.1, 0.9, 0.0; (end, att_1, lev_2) 0.3, 0.7, 0.0;
(end, att_2, lev_2) 0.2, 0.8, 0.0; (end, att_2, lev_2) 0.4, 0.6, 0.0;
(end, att_3, lev_2) 0.4, 0.6, 0.0; (end, att_3, lev_2) 0.5, 0.5, 0.0;
(end, att_4, lev_2) 0.4, 0.6, 0.0; (end, att_4, lev_2) 0.6, 0.4, 0.0;
(end, att_1, lev_3) 0.1, 0.9, 0.0; (end, att_1, lev_3) 0.4, 0.6, 0.0;
(end, att_2, lev_3) 0.2, 0.8, 0.0; (end, att_2, lev_3) 0.5, 0.5, 0.0;
(end, att_3, lev_3) 0.5, 0.5, 0.0; (end, att_3, lev_3) 0.6, 0.4, 0.0;
(end, att_4, lev_3) 0.4, 0.6, 0.0; (end, att_4, lev_3) 0.7, 0.3, 0.0;
(end, att_1, lev_4) 0.1, 0.9, 0.0; (end, att_1, lev_4) 1.0, 0.0, 0.0;
(end, att_2, lev_4) 0.2, 0.8, 0.0; (end, att_2, lev_4) 1.0, 0.0, 0.0;
(end, att_3, lev_4) 0.7, 0.3, 0.0; (end, att_3, lev_4) 1.0, 0.0, 0.0;
(end, att_4, lev_4) 0.4, 0.6, 0.0; (end, att_4, lev_4) 1.0, 0.0, 0.0;
(end, att_1, lev_5) 0.1, 0.9, 0.0; (end, att_1, lev_5) 1.0, 0.0, 0.0;
(end, att_2, lev_5) 0.2, 0.8, 0.0; (end, att_2, lev_5) 1.0, 0.0, 0.0;
(end, att_3, lev_5) 0.3, 0.7, 0.0; (end, att_3, lev_5) 1.0, 0.0, 0.0;
(end, att_4, lev_5) 0.4, 0.6, 0.0; (end, att_4, lev_5) 1.0, 0.0, 0.0;
} }
probability (game_state_t1 | user_action) { probability (game_state_t1 | user_action) {
(correct) 0.2, 0.3, 0.5; (correct) 0.25, 0.3, 0.45;
(wrong) 0.5, 0.3, 0.2; (wrong) 0.33, 0.33, 0.33;
(timeout) 0.33, 0.34, 0.33; (timeout) 0.33, 0.33, 0.33;
} }
probability (attempt_t1 | user_action) { probability (attempt_t1 | user_action) {
(correct) 0.1, 0.2, 0.3, 0.4; (correct) 0.1, 0.2, 0.25, 0.45;
(wrong) 0.4, 0.3, 0.2, 0.1; (wrong) 0.25, 0.25, 0.25, 0.25;
(timeout) 0.25, 0.25, 0.25, 0.25; (timeout) 0.25, 0.25, 0.25, 0.25;
} }
probability (user_action | robot_assistance){
(lev_0) 0.1, 0.6, 0.3;
(lev_1) 0.2, 0.5, 0.3;
(lev_2) 0.3, 0.5, 0.2;
(lev_3) 0.5, 0.3, 0.2;
(lev_4) 0.9, 0.1, 0.0;
(lev_5) 0.9, 0.1, 0.0;
}
\ No newline at end of file
...@@ -2,12 +2,102 @@ import itertools ...@@ -2,12 +2,102 @@ import itertools
import os import os
import bnlearn import bnlearn
import numpy as np import numpy as np
import random
#import classes and modules #import classes and modules
from bn_variables import Agent_Assistance, Agent_Feedback, User_Action, User_React_time, Game_State, Attempt from bn_variables import Agent_Assistance, Agent_Feedback, User_Action, User_React_time, Game_State, Attempt
import bn_functions import bn_functions
import utils import utils
from episode import Episode from episode import Episode
#
# def choose_next_states(task_progress, game_state_t0, n_attempt_t0, max_attempt_per_object,
# selected_agent_assistance_action,
# bn_model_user_action, var_user_action_target_action):
#
# def get_next_state(task_progress, game_state_t0, n_attempt_t0, max_attempt_per_object):
#
# next_state = []
#
# #correct move on the last state of the bin
# if (task_progress == 1 or task_progress == 3 or task_progress == 4) and n_attempt_t0<max_attempt_per_object:
# next_state.append((game_state_t0+1, n_attempt_t0+1))
# #correct state bu still in the bin
# elif task_progress == 0 or task_progress == 2 and n_attempt_t0<max_attempt_per_object:
# next_state.append((game_state_t0, n_attempt_t0+1))
# elif (task_progress == 1 or task_progress == 3 or task_progress == 4) and n_attempt_t0>=max_attempt_per_object:
# assert "you reach the maximum number of attempt the agent will move it for you"
# elif task_progress == 0 or task_progress == 2 and n_attempt_t0>=max_attempt_per_object:
# assert "you reach the maximum number of attempt the agent will move it for you"
#
# return next_state
#
# next_state = get_next_state(task_progress, game_state_t0, n_attempt_t0, max_attempt_per_object)
# query_answer_probs = []
# for t in next_state:
# vars_user_evidence = {"game_state_t0": game_state_t0,
# "attempt_t0": n_attempt_t0 - 1,
# "robot_assistance": selected_agent_assistance_action,
# "game_state_t1": t[0],
# "attempt_t1": t[1],
# }
#
# query_user_action_prob = bn_functions.infer_prob_from_state(bn_model_user_action,
# infer_variable=var_user_action_target_action,
# evidence_variables=vars_user_evidence)
# query_answer_probs.append(query_user_action_prob)
#
#
# #do the inference here
# #1. check given the current_state which are the possible states
# #2. for each of the possible states get the probability of user_action
# #3. select the state with the most higher action and execute it
# #4. return user_action
#
def generate_agent_assistance(preferred_assistance, agent_behaviour, n_game_state, n_attempt, alpha_action=0.1):
agent_policy = [[0 for j in range(n_attempt)] for i in range(n_game_state)]
previous_assistance = -1
def get_alternative_action(agent_assistance, previous_assistance, agent_behaviour, alpha_action):
agent_assistance_res = agent_assistance
if previous_assistance == agent_assistance:
if agent_behaviour == "challenge":
if random.random() > alpha_action:
agent_assistance_res = min(max(0, agent_assistance-1), 5)
else:
agent_assistance_res = min(max(0, agent_assistance), 5)
else:
if random.random() > alpha_action:
agent_assistance_res = min(max(0, agent_assistance + 1), 5)
else:
agent_assistance_res = min(max(0, agent_assistance), 5)
return agent_assistance_res
for gs in range(n_game_state):
for att in range(n_attempt):
if att == 0:
if random.random()>alpha_action:
agent_policy[gs][att] = preferred_assistance
previous_assistance = agent_policy[gs][att]
else:
if random.random()>0.5:
agent_policy[gs][att] = min(max(0, preferred_assistance-1),5)
previous_assistance = agent_policy[gs][att]
else:
agent_policy[gs][att] = min(max(0, preferred_assistance+1), 5)
previous_assistance = agent_policy[gs][att]
else:
if agent_behaviour == "challenge":
agent_policy[gs][att] = min(max(0, preferred_assistance-1), 5)
agent_policy[gs][att] = get_alternative_action(agent_policy[gs][att], previous_assistance, agent_behaviour, alpha_action)
previous_assistance = agent_policy[gs][att]
else:
agent_policy[gs][att] = min(max(0, preferred_assistance+1), 5)
agent_policy[gs][att] = get_alternative_action(agent_policy[gs][att], previous_assistance, agent_behaviour, alpha_action)
previous_assistance = agent_policy[gs][att]
return agent_policy
def compute_next_state(user_action, task_progress_counter, attempt_counter, correct_move_counter, def compute_next_state(user_action, task_progress_counter, attempt_counter, correct_move_counter,
wrong_move_counter, timeout_counter, max_attempt_counter, max_attempt_per_object wrong_move_counter, timeout_counter, max_attempt_counter, max_attempt_per_object
...@@ -80,11 +170,10 @@ def compute_next_state(user_action, task_progress_counter, attempt_counter, corr ...@@ -80,11 +170,10 @@ def compute_next_state(user_action, task_progress_counter, attempt_counter, corr
def simulation(bn_model_user_action, var_user_action_target_action, bn_model_user_react_time, var_user_react_time_target_action, def simulation(bn_model_user_action, var_user_action_target_action, bn_model_user_react_time, var_user_react_time_target_action,
user_memory_name, user_memory_value, user_attention_name, user_attention_value, user_memory_name, user_memory_value, user_attention_name, user_attention_value,
user_reactivity_name, user_reactivity_value, user_reactivity_name, user_reactivity_value,
task_progress_name, game_attempt_name, agent_assistance_name, agent_feedback_name, task_progress_t0_name, task_progress_t1_name, game_attempt_t0_name, game_attempt_t1_name,
bn_model_agent_assistance, var_agent_assistance_target_action, bn_model_agent_feedback, agent_assistance_name, agent_policy,
var_agent_feedback_target_action, agent_policy,
state_space, action_space, state_space, action_space,
epochs=50, task_complexity=5, max_attempt_per_object=4): epochs=50, task_complexity=5, max_attempt_per_object=4, alpha_learning=0):
''' '''
Args: Args:
...@@ -98,8 +187,6 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use ...@@ -98,8 +187,6 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use
#metrics we need, in order to compute afterwords the belief #metrics we need, in order to compute afterwords the belief
attempt_counter_per_action = [[0 for i in range(Attempt.counter.value)] for j in range(User_Action.counter.value)]
game_state_counter_per_action = [[0 for i in range(Game_State.counter.value)] for j in range(User_Action.counter.value)]
agent_feedback_per_action = [[0 for i in range(Agent_Feedback.counter.value)] for j in range(User_Action.counter.value)] agent_feedback_per_action = [[0 for i in range(Agent_Feedback.counter.value)] for j in range(User_Action.counter.value)]
agent_assistance_per_action = [[0 for i in range(Agent_Assistance.counter.value)] for j in range(User_Action.counter.value)] agent_assistance_per_action = [[0 for i in range(Agent_Assistance.counter.value)] for j in range(User_Action.counter.value)]
...@@ -108,12 +195,21 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use ...@@ -108,12 +195,21 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use
agent_feedback_per_react_time = [[0 for i in range(Agent_Feedback.counter.value)] for j in range(User_React_time.counter.value)] agent_feedback_per_react_time = [[0 for i in range(Agent_Feedback.counter.value)] for j in range(User_React_time.counter.value)]
agent_assistance_per_react_time = [[0 for i in range(Agent_Assistance.counter.value)] for j in range(User_React_time.counter.value)] agent_assistance_per_react_time = [[0 for i in range(Agent_Assistance.counter.value)] for j in range(User_React_time.counter.value)]
game_state_counter_per_agent_assistance = [[0 for i in range(Game_State.counter.value)] for j in range(Agent_Assistance.counter.value)]
attempt_counter_per_agent_assistance = [[0 for i in range(Attempt.counter.value)] for j in range(Agent_Assistance.counter.value)]
game_state_counter_per_agent_feedback = [[0 for i in range(Game_State.counter.value)] for j in range(Agent_Feedback.counter.value)] game_state_counter_per_agent_feedback = [[0 for i in range(Game_State.counter.value)] for j in range(Agent_Feedback.counter.value)]
attempt_counter_per_agent_feedback = [[0 for i in range(Attempt.counter.value)] for j in range(Agent_Feedback.counter.value)] attempt_counter_per_agent_feedback = [[0 for i in range(Attempt.counter.value)] for j in range(Agent_Feedback.counter.value)]
game_state_counter_per_agent_assistance = [[0 for i in range(Game_State.counter.value)] for j in
range(Agent_Assistance.counter.value)]
attempt_counter_per_agent_assistance = [[0 for i in range(Attempt.counter.value)] for j in
range(Agent_Assistance.counter.value)]
user_action_per_game_state_attempt_counter_agent_assistance = [[[[0 for i in range(User_Action.counter.value)] for l in range(Game_State.counter.value)] for j in
range(Attempt.counter.value)] for k in range(Agent_Assistance.counter.value)]
user_action_per_agent_assistance = [[0 for i in range(User_Action.counter.value)] for j in
range(Agent_Assistance.counter.value)]
attempt_counter_per_user_action = [[0 for i in range(Attempt.counter.value)] for j in range(User_Action.counter.value)]
game_state_counter_per_user_action = [[0 for i in range(Game_State.counter.value)] for j in
range(User_Action.counter.value)]
#output variables: #output variables:
n_correct_per_episode = [0]*epochs n_correct_per_episode = [0]*epochs
...@@ -131,6 +227,10 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use ...@@ -131,6 +227,10 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use
ep = Episode() ep = Episode()
for e in range(epochs): for e in range(epochs):
print("##########################################################")
print("EPISODE ",e)
print("##########################################################")
'''Simulation framework''' '''Simulation framework'''
#counters #counters
game_state_counter = 0 game_state_counter = 0
...@@ -142,10 +242,11 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use ...@@ -142,10 +242,11 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use
max_attempt_counter = 0 max_attempt_counter = 0
#The following variables are used to update the BN at the end of the episode #The following variables are used to update the BN at the end of the episode
user_action_dynamic_variables = {'attempt': attempt_counter_per_action, user_action_dynamic_variables = {
'game_state': game_state_counter_per_action, 'attempt_t1': attempt_counter_per_user_action,
'agent_assistance': agent_assistance_per_action, 'game_state_t1': game_state_counter_per_user_action,
'agent_feedback': agent_feedback_per_action} 'user_action': user_action_per_game_state_attempt_counter_agent_assistance
}
user_react_time_dynamic_variables = {'attempt': attempt_counter_per_react_time, user_react_time_dynamic_variables = {'attempt': attempt_counter_per_react_time,
'game_state': game_state_counter_per_react_time, 'game_state': game_state_counter_per_react_time,
...@@ -167,38 +268,8 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use ...@@ -167,38 +268,8 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use
current_state = (game_state_counter, attempt_counter, selected_user_action) current_state = (game_state_counter, attempt_counter, selected_user_action)
if type(agent_policy) is not np.ndarray: selected_agent_assistance_action = agent_policy[game_state_counter][attempt_counter-1]#random.randint(0,5)
##################QUERY FOR THE ROBOT ASSISTANCE AND FEEDBACK################## selected_agent_feedback_action = 0#random.randint(0,1)
vars_agent_evidence = {
user_reactivity_name: user_reactivity_value,
user_memory_name: user_memory_value,
task_progress_name: game_state_counter,
game_attempt_name: attempt_counter-1,
}
query_agent_assistance_prob = bn_functions.infer_prob_from_state(bn_model_agent_assistance,
infer_variable=var_agent_assistance_target_action,
evidence_variables=vars_agent_evidence)
if bn_model_agent_feedback != None:
query_agent_feedback_prob = bn_functions.infer_prob_from_state(bn_model_agent_feedback,
infer_variable=var_agent_feedback_target_action,
evidence_variables=vars_agent_evidence)
selected_agent_feedback_action = bn_functions.get_stochastic_action(query_agent_feedback_prob.values)
else:
selected_agent_feedback_action = 0
selected_agent_assistance_action = bn_functions.get_stochastic_action(query_agent_assistance_prob.values)
else:
idx_state = ep.state_from_point_to_index(state_space, current_state)
if agent_policy[idx_state]>=Agent_Assistance.counter.value:
selected_agent_assistance_action = agent_policy[idx_state]-Agent_Assistance.counter.value
selected_agent_feedback_action = 1
else:
selected_agent_assistance_action = agent_policy[idx_state]
selected_agent_feedback_action = 0
n_feedback_per_episode[e][selected_agent_feedback_action] += 1
#counters for plots #counters for plots
n_assistance_lev_per_episode[e][selected_agent_assistance_action] += 1 n_assistance_lev_per_episode[e][selected_agent_assistance_action] += 1
...@@ -211,38 +282,39 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use ...@@ -211,38 +282,39 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use
#compare the real user with the estimated Persona and returns a user action (0, 1a, 2) #compare the real user with the estimated Persona and returns a user action (0, 1a, 2)
#return the user action in this state based on the Persona profile #return the user action in this state based on the Persona profile
vars_user_evidence = {user_attention_name: user_attention_value, vars_user_evidence = { task_progress_t0_name: game_state_counter,
user_reactivity_name: user_reactivity_value, game_attempt_t0_name: attempt_counter - 1,
user_memory_name: user_memory_value, task_progress_t1_name: game_state_counter,
task_progress_name: game_state_counter, game_attempt_t1_name: attempt_counter - 1,
game_attempt_name: attempt_counter-1, agent_assistance_name: selected_agent_assistance_action,
agent_assistance_name: selected_agent_assistance_action, }
agent_feedback_name: selected_agent_feedback_action
}
query_user_action_prob = bn_functions.infer_prob_from_state(bn_model_user_action, query_user_action_prob = bn_functions.infer_prob_from_state(bn_model_user_action,
infer_variable=var_user_action_target_action, infer_variable=var_user_action_target_action,
evidence_variables=vars_user_evidence) evidence_variables=vars_user_evidence)
query_user_react_time_prob = bn_functions.infer_prob_from_state(bn_model_user_react_time, # query_user_react_time_prob = bn_functions.infer_prob_from_state(bn_model_user_react_time,
infer_variable=var_user_react_time_target_action, # infer_variable=var_user_react_time_target_action,
evidence_variables=vars_user_evidence) # evidence_variables=vars_user_evidence)
#
#
selected_user_action = bn_functions.get_stochastic_action(query_user_action_prob.values) selected_user_action = bn_functions.get_stochastic_action(query_user_action_prob.values)
selected_user_react_time = bn_functions.get_stochastic_action(query_user_react_time_prob.values) # selected_user_react_time = bn_functions.get_stochastic_action(query_user_react_time_prob.values)
# counters for plots # counters for plots
n_react_time_per_episode[e][selected_user_react_time] += 1 # n_react_time_per_episode[e][selected_user_react_time] += 1
#updates counters for user action #updates counters for user action
agent_assistance_per_action[selected_user_action][selected_agent_assistance_action] += 1
attempt_counter_per_action[selected_user_action][attempt_counter-1] += 1 user_action_per_game_state_attempt_counter_agent_assistance[selected_agent_assistance_action][attempt_counter-1][game_state_counter][selected_user_action] += 1
game_state_counter_per_action[selected_user_action][game_state_counter] += 1 attempt_counter_per_user_action[selected_user_action][attempt_counter-1] += 1
agent_feedback_per_action[selected_user_action][selected_agent_feedback_action] += 1 game_state_counter_per_user_action[selected_user_action][game_state_counter] += 1
user_action_per_agent_assistance[selected_agent_assistance_action][selected_user_action] += 1
#update counter for user react time #update counter for user react time
agent_assistance_per_react_time[selected_user_react_time][selected_agent_assistance_action] += 1 # agent_assistance_per_react_time[selected_user_react_time][selected_agent_assistance_action] += 1
attempt_counter_per_react_time[selected_user_react_time][attempt_counter-1] += 1 # attempt_counter_per_react_time[selected_user_react_time][attempt_counter-1] += 1
game_state_counter_per_react_time[selected_user_react_time][game_state_counter] += 1 # game_state_counter_per_react_time[selected_user_react_time][game_state_counter] += 1
agent_feedback_per_react_time[selected_user_react_time][selected_agent_feedback_action] += 1 # agent_feedback_per_react_time[selected_user_react_time][selected_agent_feedback_action] += 1
#update counter for agent feedback #update counter for agent feedback
game_state_counter_per_agent_feedback[selected_agent_feedback_action][game_state_counter] += 1 game_state_counter_per_agent_feedback[selected_agent_feedback_action][game_state_counter] += 1
attempt_counter_per_agent_feedback[selected_agent_feedback_action][attempt_counter-1] += 1 attempt_counter_per_agent_feedback[selected_agent_feedback_action][attempt_counter-1] += 1
...@@ -269,8 +341,6 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use ...@@ -269,8 +341,6 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use
timeout_counter, max_attempt_counter, timeout_counter, max_attempt_counter,
max_attempt_per_object) max_attempt_per_object)
# store the (state, action, next_state) # store the (state, action, next_state)
episode.append((ep.state_from_point_to_index(state_space, current_state), episode.append((ep.state_from_point_to_index(state_space, current_state),
ep.state_from_point_to_index(action_space, current_agent_action), ep.state_from_point_to_index(action_space, current_agent_action),
...@@ -286,22 +356,25 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use ...@@ -286,22 +356,25 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use
episodes.append(Episode(episode)) episodes.append(Episode(episode))
#update user models #update user models
bn_model_user_action = bn_functions.update_cpds_tables(bn_model_user_action, user_action_dynamic_variables) bn_model_user_action = bn_functions.update_cpds_tables(bn_model_user_action, user_action_dynamic_variables, alpha_learning)
bn_model_user_react_time = bn_functions.update_cpds_tables(bn_model_user_react_time, user_react_time_dynamic_variables) bn_model_user_react_time = bn_functions.update_cpds_tables(bn_model_user_react_time, user_react_time_dynamic_variables)
#update agent models #update agent models
bn_model_agent_assistance = bn_functions.update_cpds_tables(bn_model_agent_assistance, agent_assistance_dynamic_variables)
if bn_model_agent_feedback !=None: print("user_given_game_attempt:", bn_model_user_action['model'].cpds[0].values)
bn_model_agent_feedback = bn_functions.update_cpds_tables(bn_model_agent_feedback, agent_feedback_dynamic_variables) print("user_given_robot:", bn_model_user_action['model'].cpds[5].values)
print("game_user:", bn_model_user_action['model'].cpds[3].values)
print("attempt_user:", bn_model_user_action['model'].cpds[2].values)
#reset counter #reset counter
agent_assistance_per_action = [[0 for i in range(Agent_Assistance.counter.value)] for j in user_action_per_game_state_attempt_counter_agent_assistance = [[[[0 for i in range(User_Action.counter.value)]
range(User_Action.counter.value)] for l in range(Game_State.counter.value)] for j in
agent_feedback_per_action = [[0 for i in range(Agent_Feedback.counter.value)] for j in range(Attempt.counter.value)] for k in range(Agent_Assistance.counter.value)]
range(User_Action.counter.value)] user_action_per_agent_assistance = [[0 for i in range(User_Action.counter.value)] for j in
game_state_counter_per_action = [[0 for i in range(Game_State.counter.value)] for j in range(Agent_Assistance.counter.value)]
range(User_Action.counter.value)] attempt_counter_per_user_action = [[0 for i in range(Attempt.counter.value)] for j in
attempt_counter_per_action = [[0 for i in range(Attempt.counter.value)] for j in range(User_Action.counter.value)]
range(User_Action.counter.value)] game_state_counter_per_user_action = [[0 for i in range(Game_State.counter.value)] for j in
range(User_Action.counter.value)]
attempt_counter_per_react_time = [[0 for i in range(Attempt.counter.value)] for j in attempt_counter_per_react_time = [[0 for i in range(Attempt.counter.value)] for j in
range(User_React_time.counter.value)] range(User_React_time.counter.value)]
...@@ -312,15 +385,6 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use ...@@ -312,15 +385,6 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use
agent_assistance_per_react_time = [[0 for i in range(Agent_Assistance.counter.value)] for j in agent_assistance_per_react_time = [[0 for i in range(Agent_Assistance.counter.value)] for j in
range(User_React_time.counter.value)] range(User_React_time.counter.value)]
game_state_counter_per_agent_assistance = [[0 for i in range(Game_State.counter.value)] for j in
range(Agent_Assistance.counter.value)]
attempt_counter_per_agent_assistance = [[0 for i in range(Attempt.counter.value)] for j in
range(Agent_Assistance.counter.value)]
game_state_counter_per_agent_feedback = [[0 for i in range(Game_State.counter.value)] for j in
range(Agent_Feedback.counter.value)]
attempt_counter_per_agent_feedback = [[0 for i in range(Attempt.counter.value)] for j in
range(Agent_Feedback.counter.value)]
#for plots #for plots
n_correct_per_episode[e] = correct_move_counter n_correct_per_episode[e] = correct_move_counter
...@@ -343,126 +407,58 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use ...@@ -343,126 +407,58 @@ def simulation(bn_model_user_action, var_user_action_target_action, bn_model_use
############################################################################# #############################################################################
############################################################################# #############################################################################
# #SIMULATION PARAMS
# epochs = 100
#
# #initialise the agent
# bn_model_caregiver_assistance = bnlearn.import_DAG('bn_agent_model/agent_assistive_model.bif')
# bn_model_caregiver_feedback = bnlearn.import_DAG('bn_agent_model/agent_feedback_model.bif')
# bn_model_user_action = bnlearn.import_DAG('bn_persona_model/user_action_model.bif')
# bn_model_user_react_time = bnlearn.import_DAG('bn_persona_model/user_react_time_model.bif')
# bn_model_other_user_action = None#bnlearn.import_DAG('bn_persona_model/other_user_action_model.bif')
# bn_model_other_user_react_time = None#bnlearn.import_DAG('bn_persona_model/other_user_react_time_model.bif')
#
# #initialise memory, attention and reactivity varibles
# persona_memory = 0; persona_attention = 0; persona_reactivity = 1;
# #initialise memory, attention and reactivity varibles
# other_user_memory = 2; other_user_attention = 2; other_user_reactivity = 2;
#
# #define state space struct for the irl algorithm
# attempt = [i for i in range(1, Attempt.counter.value+1)]
# #+1a (3,_,_) absorbing state
# game_state = [i for i in range(0, Game_State.counter.value+1)]
# user_action = [i for i in range(-1, User_Action.counter.value-1)]
# state_space = (game_state, attempt, user_action)
# states_space_list = list(itertools.product(*state_space))
# agent_assistance_action = [i for i in range(Agent_Assistance.counter.value)]
# agent_feedback_action = [i for i in range(Agent_Feedback.counter.value)]
# action_space = (agent_assistance_action, agent_feedback_action)
# action_space_list = list(itertools.product(*action_space))
#
# ##############BEFORE RUNNING THE SIMULATION UPDATE THE BELIEF IF YOU HAVE DATA####################
# log_directory = "/home/pal/carf_ws/src/carf/caregiver_in_the_loop/log/1/0"
# if os.path.exists(log_directory):
# bn_belief_user_action_file = log_directory+"/bn_belief_user_action.pkl"
# bn_belief_user_react_time_file = log_directory+"/bn_belief_user_react_time.pkl"
# bn_belief_caregiver_assistance_file = log_directory+"/bn_belief_caregiver_assistive_action.pkl"
# bn_belief_caregiver_feedback_file = log_directory+"/bn_belief_caregiver_feedback_action.pkl"
#
# bn_belief_user_action = utils.read_user_statistics_from_pickle(bn_belief_user_action_file)
# bn_belief_user_react_time = utils.read_user_statistics_from_pickle(bn_belief_user_react_time_file)
# bn_belief_caregiver_assistance = utils.read_user_statistics_from_pickle(bn_belief_caregiver_assistance_file)
# bn_belief_caregiver_feedback = utils.read_user_statistics_from_pickle(bn_belief_caregiver_feedback_file)
# bn_model_user_action = bn_functions.update_cpds_tables(bn_model=bn_model_user_action, variables_tables=bn_belief_user_action)
# bn_model_user_react_time = bn_functions.update_cpds_tables(bn_model=bn_model_user_react_time, variables_tables=bn_belief_user_react_time)
# bn_model_caregiver_assistance = bn_functions.update_cpds_tables(bn_model=bn_model_caregiver_assistance, variables_tables=bn_belief_caregiver_assistance)
# bn_model_caregiver_feedback = bn_functions.update_cpds_tables(bn_model=bn_model_caregiver_feedback, variables_tables=bn_belief_caregiver_feedback)
#
# else:
# assert("You're not using the user information")
# question = input("Are you sure you don't want to load user's belief information?")
#
# game_performance_per_episode, react_time_per_episode, agent_assistance_per_episode, agent_feedback_per_episode, generated_episodes = \
# simulation(bn_model_user_action=bn_model_user_action, var_user_action_target_action=['user_action'],
# bn_model_user_react_time=bn_model_user_react_time,
# var_user_react_time_target_action=['user_react_time'],
# user_memory_name="memory", user_memory_value=persona_memory,
# user_attention_name="attention", user_attention_value=persona_attention,
# user_reactivity_name="reactivity", user_reactivity_value=persona_reactivity,
# task_progress_name="game_state", game_attempt_name="attempt",
# agent_assistance_name="agent_assistance", agent_feedback_name="agent_feedback",
# bn_model_agent_assistance=bn_model_caregiver_assistance,
# var_agent_assistance_target_action=["agent_assistance"],
# bn_model_agent_feedback=bn_model_caregiver_feedback, var_agent_feedback_target_action=["agent_feedback"],
# bn_model_other_user_action=bn_model_other_user_action,
# var_other_user_action_target_action=['user_action'],
# bn_model_other_user_react_time=bn_model_other_user_react_time,
# var_other_user_target_react_time_action=["user_react_time"], other_user_memory_name="memory",
# other_user_memory_value=other_user_memory, other_user_attention_name="attention",
# other_user_attention_value=other_user_attention, other_user_reactivity_name="reactivity",
# other_user_reactivity_value=other_user_reactivity,
# state_space=states_space_list, action_space=action_space_list,
# epochs=epochs, task_complexity=5, max_attempt_per_object=4)
#
#
#
# plot_game_performance_path = ""
# plot_agent_assistance_path = ""
# episodes_path = "episodes.npy"
#
# if bn_model_other_user_action != None:
# plot_game_performance_path = "game_performance_"+"_epoch_"+str(epochs)+"_real_user_memory_"+str(real_user_memory)+"_real_user_attention_"+str(real_user_attention)+"_real_user_reactivity_"+str(real_user_reactivity)+".jpg"
# plot_agent_assistance_path = "agent_assistance_"+"epoch_"+str(epochs)+"_real_user_memory_"+str(real_user_memory)+"_real_user_attention_"+str(real_user_attention)+"_real_user_reactivity_"+str(real_user_reactivity)+".jpg"
# plot_agent_feedback_path = "agent_feedback_"+"epoch_"+str(epochs)+"_real_user_memory_"+str(real_user_memory)+"_real_user_attention_"+str(real_user_attention)+"_real_user_reactivity_"+str(real_user_reactivity)+".jpg"
#
# else:
# plot_game_performance_path = "game_performance_"+"epoch_" + str(epochs) + "_persona_memory_" + str(persona_memory) + "_persona_attention_" + str(persona_attention) + "_persona_reactivity_" + str(persona_reactivity) + ".jpg"
# plot_agent_assistance_path = "agent_assistance_"+"epoch_"+str(epochs)+"_persona_memory_"+str(persona_memory)+"_persona_attention_"+str(persona_attention)+"_persona_reactivity_"+str(persona_reactivity)+".jpg"
# plot_agent_feedback_path = "agent_feedback_"+"epoch_"+str(epochs)+"_persona_memory_"+str(persona_memory)+"_persona_attention_"+str(persona_attention)+"_persona_reactivity_"+str(persona_reactivity)+".jpg"
#
# dir_name = input("Please insert the name of the directory:")
# full_path = os.getcwd()+"/results/"+dir_name+"/"
# if not os.path.exists(full_path):
# os.mkdir(full_path)
# print("Directory ", full_path, " created.")
# else:
# dir_name = input("The directory already exist please insert a new name:")
# print("Directory ", full_path, " created.")
# if os.path.exists(full_path):
# assert("Directory already exists ... start again")
# exit(0)
#
# with open(full_path+episodes_path, "ab") as f:
# np.save(full_path+episodes_path, generated_episodes)
# f.close()
#
#
# utils.plot2D_game_performance(full_path+plot_game_performance_path, epochs, game_performance_per_episode)
# utils.plot2D_assistance(full_path+plot_agent_assistance_path, epochs, agent_assistance_per_episode)
# utils.plot2D_feedback(full_path+plot_agent_feedback_path, epochs, agent_feedback_per_episode)
'''
With the current simulator we can generate a list of episodes
the episodes will be used to generate the trans probabilities and as input to the IRL algo
'''
#TODO
# - include reaction time as output
# - average mistakes, average timeout, average assistance, average_react_time
# - include real time episodes into the simulation:
# - counters for agent_assistance, agent_feedback, attempt, game_state, attention and reactivity
# - using the function update probability to generate the new user model and use it as input to the simulator
agent_policy = generate_agent_assistance(preferred_assistance=2, agent_behaviour="help", n_game_state=Game_State.counter.value, n_attempt=Attempt.counter.value, alpha_action=0.5)
# SIMULATION PARAMS
epochs = 20
scaling_factor = 1
# initialise the agent
bn_model_caregiver_assistance = bnlearn.import_DAG('/home/pal/Documents/Framework/bn_generative_model/bn_agent_model/agent_assistive_model.bif')
bn_model_caregiver_feedback = None#bnlearn.import_DAG('/home/pal/Documents/Framework/bn_generative_model/bn_agent_model/agent_feedback_model.bif')
bn_model_user_action = bnlearn.import_DAG('/home/pal/Documents/Framework/bn_generative_model/bn_persona_model/persona_model_test.bif')
bn_model_user_react_time = bnlearn.import_DAG('/home/pal/Documents/Framework/bn_generative_model/bn_persona_model/user_react_time_model.bif')
# initialise memory, attention and reactivity variables
persona_memory = 0;
persona_attention = 0;
persona_reactivity = 1;
# define state space struct for the irl algorithm
episode_instance = Episode()
# DEFINITION OF THE MDP
# define state space struct for the irl algorithm
attempt = [i for i in range(1, Attempt.counter.value + 1)]
# +1 (3,_,_) absorbing state
game_state = [i for i in range(0, Game_State.counter.value + 1)]
user_action = [i for i in range(-1, User_Action.counter.value - 1)]
state_space = (game_state, attempt, user_action)
states_space_list = list(itertools.product(*state_space))
state_space_index = [episode_instance.state_from_point_to_index(states_space_list, s) for s in states_space_list]
agent_assistance_action = [i for i in range(Agent_Assistance.counter.value)]
agent_feedback_action = [i for i in range(Agent_Feedback.counter.value)]
action_space = (agent_feedback_action, agent_assistance_action)
action_space_list = list(itertools.product(*action_space))
action_space_index = [episode_instance.state_from_point_to_index(action_space_list, a) for a in action_space_list]
terminal_state = [(Game_State.counter.value, i, user_action[j]) for i in range(1, Attempt.counter.value + 1) for j in
range(len(user_action))]
initial_state = (1, 1, 0)
#1. RUN THE SIMULATION WITH THE PARAMS SET BY THE CAREGIVER
game_performance_per_episode, react_time_per_episode, agent_assistance_per_episode, agent_feedback_per_episode, episodes_list = \
simulation(bn_model_user_action=bn_model_user_action, var_user_action_target_action=['user_action'],
bn_model_user_react_time=bn_model_user_react_time,
var_user_react_time_target_action=['user_react_time'],
user_memory_name="memory", user_memory_value=persona_memory,
user_attention_name="attention", user_attention_value=persona_attention,
user_reactivity_name="reactivity", user_reactivity_value=persona_reactivity,
task_progress_t0_name="game_state_t0", task_progress_t1_name="game_state_t1",
game_attempt_t0_name="attempt_t0", game_attempt_t1_name="attempt_t1",
agent_assistance_name="agent_assistance", agent_policy=agent_policy,
state_space=states_space_list, action_space=action_space_list,
epochs=epochs, task_complexity=5, max_attempt_per_object=4, alpha_learning=0.1)
utils.plot2D_game_performance("/home/pal/Documents/Framework/bn_generative_model/results/user_performance.png", epochs, scaling_factor, game_performance_per_episode)
utils.plot2D_assistance("/home/pal/Documents/Framework/bn_generative_model/results/agent_assistance.png", epochs, scaling_factor, agent_assistance_per_episode)
...@@ -8,14 +8,45 @@ DAG = bn.import_DAG('bn_persona_model/persona_model_test.bif') ...@@ -8,14 +8,45 @@ DAG = bn.import_DAG('bn_persona_model/persona_model_test.bif')
#DAGnew = bn.parameter_learning.fit(model, df, methodtype="bayes") #DAGnew = bn.parameter_learning.fit(model, df, methodtype="bayes")
#bn.print_CPD(DAGnew) #bn.print_CPD(DAGnew)
q1 = bn.inference.fit(DAG, variables=['user_action'], evidence={ q1 = bn.inference.fit(DAG, variables=['user_action'], evidence={
'game_state_t0': 1, 'game_state_t0': 0,
'attempt_t0':0,
'robot_assistance':5,
'game_state_t1': 1,
'attempt_t0':1, 'attempt_t0':1,
'game_state_t1': 0,
'attempt_t1':2,
'agent_assistance':0,
}) })
print(q1.variables)
print(q1.values) print(q1.values)
# robot_assistance = [0, 1, 2, 3, 4, 5]
# attempt_t0 = [0, 1, 2, 3]
# game_state_t0 = [0, 1, 2]
# attempt_t1 = [0]
# game_state_t1 = [0, 1, 2]
#
# query_result = [[[0 for j in range(len(attempt_t0))] for i in range(len(robot_assistance))] for k in range(len(game_state_t0))]
# for k in range(len(game_state_t0)):
# for i in range(len(robot_assistance)):
# for j in range(len(attempt_t0)-1):
# if j == 0:
# query = bn.inference.fit(DAG, variables=['user_action'],
# evidence={'game_state_t0': k,
# 'attempt_t0': j,
# 'agent_assistance': i,
# 'game_state_t1': k,
# 'attempt_t1': j})
# query_result[k][i][j] = query.values
# else:
# query = bn.inference.fit(DAG, variables=['user_action'],
# evidence={'game_state_t0': k,
# 'attempt_t0': j,
# 'agent_assistance': i,
# 'game_state_t1': k,
# 'attempt_t1': j + 1})
# query_result[k][i][j] = query.values
# for k in range(len(game_state_t0)):
# for i in range(len(robot_assistance)):
# for j in range(len(attempt_t0)):
# if j == 0:
# print("game_state:",k, "attempt_from:", j," attempt_to:",j, " robot_ass:",i, " prob:", query_result[k][i][j])
# else:
# print("game_state:", k, "attempt_from:", j, " attempt_to:", j+1, " robot_ass:", i, " prob:",
# query_result[k][i][j])
...@@ -43,6 +43,7 @@ def plot2D_assistance(save_path, n_episodes, scaling_factor=1, *y): ...@@ -43,6 +43,7 @@ def plot2D_assistance(save_path, n_episodes, scaling_factor=1, *y):
lev_2 = list(map(lambda x:x[2], y[0]))[1::scaling_factor] lev_2 = list(map(lambda x:x[2], y[0]))[1::scaling_factor]
lev_3 = list(map(lambda x:x[3], y[0]))[1::scaling_factor] lev_3 = list(map(lambda x:x[3], y[0]))[1::scaling_factor]
lev_4 = list(map(lambda x:x[4], y[0]))[1::scaling_factor] lev_4 = list(map(lambda x:x[4], y[0]))[1::scaling_factor]
lev_5 = list(map(lambda x:x[5], y[0]))[1::scaling_factor]
# plot bars # plot bars
plt.figure(figsize=(10, 7)) plt.figure(figsize=(10, 7))
...@@ -54,7 +55,8 @@ def plot2D_assistance(save_path, n_episodes, scaling_factor=1, *y): ...@@ -54,7 +55,8 @@ def plot2D_assistance(save_path, n_episodes, scaling_factor=1, *y):
width=barWidth, label='lev_3') width=barWidth, label='lev_3')
plt.bar(r, lev_4, bottom=np.array(lev_0) + np.array(lev_1)+ np.array(lev_2)+ np.array(lev_3), edgecolor='white', plt.bar(r, lev_4, bottom=np.array(lev_0) + np.array(lev_1)+ np.array(lev_2)+ np.array(lev_3), edgecolor='white',
width=barWidth, label='lev_4') width=barWidth, label='lev_4')
plt.bar(r, lev_5, bottom=np.array(lev_0) + np.array(lev_1) + np.array(lev_2) + np.array(lev_3)+ np.array(lev_4), edgecolor='white',
width=barWidth, label='lev_5')
plt.legend() plt.legend()
# Custom X axis # Custom X axis
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment