luckycanucky commited on
Commit
f5a3de2
·
verified ·
1 Parent(s): db25782

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -121
app.py CHANGED
@@ -2,138 +2,83 @@ import numpy as np
2
  import cv2
3
  import onnxruntime
4
  import gradio as gr
5
- from PIL import Image
6
 
7
- # === Upscaler Logic ===
8
  def pre_process(img: np.array) -> np.array:
9
- img = np.transpose(img[:, :, :3], (2, 0, 1))
10
- return np.expand_dims(img, axis=0).astype(np.float32)
 
 
 
11
 
12
 
13
  def post_process(img: np.array) -> np.array:
 
14
  img = np.squeeze(img)
15
- return np.transpose(img, (1, 2, 0))[:, :, ::-1].astype(np.uint8)
16
-
17
- # ONNX inference with GPU if available
18
- def get_session(model_path: str):
19
- # cache sessions
20
- if model_path not in get_session.cache:
21
- opts = onnxruntime.SessionOptions()
22
- # multi-threading
23
- opts.intra_op_num_threads = 1
24
- opts.inter_op_num_threads = 1
25
- # try GPU first
26
- providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
27
- get_session.cache[model_path] = onnxruntime.InferenceSession(
28
- model_path,
29
- sess_options=opts,
30
- providers=providers
31
- )
32
- return get_session.cache[model_path]
33
- get_session.cache = {}
34
 
35
 
36
  def inference(model_path: str, img_array: np.array) -> np.array:
37
- session = get_session(model_path)
38
- inputs = {session.get_inputs()[0].name: img_array}
39
- return session.run(None, inputs)[0]
40
-
41
- # PIL to BGR conversion
42
- def convert_pil_to_cv2(image: Image.Image) -> np.array:
43
- arr = np.array(image)
44
- if arr.ndim == 2:
45
- return cv2.cvtColor(arr, cv2.COLOR_GRAY2BGR)
46
- return arr[:, :, ::-1].copy()
47
-
48
- # Upscale handler
49
- def upscale(image, model_choice):
50
- model_path = f"models/{model_choice}.ort"
 
 
 
 
 
 
51
  img = convert_pil_to_cv2(image)
52
- # handle potential alpha channel
53
  if img.ndim == 2:
54
  img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
55
 
56
  if img.shape[2] == 4:
57
- alpha = cv2.cvtColor(img[:, :, 3], cv2.COLOR_GRAY2BGR)
58
- out_a = post_process(inference(model_path, pre_process(alpha)))
59
- out_a = cv2.cvtColor(out_a, cv2.COLOR_BGR2GRAY)
60
- rgb = img[:, :, :3]
61
- out_rgb = post_process(inference(model_path, pre_process(rgb)))
62
- rgba = cv2.cvtColor(out_rgb, cv2.COLOR_BGR2BGRA)
63
- rgba[:, :, 3] = out_a
64
- return rgba
65
- return post_process(inference(model_path, pre_process(img)))
66
-
67
- # === Custom CSS for styling & animations ===
68
- custom_css = """
69
- body .gradio-container {
70
- background: linear-gradient(-45deg, #ff9a9e, #fad0c4, #ffdde1);
71
- background-size: 400% 400%;
72
- animation: gradientBG 15s ease infinite;
73
- }
74
- @keyframes gradientBG {
75
- 0% { background-position: 0% 50%; }
76
- 50% { background-position: 100% 50%; }
77
- 100% { background-position: 0% 50%; }
78
- }
79
- .fancy-title {
80
- font-family: 'Poppins', sans-serif;
81
- font-size: 3rem;
82
- background: linear-gradient(90deg, #7F7FD5, #86A8E7, #91EAE4);
83
- -webkit-background-clip: text;
84
- -webkit-text-fill-color: transparent;
85
- animation: fadeInText 2s ease-in-out;
86
- text-align: center;
87
- }
88
- @keyframes fadeInText {
89
- 0% { opacity: 0; transform: translateY(-20px); }
90
- 100% { opacity: 1; transform: translateY(0); }
91
- }
92
- .gradio-image {
93
- animation: fadeIn 1s ease-in;
94
- border-radius: 12px;
95
- box-shadow: 0 8px 16px rgba(0,0,0,0.2);
96
- }
97
- @keyframes fadeIn {
98
- from { opacity: 0; }
99
- to { opacity: 1; }
100
- }
101
- .gradio-radio input[type=\"radio\"] + label:hover {
102
- transform: scale(1.1);
103
- transition: transform 0.2s;
104
- }
105
- .gradio-button {
106
- background: linear-gradient(90deg, #FF8A00, #E52E71);
107
- border: none;
108
- border-radius: 8px;
109
- color: white;
110
- font-weight: bold;
111
- padding: 12px 24px;
112
- cursor: pointer;
113
- transition: background 0.3s, transform 0.2s;
114
- }
115
- .gradio-button:hover {
116
- background: linear-gradient(90deg, #E52E71, #FF8A00);
117
- transform: scale(1.05);
118
- }
119
- #upscale_btn {
120
- margin-top: 10px;
121
- }
122
- """
123
-
124
- # === Gradio Blocks App ===
125
- with gr.Blocks(css=custom_css) as demo:
126
- gr.HTML("<h1 class='fancy-title'>✨ Ultra AI Image Upscaler ✨</h1>")
127
- with gr.Row():
128
- inp = gr.Image(type="pil", label="Drop Your Image Here")
129
- model = gr.Radio([
130
- "modelx2", "modelx2_25JXL", "modelx4", "minecraft_modelx4"
131
- ], label="Upscaler Model", value="modelx2")
132
- btn = gr.Button("Upscale Image", elem_id="upscale_btn")
133
- out = gr.Image(label="Upscaled Output", elem_classes="gradio-image")
134
-
135
- btn.click(fn=upscale, inputs=[inp, model], outputs=out)
136
- gr.HTML("<p style='text-align:center; color:#555;'>Powered by ONNX Runtime & Gradio Blocks</p>")
137
-
138
- if __name__ == "__main__":
139
- demo.launch()
 
2
  import cv2
3
  import onnxruntime
4
  import gradio as gr
 
5
 
6
+
7
  def pre_process(img: np.array) -> np.array:
8
+ # H, W, C -> C, H, W
9
+ img = np.transpose(img[:, :, 0:3], (2, 0, 1))
10
+ # C, H, W -> 1, C, H, W
11
+ img = np.expand_dims(img, axis=0).astype(np.float32)
12
+ return img
13
 
14
 
15
  def post_process(img: np.array) -> np.array:
16
+ # 1, C, H, W -> C, H, W
17
  img = np.squeeze(img)
18
+ # C, H, W -> H, W, C
19
+ img = np.transpose(img, (1, 2, 0))[:, :, ::-1].astype(np.uint8)
20
+ return img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
 
23
  def inference(model_path: str, img_array: np.array) -> np.array:
24
+ options = onnxruntime.SessionOptions()
25
+ options.intra_op_num_threads = 1
26
+ options.inter_op_num_threads = 1
27
+ ort_session = onnxruntime.InferenceSession(model_path, options)
28
+ ort_inputs = {ort_session.get_inputs()[0].name: img_array}
29
+ ort_outs = ort_session.run(None, ort_inputs)
30
+
31
+ return ort_outs[0]
32
+
33
+
34
+ def convert_pil_to_cv2(image):
35
+ # pil_image = image.convert("RGB")
36
+ open_cv_image = np.array(image)
37
+ # RGB to BGR
38
+ open_cv_image = open_cv_image[:, :, ::-1].copy()
39
+ return open_cv_image
40
+
41
+
42
+ def upscale(image, model):
43
+ model_path = f"models/{model}.ort"
44
  img = convert_pil_to_cv2(image)
 
45
  if img.ndim == 2:
46
  img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
47
 
48
  if img.shape[2] == 4:
49
+ alpha = img[:, :, 3] # GRAY
50
+ alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2BGR) # BGR
51
+ alpha_output = post_process(inference(model_path, pre_process(alpha))) # BGR
52
+ alpha_output = cv2.cvtColor(alpha_output, cv2.COLOR_BGR2GRAY) # GRAY
53
+
54
+ img = img[:, :, 0:3] # BGR
55
+ image_output = post_process(inference(model_path, pre_process(img))) # BGR
56
+ image_output = cv2.cvtColor(image_output, cv2.COLOR_BGR2BGRA) # BGRA
57
+ image_output[:, :, 3] = alpha_output
58
+
59
+ elif img.shape[2] == 3:
60
+ image_output = post_process(inference(model_path, pre_process(img))) # BGR
61
+
62
+ return image_output
63
+
64
+
65
+ css = ".output-image, .input-image, .image-preview {height: 480px !important} "
66
+ model_choices = ["modelx2", "modelx2 25 JXL", "modelx4", "minecraft_modelx4"]
67
+
68
+ gr.Interface(
69
+ fn=upscale,
70
+ inputs=[
71
+ gr.Image(type="pil", label="Input Image"),
72
+ gr.Radio(
73
+ model_choices,
74
+ type="value",
75
+ value=None,
76
+ label="Choose Upscaler",
77
+ ),
78
+ ],
79
+ outputs="image",
80
+ title="Image Upscaler! Multiple AI",
81
+ description="Model: [Anchor-based Plain Net for Mobile Image Super-Resolution](https://arxiv.org/abs/2105.09750). Repository: [SR Mobile PyTorch](https://github.com/w11wo/sr_mobile_pytorch)",
82
+ allow_flagging="never",
83
+ css=css,
84
+ ).launch()