import birder import numpy as np from birder.inference.classification import infer_image from huggingface_hub import HfApi import gradio as gr def get_birder_classification_models(): api = HfApi() models = api.list_models(author="birder-project", tags="image-classification") return [model.modelId.split("/")[-1] for model in models] def load_model_and_predict(image, model_name): try: (net, class_to_idx, signature, rgb_stats) = birder.load_pretrained_model(model_name, inference=True) size = birder.get_size_from_signature(signature) transform = birder.classification_transform(size, rgb_stats) (out, _) = infer_image(net, image, transform) idx_to_class = {v: k for k, v in class_to_idx.items()} topk_idx = np.argsort(out[0])[-3:][::-1] predictions = [(idx_to_class[idx], float(out[0][idx])) for idx in topk_idx] return predictions except Exception as e: return [(f"Error: {str(e)}", 0.0)] def predict(image, model_name): predictions = load_model_and_predict(image, model_name) return {f"{class_name} ({conf:.2%})": conf for class_name, conf in predictions} def create_interface(): models = get_birder_classification_models() examples = [ ["Common myna.jpeg", "mvit_v2_t_il-all"], ["Eurasian hoopoe.jpeg", "xcit_nano12_p16_il-common"], ["Grey heron.jpeg", "davit_tiny_il-all"], ] # Create interface iface = gr.Interface( analytics_enabled=False, fn=predict, inputs=[ gr.Image(type="pil", label="Input Image"), gr.Dropdown( choices=models, label="Select Model", value=models[0] if models else None, ), ], outputs=gr.Label(num_top_classes=3), examples=examples, title="Birder Image Classification", description="Select a model and upload an image or use one of the examples to get bird species predictions.", ) return iface # Launch the app if __name__ == "__main__": demo = create_interface() demo.launch()