Spaces:
Runtime error
Runtime error
Update
Browse files
app.py
CHANGED
@@ -77,9 +77,9 @@ def make_transform(translate: tuple[float, float], angle: float) -> np.ndarray:
|
|
77 |
return mat
|
78 |
|
79 |
|
80 |
-
def generate_z(seed: int, device: torch.device) -> torch.Tensor:
|
81 |
-
return torch.from_numpy(np.random.RandomState(seed).randn(
|
82 |
-
|
83 |
|
84 |
|
85 |
@torch.inference_mode()
|
@@ -90,7 +90,7 @@ def generate_image(model_name: str, class_index: int, seed: int,
|
|
90 |
model = model_dict[model_name]
|
91 |
seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
|
92 |
|
93 |
-
z = generate_z(seed, device)
|
94 |
|
95 |
label = torch.zeros([1, model.c_dim], device=device)
|
96 |
class_index = round(class_index)
|
@@ -117,7 +117,7 @@ def load_model(model_name: str, device: torch.device) -> nn.Module:
|
|
117 |
model.eval()
|
118 |
model.to(device)
|
119 |
with torch.inference_mode():
|
120 |
-
z = torch.zeros((1,
|
121 |
label = torch.zeros([1, model.c_dim], device=device)
|
122 |
model(z, label)
|
123 |
return model
|
|
|
77 |
return mat
|
78 |
|
79 |
|
80 |
+
def generate_z(z_dim: int, seed: int, device: torch.device) -> torch.Tensor:
|
81 |
+
return torch.from_numpy(np.random.RandomState(seed).randn(
|
82 |
+
1, z_dim)).to(device).float()
|
83 |
|
84 |
|
85 |
@torch.inference_mode()
|
|
|
90 |
model = model_dict[model_name]
|
91 |
seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
|
92 |
|
93 |
+
z = generate_z(model.z_dim, seed, device)
|
94 |
|
95 |
label = torch.zeros([1, model.c_dim], device=device)
|
96 |
class_index = round(class_index)
|
|
|
117 |
model.eval()
|
118 |
model.to(device)
|
119 |
with torch.inference_mode():
|
120 |
+
z = torch.zeros((1, model.z_dim)).to(device)
|
121 |
label = torch.zeros([1, model.c_dim], device=device)
|
122 |
model(z, label)
|
123 |
return model
|