import onnxruntime as ort import numpy as np import json from PIL import Image # 1) Load ONNX model session = ort.InferenceSession("camie_tagger_initial_v15.onnx", providers=["CPUExecutionProvider"]) # 2) Preprocess your image (512x512, etc.) def preprocess_image(img_path): """ Loads and resizes an image to 512x512, converts it to float32 [0..1], and returns a (1,3,512,512) NumPy array (NCHW format). """ img = Image.open(img_path).convert("RGB").resize((512, 512)) x = np.array(img).astype(np.float32) / 255.0 x = np.transpose(x, (2, 0, 1)) # HWC -> CHW x = np.expand_dims(x, 0) # add batch dimension -> (1,3,512,512) return x # Example input def load_thresholds(threshold_json_path, mode="balanced"): """ Loads thresholds from the given JSON file, using a particular mode (e.g. 'balanced', 'high_precision', 'high_recall') for each category. Returns: thresholds_by_category (dict): e.g. { "general": 0.328..., "character": 0.304..., ... } fallback_threshold (float): The overall threshold if category not found """ with open(threshold_json_path, "r", encoding="utf-8") as f: data = json.load(f) # The fallback threshold from the "overall" section for the chosen mode fallback_threshold = data["overall"][mode]["threshold"] # Build a dict of thresholds keyed by category thresholds_by_category = {} if "categories" in data: for cat_name, cat_modes in data["categories"].items(): # If the chosen mode is present for that category, use it; # otherwise fall back to the "overall" threshold. if mode in cat_modes and "threshold" in cat_modes[mode]: thresholds_by_category[cat_name] = cat_modes[mode]["threshold"] else: thresholds_by_category[cat_name] = fallback_threshold return thresholds_by_category, fallback_threshold def inference( input_path, output_format="verbose", mode="balanced", threshold_json_path="thresholds.json", metadata_path="metadata.json" ): """ Run inference on an image using the loaded ONNX model, then apply category-wise thresholds from `threshold.json` for the chosen mode. Arguments: input_path (str) : Path to the image file for inference. output_format (str) : Either "verbose" or "as_prompt". mode (str) : "balanced", "high_precision", or "high_recall" threshold_json_path (str) : Path to the JSON file with category thresholds. metadata_path (str) : Path to the metadata JSON file with category info. Returns: str: The predicted tags in either verbose or comma-separated format. """ # 1) Preprocess input_tensor = preprocess_image(input_path) # 2) Run inference input_name = session.get_inputs()[0].name outputs = session.run(None, {input_name: input_tensor}) initial_logits, refined_logits = outputs # shape: (1, 70527) each # 3) Convert logits to probabilities refined_probs = 1 / (1 + np.exp(-refined_logits)) # shape: (1, 70527) # 4) Load metadata & retrieve threshold info with open(metadata_path, "r", encoding="utf-8") as f: metadata = json.load(f) idx_to_tag = metadata["idx_to_tag"] # e.g. { "0": "brown_hair", "1": "blue_eyes", ... } tag_to_category = metadata.get("tag_to_category", {}) # Load thresholds from threshold.json using the specified mode thresholds_by_category, fallback_threshold = load_thresholds(threshold_json_path, mode) # 5) Collect predictions by category results_by_category = {} num_tags = refined_probs.shape[1] for i in range(num_tags): prob = float(refined_probs[0, i]) tag_name = idx_to_tag[str(i)] # str(i) because metadata uses string keys category = tag_to_category.get(tag_name, "general") # Determine the threshold to use for this category cat_threshold = thresholds_by_category.get(category, fallback_threshold) if prob >= cat_threshold: if category not in results_by_category: results_by_category[category] = [] results_by_category[category].append((tag_name, prob)) # 6) Depending on output_format, produce different return strings if output_format == "as_prompt": # Flatten all predicted tags across categories all_predicted_tags = [] for cat, tags_list in results_by_category.items(): # We only need the tag name in as_prompt format for tname, tprob in tags_list: # convert underscores to spaces tag_name_spaces = tname.replace("_", " ") all_predicted_tags.append(tag_name_spaces) # Create a comma-separated string prompt_string = ", ".join(all_predicted_tags) return prompt_string else: # "verbose" # We'll build a multiline string describing the predictions lines = [] lines.append("Predicted Tags by Category:\n") for cat, tags_list in results_by_category.items(): lines.append(f"Category: {cat} | Predicted {len(tags_list)} tags") # Sort descending by probability for tname, tprob in sorted(tags_list, key=lambda x: x[1], reverse=True): lines.append(f" Tag: {tname:30s} Prob: {tprob:.4f}") lines.append("") # blank line after each category # Join lines with newlines verbose_output = "\n".join(lines) return verbose_output if __name__ == "__main__": result = inference("", output_format="as_prompt") print(result)