diff --git a/utils.py b/utils.py index d9a77d15c2d71abf5cb641a18453816fbc479326..bbd2bf0063f6a00f72a481aaa0478718eaca71c8 100644 --- a/utils.py +++ b/utils.py @@ -55,6 +55,27 @@ def plot2D_assistance(save_path, n_episodes, *y): width=barWidth, label='lev_4') + plt.legend() + # Custom X axis + plt.xticks(r, x, fontweight='bold') + plt.ylabel("performance") + plt.savefig(save_path) + plt.show() + +def plot2D_feedback(save_path, n_episodes, *y): + # The position of the bars on the x-axis + barWidth = 0.35 + r = np.arange(n_episodes) # the x locations for the groups + # Get values from the group and categories + x = [i for i in range(n_episodes)] + + feedback_no = list(map(lambda x:x[0], y[0])) + feedback_yes = list(map(lambda x:x[1], y[0])) + + # plot bars + plt.figure(figsize=(10, 7)) + plt.bar(r, feedback_no, edgecolor='white', width=barWidth, label="feedback_no") + plt.bar(r, feedback_yes, bottom=np.array(feedback_no), edgecolor='white', width=barWidth, label='feedback_yes') plt.legend() # Custom X axis plt.xticks(r, x, fontweight='bold')