Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issues in ReplayBuffer #201

Open
hyeok9855 opened this issue Oct 18, 2024 · 10 comments
Open

Issues in ReplayBuffer #201

hyeok9855 opened this issue Oct 18, 2024 · 10 comments
Assignees
Labels
enhancement New feature or request help wanted Extra attention is needed high priority Let's do these first!

Comments

@hyeok9855
Copy link
Collaborator

hyeok9855 commented Oct 18, 2024

Here I note several issues in the current ReplayBuffer, including the points raised by @younik in #193.

  1. Types of objects, (Fix linter and add pre-commit CI #193): Currently, ReplayBuffer can take Trajectories, Transitions (which are inherited from Container), and tuple[States, States] (which is not). This makes the overall code a bit more messy. Possible solutions are as follows:
    a. Make State inherit from Container, and let the ReplayBuffer take Container objects.
    b. Make subclasses of ReplayBuffer depending on the objects
  2. Supporting terminal-state-only Buffer: I think the most popular form of the ReplayBuffer is just storing a terminal state x (with reward) as in Max's code, and sample a training trajectory with backward policy. This is preferred since 1) it significantly reduces memory usage (vs. Trajectory-based), and 2) the backward sampling gives more diverse learning signals, which I believe helps training. We need to support this, but it needs to be considered along with the issue 1 above.
    3. Prioritization: I think using PrioritizedReplayBuffer by default (instead of ReplayBuffer), with a proper optional parameter that enables the prioritization will make the code more concise.
  3. Device (minor): Currently, the device of objects in the buffer follows that of env. Maybe we should support options for both CPU and cuda, since they have different pros and cons, e.g, storing as cuda tensor is faster if we use cuda, but it consumes more GPU memory. BTW, there exist some minor errors in ReplayBuffer when cuda is enabled. We need to check.

I hope to discuss enough before tackling these!

@hyeok9855 hyeok9855 added enhancement New feature or request help wanted Extra attention is needed labels Oct 18, 2024
@hyeok9855 hyeok9855 changed the title Issues in replay buffer Issues in ReplayBuffer Oct 18, 2024
@josephdviviano
Copy link
Collaborator

Thanks @hyeok9855 this is great.

My thoughts - I think that there are many use-cases where a buffer would be better off having the full trajectory (for algorithmic reasons) - but I also agree that often you only want to store the terminal state.

While making States a type of Container actually seems reasonable to me, I'm not at all sure how to you reason about buffers that only accept states. It seems like we would need to, in some cases at least, reinvent all the trajectory-level features already present in the Trajectories or Transitions classes.

So based on this I think making a bunch of buffer types makes sense using subclassing.

I am thinking the hierarchy could be that you have Transition, Trajectory, and State level buffers, which can optionally inherit from different kinds of buffers (right now we have Replay and PrioritizedReplay) - one could imagine many kinds of buffers.

Agree RE: merging prioritized and normal buffers - but I could imagine more kinds of buffers in the future that should be their own class. Our "prioritized" buffer is really a specific kind of prioritized buffer that won't work for everyone unfortunately.

Agree 100% for the device handling.

@josephdviviano
Copy link
Collaborator

I'm also happy to help with this :)

@josephdviviano
Copy link
Collaborator

josephdviviano commented Oct 31, 2024

Thanks -- re this - I'm wondering whether we can simply use the ReplayBuffer implementation in the torchrl distribution and inherit all of their functionality. This would obviously be it's own PR and might not be possible if too many decisions we made elsewhere in the repo are incompatible (but at the very least this is a good reference implementation for us).

@saleml @younik thoughts?

/~https://github.com/pytorch/rl/tree/main/torchrl/data/replay_buffers

@younik
Copy link
Collaborator

younik commented Oct 31, 2024

Reusing TorchRL ReplayBuffer is a very good idea. They decouple sampling from the replay buffer, which is a great design choice.

I will look at the challenges to migrating to it more in depth

@josephdviviano josephdviviano added the high priority Let's do these first! label Jan 13, 2025
@hyeok9855
Copy link
Collaborator Author

Personally, I don't want to include torchrl, only to use the buffer. I will first try refactoring without it. Let's see.

@hyeok9855
Copy link
Collaborator Author

hyeok9855 commented Feb 6, 2025

Regarding prioritization, the current implementation of PrioritizedReplayBuffer uses the reward (=priority score) just to sort the samples in the buffer and then drop low-reward ones when the buffer is full.

However, I think the core of prioritization is prioritized sampling, as in here.

Here are some ideas to improve regarding the prioritization:

  1. Separate the storage and sampler, similar to the torchrl. Storage manages what to do when it is full (e.g., FIFO, sort using priority and drop tails ...), and the sampler manages how to sample, including prioritized sampling.
  2. Allow any scalar metrics for prioritization (e.g., reward, training loss, etc.)

@hyeok9855
Copy link
Collaborator Author

hyeok9855 commented Feb 6, 2025

Can anyone explain why we need terminating_states (here) in the replay buffer? It isn't used anywhere else outside of the buffer.

If it's not necessary, I will remove it when refactoring because this complicates things.

@saleml
Copy link
Collaborator

saleml commented Feb 8, 2025

Thanks @hyeok9855 for the comments

  • I agree with you. torchrl might be an overkill.
  • terminating_states is for the flow matching loss. Unfortunately, this loss requires two input objects, unlike all other losses. I also learned that this is a design choice of this library, and flow matching does not have to be implemented this way. Do you have a strong opinion on this @hyeok9855 ?

@hyeok9855
Copy link
Collaborator Author

Thanks @saleml.

I've just checked the flow matching loss. I don't have a strong opinion, but I wonder if we could improve in this way:

Can we simply add a flag indicating a state is terminal or not (maybe the existence of _log_reward?) and then calculate everything inside of flow_matching_loss (rather than using additional reward_matching_loss)?

@saleml
Copy link
Collaborator

saleml commented Feb 11, 2025

I'm definitely ok with your suggestion. Some authors refrain from using the "reward_matching_loss" term altogether.
@josephdviviano , I'd love your opinion here too.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed high priority Let's do these first!
Projects
None yet
Development

No branches or pull requests

4 participants