xizaoqu commited on
Commit
e128dab
·
1 Parent(s): 9c45273
Files changed (2) hide show
  1. algorithms/worldmem/df_video.py +20 -12
  2. app.py +16 -9
algorithms/worldmem/df_video.py CHANGED
@@ -792,39 +792,46 @@ class WorldMemMinecraft(DiffusionForcingBase):
792
 
793
  @torch.no_grad()
794
  def interactive(self, first_frame, curr_actions, first_pose, context_frames_idx, device,
795
- self_frames, self_poses, self_memory_c2w, self_frame_idx):
796
 
797
  condition_similar_length = self.condition_similar_length
798
 
799
  if self_frames is None:
 
 
 
800
  first_frame_encode = self.encode(first_frame[None, None].to(device))
801
  self_frames = first_frame_encode.cpu()
802
- self.actions = curr_actions[None, None].to(device)
803
  self_poses = first_pose[None, None].to(device)
804
  new_c2w_mat = euler_to_camera_to_world_matrix(first_pose)
805
  self_memory_c2w = new_c2w_mat[None, None].to(device)
806
  self_frame_idx = torch.tensor([[context_frames_idx]]).to(device)
807
- return first_frame.cpu(), self_frames.cpu().numpy(), self_poses.cpu(), self_memory_c2w.cpu(), self_frame_idx.cpu()
808
  else:
 
 
 
 
 
 
 
809
  last_frame = self_frames[-1].clone()
810
- self_poses = self_poses.to(device)
811
- self_memory_c2w = self_memory_c2w.to(device)
812
- self_frame_idx = self_frame_idx.to(device)
813
  last_pose_condition = self_poses[-1].clone()
814
  last_pose_condition[:,3:] = last_pose_condition[:,3:] // 15
815
- new_pose_condition_offset = self.pose_prediction_model(last_frame.to(device), curr_actions[None].to(device), last_pose_condition)
816
 
817
  new_pose_condition_offset[:,3:] = torch.round(new_pose_condition_offset[:,3:])
818
  new_pose_condition = last_pose_condition + new_pose_condition_offset
819
  new_pose_condition[:,3:] = new_pose_condition[:,3:] * 15
820
  new_pose_condition[:,3:] %= 360
821
- self.actions = torch.cat([self.actions, curr_actions[None, None].to(device)])
822
- self_poses = torch.cat([self_poses, new_pose_condition[None].to(device)])
823
  new_c2w_mat = euler_to_camera_to_world_matrix(new_pose_condition)
824
- self_memory_c2w = torch.cat([self_memory_c2w, new_c2w_mat[None].to(device)])
825
  self_frame_idx = torch.cat([self_frame_idx, torch.tensor([[context_frames_idx]]).to(device)])
826
 
827
- conditions = self.actions.clone()
828
  pose_conditions = self_poses.clone()
829
  c2w_mat = self_memory_c2w .clone()
830
  frame_idx = self_frame_idx.clone()
@@ -903,7 +910,8 @@ class WorldMemMinecraft(DiffusionForcingBase):
903
 
904
  xs_pred = self.decode(xs_pred[n_context_frames:].to(device)).cpu()
905
 
906
- return xs_pred[-1,0].cpu(), self_frames.cpu(), self_poses.cpu(), self_memory_c2w.cpu(), self_frame_idx.cpu()
 
907
 
908
 
909
  def reset(self):
 
792
 
793
  @torch.no_grad()
794
  def interactive(self, first_frame, curr_actions, first_pose, context_frames_idx, device,
795
+ self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx):
796
 
797
  condition_similar_length = self.condition_similar_length
798
 
799
  if self_frames is None:
800
+ first_frame = torch.from_numpy(first_frame)
801
+ curr_actions = torch.from_numpy(curr_actions)
802
+ first_pose = torch.from_numpy(first_pose)
803
  first_frame_encode = self.encode(first_frame[None, None].to(device))
804
  self_frames = first_frame_encode.cpu()
805
+ self_actions = curr_actions[None, None].to(device)
806
  self_poses = first_pose[None, None].to(device)
807
  new_c2w_mat = euler_to_camera_to_world_matrix(first_pose)
808
  self_memory_c2w = new_c2w_mat[None, None].to(device)
809
  self_frame_idx = torch.tensor([[context_frames_idx]]).to(device)
810
+ return first_frame.cpu(), self_frames.cpu().numpy(), self_actions.cpu().numpy(), self_poses.cpu().numpy(), self_memory_c2w.cpu().numpy(), self_frame_idx.cpu().numpy()
811
  else:
812
+ self_frames = torch.from_numpy(self_frames)
813
+ self_actions = torch.from_numpy(self_actions).to(device)
814
+ self_poses = torch.from_numpy(self_poses).to(device)
815
+ self_memory_c2w = torch.from_numpy(self_memory_c2w).to(device)
816
+ self_frame_idx = torch.from_numpy(self_frame_idx).to(device)
817
+ curr_actions = curr_actions.to(device)
818
+
819
  last_frame = self_frames[-1].clone()
 
 
 
820
  last_pose_condition = self_poses[-1].clone()
821
  last_pose_condition[:,3:] = last_pose_condition[:,3:] // 15
822
+ new_pose_condition_offset = self.pose_prediction_model(last_frame.to(device), curr_actions[None], last_pose_condition)
823
 
824
  new_pose_condition_offset[:,3:] = torch.round(new_pose_condition_offset[:,3:])
825
  new_pose_condition = last_pose_condition + new_pose_condition_offset
826
  new_pose_condition[:,3:] = new_pose_condition[:,3:] * 15
827
  new_pose_condition[:,3:] %= 360
828
+ self_actions = torch.cat([self_actions, curr_actions[None, None]])
829
+ self_poses = torch.cat([self_poses, new_pose_condition[None]])
830
  new_c2w_mat = euler_to_camera_to_world_matrix(new_pose_condition)
831
+ self_memory_c2w = torch.cat([self_memory_c2w, new_c2w_mat[None]])
832
  self_frame_idx = torch.cat([self_frame_idx, torch.tensor([[context_frames_idx]]).to(device)])
833
 
834
+ conditions = self_actions.clone()
835
  pose_conditions = self_poses.clone()
836
  c2w_mat = self_memory_c2w .clone()
837
  frame_idx = self_frame_idx.clone()
 
910
 
911
  xs_pred = self.decode(xs_pred[n_context_frames:].to(device)).cpu()
912
 
913
+ return xs_pred[-1,0].cpu().numpy(), self_frames.cpu().numpy(), self_actions.cpu().numpy(), \
914
+ self_poses.cpu().numpy(), self_memory_c2w.cpu().numpy(), self_frame_idx.cpu().numpy()
915
 
916
 
917
  def reset(self):
app.py CHANGED
@@ -177,30 +177,33 @@ load_custom_checkpoint(algo=worldmem.pose_prediction_model, checkpoint_path=cfg.
177
  worldmem.to("cuda").eval()
178
 
179
 
180
- actions = torch.zeros((1, 25))
181
- poses = torch.zeros((1, 5))
182
 
183
  memory_frames.append(load_image_as_tensor(DEFAULT_IMAGE))
184
 
185
  self_frames = None
 
186
  self_poses = None
187
  self_memory_c2w = None
188
  self_frame_idx = None
189
 
190
 
191
  @spaces.GPU()
192
- def run_interactive(first_frame, action, first_pose, curr_frame, device, self_frames, self_poses, self_memory_c2w, self_frame_idx):
193
- new_frame, self_frames, self_poses, self_memory_c2w, self_frame_idx = worldmem.interactive(first_frame,
 
194
  action,
195
  first_pose,
196
  curr_frame,
197
  device=device,
198
  self_frames=self_frames,
 
199
  self_poses=self_poses,
200
  self_memory_c2w=self_memory_c2w,
201
  self_frame_idx=self_frame_idx)
202
- # return new_frame, self_frames, self_poses, self_memory_c2w, self_frame_idx
203
- return self_frames[:,:,0,0,0]
204
 
205
  def set_denoising_steps(denoising_steps, sampling_timesteps_state):
206
  worldmem.sampling_timesteps = denoising_steps
@@ -215,6 +218,7 @@ def generate(keys):
215
  global input_history
216
  global memory_curr_frame
217
  global self_frames
 
218
  global self_poses
219
  global self_memory_c2w
220
  global self_frame_idx
@@ -222,12 +226,13 @@ def generate(keys):
222
  for i in range(len(actions)):
223
  memory_curr_frame += 1
224
 
225
- new_frame, self_frames, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0],
226
  actions[i],
227
  None,
228
  memory_curr_frame,
229
  device=device,
230
  self_frames=self_frames,
 
231
  self_poses=self_poses,
232
  self_memory_c2w=self_memory_c2w,
233
  self_frame_idx=self_frame_idx)
@@ -254,6 +259,7 @@ def reset():
254
  global input_history
255
  global memory_frames
256
  global self_frames
 
257
  global self_poses
258
  global self_memory_c2w
259
  global self_frame_idx
@@ -263,16 +269,17 @@ def reset():
263
  self_memory_c2w = None
264
  self_frame_idx = None
265
  memory_frames = []
266
- memory_frames.append(load_image_as_tensor(DEFAULT_IMAGE))
267
  memory_curr_frame = 0
268
  input_history = ""
269
 
270
- self_frames = run_interactive(memory_frames[0],
271
  actions[0],
272
  poses[0],
273
  memory_curr_frame,
274
  device=device,
275
  self_frames=self_frames,
 
276
  self_poses=self_poses,
277
  self_memory_c2w=self_memory_c2w,
278
  self_frame_idx=self_frame_idx)
 
177
  worldmem.to("cuda").eval()
178
 
179
 
180
+ actions = np.zeros((1, 25), dtype=np.float32)
181
+ poses = np.zeros((1, 5), dtype=np.float32)
182
 
183
  memory_frames.append(load_image_as_tensor(DEFAULT_IMAGE))
184
 
185
  self_frames = None
186
+ self_actions = None
187
  self_poses = None
188
  self_memory_c2w = None
189
  self_frame_idx = None
190
 
191
 
192
  @spaces.GPU()
193
+ def run_interactive(first_frame, action, first_pose, curr_frame, device, self_frames, self_actions,
194
+ self_poses, self_memory_c2w, self_frame_idx):
195
+ new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = worldmem.interactive(first_frame,
196
  action,
197
  first_pose,
198
  curr_frame,
199
  device=device,
200
  self_frames=self_frames,
201
+ self_actions=self_actions,
202
  self_poses=self_poses,
203
  self_memory_c2w=self_memory_c2w,
204
  self_frame_idx=self_frame_idx)
205
+
206
+ return new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
207
 
208
  def set_denoising_steps(denoising_steps, sampling_timesteps_state):
209
  worldmem.sampling_timesteps = denoising_steps
 
218
  global input_history
219
  global memory_curr_frame
220
  global self_frames
221
+ global self_actions
222
  global self_poses
223
  global self_memory_c2w
224
  global self_frame_idx
 
226
  for i in range(len(actions)):
227
  memory_curr_frame += 1
228
 
229
+ new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0],
230
  actions[i],
231
  None,
232
  memory_curr_frame,
233
  device=device,
234
  self_frames=self_frames,
235
+ self_actions=self_actions,
236
  self_poses=self_poses,
237
  self_memory_c2w=self_memory_c2w,
238
  self_frame_idx=self_frame_idx)
 
259
  global input_history
260
  global memory_frames
261
  global self_frames
262
+ global self_actions
263
  global self_poses
264
  global self_memory_c2w
265
  global self_frame_idx
 
269
  self_memory_c2w = None
270
  self_frame_idx = None
271
  memory_frames = []
272
+ memory_frames.append(load_image_as_tensor(DEFAULT_IMAGE).numpy())
273
  memory_curr_frame = 0
274
  input_history = ""
275
 
276
+ new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0],
277
  actions[0],
278
  poses[0],
279
  memory_curr_frame,
280
  device=device,
281
  self_frames=self_frames,
282
+ self_actions=self_actions,
283
  self_poses=self_poses,
284
  self_memory_c2w=self_memory_c2w,
285
  self_frame_idx=self_frame_idx)