Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -2,138 +2,83 @@ import numpy as np
|
|
2 |
import cv2
|
3 |
import onnxruntime
|
4 |
import gradio as gr
|
5 |
-
from PIL import Image
|
6 |
|
7 |
-
|
8 |
def pre_process(img: np.array) -> np.array:
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
11 |
|
12 |
|
13 |
def post_process(img: np.array) -> np.array:
|
|
|
14 |
img = np.squeeze(img)
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
def get_session(model_path: str):
|
19 |
-
# cache sessions
|
20 |
-
if model_path not in get_session.cache:
|
21 |
-
opts = onnxruntime.SessionOptions()
|
22 |
-
# multi-threading
|
23 |
-
opts.intra_op_num_threads = 1
|
24 |
-
opts.inter_op_num_threads = 1
|
25 |
-
# try GPU first
|
26 |
-
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
|
27 |
-
get_session.cache[model_path] = onnxruntime.InferenceSession(
|
28 |
-
model_path,
|
29 |
-
sess_options=opts,
|
30 |
-
providers=providers
|
31 |
-
)
|
32 |
-
return get_session.cache[model_path]
|
33 |
-
get_session.cache = {}
|
34 |
|
35 |
|
36 |
def inference(model_path: str, img_array: np.array) -> np.array:
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
#
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
img = convert_pil_to_cv2(image)
|
52 |
-
# handle potential alpha channel
|
53 |
if img.ndim == 2:
|
54 |
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
55 |
|
56 |
if img.shape[2] == 4:
|
57 |
-
alpha =
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
}
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
.
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
.
|
93 |
-
animation: fadeIn 1s ease-in;
|
94 |
-
border-radius: 12px;
|
95 |
-
box-shadow: 0 8px 16px rgba(0,0,0,0.2);
|
96 |
-
}
|
97 |
-
@keyframes fadeIn {
|
98 |
-
from { opacity: 0; }
|
99 |
-
to { opacity: 1; }
|
100 |
-
}
|
101 |
-
.gradio-radio input[type=\"radio\"] + label:hover {
|
102 |
-
transform: scale(1.1);
|
103 |
-
transition: transform 0.2s;
|
104 |
-
}
|
105 |
-
.gradio-button {
|
106 |
-
background: linear-gradient(90deg, #FF8A00, #E52E71);
|
107 |
-
border: none;
|
108 |
-
border-radius: 8px;
|
109 |
-
color: white;
|
110 |
-
font-weight: bold;
|
111 |
-
padding: 12px 24px;
|
112 |
-
cursor: pointer;
|
113 |
-
transition: background 0.3s, transform 0.2s;
|
114 |
-
}
|
115 |
-
.gradio-button:hover {
|
116 |
-
background: linear-gradient(90deg, #E52E71, #FF8A00);
|
117 |
-
transform: scale(1.05);
|
118 |
-
}
|
119 |
-
#upscale_btn {
|
120 |
-
margin-top: 10px;
|
121 |
-
}
|
122 |
-
"""
|
123 |
-
|
124 |
-
# === Gradio Blocks App ===
|
125 |
-
with gr.Blocks(css=custom_css) as demo:
|
126 |
-
gr.HTML("<h1 class='fancy-title'>✨ Ultra AI Image Upscaler ✨</h1>")
|
127 |
-
with gr.Row():
|
128 |
-
inp = gr.Image(type="pil", label="Drop Your Image Here")
|
129 |
-
model = gr.Radio([
|
130 |
-
"modelx2", "modelx2_25JXL", "modelx4", "minecraft_modelx4"
|
131 |
-
], label="Upscaler Model", value="modelx2")
|
132 |
-
btn = gr.Button("Upscale Image", elem_id="upscale_btn")
|
133 |
-
out = gr.Image(label="Upscaled Output", elem_classes="gradio-image")
|
134 |
-
|
135 |
-
btn.click(fn=upscale, inputs=[inp, model], outputs=out)
|
136 |
-
gr.HTML("<p style='text-align:center; color:#555;'>Powered by ONNX Runtime & Gradio Blocks</p>")
|
137 |
-
|
138 |
-
if __name__ == "__main__":
|
139 |
-
demo.launch()
|
|
|
2 |
import cv2
|
3 |
import onnxruntime
|
4 |
import gradio as gr
|
|
|
5 |
|
6 |
+
|
7 |
def pre_process(img: np.array) -> np.array:
|
8 |
+
# H, W, C -> C, H, W
|
9 |
+
img = np.transpose(img[:, :, 0:3], (2, 0, 1))
|
10 |
+
# C, H, W -> 1, C, H, W
|
11 |
+
img = np.expand_dims(img, axis=0).astype(np.float32)
|
12 |
+
return img
|
13 |
|
14 |
|
15 |
def post_process(img: np.array) -> np.array:
|
16 |
+
# 1, C, H, W -> C, H, W
|
17 |
img = np.squeeze(img)
|
18 |
+
# C, H, W -> H, W, C
|
19 |
+
img = np.transpose(img, (1, 2, 0))[:, :, ::-1].astype(np.uint8)
|
20 |
+
return img
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
|
23 |
def inference(model_path: str, img_array: np.array) -> np.array:
|
24 |
+
options = onnxruntime.SessionOptions()
|
25 |
+
options.intra_op_num_threads = 1
|
26 |
+
options.inter_op_num_threads = 1
|
27 |
+
ort_session = onnxruntime.InferenceSession(model_path, options)
|
28 |
+
ort_inputs = {ort_session.get_inputs()[0].name: img_array}
|
29 |
+
ort_outs = ort_session.run(None, ort_inputs)
|
30 |
+
|
31 |
+
return ort_outs[0]
|
32 |
+
|
33 |
+
|
34 |
+
def convert_pil_to_cv2(image):
|
35 |
+
# pil_image = image.convert("RGB")
|
36 |
+
open_cv_image = np.array(image)
|
37 |
+
# RGB to BGR
|
38 |
+
open_cv_image = open_cv_image[:, :, ::-1].copy()
|
39 |
+
return open_cv_image
|
40 |
+
|
41 |
+
|
42 |
+
def upscale(image, model):
|
43 |
+
model_path = f"models/{model}.ort"
|
44 |
img = convert_pil_to_cv2(image)
|
|
|
45 |
if img.ndim == 2:
|
46 |
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
47 |
|
48 |
if img.shape[2] == 4:
|
49 |
+
alpha = img[:, :, 3] # GRAY
|
50 |
+
alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2BGR) # BGR
|
51 |
+
alpha_output = post_process(inference(model_path, pre_process(alpha))) # BGR
|
52 |
+
alpha_output = cv2.cvtColor(alpha_output, cv2.COLOR_BGR2GRAY) # GRAY
|
53 |
+
|
54 |
+
img = img[:, :, 0:3] # BGR
|
55 |
+
image_output = post_process(inference(model_path, pre_process(img))) # BGR
|
56 |
+
image_output = cv2.cvtColor(image_output, cv2.COLOR_BGR2BGRA) # BGRA
|
57 |
+
image_output[:, :, 3] = alpha_output
|
58 |
+
|
59 |
+
elif img.shape[2] == 3:
|
60 |
+
image_output = post_process(inference(model_path, pre_process(img))) # BGR
|
61 |
+
|
62 |
+
return image_output
|
63 |
+
|
64 |
+
|
65 |
+
css = ".output-image, .input-image, .image-preview {height: 480px !important} "
|
66 |
+
model_choices = ["modelx2", "modelx2 25 JXL", "modelx4", "minecraft_modelx4"]
|
67 |
+
|
68 |
+
gr.Interface(
|
69 |
+
fn=upscale,
|
70 |
+
inputs=[
|
71 |
+
gr.Image(type="pil", label="Input Image"),
|
72 |
+
gr.Radio(
|
73 |
+
model_choices,
|
74 |
+
type="value",
|
75 |
+
value=None,
|
76 |
+
label="Choose Upscaler",
|
77 |
+
),
|
78 |
+
],
|
79 |
+
outputs="image",
|
80 |
+
title="Image Upscaler! Multiple AI",
|
81 |
+
description="Model: [Anchor-based Plain Net for Mobile Image Super-Resolution](https://arxiv.org/abs/2105.09750). Repository: [SR Mobile PyTorch](https://github.com/w11wo/sr_mobile_pytorch)",
|
82 |
+
allow_flagging="never",
|
83 |
+
css=css,
|
84 |
+
).launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|