Skip to content
Snippets Groups Projects
Commit b4d9fdba authored by Antonio Andriella's avatar Antonio Andriella
Browse files

add policy_load and get_state_action methods

parent a5606b0e
No related branches found
No related tags found
No related merge requests found
...@@ -7,12 +7,14 @@ facial expression and gesture. Every time we check if the action ...@@ -7,12 +7,14 @@ facial expression and gesture. Every time we check if the action
import rospy import rospy
import random import random
import ast import ast
import pickle
import numpy as np
from robot_behaviour.face_reproducer import Face from robot_behaviour.face_reproducer import Face
from robot_behaviour.speech_reproducer import Speech from robot_behaviour.speech_reproducer import Speech
from robot_behaviour.gesture_reproducer import Gesture from robot_behaviour.gesture_reproducer import Gesture
class Robot: class Robot:
def __init__(self, speech, sentences_file, face=None, gesture=None): def __init__(self, speech, sentences_file, action_policy_filename=None, face=None, gesture=None):
''' '''
:param speech: instance of class Speech :param speech: instance of class Speech
:param sentences_file: the file where all the sentences are stored :param sentences_file: the file where all the sentences are stored
...@@ -23,6 +25,7 @@ class Robot: ...@@ -23,6 +25,7 @@ class Robot:
self.sentences = self.load_sentences(sentences_file) self.sentences = self.load_sentences(sentences_file)
self.face = face self.face = face
self.gesture = gesture self.gesture = gesture
self.action_policy = self.load_robot_policy(action_policy_filename)
self.action = { self.action = {
"instruction": self.instruction, "instruction": self.instruction,
...@@ -52,10 +55,31 @@ class Robot: ...@@ -52,10 +55,31 @@ class Robot:
"neutral" : self.neutral "neutral" : self.neutral
} }
def load_robot_policy(self, learned_policy_filename):
with open(learned_policy_filename, "rb") as f:
loaded_policy = pickle.load(f)
return loaded_policy
def get_irl_state_action(self, state_index, epsilon=0.1):
action = 0
print("Select it between the following:", self.action_policy[state_index])
if random.random() < epsilon:
new_list = (self.action_policy[state_index])
best_action_index = np.argmax(self.action_policy[state_index])
new_list[best_action_index] = 0
action = np.argmax(new_list)
else:
action = np.argmax(self.action_policy[state_index])
return action
def get_random_state_action(self):
return random.randint(0, 6)
def send_to_rest(self): def send_to_rest(self):
self.gesture.initial_pos() self.gesture.initial_pos()
def load_sentences(self, file): def load_sentences(self, file):
file = open(file, "r") file = open(file, "r")
contents = file.read() contents = file.read()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment