diff --git a/bn_persona_model/persona_model_test.bif b/bn_persona_model/persona_model_test.bif index 28727a89c3dae174cb90be06ea3cc7b102dd5160..f3c919a6f240b9da18501e282a7c1dd09a59438a 100644 --- a/bn_persona_model/persona_model_test.bif +++ b/bn_persona_model/persona_model_test.bif @@ -6,39 +6,36 @@ network persona_model { variable agent_assistance { type discrete [ 6 ] { lev_0, lev_1, lev_2, lev_3, lev_4, lev_5 }; } -variable attempt_t0 { +variable attempt { type discrete [ 4 ] { att_1, att_2, att_3, att_4 }; } -variable game_state_t0 { - type discrete [ 3 ] { beg, mid, end }; -} -variable attempt_t1 { - type discrete [ 1 ] { att_1, att_2, att_3, att_4}; -} -variable game_state_t1 { +variable game_state { type discrete [ 3 ] { beg, mid, end }; } + variable user_action { type discrete [ 3 ] { correct, wrong, timeout }; } variable agent_feedback { type discrete [ 2 ] { no, yes }; } +variable user_memory { + type discrete [ 3 ] {low, medium, high}; +} + +variable user_reactivity { + type discrete [ 3 ] {low, medium, high}; +} %INDIVIDUAL PROBABILITIES DEFINITION + probability ( agent_assistance ) { table 0.17, 0.16, 0.16, 0.17, 0.17, 0.17; } -probability ( game_state_t0 ) { +probability ( game_state) { table 0.34, 0.33, 0.33; } -probability ( attempt_t0 ) { - table 0.25, 0.25, 0.25, 0.25; -} -probability ( game_state_t1 ) { - table 0.34, 0.33, 0.33; -} -probability ( attempt_t1 ) { +probability ( attempt ) { table 0.25, 0.25, 0.25, 0.25; } probability ( user_action ) { @@ -47,98 +44,65 @@ probability ( user_action ) { probability ( agent_feedback ) { table 0.5, 0.5; } -probability(agent_assistance | agent_feedback) { - (yes) 0.4, 0.3, 0.2, 0.1, 0.0, 0.0 - (no) 0.0, 0.0, 0.1, 0.2, 0.3, 0.4 -} - -probability (user_action | game_state_t0, attempt_t0, agent_assistance){ - (beg, att_1, lev_0) 0.1, 0.9, 0.0; - (beg, att_2, lev_0) 0.2, 0.8, 0.0; - (beg, att_3, lev_0) 0.3, 0.7, 0.0; - (beg, att_4, lev_0) 0.4, 0.6, 0.0; - (beg, att_1, lev_1) 0.2, 0.8, 0.0; - (beg, att_2, lev_1) 0.3, 0.7, 0.0; - (beg, att_3, lev_1) 0.4, 0.6, 0.0; - (beg, att_4, lev_1) 0.5, 0.5, 0.0; - (beg, att_1, lev_2) 0.3, 0.7, 0.0; - (beg, att_2, lev_2) 0.4, 0.6, 0.0; - (beg, att_3, lev_2) 0.5, 0.5, 0.0; - (beg, att_4, lev_2) 0.6, 0.4, 0.0; - (beg, att_1, lev_3) 0.4, 0.6, 0.0; - (beg, att_2, lev_3) 0.5, 0.5, 0.0; - (beg, att_3, lev_3) 0.6, 0.4, 0.0; - (beg, att_4, lev_3) 0.7, 0.3, 0.0; - (beg, att_1, lev_4) 1.0, 0.0, 0.0; - (beg, att_2, lev_4) 1.0, 0.0, 0.0; - (beg, att_3, lev_4) 1.0, 0.0, 0.0; - (beg, att_4, lev_4) 1.0, 0.0, 0.0; - (beg, att_1, lev_5) 1.0, 0.0, 0.0; - (beg, att_2, lev_5) 1.0, 0.0, 0.0; - (beg, att_3, lev_5) 1.0, 0.0, 0.0; - (beg, att_4, lev_5) 1.0, 0.0, 0.0; - - (mid, att_1, lev_0) 0.1, 0.9, 0.0; - (mid, att_2, lev_0) 0.2, 0.8, 0.0; - (mid, att_3, lev_0) 0.3, 0.7, 0.0; - (mid, att_4, lev_0) 0.4, 0.6, 0.0; - (mid, att_1, lev_1) 0.2, 0.8, 0.0; - (mid, att_2, lev_1) 0.3, 0.7, 0.0; - (mid, att_3, lev_1) 0.4, 0.6, 0.0; - (mid, att_4, lev_1) 0.5, 0.5, 0.0; - (mid, att_1, lev_2) 0.3, 0.7, 0.0; - (mid, att_2, lev_2) 0.4, 0.6, 0.0; - (mid, att_3, lev_2) 0.5, 0.5, 0.0; - (mid, att_4, lev_2) 0.6, 0.4, 0.0; - (mid, att_1, lev_3) 0.4, 0.6, 0.0; - (mid, att_2, lev_3) 0.5, 0.5, 0.0; - (mid, att_3, lev_3) 0.6, 0.4, 0.0; - (mid, att_4, lev_3) 0.7, 0.3, 0.0; - (mid, att_1, lev_4) 1.0, 0.0, 0.0; - (mid, att_2, lev_4) 1.0, 0.0, 0.0; - (mid, att_3, lev_4) 1.0, 0.0, 0.0; - (mid, att_4, lev_4) 1.0, 0.0, 0.0; - (mid, att_1, lev_5) 1.0, 0.0, 0.0; - (mid, att_2, lev_5) 1.0, 0.0, 0.0; - (mid, att_3, lev_5) 1.0, 0.0, 0.0; - (mid, att_4, lev_5) 1.0, 0.0, 0.0; - - (end, att_1, lev_0) 0.1, 0.9, 0.0; - (end, att_2, lev_0) 0.2, 0.8, 0.0; - (end, att_3, lev_0) 0.3, 0.7, 0.0; - (end, att_4, lev_0) 0.4, 0.6, 0.0; - (end, att_1, lev_1) 0.2, 0.8, 0.0; - (end, att_2, lev_1) 0.3, 0.7, 0.0; - (end, att_3, lev_1) 0.4, 0.6, 0.0; - (end, att_4, lev_1) 0.5, 0.5, 0.0; - (end, att_1, lev_2) 0.3, 0.7, 0.0; - (end, att_2, lev_2) 0.4, 0.6, 0.0; - (end, att_3, lev_2) 0.5, 0.5, 0.0; - (end, att_4, lev_2) 0.6, 0.4, 0.0; - (end, att_1, lev_3) 0.4, 0.6, 0.0; - (end, att_2, lev_3) 0.5, 0.5, 0.0; - (end, att_3, lev_3) 0.6, 0.4, 0.0; - (end, att_4, lev_3) 0.7, 0.3, 0.0; - (end, att_1, lev_4) 1.0, 0.0, 0.0; - (end, att_2, lev_4) 1.0, 0.0, 0.0; - (end, att_3, lev_4) 1.0, 0.0, 0.0; - (end, att_4, lev_4) 1.0, 0.0, 0.0; - (end, att_1, lev_5) 1.0, 0.0, 0.0; - (end, att_2, lev_5) 1.0, 0.0, 0.0; - (end, att_3, lev_5) 1.0, 0.0, 0.0; - (end, att_4, lev_5) 1.0, 0.0, 0.0; -} - - -probability (game_state_t1 | user_action) { - (correct) 0.25, 0.3, 0.45; - (wrong) 0.33, 0.33, 0.33; - (timeout) 0.33, 0.33, 0.33; -} -probability (attempt_t1 | user_action) { - (correct) 0.1, 0.2, 0.25, 0.45; - (wrong) 0.25, 0.25, 0.25, 0.25; - (timeout) 0.25, 0.25, 0.25, 0.25; +probability ( user_memory ) { + table 0.33, 0.34, 0.33; +} +probability ( user_reactivity ) { + table 0.33, 0.34, 0.33; +} + + +probability (game_state | user_memory){ + (low) 0.1, 0.3, 0.6; + (medium) 0.2, 0.4, 0.4; + (high) 0.33, 0.33, 0.34; +} + +probability (attempt | user_memory){ + (low) 0.1,0.1,0.2, 0.6; + (medium) 0.1, 0.2, 0.3, 0.4; + (high) 0.25, 0.25, 0.25, 0.25; +} + +probability (game_state | user_reactivity){ + (low) 0.1, 0.3, 0.6; + (medium) 0.2, 0.4, 0.4; + (high) 0.33, 0.33, 0.34; +} + +probability (attempt | user_reactivity){ + (low) 0.1,0.1,0.2, 0.6; + (medium) 0.1, 0.2, 0.3, 0.4; + (high) 0.25, 0.25, 0.25, 0.25; } + +probability (game_state | user_action) { + (correct) 0.30, 0.30, 0.4; + (wrong) 0.35, 0.35, 0.3; + (timeout) 0.33, 0.33, 0.34; +} +probability (attempt | user_action) { + (correct) 0.25, 0.25, 0.25, 0.25; + (wrong) 0.4, 0.3, 0.2, 0.1; + (timeout) 0.4, 0.3, 0.2, 0.1; +} + + +probability (user_action | agent_assistance, agent_feedback) { +(lev_0, no) 0.05 0.85 0.1; +(lev_1, no) 0.1 0.8 0.1; +(lev_2, no) 0.2 0.7 0.1; +(lev_3, no) 0.33 0.57 0.1; +(lev_4, no) 0.9 0.1 0.0; +(lev_5, no) 1.0 0.0 0.0; + +(lev_0, yes) 0.5 0.4 0.1; +(lev_1, yes) 0.6 0.3 0.1; +(lev_2, yes) 0.7 0.2 0.1; +(lev_3, yes) 0.8 0.1 0.1; +(lev_4, yes) 0.9 0.1 0.0; +(lev_5, yes) 1.0 0.0 0.0; + +} diff --git a/test.py b/test.py index 66bd57cac6a41c2575e609287938cb842d8dd8b0..a2d2d907fa5a2e3bfa600c16fb579d7e9b8bf4c2 100644 --- a/test.py +++ b/test.py @@ -1,20 +1,97 @@ import bnlearn as bn +import pandas as pd -DAG = bn.import_DAG('bn_persona_model/persona_model_test.bif') + +def import_data_from_csv(csv_filename, dag_filename): + print("/************************************************************/") + print("Init model") + DAG = bn.import_DAG(dag_filename) + df_caregiver = bn.sampling(DAG, n= 10000) + print("/************************************************************/") + print("real_user Model") + DAG_ = bn.import_DAG(dag_filename, CPD=False) + df_real_user = pd.read_csv(csv_filename) + DAG_real_user = bn.parameter_learning.fit(DAG_, df_real_user, methodtype='bayes') + df_real_user = bn.sampling(DAG_real_user, n=10000) + print("/************************************************************/") + print("Shared knowledge") + DAG_ = bn.import_DAG(dag_filename, CPD=False) + shared_knowledge = [df_real_user, df_caregiver] + conc_shared_knowledge = pd.concat(shared_knowledge) + DAG_shared = bn.parameter_learning.fit(DAG_, conc_shared_knowledge) + df_conc_shared_knowledge = bn.sampling(DAG_shared, n=10000) + return DAG_shared + + + +import_data_from_csv(csv_filename='bn_persona_model/cognitive_game.csv', dag_filename='bn_persona_model/persona_model_test.bif') +# DAG = bn.import_DAG('bn_persona_model/persona_model_test.bif') +# G = bn.plot(DAG) +# q1 = bn.inference.fit(DAG, variables=[ 'user_action'], evidence={ +# 'game_state': 0, +# 'attempt':0, +# 'agent_feedback':1, +# 'memory': 0, +# 'reactivity':0, +# 'agent_assistance':0, +# +# }) +# df = pd.read_csv('bn_persona_model/cognitive_game.csv') +# df = bn.sampling(DAG, n=10000) +# #model_sl = bn.structure_learning.fit(df, methodtype='hc', scoretype='bic') +# #bn.plot(model_sl, pos=G['pos']) +# DAG_update = bn.parameter_learning.fit(DAG, df) +# n_game_state = 3 +# n_attempt = 4 +# n_aas = 6 +# n_af = 2 + +# for gs in range(n_game_state): +# for att in range(n_attempt): +# for aas in range(n_aas): +# for af in range(n_af): +# q1 = bn.inference.fit(DAG_update, variables=[ 'user_action'], evidence={ +# 'game_state': gs, +# 'attempt':att, +# 'agent_feedback':af, +# 'user_memory':0, +# 'user_reactivity':0, +# 'agent_assistance':aas}) +# print("GS:", gs, " ATT:", att, " AA", aas, " AF", af) +# +# df.head() +# DAG = bn.import_DAG('bn_persona_model/persona_model_test.bif', CPD=False) +# bn.plot(DAG) +# DAG_update = bn.parameter_learning.fit(DAG, df) +# DAG_true = bn.import_DAG('bn_persona_model/persona_model_test.bif', CPD=True) +# q1 = bn.inference.fit(DAG_update, variables=['user_action'], evidence={ +# 'game_state': 0, +# 'attempt':2, +# 'agent_feedback':0, +# }) +# print("BEFORE") +# print(q1.values) +# df = bn.sampling(DAG_update, n=1000) +# DAG_update = bn.parameter_learning.fit(DAG_update, df) +# q1 = bn.inference.fit(DAG_update, variables=['user_action'], evidence={ +# 'game_state': 0, +# 'attempt':2, +# 'agent_feedback':0, +# }) +# print("AFTER") +# print(q1.values) #df = bn.sampling(DAG, n=1000, verbose=2) #model = bn.structure_learning.fit(df) #G = bn.plot(model) #DAGnew = bn.parameter_learning.fit(model, df, methodtype="bayes") #bn.print_CPD(DAGnew) -q1 = bn.inference.fit(DAG, variables=['user_action'], evidence={ - 'game_state_t0': 0, - 'attempt_t0':1, - 'game_state_t1': 0, - 'attempt_t1':2, - 'agent_assistance':0, -}) -print(q1.values) +# q1 = bn.inference.fit(DAG, variables=['user_action'], evidence={ +# 'game_state': 0, +# 'attempt':1, +# 'agent_feedback':1, +# }) +# print(q1.values) # robot_assistance = [0, 1, 2, 3, 4, 5] # attempt_t0 = [0, 1, 2, 3]