xianbao's picture
Update app.py
e38b1d0 verified
raw
history blame
2.05 kB
import gradio as gr
from typing import Iterable
import gradio as gr
from gradio.themes.base import Base
from gradio.themes.utils import colors, fonts, sizes
import time
def get_cache_dir():
from random_word import RandomWords
r = RandomWords()
return r.get_random_word()
def pull_from_ms(repo_id, cache_dir):
token=os.environ['MS_TOKEN']
from modelscope import HubApi
from modelscope import snapshot_download
if token and token.strip():
api=HubApi()
api.login(token)
model_path = snapshot_download(
repo_id, cache_dir=cache_dir)
return f'Pulled {repo_id} to {cache_dir}'
def push_to_hf(cache_dir, ms_repo_id, hf_repo_id):
from huggingface_hub import HfApi
token=os.environ['HF_TOKEN']
if not token:
raise gr.Error("Please enter your HF_TOKEN")
api = HfApi(token=token) # Token is not persisted on the machine.)
output = api.upload_folder(
folder_path=f"{cache_dir}/{ms_repo_id}",
repo_id=hf_repo_id,
repo_type="model",
)
return f'Pushed to {repo_id}'
def handle(ms_repo_id, hf_repo_id):
cache_dir = get_cache_dir()
stages = [
(pull_from_ms, (ms_repo_id, cache_dir), {}),
(push_to_hf, (cache_dir, hf_repo_id), {})
]
results = []
errors = []
for func, args, kwargs in stages:
try:
results.append(func(*args, **kwargs))
except e:
errors.append(e)
if errors:
break
return results.join('\n\n'), errors.join('\n\n')
with gr.Blocks() as demo:
ms_repo_id = gr.Textbox(label="Model Scope Repo ID (case sensitive)")
hf_repo_id = gr.Textbox(label="Target HF Model Repo ID (case sensitive)")
with gr.Row():
button = gr.Button("Submit", variant="primary")
clear = gr.Button("Clear")
output = gr.Textbox(label="Output")
error = gr.Textbox(label="Error")
button.click(handle, [ms_repo_id, hf_repo_id], [output, error])
if __name__ == "__main__":
demo.launch()