|
|
|
|
|
import numpy as np |
|
from PIL import Image |
|
import skimage |
|
import time |
|
import torch |
|
import argparse |
|
from multiprocessing import Pool |
|
from functools import partial |
|
from tqdm.auto import tqdm |
|
import os |
|
|
|
try: |
|
import ezexr |
|
except: |
|
pass |
|
|
|
def create_argparser(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--ball_dir", type=str, required=True ,help='directory that contain the image') |
|
parser.add_argument("--envmap_dir", type=str, required=True ,help='directory to output environment map') |
|
parser.add_argument("--envmap_height", type=int, default=256, help="size of the environment map height in pixel (height)") |
|
parser.add_argument("--scale", type=int, default=4, help="scale factor") |
|
parser.add_argument("--threads", type=int, default=8, help="num thread for pararell processing") |
|
return parser |
|
|
|
def create_envmap_grid(size: int): |
|
""" |
|
BLENDER CONVENSION |
|
Create the grid of environment map that contain the position in sperical coordinate |
|
Top left is (0,0) and bottom right is (pi/2, 2pi) |
|
""" |
|
|
|
theta = torch.linspace(0, np.pi * 2, size * 2) |
|
phi = torch.linspace(0, np.pi, size) |
|
|
|
|
|
theta, phi = torch.meshgrid(theta, phi ,indexing='xy') |
|
|
|
|
|
theta_phi = torch.cat([theta[..., None], phi[..., None]], dim=-1) |
|
theta_phi = theta_phi.numpy() |
|
return theta_phi |
|
|
|
def get_normal_vector(incoming_vector: np.ndarray, reflect_vector: np.ndarray): |
|
""" |
|
BLENDER CONVENSION |
|
incoming_vector: the vector from the point to the camera |
|
reflect_vector: the vector from the point to the light source |
|
""" |
|
|
|
N = (incoming_vector + reflect_vector) / np.linalg.norm(incoming_vector + reflect_vector, axis=-1, keepdims=True) |
|
return N |
|
|
|
def get_cartesian_from_spherical(theta: np.array, phi: np.array, r = 1.0): |
|
""" |
|
BLENDER CONVENSION |
|
theta: vertical angle |
|
phi: horizontal angle |
|
r: radius |
|
""" |
|
x = r * np.sin(theta) * np.cos(phi) |
|
y = r * np.sin(theta) * np.sin(phi) |
|
z = r * np.cos(theta) |
|
return np.concatenate([x[...,None],y[...,None],z[...,None]], axis=-1) |
|
|
|
|
|
def process_image(args: argparse.Namespace, file_name: str): |
|
I = np.array([1,0, 0]) |
|
|
|
|
|
envmap_output_path = os.path.join(args.envmap_dir, file_name) |
|
if os.path.exists(envmap_output_path): |
|
return None |
|
|
|
|
|
ball_path = os.path.join(args.ball_dir, file_name) |
|
if file_name.endswith(".exr"): |
|
ball_image = ezexr.imread(ball_path) |
|
else: |
|
try: |
|
ball_image = skimage.io.imread(ball_path) |
|
ball_image = skimage.img_as_float(ball_image) |
|
except: |
|
return None |
|
|
|
|
|
env_grid = create_envmap_grid(args.envmap_height * args.scale) |
|
reflect_vec = get_cartesian_from_spherical(env_grid[...,1], env_grid[...,0]) |
|
normal = get_normal_vector(I[None,None], reflect_vec) |
|
|
|
|
|
pos = (normal + 1.0) / 2 |
|
pos = 1.0 - pos |
|
pos = pos[...,1:] |
|
|
|
env_map = None |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
grid = torch.from_numpy(pos)[None].float() |
|
grid = grid * 2 - 1 |
|
|
|
|
|
ball_image = torch.from_numpy(ball_image[None]).float() |
|
ball_image = ball_image.permute(0,3,1,2) |
|
|
|
env_map = torch.nn.functional.grid_sample(ball_image, grid, mode='bilinear', padding_mode='border', align_corners=True) |
|
env_map = env_map[0].permute(1,2,0).numpy() |
|
|
|
env_map_default = skimage.transform.resize(env_map, (args.envmap_height, args.envmap_height*2), anti_aliasing=True) |
|
if file_name.endswith(".exr"): |
|
ezexr.imwrite(envmap_output_path, env_map_default.astype(np.float32)) |
|
else: |
|
env_map_default = skimage.img_as_ubyte(env_map_default) |
|
skimage.io.imsave(envmap_output_path, env_map_default) |
|
return None |
|
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
args = create_argparser().parse_args() |
|
|
|
|
|
os.makedirs(args.envmap_dir, exist_ok=True) |
|
|
|
|
|
files = sorted(os.listdir(args.ball_dir)) |
|
|
|
|
|
process_func = partial(process_image, args) |
|
|
|
|
|
with Pool(args.threads) as p: |
|
list(tqdm(p.imap(process_func, files), total=len(files))) |
|
|
|
|
|
print("TOTAL TIME: ", time.time() - start_time) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|