File size: 6,234 Bytes
a72119e
 
 
 
 
 
496112d
 
 
 
8365126
 
 
a72119e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de54836
684bc28
496112d
 
6d754a8
d6302a9
 
d3daa33
d6302a9
6d754a8
496112d
6d754a8
26bcd6d
 
 
 
8f6aa89
 
6d754a8
293e082
8f6aa89
26bcd6d
 
 
 
 
 
 
 
 
 
 
 
 
8f6aa89
26bcd6d
 
 
 
 
 
 
 
 
3d15a76
26bcd6d
8f6aa89
26bcd6d
 
 
 
 
 
 
 
 
6d754a8
 
26bcd6d
6d754a8
 
4902bd9
6d754a8
 
26bcd6d
8365126
d3daa33
26bcd6d
 
 
 
8f6aa89
293e082
8f6aa89
6d754a8
de54836
d3daa33
a72119e
 
6d754a8
a72119e
 
 
 
 
6d754a8
e63457c
d3daa33
 
e63457c
 
6d754a8
 
2189235
6d754a8
293e082
 
 
8f6aa89
 
 
 
 
 
4902bd9
 
b1d6fce
d3daa33
6d754a8
4902bd9
d3daa33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d15a76
d3daa33
 
b8c17c8
609f960
26bcd6d
d3daa33
 
 
 
26a50b2
b8c17c8
2189235
d3daa33
684bc28
26bcd6d
2189235
d3daa33
a72119e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import gradio as gr
from loadimg import load_img
import spaces
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
import moviepy.editor as mp
from pydub import AudioSegment
from PIL import Image
import numpy as np
import os
import tempfile
import uuid

torch.set_float32_matmul_precision(["high", "highest"][0])

birefnet = AutoModelForImageSegmentation.from_pretrained(
    "ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cuda")
transform_image = transforms.Compose(
    [
        transforms.Resize((1024, 1024)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

@spaces.GPU
def fn(vid, bg_type="Color", bg_image=None, color="#00FF00", fps=0):
    # Load the video using moviepy
    video = mp.VideoFileClip(vid)

    # Load original fps if fps value is equal to 0
    if fps == 0:
        fps = video.fps

    # Extract audio from the video
    audio = video.audio

    # Process video in chunks of 1 second
    chunk_duration = 1  # seconds
    total_duration = video.duration
    start_time = 0
        
    progress = f'<div class="progress-container"><div class="progress-bar" style="--current: {start_time}; --total: {total_duration};"></div></div>'

    processed_frames = []
    yield gr.update(visible=True), gr.update(visible=False), progress

    while start_time < total_duration:
        end_time = min(start_time + chunk_duration, total_duration)
        chunk = video.subclip(start_time, end_time)
        chunk_frames = chunk.iter_frames(fps=fps)

        for frame in chunk_frames:
            pil_image = Image.fromarray(frame)
            if bg_type == "Color":
                processed_image = process(pil_image, color)
            else:
                processed_image = process(pil_image, bg_image)
            processed_frames.append(np.array(processed_image))
            yield processed_image, None, progress

        # Save processed frames for the current chunk
        temp_dir = "temp"
        os.makedirs(temp_dir, exist_ok=True)
        for i, frame in enumerate(processed_frames):
            Image.fromarray(frame).save(os.path.join(temp_dir, f"frame_{start_time}_{i}.png"))

        # Clear processed frames for the current chunk
        processed_frames = []
        progress = f'<div class="progress-container"><div class="progress-bar" style="--current: {start_time}; --total: {total_duration};"></div></div>'

        yield None, None, progress

        start_time += chunk_duration

    # Load all saved frames
    all_frames = []
    for filename in sorted(os.listdir(temp_dir)):
        if filename.startswith("frame_") and filename.endswith(".png"):
            frame = np.array(Image.open(os.path.join(temp_dir, filename)))
            all_frames.append(frame)

    # Create a new video from the processed frames
    processed_video = mp.ImageSequenceClip(all_frames, fps=fps)

    # Add the original audio back to the processed video
    processed_video = processed_video.set_audio(audio)

    # Save the processed video to a temporary file
    temp_filepath = os.path.join(temp_dir, "processed_video.mp4")
    processed_video.write_videofile(temp_filepath, codec="libx264")

    # Clean up temporary files
    for filename in os.listdir(temp_dir):
        os.remove(os.path.join(temp_dir, filename))

    yield gr.update(visible=False), gr.update(visible=True), progress
    # Return the path to the temporary file
    yield processed_image, temp_filepath, progress


def process(image, bg):
    image_size = image.size
    input_images = transform_image(image).unsqueeze(0).to("cuda")
    # Prediction
    with torch.no_grad():
        preds = birefnet(input_images)[-1].sigmoid().cpu()
    pred = preds[0].squeeze()
    pred_pil = transforms.ToPILImage()(pred)
    mask = pred_pil.resize(image_size)

    if bg.startswith("#"):
        color_rgb = tuple(int(bg[i : i + 2], 16) for i in (1, 3, 5))
        background = Image.new("RGBA", image_size, color_rgb + (255,))
    else:
        background = Image.open(bg).convert("RGBA").resize(image_size)

    # Composite the image onto the background using the mask
    image = Image.composite(image, background, mask)

    return image


css="""
.progress-container {width: 100%;height: 30px;background-color: #f0f0f0;border-radius: 15px;overflow: hidden;margin-bottom: 20px}
.progress-bar {height: 100%;background-color: #4f46e5;width: calc(var(--current) / var(--total) * 100%);transition: width 0.5s ease-in-out}
"""

with gr.Blocks(css=css, theme="ocean") as demo:
    with gr.Row():
        in_video = gr.Video(label="Input Video")
        stream_image = gr.Image(label="Streaming Output", visible=False)
        out_video = gr.Video(label="Final Output Video")
    submit_button = gr.Button("Change Background")
    with gr.Row():
        fps_slider = gr.Slider(
            minimum=0,
            maximum=60,
            step=1,
            value=0,
            label="Output FPS (0 will inherit the original fps value)",
        )
        bg_type = gr.Radio(["Color", "Image"], label="Background Type", value="Color")
        color_picker = gr.ColorPicker(label="Background Color", value="#00FF00", visible=True)
        bg_image = gr.Image(label="Background Image", type="filepath", visible=False)

    def update_visibility(bg_type):
        if bg_type == "Color":
            return gr.update(visible=True), gr.update(visible=False)
        else:
            return gr.update(visible=False), gr.update(visible=True)

    bg_type.change(update_visibility, inputs=bg_type, outputs=[color_picker, bg_image])

    progress_bar = gr.Markdown(elem_id="progress")

    examples = gr.Examples(
        [["rickroll-2sec.mp4", "Image", "images.webp"], ["rickroll-2sec.mp4", "Color", None]],
        inputs=[in_video, bg_type, bg_image],
        outputs=[stream_image, out_video, progress_bar],
        fn=fn,
        cache_examples=True,
        cache_mode="eager",
    )


    submit_button.click(
        fn,
        inputs=[in_video, bg_type, bg_image, color_picker, fps_slider],
        outputs=[stream_image, out_video, progress_bar],
    )

if __name__ == "__main__":
    demo.launch(show_error=True)