xizaoqu
commited on
Commit
·
e128dab
1
Parent(s):
9c45273
update
Browse files- algorithms/worldmem/df_video.py +20 -12
- app.py +16 -9
algorithms/worldmem/df_video.py
CHANGED
@@ -792,39 +792,46 @@ class WorldMemMinecraft(DiffusionForcingBase):
|
|
792 |
|
793 |
@torch.no_grad()
|
794 |
def interactive(self, first_frame, curr_actions, first_pose, context_frames_idx, device,
|
795 |
-
self_frames, self_poses, self_memory_c2w, self_frame_idx):
|
796 |
|
797 |
condition_similar_length = self.condition_similar_length
|
798 |
|
799 |
if self_frames is None:
|
|
|
|
|
|
|
800 |
first_frame_encode = self.encode(first_frame[None, None].to(device))
|
801 |
self_frames = first_frame_encode.cpu()
|
802 |
-
|
803 |
self_poses = first_pose[None, None].to(device)
|
804 |
new_c2w_mat = euler_to_camera_to_world_matrix(first_pose)
|
805 |
self_memory_c2w = new_c2w_mat[None, None].to(device)
|
806 |
self_frame_idx = torch.tensor([[context_frames_idx]]).to(device)
|
807 |
-
return first_frame.cpu(), self_frames.cpu().numpy(), self_poses.cpu(), self_memory_c2w.cpu(), self_frame_idx.cpu()
|
808 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
809 |
last_frame = self_frames[-1].clone()
|
810 |
-
self_poses = self_poses.to(device)
|
811 |
-
self_memory_c2w = self_memory_c2w.to(device)
|
812 |
-
self_frame_idx = self_frame_idx.to(device)
|
813 |
last_pose_condition = self_poses[-1].clone()
|
814 |
last_pose_condition[:,3:] = last_pose_condition[:,3:] // 15
|
815 |
-
new_pose_condition_offset = self.pose_prediction_model(last_frame.to(device), curr_actions[None]
|
816 |
|
817 |
new_pose_condition_offset[:,3:] = torch.round(new_pose_condition_offset[:,3:])
|
818 |
new_pose_condition = last_pose_condition + new_pose_condition_offset
|
819 |
new_pose_condition[:,3:] = new_pose_condition[:,3:] * 15
|
820 |
new_pose_condition[:,3:] %= 360
|
821 |
-
|
822 |
-
self_poses = torch.cat([self_poses, new_pose_condition[None]
|
823 |
new_c2w_mat = euler_to_camera_to_world_matrix(new_pose_condition)
|
824 |
-
self_memory_c2w = torch.cat([self_memory_c2w, new_c2w_mat[None]
|
825 |
self_frame_idx = torch.cat([self_frame_idx, torch.tensor([[context_frames_idx]]).to(device)])
|
826 |
|
827 |
-
conditions =
|
828 |
pose_conditions = self_poses.clone()
|
829 |
c2w_mat = self_memory_c2w .clone()
|
830 |
frame_idx = self_frame_idx.clone()
|
@@ -903,7 +910,8 @@ class WorldMemMinecraft(DiffusionForcingBase):
|
|
903 |
|
904 |
xs_pred = self.decode(xs_pred[n_context_frames:].to(device)).cpu()
|
905 |
|
906 |
-
return xs_pred[-1,0].cpu(), self_frames.cpu()
|
|
|
907 |
|
908 |
|
909 |
def reset(self):
|
|
|
792 |
|
793 |
@torch.no_grad()
|
794 |
def interactive(self, first_frame, curr_actions, first_pose, context_frames_idx, device,
|
795 |
+
self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx):
|
796 |
|
797 |
condition_similar_length = self.condition_similar_length
|
798 |
|
799 |
if self_frames is None:
|
800 |
+
first_frame = torch.from_numpy(first_frame)
|
801 |
+
curr_actions = torch.from_numpy(curr_actions)
|
802 |
+
first_pose = torch.from_numpy(first_pose)
|
803 |
first_frame_encode = self.encode(first_frame[None, None].to(device))
|
804 |
self_frames = first_frame_encode.cpu()
|
805 |
+
self_actions = curr_actions[None, None].to(device)
|
806 |
self_poses = first_pose[None, None].to(device)
|
807 |
new_c2w_mat = euler_to_camera_to_world_matrix(first_pose)
|
808 |
self_memory_c2w = new_c2w_mat[None, None].to(device)
|
809 |
self_frame_idx = torch.tensor([[context_frames_idx]]).to(device)
|
810 |
+
return first_frame.cpu(), self_frames.cpu().numpy(), self_actions.cpu().numpy(), self_poses.cpu().numpy(), self_memory_c2w.cpu().numpy(), self_frame_idx.cpu().numpy()
|
811 |
else:
|
812 |
+
self_frames = torch.from_numpy(self_frames)
|
813 |
+
self_actions = torch.from_numpy(self_actions).to(device)
|
814 |
+
self_poses = torch.from_numpy(self_poses).to(device)
|
815 |
+
self_memory_c2w = torch.from_numpy(self_memory_c2w).to(device)
|
816 |
+
self_frame_idx = torch.from_numpy(self_frame_idx).to(device)
|
817 |
+
curr_actions = curr_actions.to(device)
|
818 |
+
|
819 |
last_frame = self_frames[-1].clone()
|
|
|
|
|
|
|
820 |
last_pose_condition = self_poses[-1].clone()
|
821 |
last_pose_condition[:,3:] = last_pose_condition[:,3:] // 15
|
822 |
+
new_pose_condition_offset = self.pose_prediction_model(last_frame.to(device), curr_actions[None], last_pose_condition)
|
823 |
|
824 |
new_pose_condition_offset[:,3:] = torch.round(new_pose_condition_offset[:,3:])
|
825 |
new_pose_condition = last_pose_condition + new_pose_condition_offset
|
826 |
new_pose_condition[:,3:] = new_pose_condition[:,3:] * 15
|
827 |
new_pose_condition[:,3:] %= 360
|
828 |
+
self_actions = torch.cat([self_actions, curr_actions[None, None]])
|
829 |
+
self_poses = torch.cat([self_poses, new_pose_condition[None]])
|
830 |
new_c2w_mat = euler_to_camera_to_world_matrix(new_pose_condition)
|
831 |
+
self_memory_c2w = torch.cat([self_memory_c2w, new_c2w_mat[None]])
|
832 |
self_frame_idx = torch.cat([self_frame_idx, torch.tensor([[context_frames_idx]]).to(device)])
|
833 |
|
834 |
+
conditions = self_actions.clone()
|
835 |
pose_conditions = self_poses.clone()
|
836 |
c2w_mat = self_memory_c2w .clone()
|
837 |
frame_idx = self_frame_idx.clone()
|
|
|
910 |
|
911 |
xs_pred = self.decode(xs_pred[n_context_frames:].to(device)).cpu()
|
912 |
|
913 |
+
return xs_pred[-1,0].cpu().numpy(), self_frames.cpu().numpy(), self_actions.cpu().numpy(), \
|
914 |
+
self_poses.cpu().numpy(), self_memory_c2w.cpu().numpy(), self_frame_idx.cpu().numpy()
|
915 |
|
916 |
|
917 |
def reset(self):
|
app.py
CHANGED
@@ -177,30 +177,33 @@ load_custom_checkpoint(algo=worldmem.pose_prediction_model, checkpoint_path=cfg.
|
|
177 |
worldmem.to("cuda").eval()
|
178 |
|
179 |
|
180 |
-
actions =
|
181 |
-
poses =
|
182 |
|
183 |
memory_frames.append(load_image_as_tensor(DEFAULT_IMAGE))
|
184 |
|
185 |
self_frames = None
|
|
|
186 |
self_poses = None
|
187 |
self_memory_c2w = None
|
188 |
self_frame_idx = None
|
189 |
|
190 |
|
191 |
@spaces.GPU()
|
192 |
-
def run_interactive(first_frame, action, first_pose, curr_frame, device, self_frames,
|
193 |
-
|
|
|
194 |
action,
|
195 |
first_pose,
|
196 |
curr_frame,
|
197 |
device=device,
|
198 |
self_frames=self_frames,
|
|
|
199 |
self_poses=self_poses,
|
200 |
self_memory_c2w=self_memory_c2w,
|
201 |
self_frame_idx=self_frame_idx)
|
202 |
-
|
203 |
-
return self_frames
|
204 |
|
205 |
def set_denoising_steps(denoising_steps, sampling_timesteps_state):
|
206 |
worldmem.sampling_timesteps = denoising_steps
|
@@ -215,6 +218,7 @@ def generate(keys):
|
|
215 |
global input_history
|
216 |
global memory_curr_frame
|
217 |
global self_frames
|
|
|
218 |
global self_poses
|
219 |
global self_memory_c2w
|
220 |
global self_frame_idx
|
@@ -222,12 +226,13 @@ def generate(keys):
|
|
222 |
for i in range(len(actions)):
|
223 |
memory_curr_frame += 1
|
224 |
|
225 |
-
new_frame, self_frames, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0],
|
226 |
actions[i],
|
227 |
None,
|
228 |
memory_curr_frame,
|
229 |
device=device,
|
230 |
self_frames=self_frames,
|
|
|
231 |
self_poses=self_poses,
|
232 |
self_memory_c2w=self_memory_c2w,
|
233 |
self_frame_idx=self_frame_idx)
|
@@ -254,6 +259,7 @@ def reset():
|
|
254 |
global input_history
|
255 |
global memory_frames
|
256 |
global self_frames
|
|
|
257 |
global self_poses
|
258 |
global self_memory_c2w
|
259 |
global self_frame_idx
|
@@ -263,16 +269,17 @@ def reset():
|
|
263 |
self_memory_c2w = None
|
264 |
self_frame_idx = None
|
265 |
memory_frames = []
|
266 |
-
memory_frames.append(load_image_as_tensor(DEFAULT_IMAGE))
|
267 |
memory_curr_frame = 0
|
268 |
input_history = ""
|
269 |
|
270 |
-
self_frames = run_interactive(memory_frames[0],
|
271 |
actions[0],
|
272 |
poses[0],
|
273 |
memory_curr_frame,
|
274 |
device=device,
|
275 |
self_frames=self_frames,
|
|
|
276 |
self_poses=self_poses,
|
277 |
self_memory_c2w=self_memory_c2w,
|
278 |
self_frame_idx=self_frame_idx)
|
|
|
177 |
worldmem.to("cuda").eval()
|
178 |
|
179 |
|
180 |
+
actions = np.zeros((1, 25), dtype=np.float32)
|
181 |
+
poses = np.zeros((1, 5), dtype=np.float32)
|
182 |
|
183 |
memory_frames.append(load_image_as_tensor(DEFAULT_IMAGE))
|
184 |
|
185 |
self_frames = None
|
186 |
+
self_actions = None
|
187 |
self_poses = None
|
188 |
self_memory_c2w = None
|
189 |
self_frame_idx = None
|
190 |
|
191 |
|
192 |
@spaces.GPU()
|
193 |
+
def run_interactive(first_frame, action, first_pose, curr_frame, device, self_frames, self_actions,
|
194 |
+
self_poses, self_memory_c2w, self_frame_idx):
|
195 |
+
new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = worldmem.interactive(first_frame,
|
196 |
action,
|
197 |
first_pose,
|
198 |
curr_frame,
|
199 |
device=device,
|
200 |
self_frames=self_frames,
|
201 |
+
self_actions=self_actions,
|
202 |
self_poses=self_poses,
|
203 |
self_memory_c2w=self_memory_c2w,
|
204 |
self_frame_idx=self_frame_idx)
|
205 |
+
|
206 |
+
return new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx
|
207 |
|
208 |
def set_denoising_steps(denoising_steps, sampling_timesteps_state):
|
209 |
worldmem.sampling_timesteps = denoising_steps
|
|
|
218 |
global input_history
|
219 |
global memory_curr_frame
|
220 |
global self_frames
|
221 |
+
global self_actions
|
222 |
global self_poses
|
223 |
global self_memory_c2w
|
224 |
global self_frame_idx
|
|
|
226 |
for i in range(len(actions)):
|
227 |
memory_curr_frame += 1
|
228 |
|
229 |
+
new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0],
|
230 |
actions[i],
|
231 |
None,
|
232 |
memory_curr_frame,
|
233 |
device=device,
|
234 |
self_frames=self_frames,
|
235 |
+
self_actions=self_actions,
|
236 |
self_poses=self_poses,
|
237 |
self_memory_c2w=self_memory_c2w,
|
238 |
self_frame_idx=self_frame_idx)
|
|
|
259 |
global input_history
|
260 |
global memory_frames
|
261 |
global self_frames
|
262 |
+
global self_actions
|
263 |
global self_poses
|
264 |
global self_memory_c2w
|
265 |
global self_frame_idx
|
|
|
269 |
self_memory_c2w = None
|
270 |
self_frame_idx = None
|
271 |
memory_frames = []
|
272 |
+
memory_frames.append(load_image_as_tensor(DEFAULT_IMAGE).numpy())
|
273 |
memory_curr_frame = 0
|
274 |
input_history = ""
|
275 |
|
276 |
+
new_frame, self_frames, self_actions, self_poses, self_memory_c2w, self_frame_idx = run_interactive(memory_frames[0],
|
277 |
actions[0],
|
278 |
poses[0],
|
279 |
memory_curr_frame,
|
280 |
device=device,
|
281 |
self_frames=self_frames,
|
282 |
+
self_actions=self_actions,
|
283 |
self_poses=self_poses,
|
284 |
self_memory_c2w=self_memory_c2w,
|
285 |
self_frame_idx=self_frame_idx)
|