Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Identifiable agents #2

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 54 additions & 13 deletions agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,24 +212,50 @@ def forward(self, desc):
class MessageProcessor(nn.Module):
'''Processes a received message from an agent'''

def __init__(self, m_dim, hid_dim, cuda):
def __init__(self, m_dim, hid_dim, cuda, identify_agents, num_agents):
super(MessageProcessor, self).__init__()
self.m_dim = m_dim
self.hid_dim = hid_dim
self.use_cuda = cuda
self.rnn = nn.GRUCell(self.m_dim, self.hid_dim)
self.identify_agents = identify_agents
self.num_agents = num_agents
if self.identify_agents:
self.rnn = nn.GRUCell(self.m_dim + self.num_agents, self.hid_dim)
else:
self.rnn = nn.GRUCell(self.m_dim, self.hid_dim)
self.reset_parameters()

def reset_parameters(self):
reset_parameters_util(self)

def forward(self, m, h, use_message):
def concat_identifier(self, message, m_identifier):
# Concatenate agent identifier with message
bs = message.size(0)
agent_identity = torch.zeros(bs, self.num_agents)
# -1 is the "blank message" option
if m_identifier != -1:
agent_identity[:, m_identifier] = 1
if self.use_cuda:
agent_identity = agent_identity.cuda()
# debuglogger.info(f'Agent identifier: {agent_identity}')
# debuglogger.info(f'Message: {message}')
new_message = torch.cat([agent_identity, message.data], dim=1)
# debuglogger.info(f'Combined shape: {new_message.shape}')
# debuglogger.info(f'Combined: {new_message}')
new_message = _Variable(new_message)
return new_message

def forward(self, m, h, use_message, m_identifier):
if use_message:
debuglogger.debug(f'Using message')
if self.identify_agents:
m = self.concat_identifier(m, m_identifier)
return self.rnn(m, h)
else:
debuglogger.debug(f'Ignoring message, using blank instead...')
blank_msg = _Variable(torch.zeros_like(m.data))
if self.identify_agents:
blank_msg = self.concat_identifier(blank_msg, m_identifier)
if self.use_cuda:
blank_msg = blank_msg.cuda()
return self.rnn(blank_msg, h)
Expand Down Expand Up @@ -292,6 +318,8 @@ def forward(self, y_scores, h_c, desc, training):
w_feats = w_binary
# debuglogger.debug(f'w_binary: {w_binary}')
else:
debuglogger.warn(f'Error: Training loop with real valued messages not vetted yet. Please set FLAGS.use_binary to true')
sys.exit()
w_feats = w_scores
w_probs = None
# debuglogger.info(f'Message : {w_feats}')
Expand Down Expand Up @@ -334,7 +362,9 @@ def __init__(self,
use_MLP,
cuda,
im_from_scratch,
dropout):
dropout,
identify_agents,
num_agents):
super(Agent, self).__init__()
self.im_feature_type = im_feature_type
self.im_feat_dim = im_feat_dim
Expand All @@ -348,14 +378,16 @@ def __init__(self,
self.use_MLP = use_MLP
self.attn_dim = attn_dim
self.use_cuda = cuda
self.identify_agents = identify_agents
self.num_agents = num_agents
if im_from_scratch:
self.image_processor = ImageProcessorFromScratch(
im_feat_dim, h_dim, use_attn, attn_dim, dropout)
else:
self.image_processor = ImageProcessor(
im_feat_dim, h_dim, use_attn, attn_dim)
self.text_processor = TextProcessor(desc_dim, h_dim)
self.message_processor = MessageProcessor(m_dim, h_dim, cuda)
self.message_processor = MessageProcessor(m_dim, h_dim, cuda, identify_agents, num_agents)
self.message_generator = MessageGenerator(m_dim, h_dim, use_binary)
self.reward_estimator = RewardEstimator(h_dim)
# Network for combining processed image and message representations
Expand Down Expand Up @@ -425,10 +457,10 @@ def predict_classes(self, h_c, desc_proc, batch_size):
debuglogger.debug(f'y: {y.size()}')
return y

def forward(self, x, m, t, desc, use_message, batch_size, training):
def forward(self, x, m, t, desc, use_message, batch_size, training, m_identifier):
"""
Update State:
h_z = message_processor(m, h_z)
h_z = message_processor(m, h_z, m_identifier)

Image processing
h_i = image_processor(x, h_z)
Expand Down Expand Up @@ -465,8 +497,9 @@ def forward(self, x, m, t, desc, use_message, batch_size, training):
desc: List of description vectors used in communication and predictions.
batch_size: size of batch
training: whether agent is training or not
m_identifier: identity of agent that sent the message
Output:
s, s_probs: A STOP bit and its associated probability, indicating whether the agent has decided to make a selection. The conversation will continue until both agents have selected STOP.
s, s_probs: A STOP bit and its associated probability, indicating whether the agent has decided to make a selection. If the exchange length is not set to FIXED then the conversation will continue until both agents have selected STOP.
w, w_probs: A binary message and the probability of each bit in the message being ``1``.
y: A prediction for each class described in the descriptions.
r: An estimate of the reward the agent will receive
Expand All @@ -482,7 +515,7 @@ def forward(self, x, m, t, desc, use_message, batch_size, training):
self.h_z = self.initial_state(batch_size)

# Process message sent from the other agent
self.h_z = self.message_processor(m, self.h_z, use_message)
self.h_z = self.message_processor(m, self.h_z, use_message, m_identifier)
debuglogger.debug(f'h_z: {self.h_z.size()}')

# Process the image
Expand Down Expand Up @@ -557,7 +590,9 @@ def forward(self, x, m, t, desc, use_message, batch_size, training):
dropout = 0.3
use_MLP = False
cuda = False
im_from_scratch = True
im_from_scratch = False
identify_agents = True
num_agents = 7
agent = Agent(im_feature_type,
im_feat_dim,
h_dim,
Expand All @@ -571,19 +606,25 @@ def forward(self, x, m, t, desc, use_message, batch_size, training):
use_MLP,
cuda,
im_from_scratch,
dropout)
dropout,
identify_agents,
num_agents)
print(agent)
total_params = sum([functools.reduce(lambda x, y: x * y, p.size(), 1.0)
for p in agent.parameters()])
image_proc_params = sum([functools.reduce(lambda x, y: x * y, p.size(), 1.0)
for p in agent.image_processor.parameters()])
print(f'Total params: {total_params}, image proc params: {image_proc_params}')
x = _Variable(torch.ones(batch_size, 3, im_feat_dim, im_feat_dim))
if im_from_scratch:
x = _Variable(torch.ones(batch_size, 3, im_feat_dim, im_feat_dim))
else:
x = _Variable(torch.ones(batch_size, im_feat_dim))
m = _Variable(torch.ones(batch_size, m_dim))
desc = _Variable(torch.ones(batch_size, num_classes, desc_dim))

agent_identifier = -1
for i in range(2):
s, w, y, r = agent(x, m, i, desc, use_message, batch_size, training)
s, w, y, r = agent(x, m, i, desc, use_message, batch_size, training, agent_identifier)
# print(f's_binary: {s[0]}')
# print(f's_probs: {s[1]}')
# print(f'w_binary: {w[0]}')
Expand Down
Loading