Skip to content
Snippets Groups Projects
Commit c5f38954 authored by Antonio Andriella's avatar Antonio Andriella
Browse files

working version

parent 1ab098db
No related branches found
No related tags found
No related merge requests found
...@@ -29,12 +29,12 @@ import argparse ...@@ -29,12 +29,12 @@ import argparse
from episode import Episode from episode import Episode
from cognitive_game_env import CognitiveGame from cognitive_game_env import CognitiveGame
from environment import Environment from environment import Environment
import maxent as M import src.maxent as M
import plot as P import src.plot as P
import solver as S import src.solver as S
import optimizer as O import src.optimizer as O
import img_utils as I import src.img_utils as I
import value_iteration as vi import src.value_iteration as vi
import simulation as Sim import simulation as Sim
import bn_functions as bn_functions import bn_functions as bn_functions
...@@ -237,8 +237,9 @@ def compute_agent_policy(training_set_filename, state_space, action_space, episo ...@@ -237,8 +237,9 @@ def compute_agent_policy(training_set_filename, state_space, action_space, episo
action_index = action_point action_index = action_point
agent_policy_counter[state_index][action_index] += 1 agent_policy_counter[state_index][action_index] += 1
row_t_0 = row['user_action'] row_t_0 = row['user_action']
min_val = np.finfo(float).eps
for s in range(len(state_space)): for s in range(len(state_space)):
agent_policy_prob[s] = list(map(lambda x:x/(sum(agent_policy_counter[s])+0.001), agent_policy_counter[s])) agent_policy_prob[s] = list(map(lambda x:x/(sum(agent_policy_counter[s])+min_val), agent_policy_counter[s]))
return agent_policy_prob return agent_policy_prob
...@@ -248,8 +249,8 @@ def main(): ...@@ -248,8 +249,8 @@ def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--bn_model_folder', '--bn_model_folder', type=str,help="folder in which all the user and the agent models are stored ", parser.add_argument('--bn_model_folder', '--bn_model_folder', type=str,help="folder in which all the user and the agent models are stored ",
default="/home/pal/Documents/Framework/GenerativeMutualShapingRL/BN_Models") default="/home/pal/Documents/Framework/GenerativeMutualShapingRL/BN_Models")
parser.add_argument('--bn_agent_model_filename', '--bn_agent_model', type=str,help="file path of the agent bn model", # parser.add_argument('--bn_agent_model_filename', '--bn_agent_model', type=str,help="file path of the agent bn model",
default="/home/pal/Documents/Framework/bn_generative_model/bn_agent_model/agent_test.bif") # default="/home/pal/Documents/Framework/bn_generative_model/bn_agent_model/agent_test.bif")
parser.add_argument('--epoch', '--epoch', type=int,help="number of epochs in the simulation", default=200) parser.add_argument('--epoch', '--epoch', type=int,help="number of epochs in the simulation", default=200)
parser.add_argument('--run', '--run', type=int, help="number of runs in the simulation", default=50) parser.add_argument('--run', '--run', type=int, help="number of runs in the simulation", default=50)
parser.add_argument('--output_policy_filename', '--p', type=str,help="output policy from the simulation", parser.add_argument('--output_policy_filename', '--p', type=str,help="output policy from the simulation",
...@@ -259,9 +260,9 @@ def main(): ...@@ -259,9 +260,9 @@ def main():
parser.add_argument('--output_value_filename', '--v', type=str, help="output value function from the simulation", parser.add_argument('--output_value_filename', '--v', type=str, help="output value function from the simulation",
default="value_function.pkl") default="value_function.pkl")
parser.add_argument('--therapist_patient_interaction_folder', '--tpi_path', type=str,help="therapist-patient interaction folder", parser.add_argument('--therapist_patient_interaction_folder', '--tpi_path', type=str,help="therapist-patient interaction folder",
default="/home/pal/carf_ws/src/carf/caregiver_in_the_loop/log") default="/home/pal/Documents/Framework/GenerativeMutualShapingRL/therapist-patient-interaction")
parser.add_argument('--agent_patient_interaction_folder', '--api_path', type=str,help="agent-patient interaction folder", parser.add_argument('--agent_patient_interaction_folder', '--api_path', type=str,help="agent-patient interaction folder",
default="/home/pal/carf_ws/src/carf/robot_in_the_loop/log") default="/home/pal/carf_ws/src/carf/robot-patient-interaction")
parser.add_argument('--user_id', '--id', type=int,help="user id", required=True) parser.add_argument('--user_id', '--id', type=int,help="user id", required=True)
parser.add_argument('--with_feedback', '--f', type=eval, choices=[True, False], help="offering sociable", required=True) parser.add_argument('--with_feedback', '--f', type=eval, choices=[True, False], help="offering sociable", required=True)
parser.add_argument('--session', '--s', type=int, help="session of the agent-human interaction", required=True) parser.add_argument('--session', '--s', type=int, help="session of the agent-human interaction", required=True)
...@@ -278,9 +279,9 @@ def main(): ...@@ -278,9 +279,9 @@ def main():
# initialise the agent # initialise the agent
bn_user_model_filename = args.bn_model_folder +"/"+str(user_id)+"/"+str(with_feedback)+"/user_model.bif" bn_user_model_filename = args.bn_model_folder +"/"+str(user_id)+"/"+str(with_feedback)+"/user_model.bif"
bn_agent_model_filename = args.bn_model_folder+"/"+str(user_id)+"/"+str(with_feedback)+"/agent_model.bif" bn_agent_model_filename = args.bn_model_folder+"/"+str(user_id)+"/"+str(with_feedback)+"/agent_model.bif"
learned_policy_filename = args.output_policy_filename learned_policy_filename = args.agent_patient_interaction_folder+"/"+str(user_id)+"/"+str(with_feedback)+"/"+str(session+1)+"/"+args.output_policy_filename
learned_reward_filename = args.output_reward_filename learned_reward_filename = args.agent_patient_interaction_folder+"/"+str(user_id)+"/"+str(with_feedback)+"/"+str(session+1)+"/"+args.output_reward_filename
learned_value_f_filename = args.output_value_filename learned_value_f_filename = args.agent_patient_interaction_folder+"/"+str(user_id)+"/"+str(with_feedback)+"/"+str(session+1)+"/"+args.output_value_filename
therapist_patient_interaction_folder = args.therapist_patient_interaction_folder therapist_patient_interaction_folder = args.therapist_patient_interaction_folder
agent_patient_interaction_folder = args.agent_patient_interaction_folder agent_patient_interaction_folder = args.agent_patient_interaction_folder
scaling_factor = 1 scaling_factor = 1
...@@ -312,6 +313,7 @@ def main(): ...@@ -312,6 +313,7 @@ def main():
#output folders #output folders
output_folder_data_path = os.getcwd() + "/results/" + str(user_id) +"/"+str(with_feedback)+"/"+str(session) output_folder_data_path = os.getcwd() + "/results/" + str(user_id) +"/"+str(with_feedback)+"/"+str(session)
if not os.path.exists(os.getcwd() + "/results"+"/"+str(user_id)): if not os.path.exists(os.getcwd() + "/results"+"/"+str(user_id)):
os.mkdir(os.getcwd() + "/results"+"/"+str(user_id)) os.mkdir(os.getcwd() + "/results"+"/"+str(user_id))
if not os.path.exists(os.getcwd() + "/results"+"/"+str(user_id) +"/"+str(with_feedback)): if not os.path.exists(os.getcwd() + "/results"+"/"+str(user_id) +"/"+str(with_feedback)):
...@@ -319,13 +321,19 @@ def main(): ...@@ -319,13 +321,19 @@ def main():
if not os.path.exists(output_folder_data_path): if not os.path.exists(output_folder_data_path):
os.mkdir(output_folder_data_path) os.mkdir(output_folder_data_path)
if not os.path.exists(args.agent_patient_interaction_folder+"/"+str(user_id)):
os.mkdir(args.agent_patient_interaction_folder+"/"+str(user_id))
if not os.path.exists(args.agent_patient_interaction_folder+"/"+str(user_id)+"/"+str(with_feedback)):
os.mkdir(args.agent_patient_interaction_folder+"/"+str(user_id)+"/"+str(with_feedback))
if not os.path.exists(args.agent_patient_interaction_folder + "/" + str(user_id) + "/" + str(with_feedback) + "/" + str(session + 1)):
os.mkdir(args.agent_patient_interaction_folder + "/" + str(user_id) + "/" + str(with_feedback) + "/" + str(session + 1))
#1. CREATE INITIAL USER COGNITIVE MODEL FROM DATA #1. CREATE INITIAL USER COGNITIVE MODEL FROM DATA
df_from_data, episode_length = merge_user_log(tpi_folder_pathname=therapist_patient_interaction_folder, df_from_data, episode_length = merge_user_log(tpi_folder_pathname=therapist_patient_interaction_folder,
file_output=output_folder_data_path+"/summary_bn_variables_from_data.csv", file_output=output_folder_data_path+"/summary_bn_variables_from_data.csv",
user_id=user_id, user_id=user_id,
with_feedback=with_feedback, with_feedback=with_feedback,
rpi_folder_pathname=None,#agent_patient_interaction_folder, rpi_folder_pathname=agent_patient_interaction_folder,
column_to_remove=None) column_to_remove=None)
#2. CREATE POLICY FROM DATA #2. CREATE POLICY FROM DATA
...@@ -397,11 +405,11 @@ def main(): ...@@ -397,11 +405,11 @@ def main():
maxent_V, maxent_P = vi.value_iteration(cognitive_game_world.p_transition, maxent_R, gamma=0.99, error=1e-2, maxent_V, maxent_P = vi.value_iteration(cognitive_game_world.p_transition, maxent_R, gamma=0.99, error=1e-2,
deterministic=False) deterministic=False)
print(maxent_P) print(maxent_P)
with open(output_folder_data_path+"/"+learned_policy_filename, 'wb') as f: with open(learned_policy_filename, 'wb') as f:
pickle.dump(maxent_P, f, protocol=2) pickle.dump(maxent_P, f, protocol=2)
with open(output_folder_data_path+"/"+learned_reward_filename, 'wb') as f: with open(learned_reward_filename, 'wb') as f:
pickle.dump(maxent_R, f, protocol=2) pickle.dump(maxent_R, f, protocol=2)
with open(output_folder_data_path+"/"+learned_value_f_filename, 'wb') as f: with open(learned_value_f_filename, 'wb') as f:
pickle.dump(maxent_V, f, protocol=2) pickle.dump(maxent_V, f, protocol=2)
# if n>0: # if n>0:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment