luckycanucky commited on
Commit
ac9616e
·
verified ·
1 Parent(s): f5a3de2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -67
app.py CHANGED
@@ -2,83 +2,142 @@ import numpy as np
2
  import cv2
3
  import onnxruntime
4
  import gradio as gr
 
5
 
6
-
7
  def pre_process(img: np.array) -> np.array:
8
- # H, W, C -> C, H, W
9
- img = np.transpose(img[:, :, 0:3], (2, 0, 1))
10
- # C, H, W -> 1, C, H, W
11
- img = np.expand_dims(img, axis=0).astype(np.float32)
12
- return img
13
 
14
 
15
  def post_process(img: np.array) -> np.array:
16
- # 1, C, H, W -> C, H, W
17
  img = np.squeeze(img)
18
- # C, H, W -> H, W, C
19
- img = np.transpose(img, (1, 2, 0))[:, :, ::-1].astype(np.uint8)
20
- return img
21
-
 
 
 
 
 
 
 
 
22
 
23
  def inference(model_path: str, img_array: np.array) -> np.array:
24
- options = onnxruntime.SessionOptions()
25
- options.intra_op_num_threads = 1
26
- options.inter_op_num_threads = 1
27
- ort_session = onnxruntime.InferenceSession(model_path, options)
28
- ort_inputs = {ort_session.get_inputs()[0].name: img_array}
29
- ort_outs = ort_session.run(None, ort_inputs)
30
-
31
- return ort_outs[0]
32
-
33
-
34
- def convert_pil_to_cv2(image):
35
- # pil_image = image.convert("RGB")
36
- open_cv_image = np.array(image)
37
- # RGB to BGR
38
- open_cv_image = open_cv_image[:, :, ::-1].copy()
39
- return open_cv_image
40
-
41
-
42
- def upscale(image, model):
43
- model_path = f"models/{model}.ort"
44
  img = convert_pil_to_cv2(image)
45
  if img.ndim == 2:
46
  img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
47
 
48
  if img.shape[2] == 4:
49
- alpha = img[:, :, 3] # GRAY
50
- alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2BGR) # BGR
51
- alpha_output = post_process(inference(model_path, pre_process(alpha))) # BGR
52
- alpha_output = cv2.cvtColor(alpha_output, cv2.COLOR_BGR2GRAY) # GRAY
53
-
54
- img = img[:, :, 0:3] # BGR
55
- image_output = post_process(inference(model_path, pre_process(img))) # BGR
56
- image_output = cv2.cvtColor(image_output, cv2.COLOR_BGR2BGRA) # BGRA
57
- image_output[:, :, 3] = alpha_output
58
-
59
- elif img.shape[2] == 3:
60
- image_output = post_process(inference(model_path, pre_process(img))) # BGR
61
-
62
- return image_output
63
-
64
-
65
- css = ".output-image, .input-image, .image-preview {height: 480px !important} "
66
- model_choices = ["modelx2", "modelx2 25 JXL", "modelx4", "minecraft_modelx4"]
67
-
68
- gr.Interface(
69
- fn=upscale,
70
- inputs=[
71
- gr.Image(type="pil", label="Input Image"),
72
- gr.Radio(
73
- model_choices,
74
- type="value",
75
- value=None,
76
- label="Choose Upscaler",
77
- ),
78
- ],
79
- outputs="image",
80
- title="Image Upscaler! Multiple AI",
81
- description="Model: [Anchor-based Plain Net for Mobile Image Super-Resolution](https://arxiv.org/abs/2105.09750). Repository: [SR Mobile PyTorch](https://github.com/w11wo/sr_mobile_pytorch)",
82
- allow_flagging="never",
83
- css=css,
84
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import cv2
3
  import onnxruntime
4
  import gradio as gr
5
+ from PIL import Image
6
 
7
+ # === Upscaler Logic ===
8
  def pre_process(img: np.array) -> np.array:
9
+ img = np.transpose(img[:, :, :3], (2, 0, 1))
10
+ return np.expand_dims(img, axis=0).astype(np.float32)
 
 
 
11
 
12
 
13
  def post_process(img: np.array) -> np.array:
 
14
  img = np.squeeze(img)
15
+ return np.transpose(img, (1, 2, 0))[:, :, ::-1].astype(np.uint8)
16
+
17
+ # ONNX inference with session cache and GPU if available
18
+ def get_session(model_path: str):
19
+ if model_path not in get_session.cache:
20
+ opts = onnxruntime.SessionOptions()
21
+ opts.intra_op_num_threads = 1
22
+ opts.inter_op_num_threads = 1
23
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
24
+ get_session.cache[model_path] = onnxruntime.InferenceSession(model_path, opts, providers=providers)
25
+ return get_session.cache[model_path]
26
+ get_session.cache = {}
27
 
28
  def inference(model_path: str, img_array: np.array) -> np.array:
29
+ session = get_session(model_path)
30
+ inputs = {session.get_inputs()[0].name: img_array}
31
+ return session.run(None, inputs)[0]
32
+
33
+ # PIL to BGR conversion
34
+ def convert_pil_to_cv2(image: Image.Image) -> np.array:
35
+ arr = np.array(image)
36
+ if arr.ndim == 2:
37
+ return cv2.cvtColor(arr, cv2.COLOR_GRAY2BGR)
38
+ return arr[:, :, ::-1].copy()
39
+
40
+ # Upscale handler
41
+ def upscale(image, model_choice):
42
+ model_path = f"models/{model_choice}.ort"
 
 
 
 
 
 
43
  img = convert_pil_to_cv2(image)
44
  if img.ndim == 2:
45
  img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
46
 
47
  if img.shape[2] == 4:
48
+ alpha = cv2.cvtColor(img[:, :, 3], cv2.COLOR_GRAY2BGR)
49
+ out_a = post_process(inference(model_path, pre_process(alpha)))
50
+ out_a = cv2.cvtColor(out_a, cv2.COLOR_BGR2GRAY)
51
+ rgb = img[:, :, :3]
52
+ out_rgb = post_process(inference(model_path, pre_process(rgb)))
53
+ rgba = cv2.cvtColor(out_rgb, cv2.COLOR_BGR2BGRA)
54
+ rgba[:, :, 3] = out_a
55
+ return rgba
56
+ return post_process(inference(model_path, pre_process(img)))
57
+
58
+ # === Dark Blue-Grey Theme CSS & Animations ===
59
+ custom_css = """
60
+ /* Dark Gradient Background */
61
+ body .gradio-container {
62
+ background: linear-gradient(135deg, #0d1b2a, #1b263b, #415a77, #1b263b);
63
+ background-size: 400% 400%;
64
+ animation: bgFade 25s ease infinite;
65
+ }
66
+ @keyframes bgFade {
67
+ 0% { background-position: 0% 0%; }
68
+ 50% { background-position: 100% 100%; }
69
+ 100% { background-position: 0% 0%; }
70
+ }
71
+
72
+ /* Title Styling */
73
+ .fancy-title {
74
+ font-family: 'Poppins', sans-serif;
75
+ font-size: 2.8rem;
76
+ background: linear-gradient(90deg, #778da9, #415a77);
77
+ -webkit-background-clip: text;
78
+ -webkit-text-fill-color: transparent;
79
+ animation: fadeInText 2s ease-out;
80
+ text-align: center;
81
+ margin-bottom: 1rem;
82
+ }
83
+ @keyframes fadeInText {
84
+ 0% { opacity: 0; transform: translateY(-10px); }
85
+ 100% { opacity: 1; transform: translateY(0); }
86
+ }
87
+
88
+ /* Inputs & Outputs */
89
+ .gradio-image, .gradio-gallery {
90
+ animation: fadeIn 1.2s ease-in;
91
+ border-radius: 10px;
92
+ box-shadow: 0 4px 12px rgba(0,0,0,0.5);
93
+ border: 2px solid #415a77;
94
+ }
95
+ @keyframes fadeIn {
96
+ from { opacity: 0; }
97
+ to { opacity: 1; }
98
+ }
99
+
100
+ /* Radio Hover */
101
+ .gradio-radio input[type="radio"] + label:hover {
102
+ transform: scale(1.1);
103
+ color: #e0e1dd;
104
+ transition: transform 0.2s, color 0.2s;
105
+ }
106
+
107
+ /* Button Styling */
108
+ .gradio-button {
109
+ background: linear-gradient(90deg, #1b263b, #415a77);
110
+ border: 1px solid #778da9;
111
+ border-radius: 6px;
112
+ color: #e0e1dd;
113
+ font-weight: 600;
114
+ padding: 10px 22px;
115
+ cursor: pointer;
116
+ box-shadow: 0 2px 6px rgba(0,0,0,0.7);
117
+ transition: background 0.3s, transform 0.2s;
118
+ }
119
+ .gradio-button:hover {
120
+ background: linear-gradient(90deg, #415a77, #1b263b);
121
+ transform: scale(1.03);
122
+ }
123
+
124
+ /* Layout tweaks */
125
+ #upscale_btn { margin-top: 1rem; }
126
+ .gradio-row { gap: 1rem; }
127
+ """
128
+
129
+ # === Gradio Blocks App ===
130
+ with gr.Blocks(css=custom_css) as demo:
131
+ gr.HTML("<h1 class='fancy-title'>✨ Ultra AI Image Upscaler ✨</h1>")
132
+ with gr.Row():
133
+ inp = gr.Image(type="pil", label="Drop Your Image Here")
134
+ model = gr.Radio([
135
+ "modelx2", "modelx2_25JXL", "modelx4", "minecraft_modelx4"
136
+ ], label="Upscaler Model", value="modelx2")
137
+ btn = gr.Button("Upscale Image", elem_id="upscale_btn")
138
+ out = gr.Image(label="Upscaled Output")
139
+ btn.click(fn=upscale, inputs=[inp, model], outputs=out)
140
+ gr.HTML("<p style='text-align:center; color:#e0e1dd;'>Powered by ONNX Runtime & Gradio Blocks</p>")
141
+
142
+ if __name__ == "__main__":
143
+ demo.launch()