Skip to content

Commit

Permalink
Merge pull request #1117 from eric-s-s/multi_queue
Browse files Browse the repository at this point in the history
windows multiprocessing (and added tempfile bonus)
  • Loading branch information
meatballs authored Aug 17, 2017
2 parents dd0bb42 + efb8135 commit e021ef7
Show file tree
Hide file tree
Showing 8 changed files with 427 additions and 142 deletions.
3 changes: 0 additions & 3 deletions axelrod/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import os

on_windows = os.name == 'nt'
DEFAULT_TURNS = 200

# The order of imports matters!
Expand Down
17 changes: 10 additions & 7 deletions axelrod/fingerprint.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from collections import namedtuple
from tempfile import NamedTemporaryFile
import os
from tempfile import mkstemp
import matplotlib.pyplot as plt
import numpy as np
import tqdm

import axelrod as axl
from axelrod import on_windows, Player
from axelrod import Player
from axelrod.strategy_transformers import JossAnnTransformer, DualTransformer
from axelrod.interaction_utils import (
compute_final_score_per_turn, read_interactions_from_file)
Expand Down Expand Up @@ -300,11 +301,9 @@ def fingerprint(
the values are the mean score for the corresponding interactions.
"""

if on_windows and (filename is None): # pragma: no cover
in_memory = True
elif filename is None:
outputfile = NamedTemporaryFile(mode='w')
filename = outputfile.name
temp_file_descriptor = None
if not in_memory and filename is None:
temp_file_descriptor, filename = mkstemp()

edges, tourn_players = self.construct_tournament_elements(
step, progress_bar=progress_bar)
Expand All @@ -324,6 +323,10 @@ def fingerprint(
self.interactions = read_interactions_from_file(
filename, progress_bar=progress_bar)

if temp_file_descriptor is not None:
os.close(temp_file_descriptor)
os.remove(filename)

self.data = generate_data(self.interactions, self.points, edges)
return self.data

Expand Down
4 changes: 2 additions & 2 deletions axelrod/result_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def __init__(self, players, interactions, repetitions=False,
----------
players : list
a list of player objects.
interactions : list
a list of dictionaries mapping tuples of player indices to
interactions : dict
a dictionary mapping tuples of player indices to
interactions (1 for each repetition)
repetitions : int
The number of repetitions
Expand Down
2 changes: 0 additions & 2 deletions axelrod/tests/integration/test_tournament.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ def test_serial_play(self):
actual_outcome = sorted(zip(self.player_names, scores))
self.assertEqual(actual_outcome, self.expected_outcome)

@unittest.skipIf(axelrod.on_windows,
"Parallel processing not supported on Windows")
def test_parallel_play(self):
tournament = axelrod.Tournament(
name=self.test_name,
Expand Down
48 changes: 46 additions & 2 deletions axelrod/tests/unit/test_fingerprint.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import os
from tempfile import mkstemp
import unittest
from unittest.mock import patch
from hypothesis import given
import axelrod as axl
from axelrod.fingerprint import (create_points, create_jossann, create_probes,
Expand All @@ -17,6 +20,22 @@
C, D = axl.Action.C, axl.Action.D


class RecordedMksTemp(object):
"""This object records all results from RecordedMksTemp.mkstemp. It's for
testing that temp files are created and then destroyed."""
record = []

@staticmethod
def mkstemp(*args, **kwargs):
temp_file_info = mkstemp(*args, **kwargs)
RecordedMksTemp.record.append(temp_file_info)
return temp_file_info

@staticmethod
def reset_record():
RecordedMksTemp.record = []


class TestFingerprint(unittest.TestCase):

@classmethod
Expand Down Expand Up @@ -165,6 +184,33 @@ def test_progress_bar_fingerprint(self):
progress_bar=True)
self.assertEqual(sorted(data.keys()), self.expected_points)

@patch('axelrod.fingerprint.mkstemp', RecordedMksTemp.mkstemp)
def test_temp_file_creation(self):

RecordedMksTemp.reset_record()
af = AshlockFingerprint(self.strategy, self.probe)
filename = "test_outputs/test_fingerprint.csv"

# No temp file is created.
af.fingerprint(turns=1, repetitions=1, step=0.5, progress_bar=False,
in_memory=True)
af.fingerprint(turns=1, repetitions=1, step=0.5, progress_bar=False,
in_memory=True, filename=filename)
af.fingerprint(turns=1, repetitions=1, step=0.5, progress_bar=False,
in_memory=False, filename=filename)

self.assertEqual(RecordedMksTemp.record, [])

# Temp file is created and destroyed.
af.fingerprint(turns=1, repetitions=1, step=0.5, progress_bar=False,
in_memory=False, filename=None)

self.assertEqual(len(RecordedMksTemp.record), 1)
filename = RecordedMksTemp.record[0][1]
self.assertIsInstance(filename, str)
self.assertNotEqual(filename, '')
self.assertFalse(os.path.isfile(filename))

def test_fingerprint_with_filename(self):
filename = "test_outputs/test_fingerprint.csv"
af = AshlockFingerprint(self.strategy, self.probe)
Expand Down Expand Up @@ -196,8 +242,6 @@ def test_serial_fingerprint(self):
self.assertEqual(edge_keys, self.expected_edges)
self.assertEqual(coord_keys, self.expected_points)

@unittest.skipIf(axl.on_windows,
"Parallel processing not supported on Windows")
def test_parallel_fingerprint(self):
af = AshlockFingerprint(self.strategy, self.probe)
af.fingerprint(turns=10, repetitions=2, step=0.5, processes=2,
Expand Down
Loading

0 comments on commit e021ef7

Please sign in to comment.