luckycanucky commited on
Commit
db25782
·
verified ·
1 Parent(s): 673f548

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -26
app.py CHANGED
@@ -6,7 +6,7 @@ from PIL import Image
6
 
7
  # === Upscaler Logic ===
8
  def pre_process(img: np.array) -> np.array:
9
- img = np.transpose(img[:, :, 0:3], (2, 0, 1))
10
  return np.expand_dims(img, axis=0).astype(np.float32)
11
 
12
 
@@ -14,20 +14,30 @@ def post_process(img: np.array) -> np.array:
14
  img = np.squeeze(img)
15
  return np.transpose(img, (1, 2, 0))[:, :, ::-1].astype(np.uint8)
16
 
17
-
18
- # ONNX inference with session cache
19
- _session_cache = {}
20
- def inference(model_path: str, img_array: np.array) -> np.array:
21
- if model_path not in _session_cache:
22
  opts = onnxruntime.SessionOptions()
 
23
  opts.intra_op_num_threads = 1
24
  opts.inter_op_num_threads = 1
25
- _session_cache[model_path] = onnxruntime.InferenceSession(model_path, opts)
26
- session = _session_cache[model_path]
 
 
 
 
 
 
 
 
 
 
 
27
  inputs = {session.get_inputs()[0].name: img_array}
28
  return session.run(None, inputs)[0]
29
 
30
-
31
  # PIL to BGR conversion
32
  def convert_pil_to_cv2(image: Image.Image) -> np.array:
33
  arr = np.array(image)
@@ -35,33 +45,29 @@ def convert_pil_to_cv2(image: Image.Image) -> np.array:
35
  return cv2.cvtColor(arr, cv2.COLOR_GRAY2BGR)
36
  return arr[:, :, ::-1].copy()
37
 
38
-
39
  # Upscale handler
40
  def upscale(image, model_choice):
41
  model_path = f"models/{model_choice}.ort"
42
  img = convert_pil_to_cv2(image)
43
- # ensure 3 channels
44
  if img.ndim == 2:
45
  img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
46
 
47
- # 4-channel alpha handling
48
  if img.shape[2] == 4:
49
  alpha = cv2.cvtColor(img[:, :, 3], cv2.COLOR_GRAY2BGR)
50
  out_a = post_process(inference(model_path, pre_process(alpha)))
51
  out_a = cv2.cvtColor(out_a, cv2.COLOR_BGR2GRAY)
52
- img_rgb = img[:, :, :3]
53
- out_rgb = post_process(inference(model_path, pre_process(img_rgb)))
54
- out_rgba = cv2.cvtColor(out_rgb, cv2.COLOR_BGR2BGRA)
55
- out_rgba[:, :, 3] = out_a
56
- return out_rgba
57
-
58
  return post_process(inference(model_path, pre_process(img)))
59
 
60
-
61
  # === Custom CSS for styling & animations ===
62
  custom_css = """
63
  body .gradio-container {
64
- background: linear-gradient(-45deg, #ff9a9e, #fad0c4, #fad0c4, #ffdde1);
65
  background-size: 400% 400%;
66
  animation: gradientBG 15s ease infinite;
67
  }
@@ -92,7 +98,7 @@ body .gradio-container {
92
  from { opacity: 0; }
93
  to { opacity: 1; }
94
  }
95
- .gradio-radio input[type="radio"] + label:hover {
96
  transform: scale(1.1);
97
  transition: transform 0.2s;
98
  }
@@ -115,15 +121,13 @@ body .gradio-container {
115
  }
116
  """
117
 
118
-
119
- # === Gradio Blocks App (no theme string) ===
120
  with gr.Blocks(css=custom_css) as demo:
121
  gr.HTML("<h1 class='fancy-title'>✨ Ultra AI Image Upscaler ✨</h1>")
122
  with gr.Row():
123
  inp = gr.Image(type="pil", label="Drop Your Image Here")
124
  model = gr.Radio([
125
- "modelx2", "modelx2_25JXL",
126
- "modelx4", "minecraft_modelx4"
127
  ], label="Upscaler Model", value="modelx2")
128
  btn = gr.Button("Upscale Image", elem_id="upscale_btn")
129
  out = gr.Image(label="Upscaled Output", elem_classes="gradio-image")
@@ -132,4 +136,4 @@ with gr.Blocks(css=custom_css) as demo:
132
  gr.HTML("<p style='text-align:center; color:#555;'>Powered by ONNX Runtime & Gradio Blocks</p>")
133
 
134
  if __name__ == "__main__":
135
- demo.launch()
 
6
 
7
  # === Upscaler Logic ===
8
  def pre_process(img: np.array) -> np.array:
9
+ img = np.transpose(img[:, :, :3], (2, 0, 1))
10
  return np.expand_dims(img, axis=0).astype(np.float32)
11
 
12
 
 
14
  img = np.squeeze(img)
15
  return np.transpose(img, (1, 2, 0))[:, :, ::-1].astype(np.uint8)
16
 
17
+ # ONNX inference with GPU if available
18
+ def get_session(model_path: str):
19
+ # cache sessions
20
+ if model_path not in get_session.cache:
 
21
  opts = onnxruntime.SessionOptions()
22
+ # multi-threading
23
  opts.intra_op_num_threads = 1
24
  opts.inter_op_num_threads = 1
25
+ # try GPU first
26
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
27
+ get_session.cache[model_path] = onnxruntime.InferenceSession(
28
+ model_path,
29
+ sess_options=opts,
30
+ providers=providers
31
+ )
32
+ return get_session.cache[model_path]
33
+ get_session.cache = {}
34
+
35
+
36
+ def inference(model_path: str, img_array: np.array) -> np.array:
37
+ session = get_session(model_path)
38
  inputs = {session.get_inputs()[0].name: img_array}
39
  return session.run(None, inputs)[0]
40
 
 
41
  # PIL to BGR conversion
42
  def convert_pil_to_cv2(image: Image.Image) -> np.array:
43
  arr = np.array(image)
 
45
  return cv2.cvtColor(arr, cv2.COLOR_GRAY2BGR)
46
  return arr[:, :, ::-1].copy()
47
 
 
48
  # Upscale handler
49
  def upscale(image, model_choice):
50
  model_path = f"models/{model_choice}.ort"
51
  img = convert_pil_to_cv2(image)
52
+ # handle potential alpha channel
53
  if img.ndim == 2:
54
  img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
55
 
 
56
  if img.shape[2] == 4:
57
  alpha = cv2.cvtColor(img[:, :, 3], cv2.COLOR_GRAY2BGR)
58
  out_a = post_process(inference(model_path, pre_process(alpha)))
59
  out_a = cv2.cvtColor(out_a, cv2.COLOR_BGR2GRAY)
60
+ rgb = img[:, :, :3]
61
+ out_rgb = post_process(inference(model_path, pre_process(rgb)))
62
+ rgba = cv2.cvtColor(out_rgb, cv2.COLOR_BGR2BGRA)
63
+ rgba[:, :, 3] = out_a
64
+ return rgba
 
65
  return post_process(inference(model_path, pre_process(img)))
66
 
 
67
  # === Custom CSS for styling & animations ===
68
  custom_css = """
69
  body .gradio-container {
70
+ background: linear-gradient(-45deg, #ff9a9e, #fad0c4, #ffdde1);
71
  background-size: 400% 400%;
72
  animation: gradientBG 15s ease infinite;
73
  }
 
98
  from { opacity: 0; }
99
  to { opacity: 1; }
100
  }
101
+ .gradio-radio input[type=\"radio\"] + label:hover {
102
  transform: scale(1.1);
103
  transition: transform 0.2s;
104
  }
 
121
  }
122
  """
123
 
124
+ # === Gradio Blocks App ===
 
125
  with gr.Blocks(css=custom_css) as demo:
126
  gr.HTML("<h1 class='fancy-title'>✨ Ultra AI Image Upscaler ✨</h1>")
127
  with gr.Row():
128
  inp = gr.Image(type="pil", label="Drop Your Image Here")
129
  model = gr.Radio([
130
+ "modelx2", "modelx2_25JXL", "modelx4", "minecraft_modelx4"
 
131
  ], label="Upscaler Model", value="modelx2")
132
  btn = gr.Button("Upscale Image", elem_id="upscale_btn")
133
  out = gr.Image(label="Upscaled Output", elem_classes="gradio-image")
 
136
  gr.HTML("<p style='text-align:center; color:#555;'>Powered by ONNX Runtime & Gradio Blocks</p>")
137
 
138
  if __name__ == "__main__":
139
+ demo.launch()