image-upsacler / app.py
luckycanucky's picture
Update app.py
53edc5f verified
import os
import numpy as np
import cv2
import onnxruntime
import gradio as gr
from PIL import Image
# === Pre-/Post-Processing ===
def pre_process(img: np.ndarray) -> np.ndarray:
# Convert HWC-BGR to CHW-RGB and batch
img = img[:, :, :3]
img = img[:, :, ::-1] # BGR to RGB
img = np.transpose(img, (2, 0, 1))
return np.expand_dims(img, axis=0).astype(np.float32)
def post_process(out: np.ndarray) -> np.ndarray:
# Remove batch dimension, convert CHW-RGB to HWC-BGR
img = np.squeeze(out, axis=0)
img = np.transpose(img, (1, 2, 0))
img = img[:, :, ::-1] # RGB to BGR
img = np.clip(img, 0, 255).astype(np.uint8)
return img
# === ONNX Inference Session with Dynamic Providers ===
def get_session(model_path: str) -> onnxruntime.InferenceSession:
if model_path not in get_session.cache:
if not os.path.isfile(model_path):
raise FileNotFoundError(f"Model file not found: {model_path}")
opts = onnxruntime.SessionOptions()
opts.intra_op_num_threads = 1
opts.inter_op_num_threads = 1
# Select CUDA if available
providers = []
for p in onnxruntime.get_available_providers():
if p == "CUDAExecutionProvider":
providers.append(p)
providers.append("CPUExecutionProvider")
sess = onnxruntime.InferenceSession(model_path, opts, providers=providers)
get_session.cache[model_path] = sess
return get_session.cache[model_path]
get_session.cache = {}
def run_inference(model_path: str, input_tensor: np.ndarray) -> np.ndarray:
session = get_session(model_path)
input_name = session.get_inputs()[0].name
return session.run(None, {input_name: input_tensor})[0]
# === Image Conversion ===
def convert_pil_to_cv2(image: Image.Image) -> np.ndarray:
arr = np.array(image)
# If grayscale
if arr.ndim == 2:
return cv2.cvtColor(arr, cv2.COLOR_GRAY2BGR)
# If RGBA
if arr.shape[2] == 4:
return arr[:, :, ::-1].copy() # RGBA to ABGR
# RGB
return arr[:, :, ::-1].copy() # RGB to BGR
# === Upscale Handler ===
def upscale(image: Image.Image, model_choice: str) -> np.ndarray:
"""
Upscale an image (RGB or RGBA) using the selected ONNX model.
"""
model_path = os.path.join("models", f"{model_choice}.ort")
img = convert_pil_to_cv2(image)
# Handle alpha channel separately
if img.shape[2] == 4:
# Split channels
rgb = img[:, :, :3]
alpha = img[:, :, 3]
# Process RGB
in_rgb = pre_process(rgb)
out_rgb = post_process(run_inference(model_path, in_rgb))
# Process alpha as grayscale
alpha_bgr = cv2.cvtColor(alpha, cv2.COLOR_GRAY2BGR)
in_alpha = pre_process(alpha_bgr)
out_alpha = post_process(run_inference(model_path, in_alpha))
out_alpha = cv2.cvtColor(out_alpha, cv2.COLOR_BGR2GRAY)
# Merge back to RGBA
rgba = cv2.cvtColor(out_rgb, cv2.COLOR_BGR2BGRA)
rgba[:, :, 3] = out_alpha
return rgba
# No alpha
inp = pre_process(img)
return post_process(run_inference(model_path, inp))
# === Custom Dark Blue-Grey CSS ===
custom_css = """
/* Dark Gradient Background */
body .gradio-container {
background: linear-gradient(135deg, #0d1b2a, #1b263b, #415a77, #1b263b);
background-size: 400% 400%;
animation: bgFade 25s ease infinite;
}
@keyframes bgFade {
0% { background-position: 0% 0%; }
50% { background-position: 100% 100%; }
100% { background-position: 0% 0%; }
}
/* Title Styling */
.fancy-title {
font-family: 'Poppins', sans-serif;
font-size: 2.8rem;
background: linear-gradient(90deg, #778da9, #415a77);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
animation: fadeInText 2s ease-out;
text-align: center;
margin-bottom: 1rem;
}
@keyframes fadeInText {
0% { opacity: 0; transform: translateY(-10px); }
100% { opacity: 1; transform: translateY(0); }
}
/* Inputs & Outputs */
.gradio-image, .gradio-gallery {
animation: fadeIn 1.2s ease-in;
border-radius: 10px;
box-shadow: 0 4px 12px rgba(0,0,0,0.5);
border: 2px solid #415a77;
}
@keyframes fadeIn {
from { opacity: 0; }
to { opacity: 1; }
}
/* Radio Hover */
.gradio-radio input[type="radio"] + label:hover {
transform: scale(1.1);
color: #e0e1dd;
transition: transform 0.2s, color 0.2s;
}
/* Button Styling */
.gradio-button {
background: linear-gradient(90deg, #1b263b, #415a77);
border: 1px solid #778da9;
border-radius: 6px;
color: #e0e1dd;
font-weight: 600;
padding: 10px 22px;
cursor: pointer;
box-shadow: 0 2px 6px rgba(0,0,0,0.7);
transition: background 0.3s, transform 0.2s;
}
.gradio-button:hover {
background: linear-gradient(90deg, #415a77, #1b263b);
transform: scale(1.03);
}
#upscale_btn { margin-top: 1rem; }
.gradio-row { gap: 1rem; }
"""
# === Gradio Blocks App ===
with gr.Blocks(css=custom_css) as demo:
gr.HTML("<h1 class='fancy-title'>✨ Ultra AI Image Upscaler ✨</h1>")
with gr.Row():
inp = gr.Image(type="pil", label="Drop Your Image Here")
model = gr.Radio(
choices=["modelx2", "modelx2_25JXL", "modelx4", "minecraft_modelx4"],
label="Upscaler Model",
value="modelx2"
)
btn = gr.Button("Upscale Image", elem_id="upscale_btn")
out = gr.Image(label="Upscaled Output")
btn.click(fn=upscale, inputs=[inp, model], outputs=out)
gr.HTML("<p style='text-align:center; color:#e0e1dd;'>Powered by ONNX Runtime & Gradio Blocks</p>")
if __name__ == "__main__":
demo.launch()