hz2475 commited on
Commit
4cf9521
·
1 Parent(s): cd33a14
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
starvector/serve/controller.py DELETED
@@ -1,293 +0,0 @@
1
- """
2
- A controller manages distributed workers.
3
- It sends worker addresses to clients.
4
- """
5
- import argparse
6
- import asyncio
7
- import dataclasses
8
- from enum import Enum, auto
9
- import json
10
- import logging
11
- import time
12
- from typing import List, Union
13
- import threading
14
-
15
- from fastapi import FastAPI, Request
16
- from fastapi.responses import StreamingResponse
17
- import numpy as np
18
- import requests
19
- import uvicorn
20
-
21
- from starvector.serve.constants import CONTROLLER_HEART_BEAT_EXPIRATION
22
- from starvector.serve.util import build_logger, server_error_msg
23
-
24
- logger = build_logger("controller", "controller.log")
25
-
26
- class DispatchMethod(Enum):
27
- LOTTERY = auto()
28
- SHORTEST_QUEUE = auto()
29
-
30
- @classmethod
31
- def from_str(cls, name):
32
- if name == "lottery":
33
- return cls.LOTTERY
34
- elif name == "shortest_queue":
35
- return cls.SHORTEST_QUEUE
36
- else:
37
- raise ValueError(f"Invalid dispatch method")
38
-
39
-
40
- @dataclasses.dataclass
41
- class WorkerInfo:
42
- model_names: List[str]
43
- speed: int
44
- queue_length: int
45
- check_heart_beat: bool
46
- last_heart_beat: str
47
-
48
-
49
- def heart_beat_controller(controller):
50
- while True:
51
- time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
52
- controller.remove_stable_workers_by_expiration()
53
-
54
-
55
- class Controller:
56
- def __init__(self, dispatch_method: str):
57
- # Dict[str -> WorkerInfo]
58
- self.worker_info = {}
59
- self.dispatch_method = DispatchMethod.from_str(dispatch_method)
60
-
61
- self.heart_beat_thread = threading.Thread(
62
- target=heart_beat_controller, args=(self,))
63
- self.heart_beat_thread.start()
64
-
65
- logger.info("Init controller")
66
-
67
- def register_worker(self, worker_name: str, check_heart_beat: bool,
68
- worker_status: dict):
69
- if worker_name not in self.worker_info:
70
- logger.info(f"Register a new worker: {worker_name}")
71
- else:
72
- logger.info(f"Register an existing worker: {worker_name}")
73
-
74
- if not worker_status:
75
- worker_status = self.get_worker_status(worker_name)
76
- if not worker_status:
77
- return False
78
-
79
- self.worker_info[worker_name] = WorkerInfo(
80
- worker_status["model_names"], worker_status["speed"], worker_status["queue_length"],
81
- check_heart_beat, time.time())
82
-
83
- logger.info(f"Register done: {worker_name}, {worker_status}")
84
- return True
85
-
86
- def get_worker_status(self, worker_name: str):
87
- try:
88
- r = requests.post(worker_name + "/worker_get_status", timeout=5)
89
- except requests.exceptions.RequestException as e:
90
- logger.error(f"Get status fails: {worker_name}, {e}")
91
- return None
92
-
93
- if r.status_code != 200:
94
- logger.error(f"Get status fails: {worker_name}, {r}")
95
- return None
96
-
97
- return r.json()
98
-
99
- def remove_worker(self, worker_name: str):
100
- del self.worker_info[worker_name]
101
-
102
- def refresh_all_workers(self):
103
- old_info = dict(self.worker_info)
104
- self.worker_info = {}
105
-
106
- for w_name, w_info in old_info.items():
107
- if not self.register_worker(w_name, w_info.check_heart_beat, None):
108
- logger.info(f"Remove stale worker: {w_name}")
109
-
110
- def list_models(self):
111
- model_names = set()
112
-
113
- for w_name, w_info in self.worker_info.items():
114
- model_names.update(w_info.model_names)
115
-
116
- return list(model_names)
117
-
118
- def get_worker_address(self, model_name: str):
119
- if self.dispatch_method == DispatchMethod.LOTTERY:
120
- worker_names = []
121
- worker_speeds = []
122
- for w_name, w_info in self.worker_info.items():
123
- if model_name in w_info.model_names:
124
- worker_names.append(w_name)
125
- worker_speeds.append(w_info.speed)
126
- worker_speeds = np.array(worker_speeds, dtype=np.float32)
127
- norm = np.sum(worker_speeds)
128
- if norm < 1e-4:
129
- return ""
130
- worker_speeds = worker_speeds / norm
131
- if True: # Directly return address
132
- pt = np.random.choice(np.arange(len(worker_names)),
133
- p=worker_speeds)
134
- worker_name = worker_names[pt]
135
- return worker_name
136
-
137
- # Check status before returning
138
- while True:
139
- pt = np.random.choice(np.arange(len(worker_names)),
140
- p=worker_speeds)
141
- worker_name = worker_names[pt]
142
-
143
- if self.get_worker_status(worker_name):
144
- break
145
- else:
146
- self.remove_worker(worker_name)
147
- worker_speeds[pt] = 0
148
- norm = np.sum(worker_speeds)
149
- if norm < 1e-4:
150
- return ""
151
- worker_speeds = worker_speeds / norm
152
- continue
153
- return worker_name
154
- elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
155
- worker_names = []
156
- worker_qlen = []
157
- for w_name, w_info in self.worker_info.items():
158
- if model_name in w_info.model_names:
159
- worker_names.append(w_name)
160
- worker_qlen.append(w_info.queue_length / w_info.speed)
161
- if len(worker_names) == 0:
162
- return ""
163
- min_index = np.argmin(worker_qlen)
164
- w_name = worker_names[min_index]
165
- self.worker_info[w_name].queue_length += 1
166
- logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}")
167
- return w_name
168
- else:
169
- raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
170
-
171
- def receive_heart_beat(self, worker_name: str, queue_length: int):
172
- if worker_name not in self.worker_info:
173
- logger.info(f"Receive unknown heart beat. {worker_name}")
174
- return False
175
-
176
- self.worker_info[worker_name].queue_length = queue_length
177
- self.worker_info[worker_name].last_heart_beat = time.time()
178
- logger.info(f"Receive heart beat. {worker_name}")
179
- return True
180
-
181
- def remove_stable_workers_by_expiration(self):
182
- expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
183
- to_delete = []
184
- for worker_name, w_info in self.worker_info.items():
185
- if w_info.check_heart_beat and w_info.last_heart_beat < expire:
186
- to_delete.append(worker_name)
187
-
188
- for worker_name in to_delete:
189
- self.remove_worker(worker_name)
190
-
191
- def worker_api_generate_stream(self, params):
192
- worker_addr = self.get_worker_address(params["model"])
193
- if not worker_addr:
194
- logger.info(f"no worker: {params['model']}")
195
- ret = {
196
- "text": server_error_msg,
197
- "error_code": 2,
198
- }
199
- yield json.dumps(ret).encode() + b"\0"
200
-
201
- try:
202
- response = requests.post(worker_addr + "/worker_generate_stream",
203
- json=params, stream=True, timeout=5)
204
- for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
205
- if chunk:
206
- yield chunk + b"\0"
207
- except requests.exceptions.RequestException as e:
208
- logger.info(f"worker timeout: {worker_addr}")
209
- ret = {
210
- "text": server_error_msg,
211
- "error_code": 3,
212
- }
213
- yield json.dumps(ret).encode() + b"\0"
214
-
215
-
216
- # Let the controller act as a worker to achieve hierarchical
217
- # management. This can be used to connect isolated sub networks.
218
- def worker_api_get_status(self):
219
- model_names = set()
220
- speed = 0
221
- queue_length = 0
222
-
223
- for w_name in self.worker_info:
224
- worker_status = self.get_worker_status(w_name)
225
- if worker_status is not None:
226
- model_names.update(worker_status["model_names"])
227
- speed += worker_status["speed"]
228
- queue_length += worker_status["queue_length"]
229
-
230
- return {
231
- "model_names": list(model_names),
232
- "speed": speed,
233
- "queue_length": queue_length,
234
- }
235
-
236
-
237
- app = FastAPI()
238
-
239
- @app.post("/register_worker")
240
- async def register_worker(request: Request):
241
- data = await request.json()
242
- controller.register_worker(
243
- data["worker_name"], data["check_heart_beat"],
244
- data.get("worker_status", None))
245
-
246
- @app.post("/refresh_all_workers")
247
- async def refresh_all_workers():
248
- models = controller.refresh_all_workers()
249
-
250
-
251
- @app.post("/list_models")
252
- async def list_models():
253
- models = controller.list_models()
254
- return {"models": models}
255
-
256
-
257
- @app.post("/get_worker_address")
258
- async def get_worker_address(request: Request):
259
- data = await request.json()
260
- addr = controller.get_worker_address(data["model"])
261
- return {"address": addr}
262
-
263
- @app.post("/receive_heart_beat")
264
- async def receive_heart_beat(request: Request):
265
- data = await request.json()
266
- exist = controller.receive_heart_beat(
267
- data["worker_name"], data["queue_length"])
268
- return {"exist": exist}
269
-
270
-
271
- @app.post("/worker_generate_stream")
272
- async def worker_api_generate_stream(request: Request):
273
- params = await request.json()
274
- generator = controller.worker_api_generate_stream(params)
275
- return StreamingResponse(generator)
276
-
277
-
278
- @app.post("/worker_get_status")
279
- async def worker_api_get_status(request: Request):
280
- return controller.worker_api_get_status()
281
-
282
-
283
- if __name__ == "__main__":
284
- parser = argparse.ArgumentParser()
285
- parser.add_argument("--host", type=str, default="localhost")
286
- parser.add_argument("--port", type=int, default=21001)
287
- parser.add_argument("--dispatch-method", type=str, choices=[
288
- "lottery", "shortest_queue"], default="shortest_queue")
289
- args = parser.parse_args()
290
- logger.info(f"args: {args}")
291
-
292
- controller = Controller(args.dispatch_method)
293
- uvicorn.run(app, host=args.host, port=args.port, log_level="info")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
starvector/serve/gradio_demo_with_updated_gradio.py DELETED
@@ -1,432 +0,0 @@
1
- import argparse
2
- import datetime
3
- import json
4
- import os
5
- import time
6
- import gradio as gr
7
- import requests
8
- from starvector.serve.conversation import default_conversation
9
- from starvector.serve.constants import LOGDIR, CLIP_QUERY_LENGTH
10
- from starvector.serve.util import (build_logger, server_error_msg)
11
-
12
- logger = build_logger("gradio_web_server", "gradio_web_server.log")
13
- headers = {"User-Agent": "StarVector Client"}
14
-
15
- no_change_btn = gr.Button()
16
- enable_btn = gr.Button(interactive=True)
17
- disable_btn = gr.Button(interactive=False)
18
-
19
- priority = {
20
- "starvector-1.4b": "aaaaaaa",
21
- }
22
-
23
- def get_conv_log_filename():
24
- t = datetime.datetime.now()
25
- name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
26
- return name
27
-
28
- def get_model_list():
29
- ret = requests.post(args.controller_url + "/refresh_all_workers")
30
- assert ret.status_code == 200
31
- ret = requests.post(args.controller_url + "/list_models")
32
- models = ret.json()["models"]
33
- models.sort(key=lambda x: priority.get(x, x))
34
- logger.info(f"Models: {models}")
35
- return models
36
-
37
- get_window_url_params = """
38
- function() {
39
- const params = new URLSearchParams(window.location.search);
40
- url_params = Object.fromEntries(params);
41
- console.log(url_params);
42
- return url_params;
43
- }
44
- """
45
-
46
- def load_demo(url_params, request: gr.Request):
47
- logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
48
-
49
- dropdown_update = gr.Dropdown(visible=True)
50
- if "model" in url_params:
51
- model = url_params["model"]
52
- if model in models:
53
- dropdown_update = gr.Dropdown(value=model, visible=True)
54
-
55
- state = default_conversation.copy()
56
- return state, dropdown_update
57
-
58
-
59
- def load_demo_refresh_model_list(request: gr.Request):
60
- logger.info(f"load_demo. ip: {request.client.host}")
61
- models = get_model_list()
62
- state = default_conversation.copy()
63
- dropdown_update = gr.Dropdown(
64
- choices=models,
65
- value=models[0] if len(models) > 0 else ""
66
- )
67
- return state, dropdown_update
68
-
69
- def vote_last_response(state, vote_type, model_selector, request: gr.Request):
70
- with open(get_conv_log_filename(), "a") as fout:
71
- data = {
72
- "tstamp": round(time.time(), 4),
73
- "type": vote_type,
74
- "model": model_selector,
75
- "state": state.dict(),
76
- "ip": request.client.host,
77
- }
78
- fout.write(json.dumps(data) + "\n")
79
-
80
- def upvote_last_response(state, model_selector, request: gr.Request):
81
- logger.info(f"upvote. ip: {request.client.host}")
82
- vote_last_response(state, "upvote", model_selector, request)
83
- return ("",) + (disable_btn,) * 3
84
-
85
- def downvote_last_response(state, model_selector, request: gr.Request):
86
- logger.info(f"downvote. ip: {request.client.host}")
87
- vote_last_response(state, "downvote", model_selector, request)
88
- return ("",) + (disable_btn,) * 3
89
-
90
- def flag_last_response(state, model_selector, request: gr.Request):
91
- logger.info(f"flag. ip: {request.client.host}")
92
- vote_last_response(state, "flag", model_selector, request)
93
- return ("",) + (disable_btn,) * 3
94
-
95
- def regenerate(state, image_process_mode, request: gr.Request):
96
- logger.info(f"regenerate. ip: {request.client.host}")
97
- state.messages[-1][-1] = None
98
- prev_human_msg = state.messages[-2]
99
- if type(prev_human_msg[1]) in (tuple, list):
100
- prev_human_msg[1] = (prev_human_msg[1][:2], image_process_mode)
101
- state.skip_next = False
102
- return (state, None, None, None) + (disable_btn,) * 6
103
-
104
- def clear_history(request: gr.Request):
105
- logger.info(f"clear_history. ip: {request.client.host}")
106
- state = default_conversation.copy()
107
- return (state, None, None) + (disable_btn,) * 6
108
-
109
- def send_image(state, image, image_process_mode, request: gr.Request):
110
- logger.info(f"send_image. ip: {request.client.host}.")
111
- state.stop_sampling = False
112
- if image is None:
113
- state.skip_next = True
114
- return (state, None, None, image) + (no_change_btn,) * 6
115
-
116
- if image is not None:
117
- text = (image, image_process_mode)
118
- state.append_message(state.roles[0], text)
119
- state.append_message(state.roles[1], "▌")
120
- state.skip_next = False
121
- msg = state.to_gradio_svg_code()[0][1]
122
- return (state, msg, state.to_gradio_svg_render(), image) + (no_change_btn,) * 6
123
-
124
- def stop_sampling(state, image, request: gr.Request):
125
- logger.info(f"stop_sampling. ip: {request.client.host}")
126
- state.stop_sampling = True
127
- return (state, None, None, image) + (disable_btn,) * 6
128
-
129
- def http_bot(state, model_selector, num_beams, temperature, len_penalty, top_p, max_new_tokens, request: gr.Request):
130
- logger.info(f"http_bot. ip: {request.client.host}")
131
- start_tstamp = time.time()
132
- model_name = model_selector
133
-
134
- if state.skip_next:
135
- # This generate call is skipped due to invalid inputs
136
- yield (state, None, None) + (no_change_btn,) * 6
137
- return
138
-
139
- # Query worker address
140
- controller_url = args.controller_url
141
- ret = requests.post(controller_url + "/get_worker_address",
142
- json={"model": model_name})
143
- worker_addr = ret.json()["address"]
144
- logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
145
-
146
- # No available worker
147
- if worker_addr == "":
148
- state.messages[-1][-1] = server_error_msg
149
- yield (state, None, None, disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn)
150
- return
151
-
152
- # Construct prompt
153
- prompt = state.get_prompt()
154
-
155
- # Make requests
156
- pload = {
157
- "model": model_name,
158
- "prompt": prompt,
159
- "num_beams": int(num_beams),
160
- "temperature": float(temperature),
161
- "len_penalty": float(len_penalty),
162
- "top_p": float(top_p),
163
- "max_new_tokens": min(int(max_new_tokens), 8192-CLIP_QUERY_LENGTH),
164
- }
165
- logger.info(f"==== request ====\n{pload}")
166
-
167
- pload['images'] = state.get_images()
168
-
169
- state.messages[-1][-1] = "▌"
170
- yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, disable_btn, disable_btn, enable_btn)
171
-
172
- try:
173
- # Stream output
174
- if state.stop_sampling:
175
- state.messages[1][-1] = "▌"
176
- yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn)
177
- return
178
-
179
- response = requests.post(worker_addr + "/worker_generate_stream",
180
- headers=headers, json=pload, stream=True, timeout=100)
181
- for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
182
- if chunk:
183
- data = json.loads(chunk.decode())
184
- if data["error_code"] == 0:
185
- # output = data["text"].strip().replace('<', '&lt;').replace('>', '&gt;') # trick to avoid the SVG getting rendered
186
- output = data["text"].strip()
187
- state.messages[-1][-1] = output + "▌"
188
- st = state.to_gradio_svg_code()
189
- yield (state, st[-1][1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, enable_btn)
190
- else:
191
- output = data["text"] + f" (error_code: {data['error_code']})"
192
- state.messages[-1][-1] = output
193
-
194
- yield (state, st[-1][1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn)
195
- return
196
- time.sleep(0.03)
197
- except requests.exceptions.RequestException as e:
198
- state.messages[-1][-1] = server_error_msg
199
- yield (state, None, None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn)
200
- return
201
-
202
- yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (enable_btn,) * 6
203
-
204
- finish_tstamp = time.time()
205
- logger.info(f"{output}")
206
-
207
- with open(get_conv_log_filename(), "a") as fout:
208
- data = {
209
- "tstamp": round(finish_tstamp, 4),
210
- "type": "chat",
211
- "model": model_name,
212
- "start": round(start_tstamp, 4),
213
- "finish": round(finish_tstamp, 4),
214
- "svg": state.messages[-1][-1],
215
- "ip": request.client.host,
216
- }
217
- fout.write(json.dumps(data) + "\n")
218
-
219
- title_markdown = ("""
220
- # 💫 StarVector: Generating Scalable Vector Graphics Code from Images and Text
221
- [[Project Page](https://starvector.github.io)] [[Code](https://github.com/joanrod/star-vector)] [[Model](https://huggingface.co/joanrodai/starvector-1.4b)] | 📚 [[StarVector](https://arxiv.org/abs/2312.11556)]
222
- """)
223
-
224
- sub_title_markdown = (""" Throw an image and vectorize it! The model expects vector-like images to generate the corresponding svg code.""")
225
- tos_markdown = ("""
226
- ### Terms of use
227
- By using this service, users are required to agree to the following terms:
228
- The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
229
- Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
230
- For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
231
- """)
232
-
233
-
234
- learn_more_markdown = ("""
235
- ### License
236
- The service is a research preview intended for non-commercial use only. Please contact us if you find any potential violation.
237
- """)
238
-
239
- block_css = """
240
-
241
- #buttons button {
242
- min-width: min(120px,100%);
243
- }
244
-
245
- .gradio-container{
246
- max-width: 1200px!important
247
- }
248
-
249
- #svg_render{
250
- padding: 20px !important;
251
- }
252
-
253
- #svg_code{
254
- height: 200px !important;
255
- overflow: scroll !important;
256
- white-space: unset !important;
257
- flex-shrink: unset !important;
258
- }
259
-
260
-
261
- h1{display: flex;align-items: center;justify-content: center;gap: .25em}
262
- *{transition: width 0.5s ease, flex-grow 0.5s ease}
263
- """
264
-
265
- def build_demo(embed_mode, concurrency_count=10):
266
- with gr.Blocks(title="StarVector", theme=gr.themes.Default(), css=block_css) as demo:
267
- state = gr.State()
268
- if not embed_mode:
269
- gr.Markdown(title_markdown)
270
- gr.Markdown(sub_title_markdown)
271
- with gr.Row():
272
- with gr.Column(scale=3):
273
- with gr.Row(elem_id="model_selector_row"):
274
- model_selector = gr.Dropdown(
275
- choices=models,
276
- value=models[0] if len(models) > 0 else "",
277
- interactive=True,
278
- show_label=False,
279
- container=False)
280
- imagebox = gr.Image(type="pil")
281
- image_process_mode = gr.Radio(
282
- ["Resize", "Pad", "Default"],
283
- value="Pad",
284
- label="Preprocess for non-square image", visible=False)
285
-
286
- cur_dir = os.path.dirname(os.path.abspath(__file__))
287
- gr.Examples(examples=[
288
- [f"{cur_dir}/examples/sample-4.png"],
289
- [f"{cur_dir}/examples/sample-7.png"],
290
- [f"{cur_dir}/examples/sample-16.png"],
291
- [f"{cur_dir}/examples/sample-17.png"],
292
- [f"{cur_dir}/examples/sample-18.png"],
293
- [f"{cur_dir}/examples/sample-0.png"],
294
- [f"{cur_dir}/examples/sample-1.png"],
295
- [f"{cur_dir}/examples/sample-6.png"],
296
- ], inputs=[imagebox])
297
-
298
- with gr.Column(scale=1, min_width=50):
299
- submit_btn = gr.Button(value="Send", variant="primary")
300
-
301
- with gr.Accordion("Parameters", open=True) as parameter_row:
302
- num_beams = gr.Slider(minimum=1, maximum=10, value=1, step=1, interactive=True, label="Num Beams", visible=False,)
303
- temperature = gr.Slider(minimum=0.0, maximum=2.0, value=0.8, step=0.05, interactive=True, label="Temperature",)
304
- len_penalty = gr.Slider(minimum=0.0, maximum=2.0, value=0.6, step=0.05, interactive=True, label="Length Penalty",)
305
- top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, interactive=True, label="Top P",)
306
- max_output_tokens = gr.Slider(minimum=0, maximum=8192, value=2000, step=64, interactive=True, label="Max output tokens",)
307
-
308
- with gr.Column(scale=8):
309
- with gr.Row():
310
- svg_code = gr.Code(label="SVG Code", elem_id='svg_code', min_width=200, interactive=False, lines=5)
311
- with gr.Row():
312
- gr.Image(width=50, height=256, label="Rendered SVG", elem_id='svg_render')
313
- with gr.Row(elem_id="buttons") as button_row:
314
- upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
315
- downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
316
- flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
317
- stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False, visible=False)
318
- regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False, visible=False)
319
- clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
320
-
321
- if not embed_mode:
322
- gr.Markdown(tos_markdown)
323
- gr.Markdown(learn_more_markdown)
324
- url_params = gr.JSON(visible=False)
325
-
326
- # Register listeners
327
- btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn, stop_btn]
328
- upvote_btn.click(
329
- upvote_last_response,
330
- [state, model_selector],
331
- [upvote_btn, downvote_btn, flag_btn],
332
- queue=False
333
- )
334
- downvote_btn.click(
335
- downvote_last_response,
336
- [state, model_selector],
337
- [upvote_btn, downvote_btn, flag_btn],
338
- queue=False
339
- )
340
- flag_btn.click(
341
- flag_last_response,
342
- [state, model_selector],
343
- [upvote_btn, downvote_btn, flag_btn],
344
- queue=False
345
- )
346
-
347
- regenerate_btn.click(
348
- regenerate,
349
- [state, image_process_mode],
350
- [state, svg_code, svg_render, imagebox] + btn_list,
351
- queue=False
352
- ).then(
353
- http_bot,
354
- [state, model_selector, num_beams, temperature, len_penalty, top_p, max_output_tokens],
355
- [state, svg_code, svg_render] + btn_list,
356
- concurrency_limit=concurrency_count
357
- )
358
-
359
- submit_btn.click(
360
- send_image,
361
- [state, imagebox, image_process_mode],
362
- [state, svg_code, svg_render, imagebox] + btn_list,
363
- queue=False
364
- ).then(
365
- http_bot,
366
- [state, model_selector, num_beams, temperature, len_penalty, top_p, max_output_tokens],
367
- [state, svg_code, svg_render] + btn_list,
368
- concurrency_limit=concurrency_count
369
- )
370
-
371
- clear_btn.click(
372
- clear_history,
373
- None,
374
- [state, svg_code, svg_render] + btn_list,
375
- queue=False
376
- )
377
-
378
- stop_btn.click(
379
- stop_sampling,
380
- [state, imagebox],
381
- [state, imagebox] + btn_list,
382
- queue=False
383
- ).then(
384
- clear_history,
385
- None,
386
- [state, svg_code, svg_render] + btn_list,
387
- queue=False
388
- )
389
-
390
- if args.model_list_mode == "once":
391
- demo.load(
392
- load_demo,
393
- [url_params],
394
- [state, model_selector],
395
- _js=get_window_url_params,
396
- )
397
- elif args.model_list_mode == "reload":
398
- demo.load(
399
- load_demo_refresh_model_list,
400
- None,
401
- [state, model_selector],
402
- queue=False
403
- )
404
- else:
405
- raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
406
-
407
- return demo
408
-
409
- if __name__ == "__main__":
410
- parser = argparse.ArgumentParser()
411
- parser.add_argument("--host", type=str, default="0.0.0.0")
412
- parser.add_argument("--port", type=int)
413
- parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
414
- parser.add_argument("--concurrency-count", type=int, default=15)
415
- parser.add_argument("--model-list-mode", type=str, default="once", choices=["once", "reload"])
416
- parser.add_argument("--share", action="store_true")
417
- parser.add_argument("--moderate", action="store_true")
418
- parser.add_argument("--embed", action="store_true")
419
- args = parser.parse_args()
420
- logger.info(f"args: {args}")
421
-
422
- models = get_model_list()
423
-
424
- logger.info(args)
425
- demo = build_demo(args.embed, concurrency_count=args.concurrency_count)
426
- demo.queue(
427
- api_open=False
428
- ).launch(
429
- server_name=args.host,
430
- server_port=args.port,
431
- share=args.share
432
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
starvector/serve/gradio_web_server.py DELETED
@@ -1,562 +0,0 @@
1
- import argparse
2
- import datetime
3
- import json
4
- import os
5
- import time
6
- import gradio as gr
7
- import requests
8
- from starvector.serve.conversation import default_conversation
9
- from starvector.serve.constants import LOGDIR, CLIP_QUERY_LENGTH
10
- from starvector.serve.util import (build_logger, server_error_msg)
11
-
12
- logger = build_logger("gradio_web_server", "gradio_web_server.log")
13
- headers = {"User-Agent": "StarVector Client"}
14
-
15
- no_change_btn = gr.Button.update()
16
- enable_btn = gr.Button.update(interactive=True)
17
- disable_btn = gr.Button.update(interactive=False)
18
-
19
- priority = {
20
- "starvector-1b-im2svg": "aaaaaaa",
21
- }
22
-
23
- def get_conv_log_filename():
24
- t = datetime.datetime.now()
25
- name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
26
- return name
27
-
28
- def get_model_list():
29
- ret = requests.post(args.controller_url + "/refresh_all_workers")
30
- assert ret.status_code == 200
31
- ret = requests.post(args.controller_url + "/list_models")
32
- models = ret.json()["models"]
33
- models.sort(key=lambda x: priority.get(x, x))
34
- logger.info(f"Models: {models}")
35
- return models
36
-
37
- def load_demo(url_params, request: gr.Request):
38
- logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
39
-
40
- dropdown_update = gr.Dropdown.update(visible=True)
41
- if "model" in url_params:
42
- model = url_params["model"]
43
- if model in models:
44
- dropdown_update = gr.Dropdown.update(
45
- value=model, visible=True)
46
-
47
- state = default_conversation.copy()
48
- return state, dropdown_update
49
-
50
- mapping_model_task = {
51
- 'Image2SVG': 'im2svg',
52
- 'Text2SVG': 'text2svg'
53
- }
54
-
55
- def get_models_dropdown_from_task(task):
56
- models = get_model_list()
57
- models = [model for model in models if mapping_model_task[task] in model]
58
- dropdown_update = gr.Dropdown.update(
59
- choices=models,
60
- value=models[0] if len(models) > 0 else ""
61
- )
62
- return dropdown_update
63
-
64
-
65
- def load_demo_refresh_model_list(task, request: gr.Request):
66
- logger.info(f"load_demo. ip: {request.client.host}")
67
- dropdown_update = get_models_dropdown_from_task(task)
68
- state = default_conversation.copy()
69
- return state, dropdown_update
70
-
71
- def vote_last_response(state, vote_type, model_selector, request: gr.Request):
72
- with open(get_conv_log_filename(), "a") as fout:
73
- data = {
74
- "tstamp": round(time.time(), 4),
75
- "type": vote_type,
76
- "model": model_selector,
77
- "state": state.dict(),
78
- "ip": request.client.host,
79
- }
80
- fout.write(json.dumps(data) + "\n")
81
-
82
- def upvote_last_response(state, model_selector, request: gr.Request):
83
- logger.info(f"upvote. ip: {request.client.host}")
84
- vote_last_response(state, "upvote", model_selector, request)
85
- return ("",) + (disable_btn,) * 7
86
-
87
- def downvote_last_response(state, model_selector, request: gr.Request):
88
- logger.info(f"downvote. ip: {request.client.host}")
89
- vote_last_response(state, "downvote", model_selector, request)
90
- return ("",) + (disable_btn,) * 7
91
-
92
- def flag_last_response(state, model_selector, request: gr.Request):
93
- logger.info(f"flag. ip: {request.client.host}")
94
- vote_last_response(state, "flag", model_selector, request)
95
- return ("",) + (disable_btn,) * 7
96
-
97
- def regenerate(state, image_process_mode, request: gr.Request):
98
- logger.info(f"regenerate. ip: {request.client.host}")
99
- state.messages[-1][-1] = None
100
- prev_human_msg = state.messages[-2]
101
- if type(prev_human_msg[1]) in (tuple, list):
102
- prev_human_msg[1] = (prev_human_msg[1][:2], image_process_mode)
103
- state.skip_next = False
104
- return (state, None, None, None) + (disable_btn,) * 7
105
-
106
- def clear_history(request: gr.Request):
107
- logger.info(f"clear_history. ip: {request.client.host}")
108
- state = default_conversation.copy()
109
- return (state, None, None) + (disable_btn,) * 7
110
-
111
- def send_data(state, image, image_process_mode, text_caption, task, request: gr.Request):
112
- logger.info(f"send_data. ip: {request.client.host}.")
113
- if task == 'Image2SVG':
114
- if image is None:
115
- state.skip_next = True
116
- return (state, None, None, image) + (no_change_btn,) * 7
117
-
118
- if image is not None:
119
- image_message = (image, image_process_mode)
120
- state.append_message(state.roles[0], image_message)
121
- state.append_message(state.roles[1], "▌")
122
- state.skip_next = False
123
- msg = state.to_gradio_svg_code()[0][1]
124
- return (state, msg, state.to_gradio_svg_render(), image) + (no_change_btn,) * 7
125
- else:
126
- if text_caption is None:
127
- state.skip_next = True
128
- return (state, None, None, image) + (no_change_btn,) * 7
129
-
130
- state.append_message(state.roles[0], text_caption)
131
- state.append_message(state.roles[1], "▌")
132
- state.skip_next = False
133
- msg = state.to_gradio_svg_code()[0][1]
134
- return (state, msg, state.to_gradio_svg_render(), image) + (no_change_btn,) * 7
135
-
136
- def download_files(state, request: gr.Request):
137
- logger.info(f"download_files. ip: {request.client.host}")
138
- svg_str, image = state.download_files()
139
-
140
- # TODO: Figure out how to download the SVG in the users browser, idk how to do it now
141
-
142
- def update_task(task):
143
- dropdown_update = get_models_dropdown_from_task(task)
144
-
145
- if task == "Text2SVG":
146
- return 1.0, 0.9, 0.95, dropdown_update
147
- else:
148
- return 0.6, 0.9, 0.95, dropdown_update
149
-
150
-
151
- def stop_sampling(state, image, request: gr.Request):
152
- logger.info(f"stop_sampling. ip: {request.client.host}")
153
- state.stop_sampling = True
154
- return (state, None, None, image) + (disable_btn,) * 7
155
-
156
- def http_bot(state, task_selector, text_caption, model_selector, num_beams, temperature, len_penalty, top_p, max_new_tokens, request: gr.Request):
157
- logger.info(f"http_bot. ip: {request.client.host}")
158
- start_tstamp = time.time()
159
- model_name = model_selector
160
-
161
- if state.skip_next:
162
- # This generate call is skipped due to invalid inputs
163
- yield (state, None, None) + (no_change_btn,) * 7
164
- return
165
-
166
- # Query worker address
167
- controller_url = args.controller_url
168
- ret = requests.post(controller_url + "/get_worker_address",
169
- json={"model": model_name})
170
- worker_addr = ret.json()["address"]
171
- logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
172
-
173
- # No available worker
174
- if worker_addr == "":
175
- state.messages[-1][-1] = server_error_msg
176
- yield (state, None, None, disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, disable_btn)
177
- return
178
-
179
- # Construct prompt
180
- if task_selector == "Image2SVG":
181
- prompt = state.get_image_prompt()
182
- else:
183
- prompt = text_caption
184
-
185
- # Make requests
186
- pload = {
187
- "model": model_name,
188
- "prompt": prompt,
189
- "num_beams": int(num_beams),
190
- "temperature": float(temperature),
191
- "len_penalty": float(len_penalty),
192
- "top_p": float(top_p),
193
- "max_new_tokens": min(int(max_new_tokens), 8192-CLIP_QUERY_LENGTH),
194
- }
195
- logger.info(f"==== request ====\n{pload}")
196
-
197
- pload['images'] = state.get_images()
198
-
199
- state.messages[-1][-1] = "▌"
200
- yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
201
-
202
- try:
203
- # Stream output
204
- if state.stop_sampling:
205
- state.messages[1][-1] = "▌"
206
- yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, enable_btn)
207
- return
208
-
209
- response = requests.post(worker_addr + "/worker_generate_stream",
210
- headers=headers, json=pload, stream=True, timeout=10)
211
- for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
212
- if chunk:
213
- data = json.loads(chunk.decode())
214
- if data["error_code"] == 0:
215
- # output = data["text"].strip().replace('<', '&lt;').replace('>', '&gt;') # trick to avoid the SVG getting rendered
216
- output = data["text"].strip()
217
- state.messages[-1][-1] = output + "▌"
218
- st = state.to_gradio_svg_code()
219
- yield (state, st[-1][1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, enable_btn, enable_btn)
220
- else:
221
- output = data["text"] + f" (error_code: {data['error_code']})"
222
- state.messages[-1][-1] = output
223
-
224
- yield (state, st[-1][1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, disable_btn)
225
- return
226
- time.sleep(0.03)
227
- except requests.exceptions.RequestException as e:
228
- state.messages[-1][-1] = server_error_msg
229
- yield (state, None, None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, disable_btn)
230
- return
231
-
232
- yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (enable_btn,) * 7
233
-
234
- finish_tstamp = time.time()
235
- logger.info(f"{output}")
236
-
237
- with open(get_conv_log_filename(), "a") as fout:
238
- data = {
239
- "tstamp": round(finish_tstamp, 4),
240
- "type": "chat",
241
- "model": model_name,
242
- "start": round(start_tstamp, 4),
243
- "finish": round(finish_tstamp, 4),
244
- "svg": state.messages[-1][-1],
245
- "ip": request.client.host,
246
- }
247
- fout.write(json.dumps(data) + "\n")
248
-
249
- title_markdown = ("""
250
- # 💫 StarVector: Generating Scalable Vector Graphics Code from Images and Text
251
-
252
- [[Project Page](https://starvector.github.io)] [[Code](https://github.com/joanrod/star-vector)] [[Model](https://huggingface.co/joanrodai/starvector-1.4b)] | 📚 [[StarVector](https://arxiv.org/abs/2312.11556)]""")
253
-
254
- sub_title_markdown = ("""**How does it work?** Select the task you want to perform, and the model will be automatically set. For **Text2SVG**, introduce a prompt in Text Caption. For **Image2SVG**, select an image and vectorize it. \
255
- **Note**: The current model works on vector-like images like icons and or vector-like designs.""")
256
- tos_markdown = ("""
257
- ### Terms of use
258
- By using this service, users are required to agree to the following terms:
259
- The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
260
- Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
261
- For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
262
- """)
263
-
264
- learn_more_markdown = ("""
265
- ### License
266
- The service is a research preview intended for non-commercial use only. Please contact us if you find any potential violation.
267
- """)
268
-
269
- block_css = """
270
-
271
- #buttons button {
272
- min-width: min(120px,100%);
273
- }
274
-
275
- .gradio-container{
276
- max-width: 1200px!important
277
- }
278
-
279
- .ͼ1 .cm-content {
280
- white-space: unset !important;
281
- flex-shrink: unset !important;
282
- }
283
-
284
- .ͼ2p .cm-scroller {
285
- max-height: 200px;
286
- overflow: scroll;
287
- }
288
-
289
- #svg_render{
290
- padding: 20px !important;
291
- }
292
-
293
- #submit_btn{
294
- max-height: 40px;
295
- }
296
-
297
- .selector{
298
- max-height: 100px;
299
- }
300
- h1{display: flex;align-items: center;justify-content: center;gap: .25em}
301
- *{transition: width 0.5s ease, flex-grow 0.5s ease}
302
- """
303
- def build_demo(embed_mode):
304
- svg_render = gr.Image(label="Rendered SVG", elem_id='svg_render', height=300)
305
- svg_code = gr.Code(label="SVG Code", elem_id='svg_code', interactive=True, lines=5)
306
-
307
- with gr.Blocks(title="StarVector", theme=gr.themes.Default(), css=block_css) as demo:
308
- state = gr.State()
309
- if not embed_mode:
310
- gr.Markdown(title_markdown)
311
- gr.Markdown(sub_title_markdown)
312
- with gr.Row():
313
- with gr.Column(scale=4):
314
- task_selector = gr.Dropdown(
315
- choices=["Image2SVG", "Text2SVG"],
316
- value="Image2SVG",
317
- label="Task",
318
- interactive=True,
319
- show_label=True,
320
- container=True,
321
- elem_id="task_selector",
322
- elem_classes=["selector"],
323
- )
324
- model_selector = gr.Dropdown(
325
- choices=models,
326
- value=models[0] if len(models) > 0 else "",
327
- label="Model",
328
- interactive=True,
329
- show_label=True,
330
- container=True,
331
- elem_classes=["selector"],
332
- )
333
-
334
- imagebox = gr.Image(type="pil", visible=True, elem_id="imagebox")
335
- image_process_mode = gr.Radio(
336
- ["Resize", "Pad", "Default"],
337
- value="Pad",
338
- label="Preprocess for non-square image", visible=False)
339
-
340
- # Text input
341
- text_caption = gr.Textbox(label="Text Caption", visible=True, value="The icon of a yellow star", elem_id="text_caption")
342
-
343
- cur_dir = os.path.dirname(os.path.abspath(__file__))
344
- gr.Examples(examples=[
345
- [f"{cur_dir}/examples/sample-4.png"],
346
- [f"{cur_dir}/examples/sample-7.png"],
347
- [f"{cur_dir}/examples/sample-16.png"],
348
- [f"{cur_dir}/examples/sample-17.png"],
349
- [f"{cur_dir}/examples/sample-18.png"],
350
- [f"{cur_dir}/examples/sample-0.png"],
351
- [f"{cur_dir}/examples/sample-1.png"],
352
- [f"{cur_dir}/examples/sample-6.png"],
353
- ], inputs=[imagebox], elem_id="examples")
354
-
355
- submit_btn = gr.Button(value="Send", variant="primary", elem_id="submit_btn", interactive=True)
356
-
357
- with gr.Accordion("Parameters", open=False):
358
- num_beams = gr.Slider(minimum=1, maximum=10, value=1, step=1, interactive=True, label="Num Beams", visible=False,)
359
- temperature = gr.Slider(minimum=0.0, maximum=2.0, value=0.9, step=0.05, interactive=True, label="Temperature",)
360
- len_penalty = gr.Slider(minimum=0.0, maximum=2.0, value=0.6, step=0.05, interactive=True, label="Length Penalty",)
361
- top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.05, interactive=True, label="Top P",)
362
- max_output_tokens = gr.Slider(minimum=0, maximum=8192, value=8192, step=64, interactive=True, label="Max output tokens",)
363
-
364
- with gr.Column(scale=9):
365
- with gr.Row():
366
- svg_code.render()
367
- with gr.Row():
368
- svg_render.render()
369
-
370
- with gr.Row(elem_id="buttons") as button_row:
371
- upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
372
- downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
373
- flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
374
- stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False, visible=False)
375
- regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False, visible=False)
376
- clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
377
- download_btn = gr.Button(value="Download SVG", interactive=False, visible=False)
378
-
379
- if not embed_mode:
380
- gr.Markdown(tos_markdown)
381
- gr.Markdown(learn_more_markdown)
382
- url_params = gr.JSON(visible=False)
383
-
384
- # Register listeners
385
- btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn, stop_btn, download_btn]
386
- upvote_btn.click(
387
- upvote_last_response,
388
- [state, model_selector],
389
- [upvote_btn, downvote_btn, flag_btn],
390
- queue=False
391
- )
392
- downvote_btn.click(
393
- downvote_last_response,
394
- [state, model_selector],
395
- [upvote_btn, downvote_btn, flag_btn],
396
- queue=False
397
- )
398
- flag_btn.click(
399
- flag_last_response,
400
- [state, model_selector],
401
- [upvote_btn, downvote_btn, flag_btn],
402
- queue=False
403
- )
404
-
405
- regenerate_btn.click(
406
- regenerate,
407
- [state, image_process_mode],
408
- [state, svg_code, svg_render, imagebox] + btn_list,
409
- queue=False
410
- ).then(
411
- http_bot,
412
- [state, task_selector, text_caption, model_selector, num_beams, temperature, len_penalty, top_p, max_output_tokens],
413
- [state, svg_code, svg_render] + btn_list)
414
-
415
- submit_btn.click(
416
- send_data,
417
- [state, imagebox, image_process_mode, text_caption, task_selector],
418
- [state, svg_code, svg_render, imagebox] + btn_list,
419
- queue=False
420
- ).then(
421
- http_bot,
422
- [state, task_selector, text_caption, model_selector, num_beams, temperature, len_penalty, top_p, max_output_tokens],
423
- [state, svg_code, svg_render] + btn_list
424
- )
425
-
426
- clear_btn.click(
427
- clear_history,
428
- None,
429
- [state, svg_code, svg_render] + btn_list,
430
- queue=False
431
- )
432
-
433
- stop_btn.click(
434
- stop_sampling,
435
- [state, imagebox],
436
- [state, imagebox] + btn_list,
437
- queue=False
438
- ).then(
439
- clear_history,
440
- None,
441
- [state, svg_code, svg_render] + btn_list,
442
- queue=False
443
- )
444
-
445
- download_btn.click(
446
- download_files,
447
- [state],
448
- None,
449
- queue=False
450
- )
451
- task_selector.change(
452
- update_task,
453
- inputs=[task_selector],
454
- outputs=[len_penalty, temperature, top_p, model_selector],
455
- queue=False,
456
- _js="""
457
- function(task) {
458
- var imageBoxElement = document.getElementById("imagebox");
459
- var textCaptionElement = document.getElementById("text_caption");
460
- var examplesElement = document.getElementById("examples");
461
- if (task === "Text2SVG") {
462
- imageBoxElement.style.display = "none";
463
- textCaptionElement.style.display = "block";
464
- examplesElement.style.display = "none";
465
- } else if (task === "Image2SVG") {
466
- imageBoxElement.style.display = "block";
467
- textCaptionElement.style.display = "none";
468
- examplesElement.style.display = "block";
469
- }
470
- return task;
471
- }
472
- """
473
- )
474
-
475
- if args.model_list_mode == "once":
476
- demo.load(
477
- load_demo,
478
- [url_params, task_selector],
479
- [state, model_selector],
480
- _js="""
481
- function() {
482
- const params = new URLSearchParams(window.location.search);
483
- url_params = Object.fromEntries(params);
484
- console.log(url_params);
485
- return url_params;
486
-
487
- }
488
- """,
489
- queue=False
490
- )
491
- elif args.model_list_mode == "reload":
492
- demo.load(
493
- load_demo_refresh_model_list,
494
- [task_selector],
495
- [state, model_selector],
496
- _js="""
497
- function(task) {
498
- var textCaptionElement = document.getElementById("text_caption");
499
- var autoScrollBottom = true;
500
- textCaptionElement.style.display = "none";
501
- function updateScroll(){
502
- if (autoScrollBottom) {
503
- var element = document.getElementsByClassName("cm-scroller")[0];
504
- element.scrollTop = element.scrollHeight;
505
- }
506
- }
507
- function handleScroll() {
508
- var element = document.getElementsByClassName("cm-scroller")[0];
509
- //if (element.scrollHeight - element.scrollTop === element.clientHeight) {
510
- if (element.scrollHeight - (element.scrollTop + element.clientHeight) < 0.2*(element.scrollTop)) {
511
- // User has scrolled to the bottom, enable auto-scrolling
512
- autoScrollBottom = true;
513
- console.log("bottom");
514
- } else {
515
- console.log("not bottom");
516
- // User has scrolled away from the bottom, disable auto-scrolling
517
- autoScrollBottom = false;
518
- }
519
- }
520
- setInterval(updateScroll,500);
521
- var element = document.getElementsByClassName("cm-scroller")[0];
522
- element.addEventListener("scroll", handleScroll);
523
-
524
- return task;
525
- }
526
-
527
- """,
528
- queue=False,
529
- )
530
-
531
- else:
532
- raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
533
-
534
- return demo
535
-
536
- if __name__ == "__main__":
537
-
538
- parser = argparse.ArgumentParser()
539
- parser.add_argument("--host", type=str, default="0.0.0.0")
540
- parser.add_argument("--port", type=int)
541
- parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
542
- parser.add_argument("--concurrency-count", type=int, default=10)
543
- parser.add_argument("--model-list-mode", type=str, default="once",
544
- choices=["once", "reload"])
545
- parser.add_argument("--share", action="store_true")
546
- parser.add_argument("--moderate", action="store_true")
547
- parser.add_argument("--embed", action="store_true")
548
- args = parser.parse_args()
549
- logger.info(f"args: {args}")
550
-
551
- models = get_model_list()
552
-
553
- logger.info(args)
554
- demo = build_demo(args.embed)
555
- demo.queue(
556
- concurrency_count=args.concurrency_count,
557
- api_open=False
558
- ).launch(
559
- server_name=args.host,
560
- server_port=args.port,
561
- share=args.share
562
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
starvector/serve/model_worker.py DELETED
@@ -1,269 +0,0 @@
1
- """
2
- A model worker executes the model.
3
- """
4
- import argparse
5
- import asyncio
6
- import json
7
- import time
8
- import threading
9
- import uuid
10
- from fastapi import FastAPI, Request, BackgroundTasks
11
- from fastapi.responses import StreamingResponse
12
- import requests
13
- import torch
14
- import uvicorn
15
- from functools import partial
16
- from starvector.serve.constants import WORKER_HEART_BEAT_INTERVAL, CLIP_QUERY_LENGTH
17
- from starvector.serve.util import (build_logger, server_error_msg,
18
- pretty_print_semaphore)
19
- from starvector.model.builder import load_pretrained_model
20
- from starvector.serve.util import process_images, load_image_from_base64
21
- from threading import Thread
22
- from transformers import TextIteratorStreamer
23
-
24
- GB = 1 << 30
25
-
26
- worker_id = str(uuid.uuid4())[:6]
27
- logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
28
- global_counter = 0
29
- model_semaphore = None
30
-
31
- def heart_beat_worker(controller):
32
- while True:
33
- time.sleep(WORKER_HEART_BEAT_INTERVAL)
34
- controller.send_heart_beat()
35
-
36
- class ModelWorker:
37
- def __init__(self, controller_addr, worker_addr,
38
- worker_id, no_register,
39
- model_path, model_base, model_name,
40
- load_8bit, load_4bit, device):
41
- self.controller_addr = controller_addr
42
- self.worker_addr = worker_addr
43
- self.worker_id = worker_id
44
- if model_path.endswith("/"):
45
- model_path = model_path[:-1]
46
- if model_name is None:
47
- model_paths = model_path.split("/")
48
- if model_paths[-1].startswith('checkpoint-'):
49
- self.model_name = model_paths[-2] + "_" + model_paths[-1]
50
- else:
51
- self.model_name = model_paths[-1]
52
- else:
53
- self.model_name = model_name
54
-
55
- if "text2svg" in self.model_name.lower():
56
- self.task = "Text2SVG"
57
- elif "im2svg" in self.model_name.lower():
58
- self.task = "Image2SVG"
59
-
60
- self.device = device
61
- logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
62
- self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
63
- model_path, device=self.device, load_in_8bit=load_8bit, load_in_4bit=load_4bit)
64
- self.model.to(torch.bfloat16)
65
- self.is_multimodal = 'starvector' in self.model_name.lower()
66
-
67
- if not no_register:
68
- self.register_to_controller()
69
- self.heart_beat_thread = threading.Thread(
70
- target=heart_beat_worker, args=(self,))
71
- self.heart_beat_thread.start()
72
-
73
- def register_to_controller(self):
74
- logger.info("Register to controller")
75
-
76
- url = self.controller_addr + "/register_worker"
77
- data = {
78
- "worker_name": self.worker_addr,
79
- "check_heart_beat": True,
80
- "worker_status": self.get_status()
81
- }
82
- r = requests.post(url, json=data)
83
- assert r.status_code == 200
84
-
85
- def send_heart_beat(self):
86
- logger.info(f"Send heart beat. Models: {[self.model_name]}. "
87
- f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
88
- f"global_counter: {global_counter}")
89
-
90
- url = self.controller_addr + "/receive_heart_beat"
91
-
92
- while True:
93
- try:
94
- ret = requests.post(url, json={
95
- "worker_name": self.worker_addr,
96
- "queue_length": self.get_queue_length()}, timeout=5)
97
- exist = ret.json()["exist"]
98
- break
99
- except requests.exceptions.RequestException as e:
100
- logger.error(f"heart beat error: {e}")
101
- time.sleep(5)
102
-
103
- if not exist:
104
- self.register_to_controller()
105
-
106
- def get_queue_length(self):
107
- if model_semaphore is None:
108
- return 0
109
- else:
110
- return args.limit_model_concurrency - model_semaphore._value + (len(
111
- model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
112
-
113
- def get_status(self):
114
- return {
115
- "model_names": [self.model_name],
116
- "speed": 1,
117
- "queue_length": self.get_queue_length(),
118
- }
119
-
120
- @torch.inference_mode()
121
- def generate_stream(self, params):
122
- tokenizer, model, image_processor, task = self.tokenizer, self.model, self.image_processor, self.task
123
-
124
- num_beams = int(params.get("num_beams", 1))
125
- temperature = float(params.get("temperature", 1.0))
126
- len_penalty = float(params.get("len_penalty", 1.0))
127
- top_p = float(params.get("top_p", 1.0))
128
- max_context_length = getattr(model.config, 'max_position_embeddings', 8192)
129
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=False, skip_special_tokens=True, timeout=15)
130
- prompt = params["prompt"]
131
-
132
- if task == "Image2SVG":
133
- images = params.get("images", None)
134
- for b64_image in images:
135
- if b64_image is not None and self.is_multimodal:
136
- image = load_image_from_base64(b64_image)
137
- image = process_images(image, image_processor)
138
- image = image.to(self.model.device, dtype=torch.float16)
139
- else:
140
- image = None
141
-
142
- max_new_tokens = min(int(params.get("max_new_tokens", 256)), 8192)
143
- max_new_tokens = min(max_new_tokens, max_context_length - CLIP_QUERY_LENGTH)
144
- pre_pend = prompt
145
- batch = {}
146
- batch["image"] = image
147
- generate_method = model.model.generate_im2svg
148
- else:
149
- max_new_tokens = min(int(params.get("max_new_tokens", 128)), 8192)
150
- pre_pend = ""
151
- batch = {}
152
- batch['caption'] = [prompt]
153
- # White PIL image
154
- batch['image'] = torch.zeros((3, 256, 256), dtype=torch.float16).to(self.model.device)
155
- generate_method = model.model.generate_text2svg
156
-
157
- if max_new_tokens < 1:
158
- yield json.dumps({"text": prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
159
- return
160
-
161
- thread = Thread(target=generate_method, kwargs=dict(
162
- batch=batch,
163
- prompt=prompt,
164
- use_nucleus_sampling=True,
165
- num_beams=num_beams,
166
- temperature=temperature,
167
- length_penalty=len_penalty,
168
- top_p=top_p,
169
- max_length=max_new_tokens,
170
- streamer=streamer,
171
- ))
172
- thread.start()
173
-
174
- generated_text = pre_pend
175
- for new_text in streamer:
176
- if new_text == " ":
177
- continue
178
- generated_text += new_text
179
- # if generated_text.endswith(stop_str):
180
- # generated_text = generated_text[:-len(stop_str)]
181
- yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
182
-
183
- def generate_stream_gate(self, params):
184
- try:
185
- for x in self.generate_stream(params):
186
- yield x
187
- except ValueError as e:
188
- print("Caught ValueError:", e)
189
- ret = {
190
- "text": server_error_msg,
191
- "error_code": 1,
192
- }
193
- yield json.dumps(ret).encode() + b"\0"
194
- except torch.cuda.CudaError as e:
195
- print("Caught torch.cuda.CudaError:", e)
196
- ret = {
197
- "text": server_error_msg,
198
- "error_code": 1,
199
- }
200
- yield json.dumps(ret).encode() + b"\0"
201
- except Exception as e:
202
- print("Caught Unknown Error", e)
203
- ret = {
204
- "text": server_error_msg,
205
- "error_code": 1,
206
- }
207
- yield json.dumps(ret).encode() + b"\0"
208
-
209
- app = FastAPI()
210
-
211
- def release_model_semaphore(fn=None):
212
- model_semaphore.release()
213
- if fn is not None:
214
- fn()
215
-
216
- @app.post("/worker_generate_stream")
217
- async def generate_stream(request: Request):
218
- global model_semaphore, global_counter
219
- global_counter += 1
220
- params = await request.json()
221
-
222
- if model_semaphore is None:
223
- model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
224
- await model_semaphore.acquire()
225
- worker.send_heart_beat()
226
- generator = worker.generate_stream_gate(params)
227
- background_tasks = BackgroundTasks()
228
- background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
229
- return StreamingResponse(generator, background=background_tasks)
230
-
231
- @app.post("/worker_get_status")
232
- async def get_status(request: Request):
233
- return worker.get_status()
234
-
235
- if __name__ == "__main__":
236
- parser = argparse.ArgumentParser()
237
- parser.add_argument("--host", type=str, default="localhost")
238
- parser.add_argument("--port", type=int, default=21002)
239
- parser.add_argument("--worker-address", type=str,
240
- default="http://localhost:21002")
241
- parser.add_argument("--controller-address", type=str,
242
- default="http://localhost:21001")
243
- parser.add_argument("--model-path", type=str, default="joanrodai/starvector-1.4b")
244
- parser.add_argument("--model-base", type=str, default=None)
245
- parser.add_argument("--model-name", type=str)
246
- parser.add_argument("--device", type=str, default="cuda")
247
- parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `starvector` is included in the model path.")
248
- parser.add_argument("--limit-model-concurrency", type=int, default=5)
249
- parser.add_argument("--stream-interval", type=int, default=1)
250
- parser.add_argument("--no-register", action="store_true")
251
- parser.add_argument("--load-8bit", action="store_true")
252
- parser.add_argument("--load-4bit", action="store_true")
253
- args = parser.parse_args()
254
- logger.info(f"args: {args}")
255
-
256
- if args.multi_modal:
257
- logger.warning("Multimodal mode is automatically detected with model name, please make sure `starvector` is included in the model path.")
258
-
259
- worker = ModelWorker(args.controller_address,
260
- args.worker_address,
261
- worker_id,
262
- args.no_register,
263
- args.model_path,
264
- args.model_base,
265
- args.model_name,
266
- args.load_8bit,
267
- args.load_4bit,
268
- args.device)
269
- uvicorn.run(app, host=args.host, port=args.port, log_level="info")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
starvector/serve/vllm_api_gradio/gradio_web_server.py CHANGED
@@ -231,7 +231,6 @@ def http_bot(state, task_selector, text_caption, model_selector, num_beams, temp
231
 
232
  yield (state, st[-1][1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, disable_btn)
233
  return
234
- time.sleep(0.01)
235
  except requests.exceptions.RequestException as e:
236
  state.messages[-1][-1] = server_error_msg
237
  yield (state, None, None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, disable_btn)
 
231
 
232
  yield (state, st[-1][1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, disable_btn)
233
  return
 
234
  except requests.exceptions.RequestException as e:
235
  state.messages[-1][-1] = server_error_msg
236
  yield (state, None, None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, disable_btn)