luckycanucky commited on
Commit
6c7000b
·
verified ·
1 Parent(s): d74cbc6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -67
app.py CHANGED
@@ -3,82 +3,128 @@ import cv2
3
  import onnxruntime
4
  import gradio as gr
5
 
6
-
7
  def pre_process(img: np.array) -> np.array:
8
- # H, W, C -> C, H, W
9
  img = np.transpose(img[:, :, 0:3], (2, 0, 1))
10
- # C, H, W -> 1, C, H, W
11
- img = np.expand_dims(img, axis=0).astype(np.float32)
12
- return img
13
-
14
 
15
  def post_process(img: np.array) -> np.array:
16
- # 1, C, H, W -> C, H, W
17
  img = np.squeeze(img)
18
- # C, H, W -> H, W, C
19
- img = np.transpose(img, (1, 2, 0))[:, :, ::-1].astype(np.uint8)
20
- return img
21
-
22
 
 
 
23
  def inference(model_path: str, img_array: np.array) -> np.array:
24
- options = onnxruntime.SessionOptions()
25
- options.intra_op_num_threads = 1
26
- options.inter_op_num_threads = 1
27
- ort_session = onnxruntime.InferenceSession(model_path, options)
28
- ort_inputs = {ort_session.get_inputs()[0].name: img_array}
29
- ort_outs = ort_session.run(None, ort_inputs)
30
-
31
- return ort_outs[0]
32
-
33
-
34
- def convert_pil_to_cv2(image):
35
- # pil_image = image.convert("RGB")
36
- open_cv_image = np.array(image)
37
- # RGB to BGR
38
- open_cv_image = open_cv_image[:, :, ::-1].copy()
39
- return open_cv_image
40
-
41
-
42
- def upscale(image, model):
43
- model_path = f"models/{model}.ort"
44
  img = convert_pil_to_cv2(image)
45
  if img.ndim == 2:
46
  img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
47
 
 
48
  if img.shape[2] == 4:
49
- alpha = img[:, :, 3] # GRAY
50
- alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2BGR) # BGR
51
- alpha_output = post_process(inference(model_path, pre_process(alpha))) # BGR
52
- alpha_output = cv2.cvtColor(alpha_output, cv2.COLOR_BGR2GRAY) # GRAY
53
-
54
- img = img[:, :, 0:3] # BGR
55
- image_output = post_process(inference(model_path, pre_process(img))) # BGR
56
- image_output = cv2.cvtColor(image_output, cv2.COLOR_BGR2BGRA) # BGRA
57
- image_output[:, :, 3] = alpha_output
58
-
59
- elif img.shape[2] == 3:
60
- image_output = post_process(inference(model_path, pre_process(img))) # BGR
61
-
62
- return image_output
63
-
64
-
65
- css = ".output-image, .input-image, .image-preview {height: 480px !important} "
66
- model_choices = ["modelx2", "modelx2 25 JXL", "modelx4", "minecraft_modelx4"]
67
-
68
- gr.Interface(
69
- fn=upscale,
70
- inputs=[
71
- gr.Image(type="pil", label="Input Image"),
72
- gr.Radio(
73
- model_choices,
74
- type="value",
75
- value=None,
76
- label="Choose Upscaler",
77
- ),
78
- ],
79
- outputs="image",
80
- title="Image Upscaler! Multiple AI",
81
- description="Model: [Anchor-based Plain Net for Mobile Image Super-Resolution](https://arxiv.org/abs/2105.09750). Repository: [SR Mobile PyTorch](https://github.com/w11wo/sr_mobile_pytorch)",
82
- allow_flagging="never",
83
- css=css,
84
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import onnxruntime
4
  import gradio as gr
5
 
6
+ # === Ultra-Cleaner & Upscaler Logic (unchanged) ===
7
  def pre_process(img: np.array) -> np.array:
 
8
  img = np.transpose(img[:, :, 0:3], (2, 0, 1))
9
+ return np.expand_dims(img, axis=0).astype(np.float32)
 
 
 
10
 
11
  def post_process(img: np.array) -> np.array:
 
12
  img = np.squeeze(img)
13
+ return np.transpose(img, (1, 2, 0))[:, :, ::-1].astype(np.uint8)
 
 
 
14
 
15
+ # ONNX inference
16
+ _session_cache = {}
17
  def inference(model_path: str, img_array: np.array) -> np.array:
18
+ if model_path not in _session_cache:
19
+ opts = onnxruntime.SessionOptions()
20
+ opts.intra_op_num_threads = 1
21
+ opts.inter_op_num_threads = 1
22
+ _session_cache[model_path] = onnxruntime.InferenceSession(model_path, opts)
23
+ session = _session_cache[model_path]
24
+ inputs = {session.get_inputs()[0].name: img_array}
25
+ return session.run(None, inputs)[0]
26
+
27
+ # Convert PIL to BGR
28
+ from PIL import Image
29
+ def convert_pil_to_cv2(image: Image.Image) -> np.array:
30
+ arr = np.array(image)
31
+ return arr[:, :, ::-1].copy() if arr.ndim == 3 else cv2.cvtColor(arr, cv2.COLOR_GRAY2BGR)
32
+
33
+ # Upscale
34
+
35
+ def upscale(image, model_choice):
36
+ model_path = f"models/{model_choice}.ort"
 
37
  img = convert_pil_to_cv2(image)
38
  if img.ndim == 2:
39
  img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
40
 
41
+ # handle alpha
42
  if img.shape[2] == 4:
43
+ alpha = cv2.cvtColor(img[:, :, 3], cv2.COLOR_GRAY2BGR)
44
+ out_alpha = post_process(inference(model_path, pre_process(alpha)))
45
+ out_alpha = cv2.cvtColor(out_alpha, cv2.COLOR_BGR2GRAY)
46
+ img = img[:, :, 0:3]
47
+ out_img = post_process(inference(model_path, pre_process(img)))
48
+ out_bgra = cv2.cvtColor(out_img, cv2.COLOR_BGR2BGRA)
49
+ out_bgra[:, :, 3] = out_alpha
50
+ return out_bgra
51
+
52
+ # normal RGB
53
+ return post_process(inference(model_path, pre_process(img)))
54
+
55
+ # === Custom CSS for gradients & animations ===
56
+ custom_css = """
57
+ body .gradio-container {
58
+ background: linear-gradient(-45deg, #ff9a9e, #fad0c4, #fad0c4, #ffdde1);
59
+ background-size: 400% 400%;
60
+ animation: gradientBG 15s ease infinite;
61
+ }
62
+
63
+ @keyframes gradientBG {
64
+ 0% { background-position: 0% 50%; }
65
+ 50% { background-position: 100% 50%; }
66
+ 100% { background-position: 0% 50%; }
67
+ }
68
+
69
+ .fancy-title {
70
+ font-family: 'Poppins', sans-serif;
71
+ font-size: 3rem;
72
+ background: linear-gradient(90deg, #7F7FD5, #86A8E7, #91EAE4);
73
+ -webkit-background-clip: text;
74
+ -webkit-text-fill-color: transparent;
75
+ animation: fadeInText 2s ease-in-out;
76
+ text-align: center;
77
+ }
78
+
79
+ @keyframes fadeInText {
80
+ 0% { opacity: 0; transform: translateY(-20px); }
81
+ 100% { opacity: 1; transform: translateY(0); }
82
+ }
83
+
84
+ .gradio-image {
85
+ animation: fadeIn 1s ease-in;
86
+ border-radius: 12px;
87
+ box-shadow: 0 8px 16px rgba(0,0,0,0.2);
88
+ }
89
+
90
+ @keyframes fadeIn {
91
+ from { opacity: 0; }
92
+ to { opacity: 1; }
93
+ }
94
+
95
+ .gradio-radio input[type="radio"] + label:hover {
96
+ transform: scale(1.1);
97
+ transition: transform 0.2s;
98
+ }
99
+
100
+ .gradio-button {
101
+ background: linear-gradient(90deg, #FF8A00, #E52E71);
102
+ border: none;
103
+ border-radius: 8px;
104
+ color: white;
105
+ font-weight: bold;
106
+ padding: 12px 24px;
107
+ cursor: pointer;
108
+ transition: background 0.3s, transform 0.2s;
109
+ }
110
+
111
+ .gradio-button:hover {
112
+ background: linear-gradient(90deg, #E52E71, #FF8A00);
113
+ transform: scale(1.05);
114
+ }
115
+ """
116
+
117
+ # === Gradio Blocks App ===
118
+ with gr.Blocks(css=custom_css, theme="gradio") as demo:
119
+ gr.HTML("<h1 class='fancy-title'>✨ Ultra AI Image Upscaler ✨</h1>")
120
+ with gr.Row():
121
+ inp = gr.Image(type="pil", label="Drop Your Image Here")
122
+ model = gr.Radio(["modelx2", "modelx2_25JXL", "modelx4", "minecraft_modelx4"], label="Upscaler Model", value="modelx2")
123
+ btn = gr.Button("Upscale Image", elem_id="upscale_btn", css="margin-top: 10px;")
124
+ out = gr.Image(label="Upscaled Output", elem_classes="gradio-image")
125
+
126
+ btn.click(fn=upscale, inputs=[inp, model], outputs=out)
127
+ gr.HTML("<p style='text-align:center; color:#555;'>Powered by ONNX Runtime & Gradio Blocks</p>")
128
+
129
+ if __name__ == "__main__":
130
+ demo.launch()