diff --git a/examples/rllib_examples/from_checkpoint.py b/examples/rllib_examples/from_checkpoint.py index 081f897b6486275687954d7a6fb6d1b03ab5dabe..c7087e9359795fe6e42c4e8c99de7409635297af 100644 --- a/examples/rllib_examples/from_checkpoint.py +++ b/examples/rllib_examples/from_checkpoint.py @@ -42,6 +42,8 @@ def get_actions(algo: Algorithm, obs: dict) -> list: else: timestep = obs[agent]["observations"][0] + # TODO: Check if using LSTM models and then supply initial state. + # Calculate the single actions for this step _temp = algo.compute_single_action( observation=obs[agent], diff --git a/examples/rllib_examples/recurrent.py b/examples/rllib_examples/recurrent.py index 1f04d79768881e0763681896990ef7945475618b..73eb669ef34d687ef9a411227e9343fbc11b0544 100644 --- a/examples/rllib_examples/recurrent.py +++ b/examples/rllib_examples/recurrent.py @@ -225,6 +225,12 @@ class LSTMActionMaskingModel(RecurrentNetwork, nn.Module): flat_inputs = torch.cat( [flat_inputs, input_dict["prev_rewards"].unsqueeze(1)], dim=1 ) + if not isinstance(seq_lens, torch.Tensor) and not seq_lens: + # This happens when calling compute_single_action + seq_lens = torch.ones([flat_inputs.shape[0], 1]) + if not state: + # This also happens when calling compute_single_action + state = [torch.unsqueeze(s, 0) for s in self.get_initial_state()] # Note that max_seq_len != input_dict.max_seq_len != seq_lens.max() # as input_dict may have extra zero-padding beyond seq_lens.max().