luckycanucky commited on
Commit
53edc5f
·
verified ·
1 Parent(s): 633ddbf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -41
app.py CHANGED
@@ -1,71 +1,102 @@
 
1
  import numpy as np
2
  import cv2
3
  import onnxruntime
4
  import gradio as gr
5
  from PIL import Image
6
 
7
- # === Upscaler Logic ===
8
- def pre_process(img: np.array) -> np.array:
9
- img = np.transpose(img[:, :, :3], (2, 0, 1))
 
 
 
10
  return np.expand_dims(img, axis=0).astype(np.float32)
11
 
12
 
13
- def post_process(img: np.array) -> np.array:
14
- img = np.squeeze(img)
15
- return np.transpose(img, (1, 2, 0))[:, :, ::-1].astype(np.uint8)
 
 
 
 
16
 
 
17
 
18
-
19
- def get_session(model_path: str):
20
  if model_path not in get_session.cache:
 
 
21
  opts = onnxruntime.SessionOptions()
22
  opts.intra_op_num_threads = 1
23
  opts.inter_op_num_threads = 1
24
 
25
- available = onnxruntime.get_available_providers()
26
- # Use GPU only if available, otherwise fall back to CPU
27
  providers = []
28
- if "CUDAExecutionProvider" in available:
29
- providers.append("CUDAExecutionProvider")
 
30
  providers.append("CPUExecutionProvider")
31
 
32
  sess = onnxruntime.InferenceSession(model_path, opts, providers=providers)
33
  get_session.cache[model_path] = sess
34
  return get_session.cache[model_path]
 
35
  get_session.cache = {}
36
 
37
 
38
- def inference(model_path: str, img_array: np.array) -> np.array:
39
  session = get_session(model_path)
40
- inputs = {session.get_inputs()[0].name: img_array}
41
- return session.run(None, inputs)[0]
42
 
43
- # PIL to BGR conversion
44
- def convert_pil_to_cv2(image: Image.Image) -> np.array:
45
  arr = np.array(image)
 
46
  if arr.ndim == 2:
47
  return cv2.cvtColor(arr, cv2.COLOR_GRAY2BGR)
48
- return arr[:, :, ::-1].copy()
49
-
50
- # Upscale handler
51
- def upscale(image, model_choice):
52
- model_path = f"models/{model_choice}.ort"
 
 
 
 
 
 
 
53
  img = convert_pil_to_cv2(image)
54
- if img.ndim == 2:
55
- img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
56
 
 
57
  if img.shape[2] == 4:
58
- alpha = cv2.cvtColor(img[:, :, 3], cv2.COLOR_GRAY2BGR)
59
- out_a = post_process(inference(model_path, pre_process(alpha)))
60
- out_a = cv2.cvtColor(out_a, cv2.COLOR_BGR2GRAY)
61
  rgb = img[:, :, :3]
62
- out_rgb = post_process(inference(model_path, pre_process(rgb)))
 
 
 
 
 
 
 
 
 
 
 
 
63
  rgba = cv2.cvtColor(out_rgb, cv2.COLOR_BGR2BGRA)
64
- rgba[:, :, 3] = out_a
65
  return rgba
66
- return post_process(inference(model_path, pre_process(img)))
67
 
68
- # === Dark Blue-Grey Theme CSS & Animations ===
 
 
 
 
69
  custom_css = """
70
  /* Dark Gradient Background */
71
  body .gradio-container {
@@ -78,7 +109,6 @@ body .gradio-container {
78
  50% { background-position: 100% 100%; }
79
  100% { background-position: 0% 0%; }
80
  }
81
-
82
  /* Title Styling */
83
  .fancy-title {
84
  font-family: 'Poppins', sans-serif;
@@ -94,7 +124,6 @@ body .gradio-container {
94
  0% { opacity: 0; transform: translateY(-10px); }
95
  100% { opacity: 1; transform: translateY(0); }
96
  }
97
-
98
  /* Inputs & Outputs */
99
  .gradio-image, .gradio-gallery {
100
  animation: fadeIn 1.2s ease-in;
@@ -106,14 +135,12 @@ body .gradio-container {
106
  from { opacity: 0; }
107
  to { opacity: 1; }
108
  }
109
-
110
  /* Radio Hover */
111
  .gradio-radio input[type="radio"] + label:hover {
112
  transform: scale(1.1);
113
  color: #e0e1dd;
114
  transition: transform 0.2s, color 0.2s;
115
  }
116
-
117
  /* Button Styling */
118
  .gradio-button {
119
  background: linear-gradient(90deg, #1b263b, #415a77);
@@ -130,8 +157,6 @@ body .gradio-container {
130
  background: linear-gradient(90deg, #415a77, #1b263b);
131
  transform: scale(1.03);
132
  }
133
-
134
- /* Layout tweaks */
135
  #upscale_btn { margin-top: 1rem; }
136
  .gradio-row { gap: 1rem; }
137
  """
@@ -141,13 +166,15 @@ with gr.Blocks(css=custom_css) as demo:
141
  gr.HTML("<h1 class='fancy-title'>✨ Ultra AI Image Upscaler ✨</h1>")
142
  with gr.Row():
143
  inp = gr.Image(type="pil", label="Drop Your Image Here")
144
- model = gr.Radio([
145
- "modelx2", "modelx2_25JXL", "modelx4", "minecraft_modelx4"
146
- ], label="Upscaler Model", value="modelx2")
 
 
147
  btn = gr.Button("Upscale Image", elem_id="upscale_btn")
148
  out = gr.Image(label="Upscaled Output")
149
  btn.click(fn=upscale, inputs=[inp, model], outputs=out)
150
  gr.HTML("<p style='text-align:center; color:#e0e1dd;'>Powered by ONNX Runtime & Gradio Blocks</p>")
151
 
152
  if __name__ == "__main__":
153
- demo.launch()
 
1
+ import os
2
  import numpy as np
3
  import cv2
4
  import onnxruntime
5
  import gradio as gr
6
  from PIL import Image
7
 
8
+ # === Pre-/Post-Processing ===
9
+ def pre_process(img: np.ndarray) -> np.ndarray:
10
+ # Convert HWC-BGR to CHW-RGB and batch
11
+ img = img[:, :, :3]
12
+ img = img[:, :, ::-1] # BGR to RGB
13
+ img = np.transpose(img, (2, 0, 1))
14
  return np.expand_dims(img, axis=0).astype(np.float32)
15
 
16
 
17
+ def post_process(out: np.ndarray) -> np.ndarray:
18
+ # Remove batch dimension, convert CHW-RGB to HWC-BGR
19
+ img = np.squeeze(out, axis=0)
20
+ img = np.transpose(img, (1, 2, 0))
21
+ img = img[:, :, ::-1] # RGB to BGR
22
+ img = np.clip(img, 0, 255).astype(np.uint8)
23
+ return img
24
 
25
+ # === ONNX Inference Session with Dynamic Providers ===
26
 
27
+ def get_session(model_path: str) -> onnxruntime.InferenceSession:
 
28
  if model_path not in get_session.cache:
29
+ if not os.path.isfile(model_path):
30
+ raise FileNotFoundError(f"Model file not found: {model_path}")
31
  opts = onnxruntime.SessionOptions()
32
  opts.intra_op_num_threads = 1
33
  opts.inter_op_num_threads = 1
34
 
35
+ # Select CUDA if available
 
36
  providers = []
37
+ for p in onnxruntime.get_available_providers():
38
+ if p == "CUDAExecutionProvider":
39
+ providers.append(p)
40
  providers.append("CPUExecutionProvider")
41
 
42
  sess = onnxruntime.InferenceSession(model_path, opts, providers=providers)
43
  get_session.cache[model_path] = sess
44
  return get_session.cache[model_path]
45
+
46
  get_session.cache = {}
47
 
48
 
49
+ def run_inference(model_path: str, input_tensor: np.ndarray) -> np.ndarray:
50
  session = get_session(model_path)
51
+ input_name = session.get_inputs()[0].name
52
+ return session.run(None, {input_name: input_tensor})[0]
53
 
54
+ # === Image Conversion ===
55
+ def convert_pil_to_cv2(image: Image.Image) -> np.ndarray:
56
  arr = np.array(image)
57
+ # If grayscale
58
  if arr.ndim == 2:
59
  return cv2.cvtColor(arr, cv2.COLOR_GRAY2BGR)
60
+ # If RGBA
61
+ if arr.shape[2] == 4:
62
+ return arr[:, :, ::-1].copy() # RGBA to ABGR
63
+ # RGB
64
+ return arr[:, :, ::-1].copy() # RGB to BGR
65
+
66
+ # === Upscale Handler ===
67
+ def upscale(image: Image.Image, model_choice: str) -> np.ndarray:
68
+ """
69
+ Upscale an image (RGB or RGBA) using the selected ONNX model.
70
+ """
71
+ model_path = os.path.join("models", f"{model_choice}.ort")
72
  img = convert_pil_to_cv2(image)
 
 
73
 
74
+ # Handle alpha channel separately
75
  if img.shape[2] == 4:
76
+ # Split channels
 
 
77
  rgb = img[:, :, :3]
78
+ alpha = img[:, :, 3]
79
+
80
+ # Process RGB
81
+ in_rgb = pre_process(rgb)
82
+ out_rgb = post_process(run_inference(model_path, in_rgb))
83
+
84
+ # Process alpha as grayscale
85
+ alpha_bgr = cv2.cvtColor(alpha, cv2.COLOR_GRAY2BGR)
86
+ in_alpha = pre_process(alpha_bgr)
87
+ out_alpha = post_process(run_inference(model_path, in_alpha))
88
+ out_alpha = cv2.cvtColor(out_alpha, cv2.COLOR_BGR2GRAY)
89
+
90
+ # Merge back to RGBA
91
  rgba = cv2.cvtColor(out_rgb, cv2.COLOR_BGR2BGRA)
92
+ rgba[:, :, 3] = out_alpha
93
  return rgba
 
94
 
95
+ # No alpha
96
+ inp = pre_process(img)
97
+ return post_process(run_inference(model_path, inp))
98
+
99
+ # === Custom Dark Blue-Grey CSS ===
100
  custom_css = """
101
  /* Dark Gradient Background */
102
  body .gradio-container {
 
109
  50% { background-position: 100% 100%; }
110
  100% { background-position: 0% 0%; }
111
  }
 
112
  /* Title Styling */
113
  .fancy-title {
114
  font-family: 'Poppins', sans-serif;
 
124
  0% { opacity: 0; transform: translateY(-10px); }
125
  100% { opacity: 1; transform: translateY(0); }
126
  }
 
127
  /* Inputs & Outputs */
128
  .gradio-image, .gradio-gallery {
129
  animation: fadeIn 1.2s ease-in;
 
135
  from { opacity: 0; }
136
  to { opacity: 1; }
137
  }
 
138
  /* Radio Hover */
139
  .gradio-radio input[type="radio"] + label:hover {
140
  transform: scale(1.1);
141
  color: #e0e1dd;
142
  transition: transform 0.2s, color 0.2s;
143
  }
 
144
  /* Button Styling */
145
  .gradio-button {
146
  background: linear-gradient(90deg, #1b263b, #415a77);
 
157
  background: linear-gradient(90deg, #415a77, #1b263b);
158
  transform: scale(1.03);
159
  }
 
 
160
  #upscale_btn { margin-top: 1rem; }
161
  .gradio-row { gap: 1rem; }
162
  """
 
166
  gr.HTML("<h1 class='fancy-title'>✨ Ultra AI Image Upscaler ✨</h1>")
167
  with gr.Row():
168
  inp = gr.Image(type="pil", label="Drop Your Image Here")
169
+ model = gr.Radio(
170
+ choices=["modelx2", "modelx2_25JXL", "modelx4", "minecraft_modelx4"],
171
+ label="Upscaler Model",
172
+ value="modelx2"
173
+ )
174
  btn = gr.Button("Upscale Image", elem_id="upscale_btn")
175
  out = gr.Image(label="Upscaled Output")
176
  btn.click(fn=upscale, inputs=[inp, model], outputs=out)
177
  gr.HTML("<p style='text-align:center; color:#e0e1dd;'>Powered by ONNX Runtime & Gradio Blocks</p>")
178
 
179
  if __name__ == "__main__":
180
+ demo.launch()