Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,304 Bytes
c115883 3bc2cfb c115883 3bc2cfb 44c7f77 c115883 44c7f77 3bc2cfb c115883 5a98ee7 c115883 3bc2cfb 2a0d582 c115883 5a98ee7 c115883 2a0d582 5a98ee7 c115883 3bc2cfb c115883 44c7f77 c115883 44c7f77 c115883 5a98ee7 44c7f77 3bc2cfb 44c7f77 c115883 44c7f77 c115883 3bc2cfb 5a98ee7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 |
import glob
import time
import uuid
import gradio as gr
from htrflow.pipeline.pipeline import Pipeline
from htrflow.pipeline.steps import init_step
import os
from htrflow.volume.volume import Collection
from htrflow.pipeline.steps import auto_import
import yaml
MAX_IMAGES = int(os.environ.get("MAX_IMAGES", 5)) # env: Maximum allowed images
class PipelineWithProgress(Pipeline):
@classmethod
def from_config(cls, config: dict[str, str]):
"""Init pipeline from config, ensuring the correct subclass is instantiated."""
return cls(
[
init_step(step["step"], step.get("settings", {}))
for step in config["steps"]
]
)
def run(self, collection, start=0, progress=None):
"""
Run pipeline on collection with Gradio progress support.
If progress is provided, it updates the Gradio progress bar during execution.
"""
total_steps = len(self.steps[start:])
for i, step in enumerate(self.steps[start:]):
step_name = f"{step} (step {start + i + 1} / {total_steps})"
try:
progress((i + 1) / total_steps, desc=f"Running {step_name}")
collection = step.run(collection)
except Exception:
if self.pickle_path:
gr.Error(
f"HTRflow: Pipeline failed on step {step_name}. A backup collection is saved at {self.pickle_path}"
)
else:
gr.Error(
f"HTRflow: Pipeline failed on step {step_name}",
)
raise
return collection
def rewrite_export_dests(config):
"""
Rewrite the 'dest' in all 'Export' steps to include 'tmp' and a UUID.
Returns:
- A new config object with the updated 'dest' values.
- A list of all updated 'dest' paths.
"""
new_config = {"steps": []}
updated_paths = []
unique_id = str(uuid.uuid4())
for step in config.get("steps", []):
new_step = step.copy()
if new_step.get("step") == "Export":
settings = new_step.get("settings", {})
if "dest" in settings:
new_dest = os.path.join("tmp", unique_id, settings["dest"])
settings["dest"] = new_dest
updated_paths.append(new_dest)
new_config["steps"].append(new_step)
return new_config, updated_paths
def run_htrflow(custom_template_yaml, batch_image_gallery, progress=gr.Progress()):
"""
Executes the HTRflow pipeline based on the provided YAML configuration and batch images.
Args:
custom_template_yaml (str): YAML string specifying the HTRflow pipeline configuration.
batch_image_gallery (list): List of uploaded images to process in the pipeline.
Returns:
tuple: A collection of processed items, list of exported file paths, and a Gradio update object.
"""
if custom_template_yaml is None or len(custom_template_yaml) < 1:
gr.Warning("HTRflow: Please insert a HTRflow-yaml template")
try:
config = yaml.safe_load(custom_template_yaml)
except Exception as e:
gr.Warning(f"HTRflow: Error loading YAML configuration: {e}")
return gr.skip()
temp_config, tmp_output_paths = rewrite_export_dests(config)
progress(0, desc="HTRflow: Starting")
time.sleep(0.3)
print(temp_config)
if batch_image_gallery is None:
gr.Warning("HTRflow: You must upload atleast 1 image or more")
images = [temp_img[0] for temp_img in batch_image_gallery]
pipe = PipelineWithProgress.from_config(temp_config)
collections = auto_import(images)
gr.Info(
f"HTRflow: processing {len(images)} {'image' if len(images) == 1 else 'images'}."
)
progress(0.1, desc="HTRflow: Processing")
for collection in collections:
if "labels" in temp_config:
collection.set_label_format(**temp_config["labels"])
collection.label = "HTRflow_demo_output"
collection: Collection = pipe.run(collection, progress=progress)
exported_files = tracking_exported_files(tmp_output_paths)
time.sleep(0.5)
progress(1, desc="HTRflow: Finish")
gr.Info("HTRflow: Finish")
yield collection, exported_files, gr.skip()
def tracking_exported_files(tmp_output_paths):
"""
Look for files with specific extensions in the provided tmp_output_paths,
including subdirectories. Eliminates duplicate files.
Args:
tmp_output_paths (list): List of temporary output directories to search.
Returns:
list: Unique paths of all matching files found in the directories.
"""
accepted_extensions = {".txt", ".xml", ".json"}
exported_files = set()
print(tmp_output_paths)
# TODO: fix so that we get the file extension for page and alto...
for tmp_folder in tmp_output_paths:
for ext in accepted_extensions:
search_pattern = os.path.join(tmp_folder, "**", f"*{ext}")
matching_files = glob.glob(search_pattern, recursive=True)
exported_files.update(matching_files)
return sorted(exported_files)
with gr.Blocks() as submit:
collection_submit_state = gr.State()
with gr.Column(variant="panel"):
with gr.Group():
with gr.Row():
with gr.Column(scale=1):
batch_image_gallery = gr.Gallery(
file_types=["image"],
label="Upload the images you want to transcribe",
interactive=True,
height=400,
object_fit="cover",
columns=5,
# preview=True,
)
with gr.Column(scale=1):
custom_template_yaml = gr.Code(
value="",
language="yaml",
label="Pipeline",
interactive=True,
)
with gr.Row():
run_button = gr.Button("Submit", variant="primary", scale=0, min_width=200)
progess_bar = gr.Textbox(visible=False, show_label=False)
collection_output_files = gr.Files(
label="Output Files", scale=0, min_width=400, visible=False
)
@batch_image_gallery.upload(
inputs=batch_image_gallery,
outputs=[batch_image_gallery],
)
def validate_images(images):
if len(images) > 5:
gr.Warning(f"Maximum images you can upload is set to: {MAX_IMAGES}")
return gr.update(value=None)
return images
run_button.click(
lambda: (gr.update(visible=True), gr.update(visible=False)),
outputs=[progess_bar, collection_output_files],
).then(
fn=run_htrflow,
inputs=[custom_template_yaml, batch_image_gallery],
outputs=[collection_submit_state, collection_output_files, progess_bar],
).then(
lambda: (gr.update(visible=False), gr.update(visible=True)),
outputs=[progess_bar, collection_output_files],
)
# TODO: valudate yaml before submitting...?
# TODO: Add toast gr.Warning: Lose previues run...
|