Skip to content
Snippets Groups Projects
Commit bef7c58c authored by Antonio Andriella's avatar Antonio Andriella
Browse files

Working version of new model

parent 4aaa0934
No related branches found
No related tags found
No related merge requests found
...@@ -44,12 +44,35 @@ def compute_prob(cpds_table): ...@@ -44,12 +44,35 @@ def compute_prob(cpds_table):
This function checks if any This function checks if any
''' '''
for val in range(len(cpds_table)): cpds_table_array = np.array(cpds_table)
cpds_table[val] = list(map(lambda x: x / (sum(cpds_table[val])+0.00001), cpds_table[val])) cpds_table_array_len = cpds_table_array.shape.__len__()
cpds_table[val] = check_zero_occurrences(cpds_table[val])
if cpds_table_array_len == 4:
# attempt
for elem1 in range(cpds_table_array.shape[0]):
# game_state
for elem2 in range(cpds_table_array.shape[1]):
#assistance
for elem3 in range(cpds_table_array.shape[2]):
cpds_table[elem1][elem2][elem3] = list(
map(lambda x: x / (sum(cpds_table[elem1][elem2][elem3]) + 0.00001), cpds_table[elem1][elem2][elem3]))
cpds_table[elem1][elem2][elem3] = check_zero_occurrences(cpds_table[elem1][elem2][elem3])
elif cpds_table_array_len ==3:
#attempt
for elem1 in range(cpds_table_array.shape[0]):
#game_state
for elem2 in range(cpds_table_array.shape[1]):
cpds_table[elem1][elem2] = list(map(lambda x: x / (sum(cpds_table[elem1][elem2]) + 0.00001), cpds_table[elem1][elem2]))
cpds_table[elem1][elem2] = check_zero_occurrences(cpds_table[elem1][elem2])
else:
for val in range(len(cpds_table)):
cpds_table[val] = list(map(lambda x: x / (sum(cpds_table[val])+0.00001), cpds_table[val]))
cpds_table[val] = check_zero_occurrences(cpds_table[val])
return cpds_table return cpds_table
def average_prob(ref_cpds_table, current_cpds_table): def average_prob(ref_cpds_table, current_cpds_table, alpha):
''' '''
Args: Args:
ref_cpds_table: table from bnlearn ref_cpds_table: table from bnlearn
...@@ -58,12 +81,13 @@ def average_prob(ref_cpds_table, current_cpds_table): ...@@ -58,12 +81,13 @@ def average_prob(ref_cpds_table, current_cpds_table):
avg from both tables avg from both tables
''' '''
res_cpds_table = ref_cpds_table.copy() res_cpds_table = ref_cpds_table.copy()
current_cpds_table_np_array = np.array(current_cpds_table)
for elem1 in range(len(ref_cpds_table)): for elem1 in range(len(ref_cpds_table)):
for elem2 in range(len(ref_cpds_table[0])): for elem2 in range(len(ref_cpds_table[0])):
res_cpds_table[elem1][elem2] = (ref_cpds_table[elem1][elem2]+current_cpds_table[elem1][elem2])/2 res_cpds_table[elem1][elem2] = (ref_cpds_table[elem1][elem2]*(1-alpha))+(current_cpds_table_np_array[elem1][elem2]*alpha)
return res_cpds_table return res_cpds_table
def update_cpds_tables(bn_model, variables_tables): def update_cpds_tables(bn_model, variables_tables, alpha=0.1):
''' '''
This function updates the bn model with the variables_tables provided in input This function updates the bn model with the variables_tables provided in input
Args: Args:
...@@ -80,7 +104,7 @@ def update_cpds_tables(bn_model, variables_tables): ...@@ -80,7 +104,7 @@ def update_cpds_tables(bn_model, variables_tables):
cpds_table_from_counter = compute_prob(val) cpds_table_from_counter = compute_prob(val)
updated_prob = average_prob( updated_prob = average_prob(
np.transpose(cpds_table), np.transpose(cpds_table),
cpds_table_from_counter) cpds_table_from_counter, alpha)
bn_model['model'].cpds[index].values = np.transpose(updated_prob) bn_model['model'].cpds[index].values = np.transpose(updated_prob)
return bn_model return bn_model
......
...@@ -3,7 +3,7 @@ network persona_model { ...@@ -3,7 +3,7 @@ network persona_model {
%VARIABLES DEFINITION %VARIABLES DEFINITION
variable robot_assistance { variable agent_assistance {
type discrete [ 6 ] { lev_0, lev_1, lev_2, lev_3, lev_4, lev_5 }; type discrete [ 6 ] { lev_0, lev_1, lev_2, lev_3, lev_4, lev_5 };
} }
variable attempt_t0 { variable attempt_t0 {
...@@ -13,18 +13,20 @@ variable game_state_t0 { ...@@ -13,18 +13,20 @@ variable game_state_t0 {
type discrete [ 3 ] { beg, mid, end }; type discrete [ 3 ] { beg, mid, end };
} }
variable attempt_t1 { variable attempt_t1 {
type discrete [ 4 ] { att_1, att_2, att_3, att_4 }; type discrete [ 1 ] { att_1, att_2, att_3, att_4};
} }
variable game_state_t1 { variable game_state_t1 {
type discrete [ 3 ] { beg, mid, end }; type discrete [ 3 ] { beg, mid, end };
} }
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
variable user_action { variable user_action {
type discrete [ 3 ] { correct, wrong, timeout }; type discrete [ 3 ] { correct, wrong, timeout };
} }
variable agent_feedback {
type discrete [ 2 ] { no, yes };
}
%INDIVIDUAL PROBABILITIES DEFINITION %INDIVIDUAL PROBABILITIES DEFINITION
probability ( robot_assistance ) { probability ( agent_assistance ) {
table 0.17, 0.16, 0.16, 0.17, 0.17, 0.17; table 0.17, 0.16, 0.16, 0.17, 0.17, 0.17;
} }
probability ( game_state_t0 ) { probability ( game_state_t0 ) {
...@@ -39,124 +41,104 @@ probability ( game_state_t1 ) { ...@@ -39,124 +41,104 @@ probability ( game_state_t1 ) {
probability ( attempt_t1 ) { probability ( attempt_t1 ) {
table 0.25, 0.25, 0.25, 0.25; table 0.25, 0.25, 0.25, 0.25;
} }
probability ( user_action ) { probability ( user_action ) {
table 0.33, 0.33, 0.34; table 0.33, 0.33, 0.34;
} }
probability ( agent_feedback ) {
probability (robot_assistance | game_state_t0, attempt_t0){ table 0.5, 0.5;
(beg, att_1) 0.0, 0.2, 0.2, 0.2, 0.2, 0.2; }
(beg, att_2) 0.05, 0.15, 0.2, 0.2, 0.2, 0.2; probability(agent_assistance | agent_feedback) {
(beg, att_3) 0.1, 0.2, 0.2, 0.2, 0.2, 0.1; (yes) 0.4, 0.3, 0.2, 0.1, 0.0, 0.0
(beg, att_4) 0.15, 0.05, 0.2, 0.2, 0.2, 0.2; (no) 0.0, 0.0, 0.1, 0.2, 0.3, 0.4
(mid, att_1) 0.0, 0.2, 0.2, 0.2, 0.2, 0.2;
(mid, att_2) 0.05, 0.15, 0.2, 0.2, 0.2, 0.2;
(mid, att_3) 0.1, 0.2, 0.2, 0.2, 0.2, 0.1;
(mid, att_4) 0.15, 0.05, 0.2, 0.2, 0.2, 0.2;
(end, att_1) 0.0, 0.2, 0.2, 0.2, 0.2, 0.2;
(end, att_2) 0.05, 0.15, 0.2, 0.2, 0.2, 0.2;
(end, att_3) 0.1, 0.2, 0.2, 0.2, 0.2, 0.1;
(end, att_4) 0.15, 0.05, 0.2, 0.2, 0.2, 0.2;
} }
probability (user_action | game_state_t0, attempt_t0, agent_assistance){
probability (user_action | game_state_t0, attempt_t0, robot_assistance){
(beg, att_1, lev_0) 0.1, 0.9, 0.0; (beg, att_1, lev_0) 0.1, 0.9, 0.0;
(beg, att_2, lev_0) 0.2, 0.8, 0.0; (beg, att_2, lev_0) 0.2, 0.8, 0.0;
(beg, att_3, lev_0) 0.3, 0.7, 0.0; (beg, att_3, lev_0) 0.3, 0.7, 0.0;
(beg, att_4, lev_0) 0.4, 0.6, 0.0; (beg, att_4, lev_0) 0.4, 0.6, 0.0;
(beg, att_1, lev_1) 0.1, 0.9, 0.0; (beg, att_1, lev_1) 0.2, 0.8, 0.0;
(beg, att_2, lev_1) 0.2, 0.8, 0.0; (beg, att_2, lev_1) 0.3, 0.7, 0.0;
(beg, att_3, lev_1) 0.3, 0.7, 0.0; (beg, att_3, lev_1) 0.4, 0.6, 0.0;
(beg, att_4, lev_1) 0.4, 0.6, 0.0; (beg, att_4, lev_1) 0.5, 0.5, 0.0;
(beg, att_1, lev_2) 0.1, 0.9, 0.0; (beg, att_1, lev_2) 0.3, 0.7, 0.0;
(beg, att_2, lev_2) 0.2, 0.8, 0.0; (beg, att_2, lev_2) 0.4, 0.6, 0.0;
(beg, att_3, lev_2) 0.3, 0.7, 0.0; (beg, att_3, lev_2) 0.5, 0.5, 0.0;
(beg, att_4, lev_2) 0.4, 0.6, 0.0; (beg, att_4, lev_2) 0.6, 0.4, 0.0;
(beg, att_1, lev_3) 0.1, 0.9, 0.0; (beg, att_1, lev_3) 0.4, 0.6, 0.0;
(beg, att_2, lev_3) 0.2, 0.8, 0.0; (beg, att_2, lev_3) 0.5, 0.5, 0.0;
(beg, att_3, lev_3) 0.3, 0.7, 0.0; (beg, att_3, lev_3) 0.6, 0.4, 0.0;
(beg, att_4, lev_3) 0.4, 0.6, 0.0; (beg, att_4, lev_3) 0.7, 0.3, 0.0;
(beg, att_1, lev_4) 0.1, 0.9, 0.0; (beg, att_1, lev_4) 1.0, 0.0, 0.0;
(beg, att_2, lev_4) 0.2, 0.8, 0.0; (beg, att_2, lev_4) 1.0, 0.0, 0.0;
(beg, att_3, lev_4) 0.3, 0.7, 0.0; (beg, att_3, lev_4) 1.0, 0.0, 0.0;
(beg, att_4, lev_4) 0.4, 0.6, 0.0; (beg, att_4, lev_4) 1.0, 0.0, 0.0;
(beg, att_1, lev_5) 0.1, 0.9, 0.0; (beg, att_1, lev_5) 1.0, 0.0, 0.0;
(beg, att_2, lev_5) 0.2, 0.8, 0.0; (beg, att_2, lev_5) 1.0, 0.0, 0.0;
(beg, att_3, lev_5) 0.3, 0.7, 0.0; (beg, att_3, lev_5) 1.0, 0.0, 0.0;
(beg, att_4, lev_5) 0.4, 0.6, 0.0; (beg, att_4, lev_5) 1.0, 0.0, 0.0;
(mid, att_1, lev_0) 0.1, 0.9, 0.0; (mid, att_1, lev_0) 0.1, 0.9, 0.0;
(mid, att_2, lev_0) 0.2, 0.8, 0.0; (mid, att_2, lev_0) 0.2, 0.8, 0.0;
(mid, att_3, lev_0) 0.3, 0.7, 0.0; (mid, att_3, lev_0) 0.3, 0.7, 0.0;
(mid, att_4, lev_0) 0.4, 0.6, 0.0; (mid, att_4, lev_0) 0.4, 0.6, 0.0;
(mid, att_1, lev_1) 0.1, 0.9, 0.0; (mid, att_1, lev_1) 0.2, 0.8, 0.0;
(mid, att_2, lev_1) 0.2, 0.8, 0.0; (mid, att_2, lev_1) 0.3, 0.7, 0.0;
(mid, att_3, lev_1) 0.3, 0.7, 0.0; (mid, att_3, lev_1) 0.4, 0.6, 0.0;
(mid, att_4, lev_1) 0.4, 0.6, 0.0; (mid, att_4, lev_1) 0.5, 0.5, 0.0;
(mid, att_1, lev_2) 0.1, 0.9, 0.0; (mid, att_1, lev_2) 0.3, 0.7, 0.0;
(mid, att_2, lev_2) 0.2, 0.8, 0.0; (mid, att_2, lev_2) 0.4, 0.6, 0.0;
(mid, att_3, lev_2) 0.3, 0.7, 0.0; (mid, att_3, lev_2) 0.5, 0.5, 0.0;
(mid, att_4, lev_2) 0.4, 0.6, 0.0; (mid, att_4, lev_2) 0.6, 0.4, 0.0;
(mid, att_1, lev_3) 0.1, 0.9, 0.0; (mid, att_1, lev_3) 0.4, 0.6, 0.0;
(mid, att_2, lev_3) 0.2, 0.8, 0.0; (mid, att_2, lev_3) 0.5, 0.5, 0.0;
(mid, att_3, lev_3) 0.3, 0.7, 0.0; (mid, att_3, lev_3) 0.6, 0.4, 0.0;
(mid, att_4, lev_3) 0.4, 0.6, 0.0; (mid, att_4, lev_3) 0.7, 0.3, 0.0;
(mid, att_1, lev_4) 0.1, 0.9, 0.0; (mid, att_1, lev_4) 1.0, 0.0, 0.0;
(mid, att_2, lev_4) 0.2, 0.8, 0.0; (mid, att_2, lev_4) 1.0, 0.0, 0.0;
(mid, att_3, lev_4) 0.3, 0.7, 0.0; (mid, att_3, lev_4) 1.0, 0.0, 0.0;
(mid, att_4, lev_4) 0.4, 0.6, 0.0; (mid, att_4, lev_4) 1.0, 0.0, 0.0;
(mid, att_1, lev_5) 0.1, 0.9, 0.0; (mid, att_1, lev_5) 1.0, 0.0, 0.0;
(mid, att_2, lev_5) 0.2, 0.8, 0.0; (mid, att_2, lev_5) 1.0, 0.0, 0.0;
(mid, att_3, lev_5) 0.3, 0.7, 0.0; (mid, att_3, lev_5) 1.0, 0.0, 0.0;
(mid, att_4, lev_5) 0.4, 0.6, 0.0; (mid, att_4, lev_5) 1.0, 0.0, 0.0;
(end, att_1, lev_0) 0.1, 0.9, 0.0; (end, att_1, lev_0) 0.1, 0.9, 0.0;
(end, att_2, lev_0) 0.2, 0.8, 0.0; (end, att_2, lev_0) 0.2, 0.8, 0.0;
(end, att_3, lev_0) 0.2, 0.8, 0.0; (end, att_3, lev_0) 0.3, 0.7, 0.0;
(end, att_4, lev_0) 0.4, 0.6, 0.0; (end, att_4, lev_0) 0.4, 0.6, 0.0;
(end, att_1, lev_1) 0.1, 0.9, 0.0; (end, att_1, lev_1) 0.2, 0.8, 0.0;
(end, att_2, lev_1) 0.2, 0.8, 0.0; (end, att_2, lev_1) 0.3, 0.7, 0.0;
(end, att_3, lev_1) 0.4, 0.6, 0.0; (end, att_3, lev_1) 0.4, 0.6, 0.0;
(end, att_4, lev_1) 0.4, 0.6, 0.0; (end, att_4, lev_1) 0.5, 0.5, 0.0;
(end, att_1, lev_2) 0.1, 0.9, 0.0; (end, att_1, lev_2) 0.3, 0.7, 0.0;
(end, att_2, lev_2) 0.2, 0.8, 0.0; (end, att_2, lev_2) 0.4, 0.6, 0.0;
(end, att_3, lev_2) 0.4, 0.6, 0.0; (end, att_3, lev_2) 0.5, 0.5, 0.0;
(end, att_4, lev_2) 0.4, 0.6, 0.0; (end, att_4, lev_2) 0.6, 0.4, 0.0;
(end, att_1, lev_3) 0.1, 0.9, 0.0; (end, att_1, lev_3) 0.4, 0.6, 0.0;
(end, att_2, lev_3) 0.2, 0.8, 0.0; (end, att_2, lev_3) 0.5, 0.5, 0.0;
(end, att_3, lev_3) 0.5, 0.5, 0.0; (end, att_3, lev_3) 0.6, 0.4, 0.0;
(end, att_4, lev_3) 0.4, 0.6, 0.0; (end, att_4, lev_3) 0.7, 0.3, 0.0;
(end, att_1, lev_4) 0.1, 0.9, 0.0; (end, att_1, lev_4) 1.0, 0.0, 0.0;
(end, att_2, lev_4) 0.2, 0.8, 0.0; (end, att_2, lev_4) 1.0, 0.0, 0.0;
(end, att_3, lev_4) 0.7, 0.3, 0.0; (end, att_3, lev_4) 1.0, 0.0, 0.0;
(end, att_4, lev_4) 0.4, 0.6, 0.0; (end, att_4, lev_4) 1.0, 0.0, 0.0;
(end, att_1, lev_5) 0.1, 0.9, 0.0; (end, att_1, lev_5) 1.0, 0.0, 0.0;
(end, att_2, lev_5) 0.2, 0.8, 0.0; (end, att_2, lev_5) 1.0, 0.0, 0.0;
(end, att_3, lev_5) 0.3, 0.7, 0.0; (end, att_3, lev_5) 1.0, 0.0, 0.0;
(end, att_4, lev_5) 0.4, 0.6, 0.0; (end, att_4, lev_5) 1.0, 0.0, 0.0;
} }
probability (game_state_t1 | user_action) { probability (game_state_t1 | user_action) {
(correct) 0.2, 0.3, 0.5; (correct) 0.25, 0.3, 0.45;
(wrong) 0.5, 0.3, 0.2; (wrong) 0.33, 0.33, 0.33;
(timeout) 0.33, 0.34, 0.33; (timeout) 0.33, 0.33, 0.33;
} }
probability (attempt_t1 | user_action) { probability (attempt_t1 | user_action) {
(correct) 0.1, 0.2, 0.3, 0.4; (correct) 0.1, 0.2, 0.25, 0.45;
(wrong) 0.4, 0.3, 0.2, 0.1; (wrong) 0.25, 0.25, 0.25, 0.25;
(timeout) 0.25, 0.25, 0.25, 0.25; (timeout) 0.25, 0.25, 0.25, 0.25;
} }
probability (user_action | robot_assistance){
(lev_0) 0.1, 0.6, 0.3;
(lev_1) 0.2, 0.5, 0.3;
(lev_2) 0.3, 0.5, 0.2;
(lev_3) 0.5, 0.3, 0.2;
(lev_4) 0.9, 0.1, 0.0;
(lev_5) 0.9, 0.1, 0.0;
}
\ No newline at end of file
This diff is collapsed.
...@@ -8,14 +8,45 @@ DAG = bn.import_DAG('bn_persona_model/persona_model_test.bif') ...@@ -8,14 +8,45 @@ DAG = bn.import_DAG('bn_persona_model/persona_model_test.bif')
#DAGnew = bn.parameter_learning.fit(model, df, methodtype="bayes") #DAGnew = bn.parameter_learning.fit(model, df, methodtype="bayes")
#bn.print_CPD(DAGnew) #bn.print_CPD(DAGnew)
q1 = bn.inference.fit(DAG, variables=['user_action'], evidence={ q1 = bn.inference.fit(DAG, variables=['user_action'], evidence={
'game_state_t0': 1, 'game_state_t0': 0,
'attempt_t0':0,
'robot_assistance':5,
'game_state_t1': 1,
'attempt_t0':1, 'attempt_t0':1,
'game_state_t1': 0,
'attempt_t1':2,
'agent_assistance':0,
}) })
print(q1.variables)
print(q1.values) print(q1.values)
# robot_assistance = [0, 1, 2, 3, 4, 5]
# attempt_t0 = [0, 1, 2, 3]
# game_state_t0 = [0, 1, 2]
# attempt_t1 = [0]
# game_state_t1 = [0, 1, 2]
#
# query_result = [[[0 for j in range(len(attempt_t0))] for i in range(len(robot_assistance))] for k in range(len(game_state_t0))]
# for k in range(len(game_state_t0)):
# for i in range(len(robot_assistance)):
# for j in range(len(attempt_t0)-1):
# if j == 0:
# query = bn.inference.fit(DAG, variables=['user_action'],
# evidence={'game_state_t0': k,
# 'attempt_t0': j,
# 'agent_assistance': i,
# 'game_state_t1': k,
# 'attempt_t1': j})
# query_result[k][i][j] = query.values
# else:
# query = bn.inference.fit(DAG, variables=['user_action'],
# evidence={'game_state_t0': k,
# 'attempt_t0': j,
# 'agent_assistance': i,
# 'game_state_t1': k,
# 'attempt_t1': j + 1})
# query_result[k][i][j] = query.values
# for k in range(len(game_state_t0)):
# for i in range(len(robot_assistance)):
# for j in range(len(attempt_t0)):
# if j == 0:
# print("game_state:",k, "attempt_from:", j," attempt_to:",j, " robot_ass:",i, " prob:", query_result[k][i][j])
# else:
# print("game_state:", k, "attempt_from:", j, " attempt_to:", j+1, " robot_ass:", i, " prob:",
# query_result[k][i][j])
...@@ -43,6 +43,7 @@ def plot2D_assistance(save_path, n_episodes, scaling_factor=1, *y): ...@@ -43,6 +43,7 @@ def plot2D_assistance(save_path, n_episodes, scaling_factor=1, *y):
lev_2 = list(map(lambda x:x[2], y[0]))[1::scaling_factor] lev_2 = list(map(lambda x:x[2], y[0]))[1::scaling_factor]
lev_3 = list(map(lambda x:x[3], y[0]))[1::scaling_factor] lev_3 = list(map(lambda x:x[3], y[0]))[1::scaling_factor]
lev_4 = list(map(lambda x:x[4], y[0]))[1::scaling_factor] lev_4 = list(map(lambda x:x[4], y[0]))[1::scaling_factor]
lev_5 = list(map(lambda x:x[5], y[0]))[1::scaling_factor]
# plot bars # plot bars
plt.figure(figsize=(10, 7)) plt.figure(figsize=(10, 7))
...@@ -54,7 +55,8 @@ def plot2D_assistance(save_path, n_episodes, scaling_factor=1, *y): ...@@ -54,7 +55,8 @@ def plot2D_assistance(save_path, n_episodes, scaling_factor=1, *y):
width=barWidth, label='lev_3') width=barWidth, label='lev_3')
plt.bar(r, lev_4, bottom=np.array(lev_0) + np.array(lev_1)+ np.array(lev_2)+ np.array(lev_3), edgecolor='white', plt.bar(r, lev_4, bottom=np.array(lev_0) + np.array(lev_1)+ np.array(lev_2)+ np.array(lev_3), edgecolor='white',
width=barWidth, label='lev_4') width=barWidth, label='lev_4')
plt.bar(r, lev_5, bottom=np.array(lev_0) + np.array(lev_1) + np.array(lev_2) + np.array(lev_3)+ np.array(lev_4), edgecolor='white',
width=barWidth, label='lev_5')
plt.legend() plt.legend()
# Custom X axis # Custom X axis
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment