-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgloss.py
24 lines (21 loc) · 830 Bytes
/
gloss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch.nn as nn
import torch
class GANLoss(nn.Module):
def __init__(self, gan_mode='vanilla', real_label=1.0, fake_label=0.0):
super().__init__()
self.register_buffer('real_label', torch.tensor(real_label))
self.register_buffer('fake_label', torch.tensor(fake_label))
if gan_mode == 'vanilla':
self.loss = nn.BCEWithLogitsLoss()
elif gan_mode == 'lsgan':
self.loss = nn.MSELoss()
def get_labels(self, preds, target_is_real):
if target_is_real:
labels = self.real_label
else:
labels = self.fake_label
return labels.expand_as(preds)
def __call__(self, preds, target_is_real):
labels = self.get_labels(preds, target_is_real)
loss = self.loss(preds, labels)
return loss