From 98dea01689ad4e5c391d0d391f12c361406cae73 Mon Sep 17 00:00:00 2001 From: Antonio Andriella <aandriella@iri.upc.edu> Date: Sat, 14 Nov 2020 23:59:31 +0100 Subject: [PATCH] connect bn therapist model with the simulation irl --- main.py | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/main.py b/main.py index 6cd75f2..f81f9f0 100644 --- a/main.py +++ b/main.py @@ -26,15 +26,15 @@ import pickle import bnlearn import argparse -from cognitive_game_env import CognitiveGame from episode import Episode +from cognitive_game_env import CognitiveGame from environment import Environment -import src.maxent as M -import src.plot as P -import src.solver as S -import src.optimizer as O -import src.img_utils as I -import src.value_iteration as vi +import maxent as M +import plot as P +import solver as S +import optimizer as O +import img_utils as I +import value_iteration as vi import simulation as Sim import bn_functions as bn_functions @@ -246,8 +246,8 @@ def main(): #################GENERATE SIMULATION################################ parser = argparse.ArgumentParser() - parser.add_argument('--bn_user_model_filename', '--bn_user_model', type=str,help="file path of the user bn model", - default="/home/pal/Documents/Framework/bn_generative_model/bn_persona_model/persona_test.bif") + parser.add_argument('--bn_model_folder', '--bn_model_folder', type=str,help="folder in which all the user and the agent models are stored ", + default="/home/pal/Documents/Framework/GenerativeMutualShapingRL/BN_Models") parser.add_argument('--bn_agent_model_filename', '--bn_agent_model', type=str,help="file path of the agent bn model", default="/home/pal/Documents/Framework/bn_generative_model/bn_agent_model/agent_test.bif") parser.add_argument('--epoch', '--epoch', type=int,help="number of epochs in the simulation", default=200) @@ -262,9 +262,9 @@ def main(): default="/home/pal/carf_ws/src/carf/caregiver_in_the_loop/log") parser.add_argument('--agent_patient_interaction_folder', '--api_path', type=str,help="agent-patient interaction folder", default="/home/pal/carf_ws/src/carf/robot_in_the_loop/log") - parser.add_argument('--user_id', '--id', type=int,help="user id") - parser.add_argument('--with_feedback', '--f', type=bool,help="offering sociable") - parser.add_argument('--session', '--s', type=int, help="session of the agent-human interaction") + parser.add_argument('--user_id', '--id', type=int,help="user id", required=True) + parser.add_argument('--with_feedback', '--f', type=eval, choices=[True, False], help="offering sociable", required=True) + parser.add_argument('--session', '--s', type=int, help="session of the agent-human interaction", required=True) args = parser.parse_args() @@ -276,8 +276,8 @@ def main(): epochs = args.epoch runs = args.run # initialise the agent - bn_user_model_filename = args.bn_user_model_filename - bn_agent_model_filename = args.bn_agent_model_filename + bn_user_model_filename = args.bn_model_folder +"/"+str(user_id)+"/"+str(with_feedback)+"/user_model.bif" + bn_agent_model_filename = args.bn_model_folder+"/"+str(user_id)+"/"+str(with_feedback)+"/agent_model.bif" learned_policy_filename = args.output_policy_filename learned_reward_filename = args.output_reward_filename learned_value_f_filename = args.output_value_filename @@ -314,10 +314,10 @@ def main(): output_folder_data_path = os.getcwd() + "/results/" + str(user_id) +"/"+str(with_feedback)+"/"+str(session) if not os.path.exists(os.getcwd() + "/results"+"/"+str(user_id)): os.mkdir(os.getcwd() + "/results"+"/"+str(user_id)) - if not os.path.exists(os.getcwd() + "/results"+"/"+str(user_id) +"/"+str(with_feedback)): - os.mkdir(os.getcwd() + "/results" + "/" +str(user_id) +"/"+str(with_feedback)) - if not os.path.exists(output_folder_data_path): - os.mkdir(output_folder_data_path) + if not os.path.exists(os.getcwd() + "/results"+"/"+str(user_id) +"/"+str(with_feedback)): + os.mkdir(os.getcwd() + "/results" + "/" +str(user_id) +"/"+str(with_feedback)) + if not os.path.exists(output_folder_data_path): + os.mkdir(output_folder_data_path) #1. CREATE INITIAL USER COGNITIVE MODEL FROM DATA @@ -325,7 +325,7 @@ def main(): file_output=output_folder_data_path+"/summary_bn_variables_from_data.csv", user_id=user_id, with_feedback=with_feedback, - rpi_folder_pathname=agent_patient_interaction_folder, + rpi_folder_pathname=None,#agent_patient_interaction_folder, column_to_remove=None) #2. CREATE POLICY FROM DATA -- GitLab