Spaces:
Paused
Paused
File size: 3,167 Bytes
1efd233 cd39699 89ace7e c1fc3a9 cc04f60 1e1efc2 27bcfa0 c1fc3a9 27bcfa0 1efd233 2ccc88d b50be2b 2ccc88d 4b29566 c1fc3a9 028d122 6bf2756 4b29566 1efd233 c1fc3a9 4b29566 c1fc3a9 88a0be3 c1fc3a9 4b29566 c1fc3a9 4b29566 0d18b6e c1fc3a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import gradio as gr
import os
import torch, torchvision, einops
import spaces
import subprocess
from transformers import AutoModelForCausalLM, AutoModel, AutoModelForVision2Seq, PaliGemmaForConditionalGeneration, LlavaForConditionalGeneration, LlavaNextForConditionalGeneration
from huggingface_hub import login
# Install required package
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
hf_token = os.getenv("HF_TOKEN")
login(token=hf_token, add_to_git_credential=True)
# Cache for storing loaded models and their summaries
model_cache = {}
# Function to get the model summary
@spaces.GPU
def get_model_summary(model_name):
if model_name in model_cache:
return model_cache[model_name], ""
try:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True).to(device)
model_summary = str(model)
model_cache[model_name] = model_summary
return model_summary, ""
except Exception as e:
return "", str(e)
# Create the Gradio Blocks interface
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
textbox = gr.Textbox(label="Model Name", placeholder="Enter the model name here OR select example below...", lines=1)
gr.Markdown("### Vision Models")
vision_examples = gr.Examples(
examples=[
["microsoft/llava-med-v1.5-mistral-7b"],
["llava-hf/llava-v1.6-mistral-7b-hf"],
["xtuner/llava-phi-3-mini-hf"],
["xtuner/llava-llama-3-8b-v1_1-transformers"],
["vikhyatk/moondream2"],
["openbmb/MiniCPM-Llama3-V-2_5"],
["microsoft/Phi-3-vision-128k-instruct"],
["google/paligemma-3b-mix-224"],
["HuggingFaceM4/idefics2-8b-chatty"]
],
inputs=textbox
)
gr.Markdown("### Other Models")
other_examples = gr.Examples(
examples=[
["google/gemma-7b"],
["microsoft/Phi-3-mini-4k-instruct"],
["meta-llama/Meta-Llama-3-8B"],
["mistralai/Mistral-7B-Instruct-v0.3"]
],
inputs=textbox
)
submit_button = gr.Button("Submit")
with gr.Column():
output = gr.Textbox(label="Model Architecture", lines=20, placeholder="Model architecture will appear here...", show_copy_button=True)
error_output = gr.Textbox(label="Error", lines=10, placeholder="Exceptions will appear here...", show_copy_button=True)
def handle_click(model_name):
model_summary, error_message = get_model_summary(model_name)
return model_summary, error_message
submit_button.click(fn=handle_click, inputs=textbox, outputs=[output, error_output])
# Launch the interface
demo.launch()
|