From 98dea01689ad4e5c391d0d391f12c361406cae73 Mon Sep 17 00:00:00 2001
From: Antonio Andriella <aandriella@iri.upc.edu>
Date: Sat, 14 Nov 2020 23:59:31 +0100
Subject: [PATCH] connect bn therapist model with the simulation irl

---
 main.py | 38 +++++++++++++++++++-------------------
 1 file changed, 19 insertions(+), 19 deletions(-)

diff --git a/main.py b/main.py
index 6cd75f2..f81f9f0 100644
--- a/main.py
+++ b/main.py
@@ -26,15 +26,15 @@ import pickle
 import bnlearn
 import argparse
 
-from cognitive_game_env import CognitiveGame
 from episode import Episode
+from cognitive_game_env import CognitiveGame
 from environment import Environment
-import src.maxent as M
-import src.plot as P
-import src.solver as S
-import src.optimizer as O
-import src.img_utils as I
-import src.value_iteration as vi
+import maxent as M
+import plot as P
+import solver as S
+import optimizer as O
+import img_utils as I
+import value_iteration as vi
 
 import simulation as Sim
 import bn_functions as bn_functions
@@ -246,8 +246,8 @@ def main():
 
     #################GENERATE SIMULATION################################
     parser = argparse.ArgumentParser()
-    parser.add_argument('--bn_user_model_filename', '--bn_user_model', type=str,help="file path of the user bn model",
-                        default="/home/pal/Documents/Framework/bn_generative_model/bn_persona_model/persona_test.bif")
+    parser.add_argument('--bn_model_folder', '--bn_model_folder', type=str,help="folder in which all the user and the agent models are stored ",
+                        default="/home/pal/Documents/Framework/GenerativeMutualShapingRL/BN_Models")
     parser.add_argument('--bn_agent_model_filename', '--bn_agent_model', type=str,help="file path of the agent bn model",
                         default="/home/pal/Documents/Framework/bn_generative_model/bn_agent_model/agent_test.bif")
     parser.add_argument('--epoch', '--epoch', type=int,help="number of epochs in the simulation", default=200)
@@ -262,9 +262,9 @@ def main():
                         default="/home/pal/carf_ws/src/carf/caregiver_in_the_loop/log")
     parser.add_argument('--agent_patient_interaction_folder', '--api_path', type=str,help="agent-patient interaction folder",
                         default="/home/pal/carf_ws/src/carf/robot_in_the_loop/log")
-    parser.add_argument('--user_id', '--id', type=int,help="user id")
-    parser.add_argument('--with_feedback', '--f', type=bool,help="offering sociable")
-    parser.add_argument('--session', '--s', type=int, help="session of the agent-human interaction")
+    parser.add_argument('--user_id', '--id', type=int,help="user id", required=True)
+    parser.add_argument('--with_feedback', '--f', type=eval, choices=[True, False], help="offering sociable", required=True)
+    parser.add_argument('--session', '--s', type=int, help="session of the agent-human interaction", required=True)
 
     args = parser.parse_args()
 
@@ -276,8 +276,8 @@ def main():
     epochs = args.epoch
     runs = args.run
     # initialise the agent
-    bn_user_model_filename = args.bn_user_model_filename
-    bn_agent_model_filename = args.bn_agent_model_filename
+    bn_user_model_filename = args.bn_model_folder  +"/"+str(user_id)+"/"+str(with_feedback)+"/user_model.bif"
+    bn_agent_model_filename = args.bn_model_folder+"/"+str(user_id)+"/"+str(with_feedback)+"/agent_model.bif"
     learned_policy_filename = args.output_policy_filename
     learned_reward_filename = args.output_reward_filename
     learned_value_f_filename = args.output_value_filename
@@ -314,10 +314,10 @@ def main():
     output_folder_data_path = os.getcwd() + "/results/" + str(user_id) +"/"+str(with_feedback)+"/"+str(session)
     if not os.path.exists(os.getcwd() + "/results"+"/"+str(user_id)):
         os.mkdir(os.getcwd() + "/results"+"/"+str(user_id))
-        if not os.path.exists(os.getcwd() + "/results"+"/"+str(user_id) +"/"+str(with_feedback)):
-            os.mkdir(os.getcwd() + "/results" + "/" +str(user_id) +"/"+str(with_feedback))
-            if not os.path.exists(output_folder_data_path):
-                os.mkdir(output_folder_data_path)
+    if not os.path.exists(os.getcwd() + "/results"+"/"+str(user_id) +"/"+str(with_feedback)):
+        os.mkdir(os.getcwd() + "/results" + "/" +str(user_id) +"/"+str(with_feedback))
+    if not os.path.exists(output_folder_data_path):
+        os.mkdir(output_folder_data_path)
 
 
 #1. CREATE INITIAL USER COGNITIVE MODEL FROM DATA
@@ -325,7 +325,7 @@ def main():
                                                   file_output=output_folder_data_path+"/summary_bn_variables_from_data.csv",
                                                   user_id=user_id,
                                                   with_feedback=with_feedback,
-                                                  rpi_folder_pathname=agent_patient_interaction_folder,
+                                                  rpi_folder_pathname=None,#agent_patient_interaction_folder,
                                                 column_to_remove=None)
 
     #2. CREATE POLICY FROM DATA
-- 
GitLab