File size: 5,021 Bytes
dd06d6b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
# convert the ball to environment map, lat, long format
import numpy as np
from PIL import Image
import skimage
import time
import torch
import argparse
from multiprocessing import Pool
from functools import partial
from tqdm.auto import tqdm
import os
try:
import ezexr
except:
pass
def create_argparser():
parser = argparse.ArgumentParser()
parser.add_argument("--ball_dir", type=str, required=True ,help='directory that contain the image')
parser.add_argument("--envmap_dir", type=str, required=True ,help='directory to output environment map') #dataset name or directory
parser.add_argument("--envmap_height", type=int, default=256, help="size of the environment map height in pixel (height)")
parser.add_argument("--scale", type=int, default=4, help="scale factor")
parser.add_argument("--threads", type=int, default=8, help="num thread for pararell processing")
return parser
def create_envmap_grid(size: int):
"""
BLENDER CONVENSION
Create the grid of environment map that contain the position in sperical coordinate
Top left is (0,0) and bottom right is (pi/2, 2pi)
"""
theta = torch.linspace(0, np.pi * 2, size * 2)
phi = torch.linspace(0, np.pi, size)
#use indexing 'xy' torch match vision's homework 3
theta, phi = torch.meshgrid(theta, phi ,indexing='xy')
theta_phi = torch.cat([theta[..., None], phi[..., None]], dim=-1)
theta_phi = theta_phi.numpy()
return theta_phi
def get_normal_vector(incoming_vector: np.ndarray, reflect_vector: np.ndarray):
"""
BLENDER CONVENSION
incoming_vector: the vector from the point to the camera
reflect_vector: the vector from the point to the light source
"""
#N = 2(R ⋅ I)R - I
N = (incoming_vector + reflect_vector) / np.linalg.norm(incoming_vector + reflect_vector, axis=-1, keepdims=True)
return N
def get_cartesian_from_spherical(theta: np.array, phi: np.array, r = 1.0):
"""
BLENDER CONVENSION
theta: vertical angle
phi: horizontal angle
r: radius
"""
x = r * np.sin(theta) * np.cos(phi)
y = r * np.sin(theta) * np.sin(phi)
z = r * np.cos(theta)
return np.concatenate([x[...,None],y[...,None],z[...,None]], axis=-1)
def process_image(args: argparse.Namespace, file_name: str):
I = np.array([1,0, 0])
# check if exist, skip!
envmap_output_path = os.path.join(args.envmap_dir, file_name)
if os.path.exists(envmap_output_path):
return None
# read ball image
ball_path = os.path.join(args.ball_dir, file_name)
if file_name.endswith(".exr"):
ball_image = ezexr.imread(ball_path)
else:
try:
ball_image = skimage.io.imread(ball_path)
ball_image = skimage.img_as_float(ball_image)
except:
return None
# compute normal map that create from reflect vector
env_grid = create_envmap_grid(args.envmap_height * args.scale)
reflect_vec = get_cartesian_from_spherical(env_grid[...,1], env_grid[...,0])
normal = get_normal_vector(I[None,None], reflect_vec)
# turn from normal map to position to lookup [Range: 0,1]
pos = (normal + 1.0) / 2
pos = 1.0 - pos
pos = pos[...,1:]
env_map = None
# using pytorch method for bilinear interpolation
with torch.no_grad():
# convert position to pytorch grid look up
grid = torch.from_numpy(pos)[None].float()
grid = grid * 2 - 1 # convert to range [-1,1]
# convert ball to support pytorch
ball_image = torch.from_numpy(ball_image[None]).float()
ball_image = ball_image.permute(0,3,1,2) # [1,3,H,W]
env_map = torch.nn.functional.grid_sample(ball_image, grid, mode='bilinear', padding_mode='border', align_corners=True)
env_map = env_map[0].permute(1,2,0).numpy()
env_map_default = skimage.transform.resize(env_map, (args.envmap_height, args.envmap_height*2), anti_aliasing=True)
if file_name.endswith(".exr"):
ezexr.imwrite(envmap_output_path, env_map_default.astype(np.float32))
else:
env_map_default = skimage.img_as_ubyte(env_map_default)
skimage.io.imsave(envmap_output_path, env_map_default)
return None
def main():
# running time measuring
start_time = time.time()
# load arguments
args = create_argparser().parse_args()
# make output directory if not exist
os.makedirs(args.envmap_dir, exist_ok=True)
# get all file in the directory
files = sorted(os.listdir(args.ball_dir))
# create partial function for pararell processing
process_func = partial(process_image, args)
# pararell processing
with Pool(args.threads) as p:
list(tqdm(p.imap(process_func, files), total=len(files)))
# print total time
print("TOTAL TIME: ", time.time() - start_time)
if __name__ == "__main__":
main()
|