lisonallen commited on
Commit
da097bc
·
1 Parent(s): 217c6bd

Fix: Make the Stop/结束生成 button work properly during generation

Browse files
Files changed (2) hide show
  1. app.py +66 -39
  2. diffusers_helper/k_diffusion/uni_pc_fm.py +34 -20
app.py CHANGED
@@ -671,21 +671,26 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
671
  last_update_time = time.time()
672
 
673
  try:
 
 
 
 
 
 
674
  preview = d['denoised']
675
  preview = vae_decode_fake(preview)
676
 
677
  preview = (preview * 255.0).detach().cpu().numpy().clip(0, 255).astype(np.uint8)
678
  preview = einops.rearrange(preview, 'b c t h w -> (b h) (t w) c')
679
 
680
- if stream.input_queue.top() == 'end':
681
- stream.output_queue.push(('end', None))
682
- raise KeyboardInterrupt('User ends the task.')
683
-
684
  current_step = d['i'] + 1
685
  percentage = int(100.0 * current_step / steps)
686
  hint = f'Sampling {current_step}/{steps}'
687
  desc = f'Total generated frames: {int(max(0, total_generated_latent_frames * 4 - 3))}, Video length: {max(0, (total_generated_latent_frames * 4 - 3) / 30) :.2f} seconds (FPS-30). The video is being extended now ...'
688
  stream.output_queue.push(('progress', (preview, desc, make_progress_bar_html(percentage, hint))))
 
 
 
689
  except Exception as e:
690
  print(f"回调函数中出错: {e}")
691
  # 不中断采样过程
@@ -695,38 +700,53 @@ def worker(input_image, prompt, n_prompt, seed, total_second_length, latent_wind
695
  sampling_start_time = time.time()
696
  print(f"开始采样,设备: {device}, 数据类型: {transformer.dtype}, 使用TeaCache: {use_teacache and not cpu_fallback_mode}")
697
 
698
- generated_latents = sample_hunyuan(
699
- transformer=transformer,
700
- sampler='unipc',
701
- width=width,
702
- height=height,
703
- frames=num_frames,
704
- real_guidance_scale=cfg,
705
- distilled_guidance_scale=gs,
706
- guidance_rescale=rs,
707
- # shift=3.0,
708
- num_inference_steps=steps,
709
- generator=rnd,
710
- prompt_embeds=llama_vec,
711
- prompt_embeds_mask=llama_attention_mask,
712
- prompt_poolers=clip_l_pooler,
713
- negative_prompt_embeds=llama_vec_n,
714
- negative_prompt_embeds_mask=llama_attention_mask_n,
715
- negative_prompt_poolers=clip_l_pooler_n,
716
- device=device,
717
- dtype=transformer.dtype,
718
- image_embeddings=image_encoder_last_hidden_state,
719
- latent_indices=latent_indices,
720
- clean_latents=clean_latents,
721
- clean_latent_indices=clean_latent_indices,
722
- clean_latents_2x=clean_latents_2x,
723
- clean_latent_2x_indices=clean_latent_2x_indices,
724
- clean_latents_4x=clean_latents_4x,
725
- clean_latent_4x_indices=clean_latent_4x_indices,
726
- callback=callback,
727
- )
728
-
729
- print(f"采样完成,用时: {time.time() - sampling_start_time:.2f}秒")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
730
  except Exception as e:
731
  print(f"采样过程中出错: {e}")
732
  traceback.print_exc()
@@ -887,7 +907,7 @@ if IN_HF_SPACE and 'spaces' in globals():
887
 
888
  if flag == 'progress':
889
  preview, desc, html = data
890
- # 更新进度时不改变错误信息
891
  yield gr.update(), gr.update(visible=True, value=preview), desc, html, gr.update(interactive=False), gr.update(interactive=True)
892
 
893
  if flag == 'error':
@@ -964,7 +984,7 @@ else:
964
 
965
  if flag == 'progress':
966
  preview, desc, html = data
967
- # 更新进度时不改变错误信息
968
  yield gr.update(), gr.update(visible=True, value=preview), desc, html, gr.update(interactive=False), gr.update(interactive=True)
969
 
970
  if flag == 'error':
@@ -1011,7 +1031,14 @@ else:
1011
 
1012
 
1013
  def end_process():
1014
- stream.input_queue.push('end')
 
 
 
 
 
 
 
1015
 
1016
 
1017
  quick_prompts = [
 
671
  last_update_time = time.time()
672
 
673
  try:
674
+ # 首先检查是否有停止信号
675
+ if stream.input_queue.top() == 'end':
676
+ print("检测到停止信号,中断采样过程...")
677
+ stream.output_queue.push(('end', None))
678
+ raise KeyboardInterrupt('用户主动结束任务')
679
+
680
  preview = d['denoised']
681
  preview = vae_decode_fake(preview)
682
 
683
  preview = (preview * 255.0).detach().cpu().numpy().clip(0, 255).astype(np.uint8)
684
  preview = einops.rearrange(preview, 'b c t h w -> (b h) (t w) c')
685
 
 
 
 
 
686
  current_step = d['i'] + 1
687
  percentage = int(100.0 * current_step / steps)
688
  hint = f'Sampling {current_step}/{steps}'
689
  desc = f'Total generated frames: {int(max(0, total_generated_latent_frames * 4 - 3))}, Video length: {max(0, (total_generated_latent_frames * 4 - 3) / 30) :.2f} seconds (FPS-30). The video is being extended now ...'
690
  stream.output_queue.push(('progress', (preview, desc, make_progress_bar_html(percentage, hint))))
691
+ except KeyboardInterrupt:
692
+ # 捕获并重新抛出中断异常,确保它能传播到采样函数
693
+ raise
694
  except Exception as e:
695
  print(f"回调函数中出错: {e}")
696
  # 不中断采样过程
 
700
  sampling_start_time = time.time()
701
  print(f"开始采样,设备: {device}, 数据类型: {transformer.dtype}, 使用TeaCache: {use_teacache and not cpu_fallback_mode}")
702
 
703
+ try:
704
+ generated_latents = sample_hunyuan(
705
+ transformer=transformer,
706
+ sampler='unipc',
707
+ width=width,
708
+ height=height,
709
+ frames=num_frames,
710
+ real_guidance_scale=cfg,
711
+ distilled_guidance_scale=gs,
712
+ guidance_rescale=rs,
713
+ # shift=3.0,
714
+ num_inference_steps=steps,
715
+ generator=rnd,
716
+ prompt_embeds=llama_vec,
717
+ prompt_embeds_mask=llama_attention_mask,
718
+ prompt_poolers=clip_l_pooler,
719
+ negative_prompt_embeds=llama_vec_n,
720
+ negative_prompt_embeds_mask=llama_attention_mask_n,
721
+ negative_prompt_poolers=clip_l_pooler_n,
722
+ device=device,
723
+ dtype=transformer.dtype,
724
+ image_embeddings=image_encoder_last_hidden_state,
725
+ latent_indices=latent_indices,
726
+ clean_latents=clean_latents,
727
+ clean_latent_indices=clean_latent_indices,
728
+ clean_latents_2x=clean_latents_2x,
729
+ clean_latent_2x_indices=clean_latent_2x_indices,
730
+ clean_latents_4x=clean_latents_4x,
731
+ clean_latent_4x_indices=clean_latent_4x_indices,
732
+ callback=callback,
733
+ )
734
+
735
+ print(f"采样完成,用时: {time.time() - sampling_start_time:.2f}秒")
736
+ except KeyboardInterrupt:
737
+ # 用户主动中断
738
+ print("用户主动中断采样过程")
739
+
740
+ # 如果已经有生成的视频,返回最后生成的视频
741
+ if last_output_filename:
742
+ stream.output_queue.push(('file', last_output_filename))
743
+ error_msg = "用户中断生成过程,但已生成部分视频"
744
+ else:
745
+ error_msg = "用户中断生成过程,未生成视频"
746
+
747
+ stream.output_queue.push(('error', error_msg))
748
+ stream.output_queue.push(('end', None))
749
+ return
750
  except Exception as e:
751
  print(f"采样过程中出错: {e}")
752
  traceback.print_exc()
 
907
 
908
  if flag == 'progress':
909
  preview, desc, html = data
910
+ # 更新进度时不改变错误信息,并确保停止按钮可交互
911
  yield gr.update(), gr.update(visible=True, value=preview), desc, html, gr.update(interactive=False), gr.update(interactive=True)
912
 
913
  if flag == 'error':
 
984
 
985
  if flag == 'progress':
986
  preview, desc, html = data
987
+ # 更新进度时不改变错误信息,并确保停止按钮可交互
988
  yield gr.update(), gr.update(visible=True, value=preview), desc, html, gr.update(interactive=False), gr.update(interactive=True)
989
 
990
  if flag == 'error':
 
1031
 
1032
 
1033
  def end_process():
1034
+ """停止生成过程函数 - 通过在队列中推送'end'信号来中断生成"""
1035
+ print("用户点击了停止按钮,发送停止信号...")
1036
+ # 确保stream已初始化
1037
+ if 'stream' in globals() and stream is not None:
1038
+ stream.input_queue.push('end')
1039
+ else:
1040
+ print("警告: stream未初始化,无法发送停止信号")
1041
+ return None
1042
 
1043
 
1044
  quick_prompts = [
diffusers_helper/k_diffusion/uni_pc_fm.py CHANGED
@@ -111,27 +111,41 @@ class FlowMatchUniPC:
111
  def sample(self, x, sigmas, callback=None, disable_pbar=False):
112
  order = min(3, len(sigmas) - 2)
113
  model_prev_list, t_prev_list = [], []
114
- for i in trange(len(sigmas) - 1, disable=disable_pbar):
115
- vec_t = sigmas[i].expand(x.shape[0])
116
-
117
- if i == 0:
118
- model_prev_list = [self.model_fn(x, vec_t)]
119
- t_prev_list = [vec_t]
120
- elif i < order:
121
- init_order = i
122
- x, model_x = self.update_fn(x, model_prev_list, t_prev_list, vec_t, init_order)
123
- model_prev_list.append(model_x)
124
- t_prev_list.append(vec_t)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  else:
126
- x, model_x = self.update_fn(x, model_prev_list, t_prev_list, vec_t, order)
127
- model_prev_list.append(model_x)
128
- t_prev_list.append(vec_t)
129
-
130
- model_prev_list = model_prev_list[-order:]
131
- t_prev_list = t_prev_list[-order:]
132
-
133
- if callback is not None:
134
- callback({'x': x, 'i': i, 'denoised': model_prev_list[-1]})
135
 
136
  return model_prev_list[-1]
137
 
 
111
  def sample(self, x, sigmas, callback=None, disable_pbar=False):
112
  order = min(3, len(sigmas) - 2)
113
  model_prev_list, t_prev_list = [], []
114
+ try:
115
+ for i in trange(len(sigmas) - 1, disable=disable_pbar):
116
+ vec_t = sigmas[i].expand(x.shape[0])
117
+
118
+ if i == 0:
119
+ model_prev_list = [self.model_fn(x, vec_t)]
120
+ t_prev_list = [vec_t]
121
+ elif i < order:
122
+ init_order = i
123
+ x, model_x = self.update_fn(x, model_prev_list, t_prev_list, vec_t, init_order)
124
+ model_prev_list.append(model_x)
125
+ t_prev_list.append(vec_t)
126
+ else:
127
+ x, model_x = self.update_fn(x, model_prev_list, t_prev_list, vec_t, order)
128
+ model_prev_list.append(model_x)
129
+ t_prev_list.append(vec_t)
130
+
131
+ model_prev_list = model_prev_list[-order:]
132
+ t_prev_list = t_prev_list[-order:]
133
+
134
+ if callback is not None:
135
+ try:
136
+ callback({'x': x, 'i': i, 'denoised': model_prev_list[-1]})
137
+ except KeyboardInterrupt as e:
138
+ print(f"User interruption detected: {e}")
139
+ # Return the last available result
140
+ return model_prev_list[-1]
141
+ except KeyboardInterrupt as e:
142
+ print(f"Process interrupted: {e}")
143
+ # Return the last available result if we have one
144
+ if model_prev_list:
145
+ return model_prev_list[-1]
146
  else:
147
+ # If no results yet, re-raise the exception
148
+ raise
 
 
 
 
 
 
 
149
 
150
  return model_prev_list[-1]
151