hysts HF Staff commited on
Commit
3442e02
·
1 Parent(s): a965d0b
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -77,9 +77,9 @@ def make_transform(translate: tuple[float, float], angle: float) -> np.ndarray:
77
  return mat
78
 
79
 
80
- def generate_z(seed: int, device: torch.device) -> torch.Tensor:
81
- return torch.from_numpy(np.random.RandomState(seed).randn(1,
82
- 64)).to(device)
83
 
84
 
85
  @torch.inference_mode()
@@ -90,7 +90,7 @@ def generate_image(model_name: str, class_index: int, seed: int,
90
  model = model_dict[model_name]
91
  seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
92
 
93
- z = generate_z(seed, device)
94
 
95
  label = torch.zeros([1, model.c_dim], device=device)
96
  class_index = round(class_index)
@@ -117,7 +117,7 @@ def load_model(model_name: str, device: torch.device) -> nn.Module:
117
  model.eval()
118
  model.to(device)
119
  with torch.inference_mode():
120
- z = torch.zeros((1, 64)).to(device)
121
  label = torch.zeros([1, model.c_dim], device=device)
122
  model(z, label)
123
  return model
 
77
  return mat
78
 
79
 
80
+ def generate_z(z_dim: int, seed: int, device: torch.device) -> torch.Tensor:
81
+ return torch.from_numpy(np.random.RandomState(seed).randn(
82
+ 1, z_dim)).to(device).float()
83
 
84
 
85
  @torch.inference_mode()
 
90
  model = model_dict[model_name]
91
  seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
92
 
93
+ z = generate_z(model.z_dim, seed, device)
94
 
95
  label = torch.zeros([1, model.c_dim], device=device)
96
  class_index = round(class_index)
 
117
  model.eval()
118
  model.to(device)
119
  with torch.inference_mode():
120
+ z = torch.zeros((1, model.z_dim)).to(device)
121
  label = torch.zeros([1, model.c_dim], device=device)
122
  model(z, label)
123
  return model