Spaces:
Runtime error
Runtime error
File size: 1,163 Bytes
2fa4776 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import torch
import torch.nn.functional as F
def generator_loss(discriminator, inputs, reconstructions, cond=None):
if cond is None:
logits_fake = discriminator(reconstructions.contiguous())
else:
logits_fake = discriminator(
torch.cat((reconstructions.contiguous(), cond), dim=1)
)
g_loss = -torch.mean(logits_fake)
return g_loss
def hinge_d_loss(logits_real, logits_fake):
loss_real = torch.mean(F.relu(1.0 - logits_real))
loss_fake = torch.mean(F.relu(1.0 + logits_fake))
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
def discriminator_loss(discriminator, inputs, reconstructions, cond=None):
if cond is None:
logits_real = discriminator(inputs.contiguous().detach())
logits_fake = discriminator(reconstructions.contiguous().detach())
else:
logits_real = discriminator(
torch.cat((inputs.contiguous().detach(), cond), dim=1)
)
logits_fake = discriminator(
torch.cat((reconstructions.contiguous().detach(), cond), dim=1)
)
d_loss = hinge_d_loss(logits_real, logits_fake).mean()
return d_loss
|