Skip to content

Commit

Permalink
Implement compare for action representation modules
Browse files Browse the repository at this point in the history
Summary: Implements `compare` for action representation modules

Reviewed By: yiwan-rl

Differential Revision: D67724761

fbshipit-source-id: 4a989878feb1ed91acaaf3a9d5b429b1229fb842
  • Loading branch information
rodrigodesalvobraz authored and facebook-github-bot committed Dec 31, 2024
1 parent cda3fa0 commit c29ed5f
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,14 @@ class ActionRepresentationModule(ABC, nn.Module):
@abstractmethod
def forward(self, x: torch.Tensor) -> torch.Tensor:
pass

@abstractmethod
def compare(self, other: "ActionRepresentationModule") -> str:
"""
Compares two action representation modules for equality.
Args:
other: The other action representation module to compare with.
Returns:
str: A string describing the differences, or an empty string if they are identical.
"""
pass
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

# pyre-strict

from typing import List

import torch

from pearl.action_representation_modules.action_representation_module import (
Expand Down Expand Up @@ -42,3 +44,34 @@ def max_number_actions(self) -> int:
@property
def representation_dim(self) -> int:
return self._bits_num

def compare(self, other: ActionRepresentationModule) -> str:
"""
Compares two BinaryActionTensorRepresentationModule instances for equality,
checking the bits_num and max_number_actions.
Args:
other: The other ActionRepresentationModule to compare with.
Returns:
str: A string describing the differences, or an empty string if they are identical.
"""

differences: List[str] = []

if not isinstance(other, BinaryActionTensorRepresentationModule):
differences.append(
"other is not an instance of BinaryActionTensorRepresentationModule"
)
else:
if self._bits_num != other._bits_num:
differences.append(
f"bits_num is different: {self._bits_num} vs {other._bits_num}"
)
if self.max_number_actions != other.max_number_actions:
differences.append(
f"max_number_actions is different: {self.max_number_actions} "
+ f"vs {other.max_number_actions}"
)

return "\n".join(differences)
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

# pyre-strict

from typing import List

import torch
from pearl.action_representation_modules.action_representation_module import (
ActionRepresentationModule,
Expand Down Expand Up @@ -41,3 +43,35 @@ def max_number_actions(self) -> int | None:
@property
def representation_dim(self) -> int | None:
return self._representation_dim

def compare(self, other: ActionRepresentationModule) -> str:
"""
Compares two IdentityActionRepresentationModule instances for equality,
checking max_number_actions and representation_dim.
Args:
other: The other ActionRepresentationModule to compare with.
Returns:
str: A string describing the differences, or an empty string if they are identical.
"""

differences: List[str] = []

if not isinstance(other, IdentityActionRepresentationModule):
differences.append(
"other is not an instance of IdentityActionRepresentationModule"
)
else:
if self.max_number_actions != other.max_number_actions:
differences.append(
f"max_number_actions is different: {self.max_number_actions} "
+ f"vs {other.max_number_actions}"
)
if self.representation_dim != other.representation_dim:
differences.append(
f"representation_dim is different: {self.representation_dim} "
+ f"vs {other.representation_dim}"
)

return "\n".join(differences)
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

# pyre-strict

from typing import List

import torch
import torch.nn.functional as F

Expand Down Expand Up @@ -41,3 +43,29 @@ def max_number_actions(self) -> int:
@property
def representation_dim(self) -> int:
return self._max_number_actions

def compare(self, other: ActionRepresentationModule) -> str:
"""
Compares two OneHotActionTensorRepresentationModule instances for equality,
checking the max_number_actions.
Args:
other: The other ActionRepresentationModule to compare with.
Returns:
str: A string describing the differences, or an empty string if they are identical.
"""

differences: List[str] = []

if not isinstance(other, OneHotActionTensorRepresentationModule):
differences.append(
"other is not an instance of OneHotActionTensorRepresentationModule"
)
else:
if self.max_number_actions != other.max_number_actions:
differences.append(
f"max_number_actions is different: {self.max_number_actions} vs {other.max_number_actions}"
)

return "\n".join(differences)
61 changes: 61 additions & 0 deletions test/unit/with_pytorch/test_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@

import torch
from later.unittest import TestCase
from pearl.action_representation_modules.binary_action_representation_module import (
BinaryActionTensorRepresentationModule,
)
from pearl.action_representation_modules.identity_action_representation_module import (
IdentityActionRepresentationModule,
)
from pearl.action_representation_modules.one_hot_action_representation_module import (
OneHotActionTensorRepresentationModule,
)
from pearl.history_summarization_modules.identity_history_summarization_module import (
IdentityHistorySummarizationModule,
)
Expand Down Expand Up @@ -420,3 +429,55 @@ def test_compare_warmup(self) -> None:

# Now the comparison should show a difference
self.assertNotEqual(module1.compare(module2), "")

def test_compare_one_hot_action_tensor_representation_module(self) -> None:
module1 = OneHotActionTensorRepresentationModule(max_number_actions=4)
module2 = OneHotActionTensorRepresentationModule(max_number_actions=4)

# Compare module1 with itself
self.assertEqual(module1.compare(module1), "")

# Compare module1 with module2 (should have no differences)
self.assertEqual(module1.compare(module2), "")

# Modify an attribute of module2 to create a difference
module2._max_number_actions = 5

# Now the comparison should show a difference
self.assertNotEqual(module1.compare(module2), "")

def test_compare_binary_action_tensor_representation_module(self) -> None:
module1 = BinaryActionTensorRepresentationModule(bits_num=3)
module2 = BinaryActionTensorRepresentationModule(bits_num=3)

# Compare module1 with itself
self.assertEqual(module1.compare(module1), "")

# Compare module1 with module2 (should have no differences)
self.assertEqual(module1.compare(module2), "")

# Modify an attribute of module2 to create a difference
module2._bits_num = 4

# Now the comparison should show a difference
self.assertNotEqual(module1.compare(module2), "")

def test_compare_identity_action_representation_module(self) -> None:
module1 = IdentityActionRepresentationModule(
max_number_actions=4, representation_dim=2
)
module2 = IdentityActionRepresentationModule(
max_number_actions=4, representation_dim=2
)

# Compare module1 with itself
self.assertEqual(module1.compare(module1), "")

# Compare module1 with module2 (should have no differences)
self.assertEqual(module1.compare(module2), "")

# Modify an attribute of module2 to create a difference
module2._max_number_actions = 5

# Now the comparison should show a difference
self.assertNotEqual(module1.compare(module2), "")

0 comments on commit c29ed5f

Please sign in to comment.