diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dcc8ad9e..0f8d0b76 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,7 +4,8 @@ repos: hooks: - id: isort args: - - --skip wandb + - --profile=black + - --skip=wandb - repo: /~https://github.com/myint/autoflake rev: v1.4 hooks: diff --git a/experiments/league.py b/experiments/league.py index bb708928..aa315a94 100644 --- a/experiments/league.py +++ b/experiments/league.py @@ -12,9 +12,17 @@ import numpy as np import pandas as pd import torch -from peewee import (JOIN, CharField, DateTimeField, FloatField, - ForeignKeyField, Model, SmallIntegerField, SqliteDatabase, - fn) +from peewee import ( + JOIN, + CharField, + DateTimeField, + FloatField, + ForeignKeyField, + Model, + SmallIntegerField, + SqliteDatabase, + fn, +) from ppo_gridnet import Agent, MicroRTSStatsRecorder from stable_baselines3.common.vec_env import VecMonitor from trueskill import Rating, quality_1vs1, rate_1vs1 diff --git a/experiments/ppo_gridnet.py b/experiments/ppo_gridnet.py index 4754c4b3..4e8945f1 100644 --- a/experiments/ppo_gridnet.py +++ b/experiments/ppo_gridnet.py @@ -13,8 +13,7 @@ import torch.nn as nn import torch.optim as optim from gym.spaces import MultiDiscrete -from stable_baselines3.common.vec_env import (VecEnvWrapper, VecMonitor, - VecVideoRecorder) +from stable_baselines3.common.vec_env import VecEnvWrapper, VecMonitor, VecVideoRecorder from torch.distributions.categorical import Categorical from torch.utils.tensorboard import SummaryWriter diff --git a/gym_microrts/envs/vec_env.py b/gym_microrts/envs/vec_env.py index 2a509ae2..5bc1ec5f 100644 --- a/gym_microrts/envs/vec_env.py +++ b/gym_microrts/envs/vec_env.py @@ -82,13 +82,15 @@ def __init__( from rts.units import UnitTypeTable self.real_utt = UnitTypeTable() - from ai.rewardfunction import (AttackRewardFunction, - ProduceBuildingRewardFunction, - ProduceCombatUnitRewardFunction, - ProduceWorkerRewardFunction, - ResourceGatherRewardFunction, - RewardFunctionInterface, - WinLossRewardFunction) + from ai.rewardfunction import ( + AttackRewardFunction, + ProduceBuildingRewardFunction, + ProduceCombatUnitRewardFunction, + ProduceWorkerRewardFunction, + ResourceGatherRewardFunction, + RewardFunctionInterface, + WinLossRewardFunction, + ) self.rfs = JArray(RewardFunctionInterface)( [ @@ -272,13 +274,15 @@ def __init__( from rts.units import UnitTypeTable self.real_utt = UnitTypeTable() - from ai.rewardfunction import (AttackRewardFunction, - ProduceBuildingRewardFunction, - ProduceCombatUnitRewardFunction, - ProduceWorkerRewardFunction, - ResourceGatherRewardFunction, - RewardFunctionInterface, - WinLossRewardFunction) + from ai.rewardfunction import ( + AttackRewardFunction, + ProduceBuildingRewardFunction, + ProduceCombatUnitRewardFunction, + ProduceWorkerRewardFunction, + ResourceGatherRewardFunction, + RewardFunctionInterface, + WinLossRewardFunction, + ) self.rfs = JArray(RewardFunctionInterface)( [ diff --git a/pyproject.toml b/pyproject.toml index c31cf1f1..c9ee721c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,4 +42,4 @@ build-backend = "poetry.core.masonry.api" [tool.poetry.extras] spyder = ["spyder"] -cleanrl = ["cleanrl"] \ No newline at end of file +cleanrl = ["cleanrl"]