Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
G
GOAL
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Antonio Andriella
GOAL
Commits
c5f38954
Commit
c5f38954
authored
4 years ago
by
Antonio Andriella
Browse files
Options
Downloads
Patches
Plain Diff
working version
parent
1ab098db
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
main.py
+27
-19
27 additions, 19 deletions
main.py
with
27 additions
and
19 deletions
main.py
+
27
−
19
View file @
c5f38954
...
@@ -29,12 +29,12 @@ import argparse
...
@@ -29,12 +29,12 @@ import argparse
from
episode
import
Episode
from
episode
import
Episode
from
cognitive_game_env
import
CognitiveGame
from
cognitive_game_env
import
CognitiveGame
from
environment
import
Environment
from
environment
import
Environment
import
maxent
as
M
import
src.
maxent
as
M
import
plot
as
P
import
src.
plot
as
P
import
solver
as
S
import
src.
solver
as
S
import
optimizer
as
O
import
src.
optimizer
as
O
import
img_utils
as
I
import
src.
img_utils
as
I
import
value_iteration
as
vi
import
src.
value_iteration
as
vi
import
simulation
as
Sim
import
simulation
as
Sim
import
bn_functions
as
bn_functions
import
bn_functions
as
bn_functions
...
@@ -237,8 +237,9 @@ def compute_agent_policy(training_set_filename, state_space, action_space, episo
...
@@ -237,8 +237,9 @@ def compute_agent_policy(training_set_filename, state_space, action_space, episo
action_index
=
action_point
action_index
=
action_point
agent_policy_counter
[
state_index
][
action_index
]
+=
1
agent_policy_counter
[
state_index
][
action_index
]
+=
1
row_t_0
=
row
[
'
user_action
'
]
row_t_0
=
row
[
'
user_action
'
]
min_val
=
np
.
finfo
(
float
).
eps
for
s
in
range
(
len
(
state_space
)):
for
s
in
range
(
len
(
state_space
)):
agent_policy_prob
[
s
]
=
list
(
map
(
lambda
x
:
x
/
(
sum
(
agent_policy_counter
[
s
])
+
0.001
),
agent_policy_counter
[
s
]))
agent_policy_prob
[
s
]
=
list
(
map
(
lambda
x
:
x
/
(
sum
(
agent_policy_counter
[
s
])
+
min_val
),
agent_policy_counter
[
s
]))
return
agent_policy_prob
return
agent_policy_prob
...
@@ -248,8 +249,8 @@ def main():
...
@@ -248,8 +249,8 @@ def main():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--bn_model_folder
'
,
'
--bn_model_folder
'
,
type
=
str
,
help
=
"
folder in which all the user and the agent models are stored
"
,
parser
.
add_argument
(
'
--bn_model_folder
'
,
'
--bn_model_folder
'
,
type
=
str
,
help
=
"
folder in which all the user and the agent models are stored
"
,
default
=
"
/home/pal/Documents/Framework/GenerativeMutualShapingRL/BN_Models
"
)
default
=
"
/home/pal/Documents/Framework/GenerativeMutualShapingRL/BN_Models
"
)
parser
.
add_argument
(
'
--bn_agent_model_filename
'
,
'
--bn_agent_model
'
,
type
=
str
,
help
=
"
file path of the agent bn model
"
,
#
parser.add_argument('--bn_agent_model_filename', '--bn_agent_model', type=str,help="file path of the agent bn model",
default
=
"
/home/pal/Documents/Framework/bn_generative_model/bn_agent_model/agent_test.bif
"
)
#
default="/home/pal/Documents/Framework/bn_generative_model/bn_agent_model/agent_test.bif")
parser
.
add_argument
(
'
--epoch
'
,
'
--epoch
'
,
type
=
int
,
help
=
"
number of epochs in the simulation
"
,
default
=
200
)
parser
.
add_argument
(
'
--epoch
'
,
'
--epoch
'
,
type
=
int
,
help
=
"
number of epochs in the simulation
"
,
default
=
200
)
parser
.
add_argument
(
'
--run
'
,
'
--run
'
,
type
=
int
,
help
=
"
number of runs in the simulation
"
,
default
=
50
)
parser
.
add_argument
(
'
--run
'
,
'
--run
'
,
type
=
int
,
help
=
"
number of runs in the simulation
"
,
default
=
50
)
parser
.
add_argument
(
'
--output_policy_filename
'
,
'
--p
'
,
type
=
str
,
help
=
"
output policy from the simulation
"
,
parser
.
add_argument
(
'
--output_policy_filename
'
,
'
--p
'
,
type
=
str
,
help
=
"
output policy from the simulation
"
,
...
@@ -259,9 +260,9 @@ def main():
...
@@ -259,9 +260,9 @@ def main():
parser
.
add_argument
(
'
--output_value_filename
'
,
'
--v
'
,
type
=
str
,
help
=
"
output value function from the simulation
"
,
parser
.
add_argument
(
'
--output_value_filename
'
,
'
--v
'
,
type
=
str
,
help
=
"
output value function from the simulation
"
,
default
=
"
value_function.pkl
"
)
default
=
"
value_function.pkl
"
)
parser
.
add_argument
(
'
--therapist_patient_interaction_folder
'
,
'
--tpi_path
'
,
type
=
str
,
help
=
"
therapist-patient interaction folder
"
,
parser
.
add_argument
(
'
--therapist_patient_interaction_folder
'
,
'
--tpi_path
'
,
type
=
str
,
help
=
"
therapist-patient interaction folder
"
,
default
=
"
/home/pal/
carf_ws/src/carf/caregiver_in_the_loop/log
"
)
default
=
"
/home/pal/
Documents/Framework/GenerativeMutualShapingRL/therapist-patient-interaction
"
)
parser
.
add_argument
(
'
--agent_patient_interaction_folder
'
,
'
--api_path
'
,
type
=
str
,
help
=
"
agent-patient interaction folder
"
,
parser
.
add_argument
(
'
--agent_patient_interaction_folder
'
,
'
--api_path
'
,
type
=
str
,
help
=
"
agent-patient interaction folder
"
,
default
=
"
/home/pal/carf_ws/src/carf/robot
_in_the_loop/log
"
)
default
=
"
/home/pal/carf_ws/src/carf/robot
-patient-interaction
"
)
parser
.
add_argument
(
'
--user_id
'
,
'
--id
'
,
type
=
int
,
help
=
"
user id
"
,
required
=
True
)
parser
.
add_argument
(
'
--user_id
'
,
'
--id
'
,
type
=
int
,
help
=
"
user id
"
,
required
=
True
)
parser
.
add_argument
(
'
--with_feedback
'
,
'
--f
'
,
type
=
eval
,
choices
=
[
True
,
False
],
help
=
"
offering sociable
"
,
required
=
True
)
parser
.
add_argument
(
'
--with_feedback
'
,
'
--f
'
,
type
=
eval
,
choices
=
[
True
,
False
],
help
=
"
offering sociable
"
,
required
=
True
)
parser
.
add_argument
(
'
--session
'
,
'
--s
'
,
type
=
int
,
help
=
"
session of the agent-human interaction
"
,
required
=
True
)
parser
.
add_argument
(
'
--session
'
,
'
--s
'
,
type
=
int
,
help
=
"
session of the agent-human interaction
"
,
required
=
True
)
...
@@ -278,9 +279,9 @@ def main():
...
@@ -278,9 +279,9 @@ def main():
# initialise the agent
# initialise the agent
bn_user_model_filename
=
args
.
bn_model_folder
+
"
/
"
+
str
(
user_id
)
+
"
/
"
+
str
(
with_feedback
)
+
"
/user_model.bif
"
bn_user_model_filename
=
args
.
bn_model_folder
+
"
/
"
+
str
(
user_id
)
+
"
/
"
+
str
(
with_feedback
)
+
"
/user_model.bif
"
bn_agent_model_filename
=
args
.
bn_model_folder
+
"
/
"
+
str
(
user_id
)
+
"
/
"
+
str
(
with_feedback
)
+
"
/agent_model.bif
"
bn_agent_model_filename
=
args
.
bn_model_folder
+
"
/
"
+
str
(
user_id
)
+
"
/
"
+
str
(
with_feedback
)
+
"
/agent_model.bif
"
learned_policy_filename
=
args
.
output_policy_filename
learned_policy_filename
=
args
.
agent_patient_interaction_folder
+
"
/
"
+
str
(
user_id
)
+
"
/
"
+
str
(
with_feedback
)
+
"
/
"
+
str
(
session
+
1
)
+
"
/
"
+
args
.
output_policy_filename
learned_reward_filename
=
args
.
output_reward_filename
learned_reward_filename
=
args
.
agent_patient_interaction_folder
+
"
/
"
+
str
(
user_id
)
+
"
/
"
+
str
(
with_feedback
)
+
"
/
"
+
str
(
session
+
1
)
+
"
/
"
+
args
.
output_reward_filename
learned_value_f_filename
=
args
.
output_value_filename
learned_value_f_filename
=
args
.
agent_patient_interaction_folder
+
"
/
"
+
str
(
user_id
)
+
"
/
"
+
str
(
with_feedback
)
+
"
/
"
+
str
(
session
+
1
)
+
"
/
"
+
args
.
output_value_filename
therapist_patient_interaction_folder
=
args
.
therapist_patient_interaction_folder
therapist_patient_interaction_folder
=
args
.
therapist_patient_interaction_folder
agent_patient_interaction_folder
=
args
.
agent_patient_interaction_folder
agent_patient_interaction_folder
=
args
.
agent_patient_interaction_folder
scaling_factor
=
1
scaling_factor
=
1
...
@@ -312,6 +313,7 @@ def main():
...
@@ -312,6 +313,7 @@ def main():
#output folders
#output folders
output_folder_data_path
=
os
.
getcwd
()
+
"
/results/
"
+
str
(
user_id
)
+
"
/
"
+
str
(
with_feedback
)
+
"
/
"
+
str
(
session
)
output_folder_data_path
=
os
.
getcwd
()
+
"
/results/
"
+
str
(
user_id
)
+
"
/
"
+
str
(
with_feedback
)
+
"
/
"
+
str
(
session
)
if
not
os
.
path
.
exists
(
os
.
getcwd
()
+
"
/results
"
+
"
/
"
+
str
(
user_id
)):
if
not
os
.
path
.
exists
(
os
.
getcwd
()
+
"
/results
"
+
"
/
"
+
str
(
user_id
)):
os
.
mkdir
(
os
.
getcwd
()
+
"
/results
"
+
"
/
"
+
str
(
user_id
))
os
.
mkdir
(
os
.
getcwd
()
+
"
/results
"
+
"
/
"
+
str
(
user_id
))
if
not
os
.
path
.
exists
(
os
.
getcwd
()
+
"
/results
"
+
"
/
"
+
str
(
user_id
)
+
"
/
"
+
str
(
with_feedback
)):
if
not
os
.
path
.
exists
(
os
.
getcwd
()
+
"
/results
"
+
"
/
"
+
str
(
user_id
)
+
"
/
"
+
str
(
with_feedback
)):
...
@@ -319,13 +321,19 @@ def main():
...
@@ -319,13 +321,19 @@ def main():
if
not
os
.
path
.
exists
(
output_folder_data_path
):
if
not
os
.
path
.
exists
(
output_folder_data_path
):
os
.
mkdir
(
output_folder_data_path
)
os
.
mkdir
(
output_folder_data_path
)
if
not
os
.
path
.
exists
(
args
.
agent_patient_interaction_folder
+
"
/
"
+
str
(
user_id
)):
os
.
mkdir
(
args
.
agent_patient_interaction_folder
+
"
/
"
+
str
(
user_id
))
if
not
os
.
path
.
exists
(
args
.
agent_patient_interaction_folder
+
"
/
"
+
str
(
user_id
)
+
"
/
"
+
str
(
with_feedback
)):
os
.
mkdir
(
args
.
agent_patient_interaction_folder
+
"
/
"
+
str
(
user_id
)
+
"
/
"
+
str
(
with_feedback
))
if
not
os
.
path
.
exists
(
args
.
agent_patient_interaction_folder
+
"
/
"
+
str
(
user_id
)
+
"
/
"
+
str
(
with_feedback
)
+
"
/
"
+
str
(
session
+
1
)):
os
.
mkdir
(
args
.
agent_patient_interaction_folder
+
"
/
"
+
str
(
user_id
)
+
"
/
"
+
str
(
with_feedback
)
+
"
/
"
+
str
(
session
+
1
))
#1. CREATE INITIAL USER COGNITIVE MODEL FROM DATA
#1. CREATE INITIAL USER COGNITIVE MODEL FROM DATA
df_from_data
,
episode_length
=
merge_user_log
(
tpi_folder_pathname
=
therapist_patient_interaction_folder
,
df_from_data
,
episode_length
=
merge_user_log
(
tpi_folder_pathname
=
therapist_patient_interaction_folder
,
file_output
=
output_folder_data_path
+
"
/summary_bn_variables_from_data.csv
"
,
file_output
=
output_folder_data_path
+
"
/summary_bn_variables_from_data.csv
"
,
user_id
=
user_id
,
user_id
=
user_id
,
with_feedback
=
with_feedback
,
with_feedback
=
with_feedback
,
rpi_folder_pathname
=
None
,
#
agent_patient_interaction_folder,
rpi_folder_pathname
=
agent_patient_interaction_folder
,
column_to_remove
=
None
)
column_to_remove
=
None
)
#2. CREATE POLICY FROM DATA
#2. CREATE POLICY FROM DATA
...
@@ -397,11 +405,11 @@ def main():
...
@@ -397,11 +405,11 @@ def main():
maxent_V
,
maxent_P
=
vi
.
value_iteration
(
cognitive_game_world
.
p_transition
,
maxent_R
,
gamma
=
0.99
,
error
=
1e-2
,
maxent_V
,
maxent_P
=
vi
.
value_iteration
(
cognitive_game_world
.
p_transition
,
maxent_R
,
gamma
=
0.99
,
error
=
1e-2
,
deterministic
=
False
)
deterministic
=
False
)
print
(
maxent_P
)
print
(
maxent_P
)
with
open
(
output_folder_data_path
+
"
/
"
+
learned_policy_filename
,
'
wb
'
)
as
f
:
with
open
(
learned_policy_filename
,
'
wb
'
)
as
f
:
pickle
.
dump
(
maxent_P
,
f
,
protocol
=
2
)
pickle
.
dump
(
maxent_P
,
f
,
protocol
=
2
)
with
open
(
output_folder_data_path
+
"
/
"
+
learned_reward_filename
,
'
wb
'
)
as
f
:
with
open
(
learned_reward_filename
,
'
wb
'
)
as
f
:
pickle
.
dump
(
maxent_R
,
f
,
protocol
=
2
)
pickle
.
dump
(
maxent_R
,
f
,
protocol
=
2
)
with
open
(
output_folder_data_path
+
"
/
"
+
learned_value_f_filename
,
'
wb
'
)
as
f
:
with
open
(
learned_value_f_filename
,
'
wb
'
)
as
f
:
pickle
.
dump
(
maxent_V
,
f
,
protocol
=
2
)
pickle
.
dump
(
maxent_V
,
f
,
protocol
=
2
)
# if n>0:
# if n>0:
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment