File size: 7,304 Bytes
c115883
3bc2cfb
c115883
3bc2cfb
44c7f77
c115883
 
 
 
44c7f77
 
3bc2cfb
c115883
 
 
 
 
 
 
5a98ee7
 
 
 
 
 
c115883
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3bc2cfb
2a0d582
 
c115883
 
 
 
 
 
 
 
5a98ee7
 
 
c115883
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a0d582
 
5a98ee7
 
c115883
 
 
 
 
 
 
3bc2cfb
 
 
c115883
 
44c7f77
 
 
 
 
 
 
 
 
 
 
c115883
44c7f77
 
 
 
 
 
 
 
 
 
c115883
 
5a98ee7
 
 
44c7f77
 
 
 
3bc2cfb
44c7f77
 
 
 
 
 
 
c115883
 
 
 
44c7f77
c115883
 
 
 
3bc2cfb
5a98ee7
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
import glob
import time
import uuid
import gradio as gr
from htrflow.pipeline.pipeline import Pipeline
from htrflow.pipeline.steps import init_step
import os
from htrflow.volume.volume import Collection

from htrflow.pipeline.steps import auto_import
import yaml

MAX_IMAGES = int(os.environ.get("MAX_IMAGES", 5))  # env: Maximum allowed images


class PipelineWithProgress(Pipeline):
    @classmethod
    def from_config(cls, config: dict[str, str]):
        """Init pipeline from config, ensuring the correct subclass is instantiated."""
        return cls(
            [
                init_step(step["step"], step.get("settings", {}))
                for step in config["steps"]
            ]
        )

    def run(self, collection, start=0, progress=None):
        """
        Run pipeline on collection with Gradio progress support.
        If progress is provided, it updates the Gradio progress bar during execution.
        """
        total_steps = len(self.steps[start:])
        for i, step in enumerate(self.steps[start:]):
            step_name = f"{step} (step {start + i + 1} / {total_steps})"

            try:
                progress((i + 1) / total_steps, desc=f"Running {step_name}")
                collection = step.run(collection)

            except Exception:
                if self.pickle_path:
                    gr.Error(
                        f"HTRflow: Pipeline failed on step {step_name}. A backup collection is saved at {self.pickle_path}"
                    )
                else:
                    gr.Error(
                        f"HTRflow: Pipeline failed on step {step_name}",
                    )
                raise
        return collection


def rewrite_export_dests(config):
    """
    Rewrite the 'dest' in all 'Export' steps to include 'tmp' and a UUID.
    Returns:
        - A new config object with the updated 'dest' values.
        - A list of all updated 'dest' paths.
    """
    new_config = {"steps": []}
    updated_paths = []

    unique_id = str(uuid.uuid4())

    for step in config.get("steps", []):
        new_step = step.copy()
        if new_step.get("step") == "Export":
            settings = new_step.get("settings", {})
            if "dest" in settings:
                new_dest = os.path.join("tmp", unique_id, settings["dest"])
                settings["dest"] = new_dest
                updated_paths.append(new_dest)
        new_config["steps"].append(new_step)

    return new_config, updated_paths


def run_htrflow(custom_template_yaml, batch_image_gallery, progress=gr.Progress()):
    """
    Executes the HTRflow pipeline based on the provided YAML configuration and batch images.
    Args:
        custom_template_yaml (str): YAML string specifying the HTRflow pipeline configuration.
        batch_image_gallery (list): List of uploaded images to process in the pipeline.
    Returns:
        tuple: A collection of processed items, list of exported file paths, and a Gradio update object.
    """

    if custom_template_yaml is None or len(custom_template_yaml) < 1:
        gr.Warning("HTRflow: Please insert a HTRflow-yaml template")
    try:
        config = yaml.safe_load(custom_template_yaml)
    except Exception as e:
        gr.Warning(f"HTRflow: Error loading YAML configuration: {e}")
        return gr.skip()

    temp_config, tmp_output_paths = rewrite_export_dests(config)

    progress(0, desc="HTRflow: Starting")
    time.sleep(0.3)

    print(temp_config)

    if batch_image_gallery is None:
        gr.Warning("HTRflow: You must upload atleast 1 image or more")

    images = [temp_img[0] for temp_img in batch_image_gallery]

    pipe = PipelineWithProgress.from_config(temp_config)
    collections = auto_import(images)

    gr.Info(
        f"HTRflow: processing {len(images)} {'image' if len(images) == 1 else 'images'}."
    )
    progress(0.1, desc="HTRflow: Processing")

    for collection in collections:
        if "labels" in temp_config:
            collection.set_label_format(**temp_config["labels"])

        collection.label = "HTRflow_demo_output"
        collection: Collection = pipe.run(collection, progress=progress)

    exported_files = tracking_exported_files(tmp_output_paths)

    time.sleep(0.5)
    progress(1, desc="HTRflow: Finish")
    gr.Info("HTRflow: Finish")

    yield collection, exported_files, gr.skip()


def tracking_exported_files(tmp_output_paths):
    """
    Look for files with specific extensions in the provided tmp_output_paths,
    including subdirectories. Eliminates duplicate files.

    Args:
        tmp_output_paths (list): List of temporary output directories to search.

    Returns:
        list: Unique paths of all matching files found in the directories.
    """
    accepted_extensions = {".txt", ".xml", ".json"}

    exported_files = set()

    print(tmp_output_paths)

    # TODO: fix so that we get the file extension for page and alto...

    for tmp_folder in tmp_output_paths:
        for ext in accepted_extensions:
            search_pattern = os.path.join(tmp_folder, "**", f"*{ext}")
            matching_files = glob.glob(search_pattern, recursive=True)
            exported_files.update(matching_files)

    return sorted(exported_files)


with gr.Blocks() as submit:
    collection_submit_state = gr.State()

    with gr.Column(variant="panel"):
        with gr.Group():
            with gr.Row():
                with gr.Column(scale=1):
                    batch_image_gallery = gr.Gallery(
                        file_types=["image"],
                        label="Upload the images you want to transcribe",
                        interactive=True,
                        height=400,
                        object_fit="cover",
                        columns=5,
                        # preview=True,
                    )

                with gr.Column(scale=1):
                    custom_template_yaml = gr.Code(
                        value="",
                        language="yaml",
                        label="Pipeline",
                        interactive=True,
                    )
        with gr.Row():
            run_button = gr.Button("Submit", variant="primary", scale=0, min_width=200)
            progess_bar = gr.Textbox(visible=False, show_label=False)
            collection_output_files = gr.Files(
                label="Output Files", scale=0, min_width=400, visible=False
            )

    @batch_image_gallery.upload(
        inputs=batch_image_gallery,
        outputs=[batch_image_gallery],
    )
    def validate_images(images):
        if len(images) > 5:
            gr.Warning(f"Maximum images you can upload is set to: {MAX_IMAGES}")
            return gr.update(value=None)
        return images

    run_button.click(
        lambda: (gr.update(visible=True), gr.update(visible=False)),
        outputs=[progess_bar, collection_output_files],
    ).then(
        fn=run_htrflow,
        inputs=[custom_template_yaml, batch_image_gallery],
        outputs=[collection_submit_state, collection_output_files, progess_bar],
    ).then(
        lambda: (gr.update(visible=False), gr.update(visible=True)),
        outputs=[progess_bar, collection_output_files],
    )

# TODO: valudate yaml before submitting...?
# TODO: Add toast gr.Warning: Lose previues run...