File size: 2,046 Bytes
4c10d2b
 
 
8c1eec8
 
 
 
 
 
4c10d2b
e38b1d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c10d2b
c03f3f7
e38b1d0
 
8c1eec8
 
 
 
dbb2737
8c1eec8
e38b1d0
8c1eec8
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import gradio as gr


from typing import Iterable
import gradio as gr
from gradio.themes.base import Base
from gradio.themes.utils import colors, fonts, sizes
import time


def get_cache_dir():
    from random_word import RandomWords
    r = RandomWords()
    return r.get_random_word()


def pull_from_ms(repo_id, cache_dir):
    token=os.environ['MS_TOKEN']

    from modelscope import HubApi
    from modelscope import snapshot_download

    if token and token.strip():
        api=HubApi()
        api.login(token)

    model_path = snapshot_download(
    repo_id, cache_dir=cache_dir)
    return f'Pulled {repo_id} to {cache_dir}'


def push_to_hf(cache_dir, ms_repo_id, hf_repo_id):
    from huggingface_hub import HfApi
    
    token=os.environ['HF_TOKEN']
    if not token:
        raise gr.Error("Please enter your HF_TOKEN")
    
    api = HfApi(token=token) # Token is not persisted on the machine.)
    output = api.upload_folder(
        folder_path=f"{cache_dir}/{ms_repo_id}",
        repo_id=hf_repo_id,
        repo_type="model",
    )
    return f'Pushed to {repo_id}'
    
    
def handle(ms_repo_id, hf_repo_id):
    cache_dir = get_cache_dir()
    stages = [
        (pull_from_ms, (ms_repo_id, cache_dir), {}),
        (push_to_hf, (cache_dir, hf_repo_id), {})
    ]

    results = []
    errors = []
    for func, args, kwargs in stages:
        try:
            results.append(func(*args, **kwargs))
        except e:
            errors.append(e)
        if errors:
            break

    return results.join('\n\n'), errors.join('\n\n')


with gr.Blocks() as demo:
    ms_repo_id = gr.Textbox(label="Model Scope Repo ID (case sensitive)")
    hf_repo_id = gr.Textbox(label="Target HF Model Repo ID (case sensitive)")
    with gr.Row():
        button = gr.Button("Submit", variant="primary")
        clear = gr.Button("Clear")
    output = gr.Textbox(label="Output")
    error = gr.Textbox(label="Error")

    button.click(handle, [ms_repo_id, hf_repo_id], [output, error])

if __name__ == "__main__":
    demo.launch()