diff --git a/examples/rllib_examples/from_checkpoint.py b/examples/rllib_examples/from_checkpoint.py
index 081f897b6486275687954d7a6fb6d1b03ab5dabe..c7087e9359795fe6e42c4e8c99de7409635297af 100644
--- a/examples/rllib_examples/from_checkpoint.py
+++ b/examples/rllib_examples/from_checkpoint.py
@@ -42,6 +42,8 @@ def get_actions(algo: Algorithm, obs: dict) -> list:
         else:
             timestep = obs[agent]["observations"][0]
 
+        # TODO: Check if using LSTM models and then supply initial state.
+
         # Calculate the single actions for this step
         _temp = algo.compute_single_action(
             observation=obs[agent],
diff --git a/examples/rllib_examples/recurrent.py b/examples/rllib_examples/recurrent.py
index 1f04d79768881e0763681896990ef7945475618b..73eb669ef34d687ef9a411227e9343fbc11b0544 100644
--- a/examples/rllib_examples/recurrent.py
+++ b/examples/rllib_examples/recurrent.py
@@ -225,6 +225,12 @@ class LSTMActionMaskingModel(RecurrentNetwork, nn.Module):
             flat_inputs = torch.cat(
                 [flat_inputs, input_dict["prev_rewards"].unsqueeze(1)], dim=1
             )
+        if not isinstance(seq_lens, torch.Tensor) and not seq_lens:
+            # This happens when calling compute_single_action
+            seq_lens = torch.ones([flat_inputs.shape[0], 1])
+        if not state:
+            # This also happens when calling compute_single_action
+            state = [torch.unsqueeze(s, 0) for s in self.get_initial_state()]
 
         # Note that max_seq_len != input_dict.max_seq_len != seq_lens.max()
         # as input_dict may have extra zero-padding beyond seq_lens.max().