Skip to content

Commit

Permalink
Prevent messages persisting between iterations (#973)
Browse files Browse the repository at this point in the history
* Implement opt-in message list persistence

Message lists are now non-persistent by default (messages from the previous iteration are no longer available prior to output in the current iteration)

message lists can be marked as persistent by calling .getPersistent() on the description object.

Adds brute-force tests to check get/set works as intended

Adds brute-force tests for persistence / non-persistence working on a multi-iteation simulation

Adds Bucket persistence / non-persistence tests (C++ only)

* Move MessageBruteForce::CDescription::setPersistent to the header

This fixes a swig linker issue on Windows CI, which is somehow related to the using statements.
  • Loading branch information
ptheywood authored Nov 23, 2022
1 parent 7d38299 commit de88c97
Show file tree
Hide file tree
Showing 12 changed files with 574 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ class MessageArray::Description : public CDescription {
Description& operator=(const Description& other_message) = default;
Description& operator=(Description&& other_message) = default;

using MessageBruteForce::CDescription::setPersistent;
using MessageBruteForce::CDescription::newVariable;
#ifdef SWIG
using MessageBruteForce::CDescription::newVariableArray;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ class MessageArray2D::Description : public CDescription {
Description& operator=(const Description& other_message) = default;
Description& operator=(Description&& other_message) = default;

using MessageBruteForce::CDescription::setPersistent;
using MessageBruteForce::CDescription::newVariable;
#ifdef SWIG
using MessageBruteForce::CDescription::newVariableArray;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ class MessageArray3D::Description : public CDescription {
Description& operator=(const Description& other_message) = default;
Description& operator=(Description&& other_message) = default;

using MessageBruteForce::CDescription::setPersistent;
using MessageBruteForce::CDescription::newVariable;
#ifdef SWIG
using MessageBruteForce::CDescription::newVariableArray;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ struct MessageBruteForce::Data {
* Name of the message, used to refer to the message in many functions
*/
std::string name;
/**
* Boolean indicating if the message list is allowed to persist iterations or not
*/
bool persistent;
/**
* The number of functions that have optional output of this message type
* This value is modified by AgentFunctionDescription
Expand Down Expand Up @@ -197,6 +201,11 @@ class MessageBruteForce::CDescription {
* @return The message's name
*/
std::string getName() const;
/**
* Query if the message list is a persistent message list or not (messages will persist from one iteration to the next)
* @return if the message list is persistent or not
*/
bool getPersistent() const;
/**
* @param variable_name Name used to refer to the desired variable
* @return The type of the named variable
Expand Down Expand Up @@ -233,6 +242,13 @@ class MessageBruteForce::CDescription {
/// These mutable accessors will only be available via mutable subclasses
/// This solves a multiple inheritance issue
///
/**
* Set that the message list should be persistent or not
* @param persistent new value for message list persistence
*/
void setPersistent(const bool persistent) {
message->persistent = persistent;
}
/**
* Adds a new variable to the message
* @param variable_name Name of the variable
Expand Down Expand Up @@ -298,6 +314,7 @@ class MessageBruteForce::Description : public CDescription {
Description& operator=(const Description& other_message) = default;
Description& operator=(Description&& other_message) = default;

using MessageBruteForce::CDescription::setPersistent;
using MessageBruteForce::CDescription::newVariable;
#ifdef SWIG
using MessageBruteForce::CDescription::newVariableArray;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ class MessageBucket::Description : public CDescription {
Description& operator=(const Description & other_message) = default;
Description& operator=(Description && other_message) = default;

using MessageBruteForce::CDescription::setPersistent;
using MessageBruteForce::CDescription::newVariable;
#ifdef SWIG
using MessageBruteForce::CDescription::newVariableArray;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ class MessageSpatial2D::Description : public CDescription {
Description& operator=(const Description & other_message) = default;
Description& operator=(Description && other_message) = default;

using MessageBruteForce::CDescription::setPersistent;
using MessageBruteForce::CDescription::newVariable;
#ifdef SWIG
using MessageBruteForce::CDescription::newVariableArray;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ class MessageSpatial3D::Description : public CDescription {
Description& operator=(const Description& other_message) = default;
Description& operator=(Description&& other_message) = default;

using MessageBruteForce::CDescription::setPersistent;
using MessageBruteForce::CDescription::newVariable;
#ifdef SWIG
using MessageBruteForce::CDescription::newVariableArray;
Expand Down
9 changes: 9 additions & 0 deletions src/flamegpu/gpu/CUDASimulation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,15 @@ bool CUDASimulation::step() {
// Run the exit conditons, detecting wheter or not any we
bool exitRequired = this->stepExitConditions();

// Set message counts to zero, and set flags to update state of non-persistent message lists
for (auto &a : message_map) {
if (!a.second->getMessageDescription().persistent) {
a.second->setMessageCount(0);
a.second->setTruncateMessageListFlag();
a.second->setPBMConstructionRequiredFlag();
}
}

// Record, store and output the elapsed time of the step.
stepTimer->stop();
float stepMilliseconds = stepTimer->getElapsedSeconds();
Expand Down
7 changes: 7 additions & 0 deletions src/flamegpu/runtime/messaging/MessageBruteForce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,13 @@ void MessageBruteForce::CUDAModelHandler::buildIndex(CUDAScatter &, unsigned int
MessageBruteForce::Data::Data(std::shared_ptr<const ModelData> _model, const std::string &message_name)
: model(_model)
, name(message_name)
, persistent(false)
, optional_outputs(0) { }
MessageBruteForce::Data::Data(std::shared_ptr<const ModelData> _model, const MessageBruteForce::Data &other)
: model(_model)
, variables(other.variables)
, name(other.name)
, persistent(other.persistent)
, optional_outputs(other.optional_outputs) { }
MessageBruteForce::Data *MessageBruteForce::Data::clone(const std::shared_ptr<const ModelData> &newParent) {
return new MessageBruteForce::Data(newParent, *this);
Expand All @@ -55,6 +57,7 @@ bool MessageBruteForce::Data::operator==(const MessageBruteForce::Data& rhs) con
return true;
if (name == rhs.name
// && model.lock() == rhs.model.lock() // Don't check weak pointers
&& persistent == rhs.persistent
&& variables.size() == rhs.variables.size()) {
{ // Compare variables
for (auto &v : variables) {
Expand Down Expand Up @@ -109,6 +112,10 @@ std::string MessageBruteForce::CDescription::getName() const {
return message->name;
}

bool MessageBruteForce::CDescription::getPersistent() const {
return message->persistent;
}

const std::type_index& MessageBruteForce::CDescription::getVariableType(const std::string& variable_name) const {
auto f = message->variables.find(variable_name);
if (f != message->variables.end()) {
Expand Down
164 changes: 164 additions & 0 deletions tests/swig/python/runtime/messaging/test_brute_force.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,84 @@
return flamegpu::ALIVE;
}
"""
even_only_condition = """
FLAMEGPU_AGENT_FUNCTION_CONDITION(EvenOnlyCondition) {
return FLAMEGPU->getStepCounter() % 2 == 0;
}
"""
out_simple = """
FLAMEGPU_AGENT_FUNCTION(out_simple, flamegpu::MessageNone, flamegpu::MessageBruteForce) {
int id = FLAMEGPU->getVariable<int>("id");
FLAMEGPU->message_out.setVariable<int>("id", id);
return flamegpu::ALIVE;
}
"""

in_simple = """
FLAMEGPU_AGENT_FUNCTION(in_simple, flamegpu::MessageBruteForce, flamegpu::MessageNone) {
const int id = FLAMEGPU->getVariable<int>("id");
unsigned int count = 0;
unsigned int sum = 0;
for (auto &m : FLAMEGPU->message_in) {
count++;
sum += m.getVariable<int>("id");
}
FLAMEGPU->setVariable<unsigned int>("count", count);
FLAMEGPU->setVariable<unsigned int>("sum", sum);
return flamegpu::ALIVE;
}"""

class InitPopulationEvenOutputOnly(pyflamegpu.HostFunctionCallback):
def __init__(self):
super().__init__()

def run(self, FLAMEGPU):
# Generate a basic pop
agent = FLAMEGPU.agent(AGENT_NAME)
for i in range (AGENT_COUNT) :
instance = agent.newAgent()
instance.setVariableInt("id", i)
instance.setVariableUInt("count", 0)
instance.setVariableUInt("sum", 0)

class AssertEvenOutputOnly(pyflamegpu.HostFunctionCallback):
def __init__(self):
super().__init__()

def run(self, FLAMEGPU):
agent = FLAMEGPU.agent(AGENT_NAME)
# Get the population data
av = agent.getPopulationData()
# Iterate the population, ensuring that each agent read the correct number of messages and got the correct sum of messages.
# These values expect only a single bin is used, in the interest of simplicitly.
exepctedCountEven = agent.count()
expectedCountOdd = 0
for a in av:
if (FLAMEGPU.getStepCounter() % 2 == 0):
# Even iterations expect the count to match the number of agents, and sum to be non zero.
assert a.getVariableUInt("count") == exepctedCountEven
assert a.getVariableUInt("sum") != 0
else:
# Odd iters expect 0 count and 0 sum
assert a.getVariableUInt("count") == expectedCountOdd
assert a.getVariableUInt("sum") == 0

class AssertPersistent(pyflamegpu.HostFunctionCallback):
def __init__(self):
super().__init__()

def run(self, FLAMEGPU):
agent = FLAMEGPU.agent(AGENT_NAME)
# Get the population data
av = agent.getPopulationData()
# Iterate the population, ensuring that each agent read the correct number of messages and got the correct sum of messages.
# These values expect only a single bin is used, in the interest of simplicitly.
exepctedCountEven = agent.count()
for a in av:
if (FLAMEGPU.getStepCounter() % 2 == 0):
# all iterations expect the count to match the number of agents, and sum to be non zero.
assert a.getVariableUInt("count") == exepctedCountEven
assert a.getVariableUInt("sum") != 0

class TestMessage_BruteForce(TestCase):

Expand Down Expand Up @@ -321,3 +399,89 @@ def test_ReadEmpty(self):
assert len(pop_out) == 1
ai = pop_out.front()
assert ai.getVariableUInt("count") == 0

def test_getSetPersistent(self):
"""Test that getting and setting a message lists's persistent flag behaves as intended
"""
model = pyflamegpu.ModelDescription("Model")
message = model.newMessageBruteForce("location")
# message lists should be non-persistent by default
assert message.getPersistent() == False
# Settiog the persistent value ot true should not throw
message.setPersistent(True)
# The value should now be true
assert message.getPersistent() == True
# Set it to true again, to make sure it isn't an invert
message.setPersistent(True)
assert message.getPersistent() == True
# And flip it back to false for good measure
message.setPersistent(False)
assert message.getPersistent() == False

def test_PersistenceOff(self):
"""Test for persistence / non persistence of messaging, by emitting messages on even iters, but reading on all iters.
"""
model = pyflamegpu.ModelDescription("TestMessage_BruteForce")
message = model.newMessageBruteForce("msg")
message.newVariableInt("id")

# agent
agent = model.newAgent(AGENT_NAME)
agent.newVariableInt("id")
agent.newVariableUInt("count", 0) # Count the number of messages read
agent.newVariableUInt("sum", 0) # Count of IDs
ouf = agent.newRTCFunction("out", out_simple)
ouf.setMessageOutput("msg")
ouf.setMessageOutputOptional(True)
ouf.setRTCFunctionCondition(even_only_condition)
inf = agent.newRTCFunction("in", in_simple)
inf.setMessageInput("msg")

# Define layers
model.newLayer().addAgentFunction(ouf)
model.newLayer().addAgentFunction(inf)
# init function for pop
init_population_even_output_only = InitPopulationEvenOutputOnly()
model.addInitFunctionCallback(init_population_even_output_only)
# add a step function which validates the correct number of messages was read
assert_even_output_only = AssertEvenOutputOnly()
model.addStepFunctionCallback(assert_even_output_only)

cudaSimulation = pyflamegpu.CUDASimulation(model)
# Execute model
cudaSimulation.SimulationConfig().steps = 2
cudaSimulation.simulate()

def test_PersistenceOn(self):
"""Test for persistence / non persistence of messaging, by emitting messages on even iters, but reading on all iters.
"""
model = pyflamegpu.ModelDescription("TestMessage_BruteForce")
message = model.newMessageBruteForce("msg")
message.newVariableInt("id")

# agent
agent = model.newAgent(AGENT_NAME)
agent.newVariableInt("id")
agent.newVariableUInt("count", 0) # Count the number of messages read
agent.newVariableUInt("sum", 0) # Count of IDs
ouf = agent.newRTCFunction("out", out_simple)
ouf.setMessageOutput("msg")
ouf.setMessageOutputOptional(True)
ouf.setRTCFunctionCondition(even_only_condition)
inf = agent.newRTCFunction("in", in_simple)
inf.setMessageInput("msg")

# Define layers
model.newLayer().addAgentFunction(ouf)
model.newLayer().addAgentFunction(inf)
# init function for pop
init_population_even_output_only = InitPopulationEvenOutputOnly()
model.addInitFunctionCallback(init_population_even_output_only)
# add a step function which validates the correct number of messages was read
assert_persistent = AssertPersistent()
model.addStepFunctionCallback(assert_persistent)

cudaSimulation = pyflamegpu.CUDASimulation(model)
# Execute model
cudaSimulation.SimulationConfig().steps = 2
cudaSimulation.simulate()
Loading

0 comments on commit de88c97

Please sign in to comment.