From 1b0702e31c546dff20c94dd6d4c83324364230b6 Mon Sep 17 00:00:00 2001 From: Yi Wan Date: Fri, 13 Dec 2024 14:02:05 -0800 Subject: [PATCH] Rename VanillaCNN to CNNValueNetwork and divide state values by 255. Summary: This diff makes two changes: 1) Following our naming convention for q value networks, rename VanillaCNN to CNNValueNetwork. 2) For atari games, raw images pixels values (0-255) are stored in the replay buffer (instead of values normalized to be within 0-1) to save memory. We need to do normalization in our CNN networks. Reviewed By: rodrigodesalvobraz Differential Revision: D66280552 fbshipit-source-id: 1346b6eb18cae8a831e7f071467487239723115a --- pearl/neural_networks/common/__init__.py | 4 ++-- pearl/neural_networks/common/value_networks.py | 6 +++--- test/unit/with_pytorch/test_vanilla_cnns.py | 8 ++++---- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/pearl/neural_networks/common/__init__.py b/pearl/neural_networks/common/__init__.py index 968d6ca4..99e2f847 100644 --- a/pearl/neural_networks/common/__init__.py +++ b/pearl/neural_networks/common/__init__.py @@ -8,7 +8,7 @@ from .epistemic_neural_networks import Ensemble, EpistemicNeuralNetwork, MLPWithPrior from .residual_wrapper import ResidualWrapper -from .value_networks import ValueNetwork, VanillaCNN, VanillaValueNetwork +from .value_networks import CNNValueNetwork, ValueNetwork, VanillaValueNetwork __all__ = [ "Ensemble", @@ -16,7 +16,7 @@ "MLPWithPrior", "ResidualWrapper", "ValueNetwork", - "VanillaCNN", + "CNNValueNetwork", "VanillaValueNetwork", "Epinet", ] diff --git a/pearl/neural_networks/common/value_networks.py b/pearl/neural_networks/common/value_networks.py index 1bc41eaf..e0393af5 100644 --- a/pearl/neural_networks/common/value_networks.py +++ b/pearl/neural_networks/common/value_networks.py @@ -61,7 +61,7 @@ def xavier_init(self) -> None: nn.init.xavier_normal_(layer.weight) -class VanillaCNN(ValueNetwork): +class CNNValueNetwork(ValueNetwork): """ Vanilla CNN with a convolutional block followed by an mlp block. Args: @@ -101,7 +101,7 @@ def __init__( == len(strides) == len(paddings) ) - super().__init__() + super(CNNValueNetwork, self).__init__() self._input_channels = input_channels_count self._input_height = input_height @@ -142,7 +142,7 @@ def __init__( ) def forward(self, x: Tensor) -> Tensor: - out_cnn = self._model_cnn(x) + out_cnn = self._model_cnn(x / 255.0) out_flattened = torch.flatten(out_cnn, start_dim=1, end_dim=-1) out_fc = self._model_fc(out_flattened) return out_fc diff --git a/test/unit/with_pytorch/test_vanilla_cnns.py b/test/unit/with_pytorch/test_vanilla_cnns.py index 19db6d97..dda9ff37 100644 --- a/test/unit/with_pytorch/test_vanilla_cnns.py +++ b/test/unit/with_pytorch/test_vanilla_cnns.py @@ -13,14 +13,14 @@ import torch import torchvision -from pearl.neural_networks.common.value_networks import VanillaCNN +from pearl.neural_networks.common.value_networks import CNNValueNetwork from torch import optim from torch.utils.data import DataLoader, Subset from torchvision import transforms -class TestVanillaCNNs(unittest.TestCase): +class TestCNNValueNetworks(unittest.TestCase): def setUp(self) -> None: transform = transforms.Compose([transforms.ToTensor()]) mnist_dataset = torchvision.datasets.MNIST( @@ -42,7 +42,7 @@ def setUp(self) -> None: self.mnist_train_dataset, self.batch_size, shuffle=True ) - def test_vanilla_cnns(self) -> None: + def test_cnns(self) -> None: """ a simple cnn should be able to fit the mnist digit dataset and the training accuracy should be close to 90% @@ -58,7 +58,7 @@ def test_vanilla_cnns(self) -> None: paddings = [2] hidden_dims_fully_connected = [64] output_dim = 10 - network = VanillaCNN( + network = CNNValueNetwork( input_width=input_width, input_height=input_height, input_channels_count=input_channels,