Spaces:
Runtime error
Runtime error
Update glide_text2im/adv.py
Browse files- glide_text2im/adv.py +5 -5
glide_text2im/adv.py
CHANGED
|
@@ -4,7 +4,7 @@ import torch.nn.functional as F
|
|
| 4 |
import torch.optim as optim
|
| 5 |
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
|
| 6 |
from .nn import mean_flat
|
| 7 |
-
from . import dist_util
|
| 8 |
import functools
|
| 9 |
|
| 10 |
class AdversarialLoss(nn.Module):
|
|
@@ -16,11 +16,11 @@ class AdversarialLoss(nn.Module):
|
|
| 16 |
self.gan_type = gan_type
|
| 17 |
self.gan_k = gan_k
|
| 18 |
|
| 19 |
-
model = NLayerDiscriminator().
|
| 20 |
self.discriminator = DDP(
|
| 21 |
model,
|
| 22 |
-
device_ids=[
|
| 23 |
-
output_device=
|
| 24 |
broadcast_buffers=False,
|
| 25 |
bucket_cap_mb=128,
|
| 26 |
find_unused_parameters=False,
|
|
@@ -41,7 +41,7 @@ class AdversarialLoss(nn.Module):
|
|
| 41 |
if (self.gan_type.find('WGAN') >= 0):
|
| 42 |
loss_d = (d_fake - d_real).mean()
|
| 43 |
if self.gan_type.find('GP') >= 0:
|
| 44 |
-
epsilon = torch.rand(real.size(0), 1, 1, 1).
|
| 45 |
epsilon = epsilon.expand(real.size())
|
| 46 |
hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon)
|
| 47 |
hat.requires_grad = True
|
|
|
|
| 4 |
import torch.optim as optim
|
| 5 |
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
|
| 6 |
from .nn import mean_flat
|
| 7 |
+
#from . import dist_util
|
| 8 |
import functools
|
| 9 |
|
| 10 |
class AdversarialLoss(nn.Module):
|
|
|
|
| 16 |
self.gan_type = gan_type
|
| 17 |
self.gan_k = gan_k
|
| 18 |
|
| 19 |
+
model = NLayerDiscriminator().cuda()
|
| 20 |
self.discriminator = DDP(
|
| 21 |
model,
|
| 22 |
+
device_ids=[torch.device('cuda')],
|
| 23 |
+
output_device=torch.device('cuda'),
|
| 24 |
broadcast_buffers=False,
|
| 25 |
bucket_cap_mb=128,
|
| 26 |
find_unused_parameters=False,
|
|
|
|
| 41 |
if (self.gan_type.find('WGAN') >= 0):
|
| 42 |
loss_d = (d_fake - d_real).mean()
|
| 43 |
if self.gan_type.find('GP') >= 0:
|
| 44 |
+
epsilon = torch.rand(real.size(0), 1, 1, 1).cuda()
|
| 45 |
epsilon = epsilon.expand(real.size())
|
| 46 |
hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon)
|
| 47 |
hat.requires_grad = True
|