optimize
Browse files
.DS_Store
CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
|
|
starvector/serve/controller.py
DELETED
@@ -1,293 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
A controller manages distributed workers.
|
3 |
-
It sends worker addresses to clients.
|
4 |
-
"""
|
5 |
-
import argparse
|
6 |
-
import asyncio
|
7 |
-
import dataclasses
|
8 |
-
from enum import Enum, auto
|
9 |
-
import json
|
10 |
-
import logging
|
11 |
-
import time
|
12 |
-
from typing import List, Union
|
13 |
-
import threading
|
14 |
-
|
15 |
-
from fastapi import FastAPI, Request
|
16 |
-
from fastapi.responses import StreamingResponse
|
17 |
-
import numpy as np
|
18 |
-
import requests
|
19 |
-
import uvicorn
|
20 |
-
|
21 |
-
from starvector.serve.constants import CONTROLLER_HEART_BEAT_EXPIRATION
|
22 |
-
from starvector.serve.util import build_logger, server_error_msg
|
23 |
-
|
24 |
-
logger = build_logger("controller", "controller.log")
|
25 |
-
|
26 |
-
class DispatchMethod(Enum):
|
27 |
-
LOTTERY = auto()
|
28 |
-
SHORTEST_QUEUE = auto()
|
29 |
-
|
30 |
-
@classmethod
|
31 |
-
def from_str(cls, name):
|
32 |
-
if name == "lottery":
|
33 |
-
return cls.LOTTERY
|
34 |
-
elif name == "shortest_queue":
|
35 |
-
return cls.SHORTEST_QUEUE
|
36 |
-
else:
|
37 |
-
raise ValueError(f"Invalid dispatch method")
|
38 |
-
|
39 |
-
|
40 |
-
@dataclasses.dataclass
|
41 |
-
class WorkerInfo:
|
42 |
-
model_names: List[str]
|
43 |
-
speed: int
|
44 |
-
queue_length: int
|
45 |
-
check_heart_beat: bool
|
46 |
-
last_heart_beat: str
|
47 |
-
|
48 |
-
|
49 |
-
def heart_beat_controller(controller):
|
50 |
-
while True:
|
51 |
-
time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
|
52 |
-
controller.remove_stable_workers_by_expiration()
|
53 |
-
|
54 |
-
|
55 |
-
class Controller:
|
56 |
-
def __init__(self, dispatch_method: str):
|
57 |
-
# Dict[str -> WorkerInfo]
|
58 |
-
self.worker_info = {}
|
59 |
-
self.dispatch_method = DispatchMethod.from_str(dispatch_method)
|
60 |
-
|
61 |
-
self.heart_beat_thread = threading.Thread(
|
62 |
-
target=heart_beat_controller, args=(self,))
|
63 |
-
self.heart_beat_thread.start()
|
64 |
-
|
65 |
-
logger.info("Init controller")
|
66 |
-
|
67 |
-
def register_worker(self, worker_name: str, check_heart_beat: bool,
|
68 |
-
worker_status: dict):
|
69 |
-
if worker_name not in self.worker_info:
|
70 |
-
logger.info(f"Register a new worker: {worker_name}")
|
71 |
-
else:
|
72 |
-
logger.info(f"Register an existing worker: {worker_name}")
|
73 |
-
|
74 |
-
if not worker_status:
|
75 |
-
worker_status = self.get_worker_status(worker_name)
|
76 |
-
if not worker_status:
|
77 |
-
return False
|
78 |
-
|
79 |
-
self.worker_info[worker_name] = WorkerInfo(
|
80 |
-
worker_status["model_names"], worker_status["speed"], worker_status["queue_length"],
|
81 |
-
check_heart_beat, time.time())
|
82 |
-
|
83 |
-
logger.info(f"Register done: {worker_name}, {worker_status}")
|
84 |
-
return True
|
85 |
-
|
86 |
-
def get_worker_status(self, worker_name: str):
|
87 |
-
try:
|
88 |
-
r = requests.post(worker_name + "/worker_get_status", timeout=5)
|
89 |
-
except requests.exceptions.RequestException as e:
|
90 |
-
logger.error(f"Get status fails: {worker_name}, {e}")
|
91 |
-
return None
|
92 |
-
|
93 |
-
if r.status_code != 200:
|
94 |
-
logger.error(f"Get status fails: {worker_name}, {r}")
|
95 |
-
return None
|
96 |
-
|
97 |
-
return r.json()
|
98 |
-
|
99 |
-
def remove_worker(self, worker_name: str):
|
100 |
-
del self.worker_info[worker_name]
|
101 |
-
|
102 |
-
def refresh_all_workers(self):
|
103 |
-
old_info = dict(self.worker_info)
|
104 |
-
self.worker_info = {}
|
105 |
-
|
106 |
-
for w_name, w_info in old_info.items():
|
107 |
-
if not self.register_worker(w_name, w_info.check_heart_beat, None):
|
108 |
-
logger.info(f"Remove stale worker: {w_name}")
|
109 |
-
|
110 |
-
def list_models(self):
|
111 |
-
model_names = set()
|
112 |
-
|
113 |
-
for w_name, w_info in self.worker_info.items():
|
114 |
-
model_names.update(w_info.model_names)
|
115 |
-
|
116 |
-
return list(model_names)
|
117 |
-
|
118 |
-
def get_worker_address(self, model_name: str):
|
119 |
-
if self.dispatch_method == DispatchMethod.LOTTERY:
|
120 |
-
worker_names = []
|
121 |
-
worker_speeds = []
|
122 |
-
for w_name, w_info in self.worker_info.items():
|
123 |
-
if model_name in w_info.model_names:
|
124 |
-
worker_names.append(w_name)
|
125 |
-
worker_speeds.append(w_info.speed)
|
126 |
-
worker_speeds = np.array(worker_speeds, dtype=np.float32)
|
127 |
-
norm = np.sum(worker_speeds)
|
128 |
-
if norm < 1e-4:
|
129 |
-
return ""
|
130 |
-
worker_speeds = worker_speeds / norm
|
131 |
-
if True: # Directly return address
|
132 |
-
pt = np.random.choice(np.arange(len(worker_names)),
|
133 |
-
p=worker_speeds)
|
134 |
-
worker_name = worker_names[pt]
|
135 |
-
return worker_name
|
136 |
-
|
137 |
-
# Check status before returning
|
138 |
-
while True:
|
139 |
-
pt = np.random.choice(np.arange(len(worker_names)),
|
140 |
-
p=worker_speeds)
|
141 |
-
worker_name = worker_names[pt]
|
142 |
-
|
143 |
-
if self.get_worker_status(worker_name):
|
144 |
-
break
|
145 |
-
else:
|
146 |
-
self.remove_worker(worker_name)
|
147 |
-
worker_speeds[pt] = 0
|
148 |
-
norm = np.sum(worker_speeds)
|
149 |
-
if norm < 1e-4:
|
150 |
-
return ""
|
151 |
-
worker_speeds = worker_speeds / norm
|
152 |
-
continue
|
153 |
-
return worker_name
|
154 |
-
elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
|
155 |
-
worker_names = []
|
156 |
-
worker_qlen = []
|
157 |
-
for w_name, w_info in self.worker_info.items():
|
158 |
-
if model_name in w_info.model_names:
|
159 |
-
worker_names.append(w_name)
|
160 |
-
worker_qlen.append(w_info.queue_length / w_info.speed)
|
161 |
-
if len(worker_names) == 0:
|
162 |
-
return ""
|
163 |
-
min_index = np.argmin(worker_qlen)
|
164 |
-
w_name = worker_names[min_index]
|
165 |
-
self.worker_info[w_name].queue_length += 1
|
166 |
-
logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}")
|
167 |
-
return w_name
|
168 |
-
else:
|
169 |
-
raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
|
170 |
-
|
171 |
-
def receive_heart_beat(self, worker_name: str, queue_length: int):
|
172 |
-
if worker_name not in self.worker_info:
|
173 |
-
logger.info(f"Receive unknown heart beat. {worker_name}")
|
174 |
-
return False
|
175 |
-
|
176 |
-
self.worker_info[worker_name].queue_length = queue_length
|
177 |
-
self.worker_info[worker_name].last_heart_beat = time.time()
|
178 |
-
logger.info(f"Receive heart beat. {worker_name}")
|
179 |
-
return True
|
180 |
-
|
181 |
-
def remove_stable_workers_by_expiration(self):
|
182 |
-
expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
|
183 |
-
to_delete = []
|
184 |
-
for worker_name, w_info in self.worker_info.items():
|
185 |
-
if w_info.check_heart_beat and w_info.last_heart_beat < expire:
|
186 |
-
to_delete.append(worker_name)
|
187 |
-
|
188 |
-
for worker_name in to_delete:
|
189 |
-
self.remove_worker(worker_name)
|
190 |
-
|
191 |
-
def worker_api_generate_stream(self, params):
|
192 |
-
worker_addr = self.get_worker_address(params["model"])
|
193 |
-
if not worker_addr:
|
194 |
-
logger.info(f"no worker: {params['model']}")
|
195 |
-
ret = {
|
196 |
-
"text": server_error_msg,
|
197 |
-
"error_code": 2,
|
198 |
-
}
|
199 |
-
yield json.dumps(ret).encode() + b"\0"
|
200 |
-
|
201 |
-
try:
|
202 |
-
response = requests.post(worker_addr + "/worker_generate_stream",
|
203 |
-
json=params, stream=True, timeout=5)
|
204 |
-
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
205 |
-
if chunk:
|
206 |
-
yield chunk + b"\0"
|
207 |
-
except requests.exceptions.RequestException as e:
|
208 |
-
logger.info(f"worker timeout: {worker_addr}")
|
209 |
-
ret = {
|
210 |
-
"text": server_error_msg,
|
211 |
-
"error_code": 3,
|
212 |
-
}
|
213 |
-
yield json.dumps(ret).encode() + b"\0"
|
214 |
-
|
215 |
-
|
216 |
-
# Let the controller act as a worker to achieve hierarchical
|
217 |
-
# management. This can be used to connect isolated sub networks.
|
218 |
-
def worker_api_get_status(self):
|
219 |
-
model_names = set()
|
220 |
-
speed = 0
|
221 |
-
queue_length = 0
|
222 |
-
|
223 |
-
for w_name in self.worker_info:
|
224 |
-
worker_status = self.get_worker_status(w_name)
|
225 |
-
if worker_status is not None:
|
226 |
-
model_names.update(worker_status["model_names"])
|
227 |
-
speed += worker_status["speed"]
|
228 |
-
queue_length += worker_status["queue_length"]
|
229 |
-
|
230 |
-
return {
|
231 |
-
"model_names": list(model_names),
|
232 |
-
"speed": speed,
|
233 |
-
"queue_length": queue_length,
|
234 |
-
}
|
235 |
-
|
236 |
-
|
237 |
-
app = FastAPI()
|
238 |
-
|
239 |
-
@app.post("/register_worker")
|
240 |
-
async def register_worker(request: Request):
|
241 |
-
data = await request.json()
|
242 |
-
controller.register_worker(
|
243 |
-
data["worker_name"], data["check_heart_beat"],
|
244 |
-
data.get("worker_status", None))
|
245 |
-
|
246 |
-
@app.post("/refresh_all_workers")
|
247 |
-
async def refresh_all_workers():
|
248 |
-
models = controller.refresh_all_workers()
|
249 |
-
|
250 |
-
|
251 |
-
@app.post("/list_models")
|
252 |
-
async def list_models():
|
253 |
-
models = controller.list_models()
|
254 |
-
return {"models": models}
|
255 |
-
|
256 |
-
|
257 |
-
@app.post("/get_worker_address")
|
258 |
-
async def get_worker_address(request: Request):
|
259 |
-
data = await request.json()
|
260 |
-
addr = controller.get_worker_address(data["model"])
|
261 |
-
return {"address": addr}
|
262 |
-
|
263 |
-
@app.post("/receive_heart_beat")
|
264 |
-
async def receive_heart_beat(request: Request):
|
265 |
-
data = await request.json()
|
266 |
-
exist = controller.receive_heart_beat(
|
267 |
-
data["worker_name"], data["queue_length"])
|
268 |
-
return {"exist": exist}
|
269 |
-
|
270 |
-
|
271 |
-
@app.post("/worker_generate_stream")
|
272 |
-
async def worker_api_generate_stream(request: Request):
|
273 |
-
params = await request.json()
|
274 |
-
generator = controller.worker_api_generate_stream(params)
|
275 |
-
return StreamingResponse(generator)
|
276 |
-
|
277 |
-
|
278 |
-
@app.post("/worker_get_status")
|
279 |
-
async def worker_api_get_status(request: Request):
|
280 |
-
return controller.worker_api_get_status()
|
281 |
-
|
282 |
-
|
283 |
-
if __name__ == "__main__":
|
284 |
-
parser = argparse.ArgumentParser()
|
285 |
-
parser.add_argument("--host", type=str, default="localhost")
|
286 |
-
parser.add_argument("--port", type=int, default=21001)
|
287 |
-
parser.add_argument("--dispatch-method", type=str, choices=[
|
288 |
-
"lottery", "shortest_queue"], default="shortest_queue")
|
289 |
-
args = parser.parse_args()
|
290 |
-
logger.info(f"args: {args}")
|
291 |
-
|
292 |
-
controller = Controller(args.dispatch_method)
|
293 |
-
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
starvector/serve/gradio_demo_with_updated_gradio.py
DELETED
@@ -1,432 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import datetime
|
3 |
-
import json
|
4 |
-
import os
|
5 |
-
import time
|
6 |
-
import gradio as gr
|
7 |
-
import requests
|
8 |
-
from starvector.serve.conversation import default_conversation
|
9 |
-
from starvector.serve.constants import LOGDIR, CLIP_QUERY_LENGTH
|
10 |
-
from starvector.serve.util import (build_logger, server_error_msg)
|
11 |
-
|
12 |
-
logger = build_logger("gradio_web_server", "gradio_web_server.log")
|
13 |
-
headers = {"User-Agent": "StarVector Client"}
|
14 |
-
|
15 |
-
no_change_btn = gr.Button()
|
16 |
-
enable_btn = gr.Button(interactive=True)
|
17 |
-
disable_btn = gr.Button(interactive=False)
|
18 |
-
|
19 |
-
priority = {
|
20 |
-
"starvector-1.4b": "aaaaaaa",
|
21 |
-
}
|
22 |
-
|
23 |
-
def get_conv_log_filename():
|
24 |
-
t = datetime.datetime.now()
|
25 |
-
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
|
26 |
-
return name
|
27 |
-
|
28 |
-
def get_model_list():
|
29 |
-
ret = requests.post(args.controller_url + "/refresh_all_workers")
|
30 |
-
assert ret.status_code == 200
|
31 |
-
ret = requests.post(args.controller_url + "/list_models")
|
32 |
-
models = ret.json()["models"]
|
33 |
-
models.sort(key=lambda x: priority.get(x, x))
|
34 |
-
logger.info(f"Models: {models}")
|
35 |
-
return models
|
36 |
-
|
37 |
-
get_window_url_params = """
|
38 |
-
function() {
|
39 |
-
const params = new URLSearchParams(window.location.search);
|
40 |
-
url_params = Object.fromEntries(params);
|
41 |
-
console.log(url_params);
|
42 |
-
return url_params;
|
43 |
-
}
|
44 |
-
"""
|
45 |
-
|
46 |
-
def load_demo(url_params, request: gr.Request):
|
47 |
-
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
48 |
-
|
49 |
-
dropdown_update = gr.Dropdown(visible=True)
|
50 |
-
if "model" in url_params:
|
51 |
-
model = url_params["model"]
|
52 |
-
if model in models:
|
53 |
-
dropdown_update = gr.Dropdown(value=model, visible=True)
|
54 |
-
|
55 |
-
state = default_conversation.copy()
|
56 |
-
return state, dropdown_update
|
57 |
-
|
58 |
-
|
59 |
-
def load_demo_refresh_model_list(request: gr.Request):
|
60 |
-
logger.info(f"load_demo. ip: {request.client.host}")
|
61 |
-
models = get_model_list()
|
62 |
-
state = default_conversation.copy()
|
63 |
-
dropdown_update = gr.Dropdown(
|
64 |
-
choices=models,
|
65 |
-
value=models[0] if len(models) > 0 else ""
|
66 |
-
)
|
67 |
-
return state, dropdown_update
|
68 |
-
|
69 |
-
def vote_last_response(state, vote_type, model_selector, request: gr.Request):
|
70 |
-
with open(get_conv_log_filename(), "a") as fout:
|
71 |
-
data = {
|
72 |
-
"tstamp": round(time.time(), 4),
|
73 |
-
"type": vote_type,
|
74 |
-
"model": model_selector,
|
75 |
-
"state": state.dict(),
|
76 |
-
"ip": request.client.host,
|
77 |
-
}
|
78 |
-
fout.write(json.dumps(data) + "\n")
|
79 |
-
|
80 |
-
def upvote_last_response(state, model_selector, request: gr.Request):
|
81 |
-
logger.info(f"upvote. ip: {request.client.host}")
|
82 |
-
vote_last_response(state, "upvote", model_selector, request)
|
83 |
-
return ("",) + (disable_btn,) * 3
|
84 |
-
|
85 |
-
def downvote_last_response(state, model_selector, request: gr.Request):
|
86 |
-
logger.info(f"downvote. ip: {request.client.host}")
|
87 |
-
vote_last_response(state, "downvote", model_selector, request)
|
88 |
-
return ("",) + (disable_btn,) * 3
|
89 |
-
|
90 |
-
def flag_last_response(state, model_selector, request: gr.Request):
|
91 |
-
logger.info(f"flag. ip: {request.client.host}")
|
92 |
-
vote_last_response(state, "flag", model_selector, request)
|
93 |
-
return ("",) + (disable_btn,) * 3
|
94 |
-
|
95 |
-
def regenerate(state, image_process_mode, request: gr.Request):
|
96 |
-
logger.info(f"regenerate. ip: {request.client.host}")
|
97 |
-
state.messages[-1][-1] = None
|
98 |
-
prev_human_msg = state.messages[-2]
|
99 |
-
if type(prev_human_msg[1]) in (tuple, list):
|
100 |
-
prev_human_msg[1] = (prev_human_msg[1][:2], image_process_mode)
|
101 |
-
state.skip_next = False
|
102 |
-
return (state, None, None, None) + (disable_btn,) * 6
|
103 |
-
|
104 |
-
def clear_history(request: gr.Request):
|
105 |
-
logger.info(f"clear_history. ip: {request.client.host}")
|
106 |
-
state = default_conversation.copy()
|
107 |
-
return (state, None, None) + (disable_btn,) * 6
|
108 |
-
|
109 |
-
def send_image(state, image, image_process_mode, request: gr.Request):
|
110 |
-
logger.info(f"send_image. ip: {request.client.host}.")
|
111 |
-
state.stop_sampling = False
|
112 |
-
if image is None:
|
113 |
-
state.skip_next = True
|
114 |
-
return (state, None, None, image) + (no_change_btn,) * 6
|
115 |
-
|
116 |
-
if image is not None:
|
117 |
-
text = (image, image_process_mode)
|
118 |
-
state.append_message(state.roles[0], text)
|
119 |
-
state.append_message(state.roles[1], "▌")
|
120 |
-
state.skip_next = False
|
121 |
-
msg = state.to_gradio_svg_code()[0][1]
|
122 |
-
return (state, msg, state.to_gradio_svg_render(), image) + (no_change_btn,) * 6
|
123 |
-
|
124 |
-
def stop_sampling(state, image, request: gr.Request):
|
125 |
-
logger.info(f"stop_sampling. ip: {request.client.host}")
|
126 |
-
state.stop_sampling = True
|
127 |
-
return (state, None, None, image) + (disable_btn,) * 6
|
128 |
-
|
129 |
-
def http_bot(state, model_selector, num_beams, temperature, len_penalty, top_p, max_new_tokens, request: gr.Request):
|
130 |
-
logger.info(f"http_bot. ip: {request.client.host}")
|
131 |
-
start_tstamp = time.time()
|
132 |
-
model_name = model_selector
|
133 |
-
|
134 |
-
if state.skip_next:
|
135 |
-
# This generate call is skipped due to invalid inputs
|
136 |
-
yield (state, None, None) + (no_change_btn,) * 6
|
137 |
-
return
|
138 |
-
|
139 |
-
# Query worker address
|
140 |
-
controller_url = args.controller_url
|
141 |
-
ret = requests.post(controller_url + "/get_worker_address",
|
142 |
-
json={"model": model_name})
|
143 |
-
worker_addr = ret.json()["address"]
|
144 |
-
logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
|
145 |
-
|
146 |
-
# No available worker
|
147 |
-
if worker_addr == "":
|
148 |
-
state.messages[-1][-1] = server_error_msg
|
149 |
-
yield (state, None, None, disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn)
|
150 |
-
return
|
151 |
-
|
152 |
-
# Construct prompt
|
153 |
-
prompt = state.get_prompt()
|
154 |
-
|
155 |
-
# Make requests
|
156 |
-
pload = {
|
157 |
-
"model": model_name,
|
158 |
-
"prompt": prompt,
|
159 |
-
"num_beams": int(num_beams),
|
160 |
-
"temperature": float(temperature),
|
161 |
-
"len_penalty": float(len_penalty),
|
162 |
-
"top_p": float(top_p),
|
163 |
-
"max_new_tokens": min(int(max_new_tokens), 8192-CLIP_QUERY_LENGTH),
|
164 |
-
}
|
165 |
-
logger.info(f"==== request ====\n{pload}")
|
166 |
-
|
167 |
-
pload['images'] = state.get_images()
|
168 |
-
|
169 |
-
state.messages[-1][-1] = "▌"
|
170 |
-
yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, disable_btn, disable_btn, enable_btn)
|
171 |
-
|
172 |
-
try:
|
173 |
-
# Stream output
|
174 |
-
if state.stop_sampling:
|
175 |
-
state.messages[1][-1] = "▌"
|
176 |
-
yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn)
|
177 |
-
return
|
178 |
-
|
179 |
-
response = requests.post(worker_addr + "/worker_generate_stream",
|
180 |
-
headers=headers, json=pload, stream=True, timeout=100)
|
181 |
-
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
182 |
-
if chunk:
|
183 |
-
data = json.loads(chunk.decode())
|
184 |
-
if data["error_code"] == 0:
|
185 |
-
# output = data["text"].strip().replace('<', '<').replace('>', '>') # trick to avoid the SVG getting rendered
|
186 |
-
output = data["text"].strip()
|
187 |
-
state.messages[-1][-1] = output + "▌"
|
188 |
-
st = state.to_gradio_svg_code()
|
189 |
-
yield (state, st[-1][1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, enable_btn)
|
190 |
-
else:
|
191 |
-
output = data["text"] + f" (error_code: {data['error_code']})"
|
192 |
-
state.messages[-1][-1] = output
|
193 |
-
|
194 |
-
yield (state, st[-1][1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn)
|
195 |
-
return
|
196 |
-
time.sleep(0.03)
|
197 |
-
except requests.exceptions.RequestException as e:
|
198 |
-
state.messages[-1][-1] = server_error_msg
|
199 |
-
yield (state, None, None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn)
|
200 |
-
return
|
201 |
-
|
202 |
-
yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (enable_btn,) * 6
|
203 |
-
|
204 |
-
finish_tstamp = time.time()
|
205 |
-
logger.info(f"{output}")
|
206 |
-
|
207 |
-
with open(get_conv_log_filename(), "a") as fout:
|
208 |
-
data = {
|
209 |
-
"tstamp": round(finish_tstamp, 4),
|
210 |
-
"type": "chat",
|
211 |
-
"model": model_name,
|
212 |
-
"start": round(start_tstamp, 4),
|
213 |
-
"finish": round(finish_tstamp, 4),
|
214 |
-
"svg": state.messages[-1][-1],
|
215 |
-
"ip": request.client.host,
|
216 |
-
}
|
217 |
-
fout.write(json.dumps(data) + "\n")
|
218 |
-
|
219 |
-
title_markdown = ("""
|
220 |
-
# 💫 StarVector: Generating Scalable Vector Graphics Code from Images and Text
|
221 |
-
[[Project Page](https://starvector.github.io)] [[Code](https://github.com/joanrod/star-vector)] [[Model](https://huggingface.co/joanrodai/starvector-1.4b)] | 📚 [[StarVector](https://arxiv.org/abs/2312.11556)]
|
222 |
-
""")
|
223 |
-
|
224 |
-
sub_title_markdown = (""" Throw an image and vectorize it! The model expects vector-like images to generate the corresponding svg code.""")
|
225 |
-
tos_markdown = ("""
|
226 |
-
### Terms of use
|
227 |
-
By using this service, users are required to agree to the following terms:
|
228 |
-
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
|
229 |
-
Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
|
230 |
-
For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
|
231 |
-
""")
|
232 |
-
|
233 |
-
|
234 |
-
learn_more_markdown = ("""
|
235 |
-
### License
|
236 |
-
The service is a research preview intended for non-commercial use only. Please contact us if you find any potential violation.
|
237 |
-
""")
|
238 |
-
|
239 |
-
block_css = """
|
240 |
-
|
241 |
-
#buttons button {
|
242 |
-
min-width: min(120px,100%);
|
243 |
-
}
|
244 |
-
|
245 |
-
.gradio-container{
|
246 |
-
max-width: 1200px!important
|
247 |
-
}
|
248 |
-
|
249 |
-
#svg_render{
|
250 |
-
padding: 20px !important;
|
251 |
-
}
|
252 |
-
|
253 |
-
#svg_code{
|
254 |
-
height: 200px !important;
|
255 |
-
overflow: scroll !important;
|
256 |
-
white-space: unset !important;
|
257 |
-
flex-shrink: unset !important;
|
258 |
-
}
|
259 |
-
|
260 |
-
|
261 |
-
h1{display: flex;align-items: center;justify-content: center;gap: .25em}
|
262 |
-
*{transition: width 0.5s ease, flex-grow 0.5s ease}
|
263 |
-
"""
|
264 |
-
|
265 |
-
def build_demo(embed_mode, concurrency_count=10):
|
266 |
-
with gr.Blocks(title="StarVector", theme=gr.themes.Default(), css=block_css) as demo:
|
267 |
-
state = gr.State()
|
268 |
-
if not embed_mode:
|
269 |
-
gr.Markdown(title_markdown)
|
270 |
-
gr.Markdown(sub_title_markdown)
|
271 |
-
with gr.Row():
|
272 |
-
with gr.Column(scale=3):
|
273 |
-
with gr.Row(elem_id="model_selector_row"):
|
274 |
-
model_selector = gr.Dropdown(
|
275 |
-
choices=models,
|
276 |
-
value=models[0] if len(models) > 0 else "",
|
277 |
-
interactive=True,
|
278 |
-
show_label=False,
|
279 |
-
container=False)
|
280 |
-
imagebox = gr.Image(type="pil")
|
281 |
-
image_process_mode = gr.Radio(
|
282 |
-
["Resize", "Pad", "Default"],
|
283 |
-
value="Pad",
|
284 |
-
label="Preprocess for non-square image", visible=False)
|
285 |
-
|
286 |
-
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
287 |
-
gr.Examples(examples=[
|
288 |
-
[f"{cur_dir}/examples/sample-4.png"],
|
289 |
-
[f"{cur_dir}/examples/sample-7.png"],
|
290 |
-
[f"{cur_dir}/examples/sample-16.png"],
|
291 |
-
[f"{cur_dir}/examples/sample-17.png"],
|
292 |
-
[f"{cur_dir}/examples/sample-18.png"],
|
293 |
-
[f"{cur_dir}/examples/sample-0.png"],
|
294 |
-
[f"{cur_dir}/examples/sample-1.png"],
|
295 |
-
[f"{cur_dir}/examples/sample-6.png"],
|
296 |
-
], inputs=[imagebox])
|
297 |
-
|
298 |
-
with gr.Column(scale=1, min_width=50):
|
299 |
-
submit_btn = gr.Button(value="Send", variant="primary")
|
300 |
-
|
301 |
-
with gr.Accordion("Parameters", open=True) as parameter_row:
|
302 |
-
num_beams = gr.Slider(minimum=1, maximum=10, value=1, step=1, interactive=True, label="Num Beams", visible=False,)
|
303 |
-
temperature = gr.Slider(minimum=0.0, maximum=2.0, value=0.8, step=0.05, interactive=True, label="Temperature",)
|
304 |
-
len_penalty = gr.Slider(minimum=0.0, maximum=2.0, value=0.6, step=0.05, interactive=True, label="Length Penalty",)
|
305 |
-
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, interactive=True, label="Top P",)
|
306 |
-
max_output_tokens = gr.Slider(minimum=0, maximum=8192, value=2000, step=64, interactive=True, label="Max output tokens",)
|
307 |
-
|
308 |
-
with gr.Column(scale=8):
|
309 |
-
with gr.Row():
|
310 |
-
svg_code = gr.Code(label="SVG Code", elem_id='svg_code', min_width=200, interactive=False, lines=5)
|
311 |
-
with gr.Row():
|
312 |
-
gr.Image(width=50, height=256, label="Rendered SVG", elem_id='svg_render')
|
313 |
-
with gr.Row(elem_id="buttons") as button_row:
|
314 |
-
upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
|
315 |
-
downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
|
316 |
-
flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
|
317 |
-
stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False, visible=False)
|
318 |
-
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False, visible=False)
|
319 |
-
clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
|
320 |
-
|
321 |
-
if not embed_mode:
|
322 |
-
gr.Markdown(tos_markdown)
|
323 |
-
gr.Markdown(learn_more_markdown)
|
324 |
-
url_params = gr.JSON(visible=False)
|
325 |
-
|
326 |
-
# Register listeners
|
327 |
-
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn, stop_btn]
|
328 |
-
upvote_btn.click(
|
329 |
-
upvote_last_response,
|
330 |
-
[state, model_selector],
|
331 |
-
[upvote_btn, downvote_btn, flag_btn],
|
332 |
-
queue=False
|
333 |
-
)
|
334 |
-
downvote_btn.click(
|
335 |
-
downvote_last_response,
|
336 |
-
[state, model_selector],
|
337 |
-
[upvote_btn, downvote_btn, flag_btn],
|
338 |
-
queue=False
|
339 |
-
)
|
340 |
-
flag_btn.click(
|
341 |
-
flag_last_response,
|
342 |
-
[state, model_selector],
|
343 |
-
[upvote_btn, downvote_btn, flag_btn],
|
344 |
-
queue=False
|
345 |
-
)
|
346 |
-
|
347 |
-
regenerate_btn.click(
|
348 |
-
regenerate,
|
349 |
-
[state, image_process_mode],
|
350 |
-
[state, svg_code, svg_render, imagebox] + btn_list,
|
351 |
-
queue=False
|
352 |
-
).then(
|
353 |
-
http_bot,
|
354 |
-
[state, model_selector, num_beams, temperature, len_penalty, top_p, max_output_tokens],
|
355 |
-
[state, svg_code, svg_render] + btn_list,
|
356 |
-
concurrency_limit=concurrency_count
|
357 |
-
)
|
358 |
-
|
359 |
-
submit_btn.click(
|
360 |
-
send_image,
|
361 |
-
[state, imagebox, image_process_mode],
|
362 |
-
[state, svg_code, svg_render, imagebox] + btn_list,
|
363 |
-
queue=False
|
364 |
-
).then(
|
365 |
-
http_bot,
|
366 |
-
[state, model_selector, num_beams, temperature, len_penalty, top_p, max_output_tokens],
|
367 |
-
[state, svg_code, svg_render] + btn_list,
|
368 |
-
concurrency_limit=concurrency_count
|
369 |
-
)
|
370 |
-
|
371 |
-
clear_btn.click(
|
372 |
-
clear_history,
|
373 |
-
None,
|
374 |
-
[state, svg_code, svg_render] + btn_list,
|
375 |
-
queue=False
|
376 |
-
)
|
377 |
-
|
378 |
-
stop_btn.click(
|
379 |
-
stop_sampling,
|
380 |
-
[state, imagebox],
|
381 |
-
[state, imagebox] + btn_list,
|
382 |
-
queue=False
|
383 |
-
).then(
|
384 |
-
clear_history,
|
385 |
-
None,
|
386 |
-
[state, svg_code, svg_render] + btn_list,
|
387 |
-
queue=False
|
388 |
-
)
|
389 |
-
|
390 |
-
if args.model_list_mode == "once":
|
391 |
-
demo.load(
|
392 |
-
load_demo,
|
393 |
-
[url_params],
|
394 |
-
[state, model_selector],
|
395 |
-
_js=get_window_url_params,
|
396 |
-
)
|
397 |
-
elif args.model_list_mode == "reload":
|
398 |
-
demo.load(
|
399 |
-
load_demo_refresh_model_list,
|
400 |
-
None,
|
401 |
-
[state, model_selector],
|
402 |
-
queue=False
|
403 |
-
)
|
404 |
-
else:
|
405 |
-
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
|
406 |
-
|
407 |
-
return demo
|
408 |
-
|
409 |
-
if __name__ == "__main__":
|
410 |
-
parser = argparse.ArgumentParser()
|
411 |
-
parser.add_argument("--host", type=str, default="0.0.0.0")
|
412 |
-
parser.add_argument("--port", type=int)
|
413 |
-
parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
|
414 |
-
parser.add_argument("--concurrency-count", type=int, default=15)
|
415 |
-
parser.add_argument("--model-list-mode", type=str, default="once", choices=["once", "reload"])
|
416 |
-
parser.add_argument("--share", action="store_true")
|
417 |
-
parser.add_argument("--moderate", action="store_true")
|
418 |
-
parser.add_argument("--embed", action="store_true")
|
419 |
-
args = parser.parse_args()
|
420 |
-
logger.info(f"args: {args}")
|
421 |
-
|
422 |
-
models = get_model_list()
|
423 |
-
|
424 |
-
logger.info(args)
|
425 |
-
demo = build_demo(args.embed, concurrency_count=args.concurrency_count)
|
426 |
-
demo.queue(
|
427 |
-
api_open=False
|
428 |
-
).launch(
|
429 |
-
server_name=args.host,
|
430 |
-
server_port=args.port,
|
431 |
-
share=args.share
|
432 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
starvector/serve/gradio_web_server.py
DELETED
@@ -1,562 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import datetime
|
3 |
-
import json
|
4 |
-
import os
|
5 |
-
import time
|
6 |
-
import gradio as gr
|
7 |
-
import requests
|
8 |
-
from starvector.serve.conversation import default_conversation
|
9 |
-
from starvector.serve.constants import LOGDIR, CLIP_QUERY_LENGTH
|
10 |
-
from starvector.serve.util import (build_logger, server_error_msg)
|
11 |
-
|
12 |
-
logger = build_logger("gradio_web_server", "gradio_web_server.log")
|
13 |
-
headers = {"User-Agent": "StarVector Client"}
|
14 |
-
|
15 |
-
no_change_btn = gr.Button.update()
|
16 |
-
enable_btn = gr.Button.update(interactive=True)
|
17 |
-
disable_btn = gr.Button.update(interactive=False)
|
18 |
-
|
19 |
-
priority = {
|
20 |
-
"starvector-1b-im2svg": "aaaaaaa",
|
21 |
-
}
|
22 |
-
|
23 |
-
def get_conv_log_filename():
|
24 |
-
t = datetime.datetime.now()
|
25 |
-
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
|
26 |
-
return name
|
27 |
-
|
28 |
-
def get_model_list():
|
29 |
-
ret = requests.post(args.controller_url + "/refresh_all_workers")
|
30 |
-
assert ret.status_code == 200
|
31 |
-
ret = requests.post(args.controller_url + "/list_models")
|
32 |
-
models = ret.json()["models"]
|
33 |
-
models.sort(key=lambda x: priority.get(x, x))
|
34 |
-
logger.info(f"Models: {models}")
|
35 |
-
return models
|
36 |
-
|
37 |
-
def load_demo(url_params, request: gr.Request):
|
38 |
-
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
39 |
-
|
40 |
-
dropdown_update = gr.Dropdown.update(visible=True)
|
41 |
-
if "model" in url_params:
|
42 |
-
model = url_params["model"]
|
43 |
-
if model in models:
|
44 |
-
dropdown_update = gr.Dropdown.update(
|
45 |
-
value=model, visible=True)
|
46 |
-
|
47 |
-
state = default_conversation.copy()
|
48 |
-
return state, dropdown_update
|
49 |
-
|
50 |
-
mapping_model_task = {
|
51 |
-
'Image2SVG': 'im2svg',
|
52 |
-
'Text2SVG': 'text2svg'
|
53 |
-
}
|
54 |
-
|
55 |
-
def get_models_dropdown_from_task(task):
|
56 |
-
models = get_model_list()
|
57 |
-
models = [model for model in models if mapping_model_task[task] in model]
|
58 |
-
dropdown_update = gr.Dropdown.update(
|
59 |
-
choices=models,
|
60 |
-
value=models[0] if len(models) > 0 else ""
|
61 |
-
)
|
62 |
-
return dropdown_update
|
63 |
-
|
64 |
-
|
65 |
-
def load_demo_refresh_model_list(task, request: gr.Request):
|
66 |
-
logger.info(f"load_demo. ip: {request.client.host}")
|
67 |
-
dropdown_update = get_models_dropdown_from_task(task)
|
68 |
-
state = default_conversation.copy()
|
69 |
-
return state, dropdown_update
|
70 |
-
|
71 |
-
def vote_last_response(state, vote_type, model_selector, request: gr.Request):
|
72 |
-
with open(get_conv_log_filename(), "a") as fout:
|
73 |
-
data = {
|
74 |
-
"tstamp": round(time.time(), 4),
|
75 |
-
"type": vote_type,
|
76 |
-
"model": model_selector,
|
77 |
-
"state": state.dict(),
|
78 |
-
"ip": request.client.host,
|
79 |
-
}
|
80 |
-
fout.write(json.dumps(data) + "\n")
|
81 |
-
|
82 |
-
def upvote_last_response(state, model_selector, request: gr.Request):
|
83 |
-
logger.info(f"upvote. ip: {request.client.host}")
|
84 |
-
vote_last_response(state, "upvote", model_selector, request)
|
85 |
-
return ("",) + (disable_btn,) * 7
|
86 |
-
|
87 |
-
def downvote_last_response(state, model_selector, request: gr.Request):
|
88 |
-
logger.info(f"downvote. ip: {request.client.host}")
|
89 |
-
vote_last_response(state, "downvote", model_selector, request)
|
90 |
-
return ("",) + (disable_btn,) * 7
|
91 |
-
|
92 |
-
def flag_last_response(state, model_selector, request: gr.Request):
|
93 |
-
logger.info(f"flag. ip: {request.client.host}")
|
94 |
-
vote_last_response(state, "flag", model_selector, request)
|
95 |
-
return ("",) + (disable_btn,) * 7
|
96 |
-
|
97 |
-
def regenerate(state, image_process_mode, request: gr.Request):
|
98 |
-
logger.info(f"regenerate. ip: {request.client.host}")
|
99 |
-
state.messages[-1][-1] = None
|
100 |
-
prev_human_msg = state.messages[-2]
|
101 |
-
if type(prev_human_msg[1]) in (tuple, list):
|
102 |
-
prev_human_msg[1] = (prev_human_msg[1][:2], image_process_mode)
|
103 |
-
state.skip_next = False
|
104 |
-
return (state, None, None, None) + (disable_btn,) * 7
|
105 |
-
|
106 |
-
def clear_history(request: gr.Request):
|
107 |
-
logger.info(f"clear_history. ip: {request.client.host}")
|
108 |
-
state = default_conversation.copy()
|
109 |
-
return (state, None, None) + (disable_btn,) * 7
|
110 |
-
|
111 |
-
def send_data(state, image, image_process_mode, text_caption, task, request: gr.Request):
|
112 |
-
logger.info(f"send_data. ip: {request.client.host}.")
|
113 |
-
if task == 'Image2SVG':
|
114 |
-
if image is None:
|
115 |
-
state.skip_next = True
|
116 |
-
return (state, None, None, image) + (no_change_btn,) * 7
|
117 |
-
|
118 |
-
if image is not None:
|
119 |
-
image_message = (image, image_process_mode)
|
120 |
-
state.append_message(state.roles[0], image_message)
|
121 |
-
state.append_message(state.roles[1], "▌")
|
122 |
-
state.skip_next = False
|
123 |
-
msg = state.to_gradio_svg_code()[0][1]
|
124 |
-
return (state, msg, state.to_gradio_svg_render(), image) + (no_change_btn,) * 7
|
125 |
-
else:
|
126 |
-
if text_caption is None:
|
127 |
-
state.skip_next = True
|
128 |
-
return (state, None, None, image) + (no_change_btn,) * 7
|
129 |
-
|
130 |
-
state.append_message(state.roles[0], text_caption)
|
131 |
-
state.append_message(state.roles[1], "▌")
|
132 |
-
state.skip_next = False
|
133 |
-
msg = state.to_gradio_svg_code()[0][1]
|
134 |
-
return (state, msg, state.to_gradio_svg_render(), image) + (no_change_btn,) * 7
|
135 |
-
|
136 |
-
def download_files(state, request: gr.Request):
|
137 |
-
logger.info(f"download_files. ip: {request.client.host}")
|
138 |
-
svg_str, image = state.download_files()
|
139 |
-
|
140 |
-
# TODO: Figure out how to download the SVG in the users browser, idk how to do it now
|
141 |
-
|
142 |
-
def update_task(task):
|
143 |
-
dropdown_update = get_models_dropdown_from_task(task)
|
144 |
-
|
145 |
-
if task == "Text2SVG":
|
146 |
-
return 1.0, 0.9, 0.95, dropdown_update
|
147 |
-
else:
|
148 |
-
return 0.6, 0.9, 0.95, dropdown_update
|
149 |
-
|
150 |
-
|
151 |
-
def stop_sampling(state, image, request: gr.Request):
|
152 |
-
logger.info(f"stop_sampling. ip: {request.client.host}")
|
153 |
-
state.stop_sampling = True
|
154 |
-
return (state, None, None, image) + (disable_btn,) * 7
|
155 |
-
|
156 |
-
def http_bot(state, task_selector, text_caption, model_selector, num_beams, temperature, len_penalty, top_p, max_new_tokens, request: gr.Request):
|
157 |
-
logger.info(f"http_bot. ip: {request.client.host}")
|
158 |
-
start_tstamp = time.time()
|
159 |
-
model_name = model_selector
|
160 |
-
|
161 |
-
if state.skip_next:
|
162 |
-
# This generate call is skipped due to invalid inputs
|
163 |
-
yield (state, None, None) + (no_change_btn,) * 7
|
164 |
-
return
|
165 |
-
|
166 |
-
# Query worker address
|
167 |
-
controller_url = args.controller_url
|
168 |
-
ret = requests.post(controller_url + "/get_worker_address",
|
169 |
-
json={"model": model_name})
|
170 |
-
worker_addr = ret.json()["address"]
|
171 |
-
logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
|
172 |
-
|
173 |
-
# No available worker
|
174 |
-
if worker_addr == "":
|
175 |
-
state.messages[-1][-1] = server_error_msg
|
176 |
-
yield (state, None, None, disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, disable_btn)
|
177 |
-
return
|
178 |
-
|
179 |
-
# Construct prompt
|
180 |
-
if task_selector == "Image2SVG":
|
181 |
-
prompt = state.get_image_prompt()
|
182 |
-
else:
|
183 |
-
prompt = text_caption
|
184 |
-
|
185 |
-
# Make requests
|
186 |
-
pload = {
|
187 |
-
"model": model_name,
|
188 |
-
"prompt": prompt,
|
189 |
-
"num_beams": int(num_beams),
|
190 |
-
"temperature": float(temperature),
|
191 |
-
"len_penalty": float(len_penalty),
|
192 |
-
"top_p": float(top_p),
|
193 |
-
"max_new_tokens": min(int(max_new_tokens), 8192-CLIP_QUERY_LENGTH),
|
194 |
-
}
|
195 |
-
logger.info(f"==== request ====\n{pload}")
|
196 |
-
|
197 |
-
pload['images'] = state.get_images()
|
198 |
-
|
199 |
-
state.messages[-1][-1] = "▌"
|
200 |
-
yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
|
201 |
-
|
202 |
-
try:
|
203 |
-
# Stream output
|
204 |
-
if state.stop_sampling:
|
205 |
-
state.messages[1][-1] = "▌"
|
206 |
-
yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, enable_btn)
|
207 |
-
return
|
208 |
-
|
209 |
-
response = requests.post(worker_addr + "/worker_generate_stream",
|
210 |
-
headers=headers, json=pload, stream=True, timeout=10)
|
211 |
-
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
212 |
-
if chunk:
|
213 |
-
data = json.loads(chunk.decode())
|
214 |
-
if data["error_code"] == 0:
|
215 |
-
# output = data["text"].strip().replace('<', '<').replace('>', '>') # trick to avoid the SVG getting rendered
|
216 |
-
output = data["text"].strip()
|
217 |
-
state.messages[-1][-1] = output + "▌"
|
218 |
-
st = state.to_gradio_svg_code()
|
219 |
-
yield (state, st[-1][1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, enable_btn, enable_btn)
|
220 |
-
else:
|
221 |
-
output = data["text"] + f" (error_code: {data['error_code']})"
|
222 |
-
state.messages[-1][-1] = output
|
223 |
-
|
224 |
-
yield (state, st[-1][1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, disable_btn)
|
225 |
-
return
|
226 |
-
time.sleep(0.03)
|
227 |
-
except requests.exceptions.RequestException as e:
|
228 |
-
state.messages[-1][-1] = server_error_msg
|
229 |
-
yield (state, None, None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, disable_btn)
|
230 |
-
return
|
231 |
-
|
232 |
-
yield (state, state.messages[-1][-1], state.to_gradio_svg_render()) + (enable_btn,) * 7
|
233 |
-
|
234 |
-
finish_tstamp = time.time()
|
235 |
-
logger.info(f"{output}")
|
236 |
-
|
237 |
-
with open(get_conv_log_filename(), "a") as fout:
|
238 |
-
data = {
|
239 |
-
"tstamp": round(finish_tstamp, 4),
|
240 |
-
"type": "chat",
|
241 |
-
"model": model_name,
|
242 |
-
"start": round(start_tstamp, 4),
|
243 |
-
"finish": round(finish_tstamp, 4),
|
244 |
-
"svg": state.messages[-1][-1],
|
245 |
-
"ip": request.client.host,
|
246 |
-
}
|
247 |
-
fout.write(json.dumps(data) + "\n")
|
248 |
-
|
249 |
-
title_markdown = ("""
|
250 |
-
# 💫 StarVector: Generating Scalable Vector Graphics Code from Images and Text
|
251 |
-
|
252 |
-
[[Project Page](https://starvector.github.io)] [[Code](https://github.com/joanrod/star-vector)] [[Model](https://huggingface.co/joanrodai/starvector-1.4b)] | 📚 [[StarVector](https://arxiv.org/abs/2312.11556)]""")
|
253 |
-
|
254 |
-
sub_title_markdown = ("""**How does it work?** Select the task you want to perform, and the model will be automatically set. For **Text2SVG**, introduce a prompt in Text Caption. For **Image2SVG**, select an image and vectorize it. \
|
255 |
-
**Note**: The current model works on vector-like images like icons and or vector-like designs.""")
|
256 |
-
tos_markdown = ("""
|
257 |
-
### Terms of use
|
258 |
-
By using this service, users are required to agree to the following terms:
|
259 |
-
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
|
260 |
-
Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
|
261 |
-
For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
|
262 |
-
""")
|
263 |
-
|
264 |
-
learn_more_markdown = ("""
|
265 |
-
### License
|
266 |
-
The service is a research preview intended for non-commercial use only. Please contact us if you find any potential violation.
|
267 |
-
""")
|
268 |
-
|
269 |
-
block_css = """
|
270 |
-
|
271 |
-
#buttons button {
|
272 |
-
min-width: min(120px,100%);
|
273 |
-
}
|
274 |
-
|
275 |
-
.gradio-container{
|
276 |
-
max-width: 1200px!important
|
277 |
-
}
|
278 |
-
|
279 |
-
.ͼ1 .cm-content {
|
280 |
-
white-space: unset !important;
|
281 |
-
flex-shrink: unset !important;
|
282 |
-
}
|
283 |
-
|
284 |
-
.ͼ2p .cm-scroller {
|
285 |
-
max-height: 200px;
|
286 |
-
overflow: scroll;
|
287 |
-
}
|
288 |
-
|
289 |
-
#svg_render{
|
290 |
-
padding: 20px !important;
|
291 |
-
}
|
292 |
-
|
293 |
-
#submit_btn{
|
294 |
-
max-height: 40px;
|
295 |
-
}
|
296 |
-
|
297 |
-
.selector{
|
298 |
-
max-height: 100px;
|
299 |
-
}
|
300 |
-
h1{display: flex;align-items: center;justify-content: center;gap: .25em}
|
301 |
-
*{transition: width 0.5s ease, flex-grow 0.5s ease}
|
302 |
-
"""
|
303 |
-
def build_demo(embed_mode):
|
304 |
-
svg_render = gr.Image(label="Rendered SVG", elem_id='svg_render', height=300)
|
305 |
-
svg_code = gr.Code(label="SVG Code", elem_id='svg_code', interactive=True, lines=5)
|
306 |
-
|
307 |
-
with gr.Blocks(title="StarVector", theme=gr.themes.Default(), css=block_css) as demo:
|
308 |
-
state = gr.State()
|
309 |
-
if not embed_mode:
|
310 |
-
gr.Markdown(title_markdown)
|
311 |
-
gr.Markdown(sub_title_markdown)
|
312 |
-
with gr.Row():
|
313 |
-
with gr.Column(scale=4):
|
314 |
-
task_selector = gr.Dropdown(
|
315 |
-
choices=["Image2SVG", "Text2SVG"],
|
316 |
-
value="Image2SVG",
|
317 |
-
label="Task",
|
318 |
-
interactive=True,
|
319 |
-
show_label=True,
|
320 |
-
container=True,
|
321 |
-
elem_id="task_selector",
|
322 |
-
elem_classes=["selector"],
|
323 |
-
)
|
324 |
-
model_selector = gr.Dropdown(
|
325 |
-
choices=models,
|
326 |
-
value=models[0] if len(models) > 0 else "",
|
327 |
-
label="Model",
|
328 |
-
interactive=True,
|
329 |
-
show_label=True,
|
330 |
-
container=True,
|
331 |
-
elem_classes=["selector"],
|
332 |
-
)
|
333 |
-
|
334 |
-
imagebox = gr.Image(type="pil", visible=True, elem_id="imagebox")
|
335 |
-
image_process_mode = gr.Radio(
|
336 |
-
["Resize", "Pad", "Default"],
|
337 |
-
value="Pad",
|
338 |
-
label="Preprocess for non-square image", visible=False)
|
339 |
-
|
340 |
-
# Text input
|
341 |
-
text_caption = gr.Textbox(label="Text Caption", visible=True, value="The icon of a yellow star", elem_id="text_caption")
|
342 |
-
|
343 |
-
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
344 |
-
gr.Examples(examples=[
|
345 |
-
[f"{cur_dir}/examples/sample-4.png"],
|
346 |
-
[f"{cur_dir}/examples/sample-7.png"],
|
347 |
-
[f"{cur_dir}/examples/sample-16.png"],
|
348 |
-
[f"{cur_dir}/examples/sample-17.png"],
|
349 |
-
[f"{cur_dir}/examples/sample-18.png"],
|
350 |
-
[f"{cur_dir}/examples/sample-0.png"],
|
351 |
-
[f"{cur_dir}/examples/sample-1.png"],
|
352 |
-
[f"{cur_dir}/examples/sample-6.png"],
|
353 |
-
], inputs=[imagebox], elem_id="examples")
|
354 |
-
|
355 |
-
submit_btn = gr.Button(value="Send", variant="primary", elem_id="submit_btn", interactive=True)
|
356 |
-
|
357 |
-
with gr.Accordion("Parameters", open=False):
|
358 |
-
num_beams = gr.Slider(minimum=1, maximum=10, value=1, step=1, interactive=True, label="Num Beams", visible=False,)
|
359 |
-
temperature = gr.Slider(minimum=0.0, maximum=2.0, value=0.9, step=0.05, interactive=True, label="Temperature",)
|
360 |
-
len_penalty = gr.Slider(minimum=0.0, maximum=2.0, value=0.6, step=0.05, interactive=True, label="Length Penalty",)
|
361 |
-
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.05, interactive=True, label="Top P",)
|
362 |
-
max_output_tokens = gr.Slider(minimum=0, maximum=8192, value=8192, step=64, interactive=True, label="Max output tokens",)
|
363 |
-
|
364 |
-
with gr.Column(scale=9):
|
365 |
-
with gr.Row():
|
366 |
-
svg_code.render()
|
367 |
-
with gr.Row():
|
368 |
-
svg_render.render()
|
369 |
-
|
370 |
-
with gr.Row(elem_id="buttons") as button_row:
|
371 |
-
upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
|
372 |
-
downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
|
373 |
-
flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
|
374 |
-
stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False, visible=False)
|
375 |
-
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False, visible=False)
|
376 |
-
clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
|
377 |
-
download_btn = gr.Button(value="Download SVG", interactive=False, visible=False)
|
378 |
-
|
379 |
-
if not embed_mode:
|
380 |
-
gr.Markdown(tos_markdown)
|
381 |
-
gr.Markdown(learn_more_markdown)
|
382 |
-
url_params = gr.JSON(visible=False)
|
383 |
-
|
384 |
-
# Register listeners
|
385 |
-
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn, stop_btn, download_btn]
|
386 |
-
upvote_btn.click(
|
387 |
-
upvote_last_response,
|
388 |
-
[state, model_selector],
|
389 |
-
[upvote_btn, downvote_btn, flag_btn],
|
390 |
-
queue=False
|
391 |
-
)
|
392 |
-
downvote_btn.click(
|
393 |
-
downvote_last_response,
|
394 |
-
[state, model_selector],
|
395 |
-
[upvote_btn, downvote_btn, flag_btn],
|
396 |
-
queue=False
|
397 |
-
)
|
398 |
-
flag_btn.click(
|
399 |
-
flag_last_response,
|
400 |
-
[state, model_selector],
|
401 |
-
[upvote_btn, downvote_btn, flag_btn],
|
402 |
-
queue=False
|
403 |
-
)
|
404 |
-
|
405 |
-
regenerate_btn.click(
|
406 |
-
regenerate,
|
407 |
-
[state, image_process_mode],
|
408 |
-
[state, svg_code, svg_render, imagebox] + btn_list,
|
409 |
-
queue=False
|
410 |
-
).then(
|
411 |
-
http_bot,
|
412 |
-
[state, task_selector, text_caption, model_selector, num_beams, temperature, len_penalty, top_p, max_output_tokens],
|
413 |
-
[state, svg_code, svg_render] + btn_list)
|
414 |
-
|
415 |
-
submit_btn.click(
|
416 |
-
send_data,
|
417 |
-
[state, imagebox, image_process_mode, text_caption, task_selector],
|
418 |
-
[state, svg_code, svg_render, imagebox] + btn_list,
|
419 |
-
queue=False
|
420 |
-
).then(
|
421 |
-
http_bot,
|
422 |
-
[state, task_selector, text_caption, model_selector, num_beams, temperature, len_penalty, top_p, max_output_tokens],
|
423 |
-
[state, svg_code, svg_render] + btn_list
|
424 |
-
)
|
425 |
-
|
426 |
-
clear_btn.click(
|
427 |
-
clear_history,
|
428 |
-
None,
|
429 |
-
[state, svg_code, svg_render] + btn_list,
|
430 |
-
queue=False
|
431 |
-
)
|
432 |
-
|
433 |
-
stop_btn.click(
|
434 |
-
stop_sampling,
|
435 |
-
[state, imagebox],
|
436 |
-
[state, imagebox] + btn_list,
|
437 |
-
queue=False
|
438 |
-
).then(
|
439 |
-
clear_history,
|
440 |
-
None,
|
441 |
-
[state, svg_code, svg_render] + btn_list,
|
442 |
-
queue=False
|
443 |
-
)
|
444 |
-
|
445 |
-
download_btn.click(
|
446 |
-
download_files,
|
447 |
-
[state],
|
448 |
-
None,
|
449 |
-
queue=False
|
450 |
-
)
|
451 |
-
task_selector.change(
|
452 |
-
update_task,
|
453 |
-
inputs=[task_selector],
|
454 |
-
outputs=[len_penalty, temperature, top_p, model_selector],
|
455 |
-
queue=False,
|
456 |
-
_js="""
|
457 |
-
function(task) {
|
458 |
-
var imageBoxElement = document.getElementById("imagebox");
|
459 |
-
var textCaptionElement = document.getElementById("text_caption");
|
460 |
-
var examplesElement = document.getElementById("examples");
|
461 |
-
if (task === "Text2SVG") {
|
462 |
-
imageBoxElement.style.display = "none";
|
463 |
-
textCaptionElement.style.display = "block";
|
464 |
-
examplesElement.style.display = "none";
|
465 |
-
} else if (task === "Image2SVG") {
|
466 |
-
imageBoxElement.style.display = "block";
|
467 |
-
textCaptionElement.style.display = "none";
|
468 |
-
examplesElement.style.display = "block";
|
469 |
-
}
|
470 |
-
return task;
|
471 |
-
}
|
472 |
-
"""
|
473 |
-
)
|
474 |
-
|
475 |
-
if args.model_list_mode == "once":
|
476 |
-
demo.load(
|
477 |
-
load_demo,
|
478 |
-
[url_params, task_selector],
|
479 |
-
[state, model_selector],
|
480 |
-
_js="""
|
481 |
-
function() {
|
482 |
-
const params = new URLSearchParams(window.location.search);
|
483 |
-
url_params = Object.fromEntries(params);
|
484 |
-
console.log(url_params);
|
485 |
-
return url_params;
|
486 |
-
|
487 |
-
}
|
488 |
-
""",
|
489 |
-
queue=False
|
490 |
-
)
|
491 |
-
elif args.model_list_mode == "reload":
|
492 |
-
demo.load(
|
493 |
-
load_demo_refresh_model_list,
|
494 |
-
[task_selector],
|
495 |
-
[state, model_selector],
|
496 |
-
_js="""
|
497 |
-
function(task) {
|
498 |
-
var textCaptionElement = document.getElementById("text_caption");
|
499 |
-
var autoScrollBottom = true;
|
500 |
-
textCaptionElement.style.display = "none";
|
501 |
-
function updateScroll(){
|
502 |
-
if (autoScrollBottom) {
|
503 |
-
var element = document.getElementsByClassName("cm-scroller")[0];
|
504 |
-
element.scrollTop = element.scrollHeight;
|
505 |
-
}
|
506 |
-
}
|
507 |
-
function handleScroll() {
|
508 |
-
var element = document.getElementsByClassName("cm-scroller")[0];
|
509 |
-
//if (element.scrollHeight - element.scrollTop === element.clientHeight) {
|
510 |
-
if (element.scrollHeight - (element.scrollTop + element.clientHeight) < 0.2*(element.scrollTop)) {
|
511 |
-
// User has scrolled to the bottom, enable auto-scrolling
|
512 |
-
autoScrollBottom = true;
|
513 |
-
console.log("bottom");
|
514 |
-
} else {
|
515 |
-
console.log("not bottom");
|
516 |
-
// User has scrolled away from the bottom, disable auto-scrolling
|
517 |
-
autoScrollBottom = false;
|
518 |
-
}
|
519 |
-
}
|
520 |
-
setInterval(updateScroll,500);
|
521 |
-
var element = document.getElementsByClassName("cm-scroller")[0];
|
522 |
-
element.addEventListener("scroll", handleScroll);
|
523 |
-
|
524 |
-
return task;
|
525 |
-
}
|
526 |
-
|
527 |
-
""",
|
528 |
-
queue=False,
|
529 |
-
)
|
530 |
-
|
531 |
-
else:
|
532 |
-
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
|
533 |
-
|
534 |
-
return demo
|
535 |
-
|
536 |
-
if __name__ == "__main__":
|
537 |
-
|
538 |
-
parser = argparse.ArgumentParser()
|
539 |
-
parser.add_argument("--host", type=str, default="0.0.0.0")
|
540 |
-
parser.add_argument("--port", type=int)
|
541 |
-
parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
|
542 |
-
parser.add_argument("--concurrency-count", type=int, default=10)
|
543 |
-
parser.add_argument("--model-list-mode", type=str, default="once",
|
544 |
-
choices=["once", "reload"])
|
545 |
-
parser.add_argument("--share", action="store_true")
|
546 |
-
parser.add_argument("--moderate", action="store_true")
|
547 |
-
parser.add_argument("--embed", action="store_true")
|
548 |
-
args = parser.parse_args()
|
549 |
-
logger.info(f"args: {args}")
|
550 |
-
|
551 |
-
models = get_model_list()
|
552 |
-
|
553 |
-
logger.info(args)
|
554 |
-
demo = build_demo(args.embed)
|
555 |
-
demo.queue(
|
556 |
-
concurrency_count=args.concurrency_count,
|
557 |
-
api_open=False
|
558 |
-
).launch(
|
559 |
-
server_name=args.host,
|
560 |
-
server_port=args.port,
|
561 |
-
share=args.share
|
562 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
starvector/serve/model_worker.py
DELETED
@@ -1,269 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
A model worker executes the model.
|
3 |
-
"""
|
4 |
-
import argparse
|
5 |
-
import asyncio
|
6 |
-
import json
|
7 |
-
import time
|
8 |
-
import threading
|
9 |
-
import uuid
|
10 |
-
from fastapi import FastAPI, Request, BackgroundTasks
|
11 |
-
from fastapi.responses import StreamingResponse
|
12 |
-
import requests
|
13 |
-
import torch
|
14 |
-
import uvicorn
|
15 |
-
from functools import partial
|
16 |
-
from starvector.serve.constants import WORKER_HEART_BEAT_INTERVAL, CLIP_QUERY_LENGTH
|
17 |
-
from starvector.serve.util import (build_logger, server_error_msg,
|
18 |
-
pretty_print_semaphore)
|
19 |
-
from starvector.model.builder import load_pretrained_model
|
20 |
-
from starvector.serve.util import process_images, load_image_from_base64
|
21 |
-
from threading import Thread
|
22 |
-
from transformers import TextIteratorStreamer
|
23 |
-
|
24 |
-
GB = 1 << 30
|
25 |
-
|
26 |
-
worker_id = str(uuid.uuid4())[:6]
|
27 |
-
logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
|
28 |
-
global_counter = 0
|
29 |
-
model_semaphore = None
|
30 |
-
|
31 |
-
def heart_beat_worker(controller):
|
32 |
-
while True:
|
33 |
-
time.sleep(WORKER_HEART_BEAT_INTERVAL)
|
34 |
-
controller.send_heart_beat()
|
35 |
-
|
36 |
-
class ModelWorker:
|
37 |
-
def __init__(self, controller_addr, worker_addr,
|
38 |
-
worker_id, no_register,
|
39 |
-
model_path, model_base, model_name,
|
40 |
-
load_8bit, load_4bit, device):
|
41 |
-
self.controller_addr = controller_addr
|
42 |
-
self.worker_addr = worker_addr
|
43 |
-
self.worker_id = worker_id
|
44 |
-
if model_path.endswith("/"):
|
45 |
-
model_path = model_path[:-1]
|
46 |
-
if model_name is None:
|
47 |
-
model_paths = model_path.split("/")
|
48 |
-
if model_paths[-1].startswith('checkpoint-'):
|
49 |
-
self.model_name = model_paths[-2] + "_" + model_paths[-1]
|
50 |
-
else:
|
51 |
-
self.model_name = model_paths[-1]
|
52 |
-
else:
|
53 |
-
self.model_name = model_name
|
54 |
-
|
55 |
-
if "text2svg" in self.model_name.lower():
|
56 |
-
self.task = "Text2SVG"
|
57 |
-
elif "im2svg" in self.model_name.lower():
|
58 |
-
self.task = "Image2SVG"
|
59 |
-
|
60 |
-
self.device = device
|
61 |
-
logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
|
62 |
-
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
|
63 |
-
model_path, device=self.device, load_in_8bit=load_8bit, load_in_4bit=load_4bit)
|
64 |
-
self.model.to(torch.bfloat16)
|
65 |
-
self.is_multimodal = 'starvector' in self.model_name.lower()
|
66 |
-
|
67 |
-
if not no_register:
|
68 |
-
self.register_to_controller()
|
69 |
-
self.heart_beat_thread = threading.Thread(
|
70 |
-
target=heart_beat_worker, args=(self,))
|
71 |
-
self.heart_beat_thread.start()
|
72 |
-
|
73 |
-
def register_to_controller(self):
|
74 |
-
logger.info("Register to controller")
|
75 |
-
|
76 |
-
url = self.controller_addr + "/register_worker"
|
77 |
-
data = {
|
78 |
-
"worker_name": self.worker_addr,
|
79 |
-
"check_heart_beat": True,
|
80 |
-
"worker_status": self.get_status()
|
81 |
-
}
|
82 |
-
r = requests.post(url, json=data)
|
83 |
-
assert r.status_code == 200
|
84 |
-
|
85 |
-
def send_heart_beat(self):
|
86 |
-
logger.info(f"Send heart beat. Models: {[self.model_name]}. "
|
87 |
-
f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
|
88 |
-
f"global_counter: {global_counter}")
|
89 |
-
|
90 |
-
url = self.controller_addr + "/receive_heart_beat"
|
91 |
-
|
92 |
-
while True:
|
93 |
-
try:
|
94 |
-
ret = requests.post(url, json={
|
95 |
-
"worker_name": self.worker_addr,
|
96 |
-
"queue_length": self.get_queue_length()}, timeout=5)
|
97 |
-
exist = ret.json()["exist"]
|
98 |
-
break
|
99 |
-
except requests.exceptions.RequestException as e:
|
100 |
-
logger.error(f"heart beat error: {e}")
|
101 |
-
time.sleep(5)
|
102 |
-
|
103 |
-
if not exist:
|
104 |
-
self.register_to_controller()
|
105 |
-
|
106 |
-
def get_queue_length(self):
|
107 |
-
if model_semaphore is None:
|
108 |
-
return 0
|
109 |
-
else:
|
110 |
-
return args.limit_model_concurrency - model_semaphore._value + (len(
|
111 |
-
model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
|
112 |
-
|
113 |
-
def get_status(self):
|
114 |
-
return {
|
115 |
-
"model_names": [self.model_name],
|
116 |
-
"speed": 1,
|
117 |
-
"queue_length": self.get_queue_length(),
|
118 |
-
}
|
119 |
-
|
120 |
-
@torch.inference_mode()
|
121 |
-
def generate_stream(self, params):
|
122 |
-
tokenizer, model, image_processor, task = self.tokenizer, self.model, self.image_processor, self.task
|
123 |
-
|
124 |
-
num_beams = int(params.get("num_beams", 1))
|
125 |
-
temperature = float(params.get("temperature", 1.0))
|
126 |
-
len_penalty = float(params.get("len_penalty", 1.0))
|
127 |
-
top_p = float(params.get("top_p", 1.0))
|
128 |
-
max_context_length = getattr(model.config, 'max_position_embeddings', 8192)
|
129 |
-
streamer = TextIteratorStreamer(tokenizer, skip_prompt=False, skip_special_tokens=True, timeout=15)
|
130 |
-
prompt = params["prompt"]
|
131 |
-
|
132 |
-
if task == "Image2SVG":
|
133 |
-
images = params.get("images", None)
|
134 |
-
for b64_image in images:
|
135 |
-
if b64_image is not None and self.is_multimodal:
|
136 |
-
image = load_image_from_base64(b64_image)
|
137 |
-
image = process_images(image, image_processor)
|
138 |
-
image = image.to(self.model.device, dtype=torch.float16)
|
139 |
-
else:
|
140 |
-
image = None
|
141 |
-
|
142 |
-
max_new_tokens = min(int(params.get("max_new_tokens", 256)), 8192)
|
143 |
-
max_new_tokens = min(max_new_tokens, max_context_length - CLIP_QUERY_LENGTH)
|
144 |
-
pre_pend = prompt
|
145 |
-
batch = {}
|
146 |
-
batch["image"] = image
|
147 |
-
generate_method = model.model.generate_im2svg
|
148 |
-
else:
|
149 |
-
max_new_tokens = min(int(params.get("max_new_tokens", 128)), 8192)
|
150 |
-
pre_pend = ""
|
151 |
-
batch = {}
|
152 |
-
batch['caption'] = [prompt]
|
153 |
-
# White PIL image
|
154 |
-
batch['image'] = torch.zeros((3, 256, 256), dtype=torch.float16).to(self.model.device)
|
155 |
-
generate_method = model.model.generate_text2svg
|
156 |
-
|
157 |
-
if max_new_tokens < 1:
|
158 |
-
yield json.dumps({"text": prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
|
159 |
-
return
|
160 |
-
|
161 |
-
thread = Thread(target=generate_method, kwargs=dict(
|
162 |
-
batch=batch,
|
163 |
-
prompt=prompt,
|
164 |
-
use_nucleus_sampling=True,
|
165 |
-
num_beams=num_beams,
|
166 |
-
temperature=temperature,
|
167 |
-
length_penalty=len_penalty,
|
168 |
-
top_p=top_p,
|
169 |
-
max_length=max_new_tokens,
|
170 |
-
streamer=streamer,
|
171 |
-
))
|
172 |
-
thread.start()
|
173 |
-
|
174 |
-
generated_text = pre_pend
|
175 |
-
for new_text in streamer:
|
176 |
-
if new_text == " ":
|
177 |
-
continue
|
178 |
-
generated_text += new_text
|
179 |
-
# if generated_text.endswith(stop_str):
|
180 |
-
# generated_text = generated_text[:-len(stop_str)]
|
181 |
-
yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
|
182 |
-
|
183 |
-
def generate_stream_gate(self, params):
|
184 |
-
try:
|
185 |
-
for x in self.generate_stream(params):
|
186 |
-
yield x
|
187 |
-
except ValueError as e:
|
188 |
-
print("Caught ValueError:", e)
|
189 |
-
ret = {
|
190 |
-
"text": server_error_msg,
|
191 |
-
"error_code": 1,
|
192 |
-
}
|
193 |
-
yield json.dumps(ret).encode() + b"\0"
|
194 |
-
except torch.cuda.CudaError as e:
|
195 |
-
print("Caught torch.cuda.CudaError:", e)
|
196 |
-
ret = {
|
197 |
-
"text": server_error_msg,
|
198 |
-
"error_code": 1,
|
199 |
-
}
|
200 |
-
yield json.dumps(ret).encode() + b"\0"
|
201 |
-
except Exception as e:
|
202 |
-
print("Caught Unknown Error", e)
|
203 |
-
ret = {
|
204 |
-
"text": server_error_msg,
|
205 |
-
"error_code": 1,
|
206 |
-
}
|
207 |
-
yield json.dumps(ret).encode() + b"\0"
|
208 |
-
|
209 |
-
app = FastAPI()
|
210 |
-
|
211 |
-
def release_model_semaphore(fn=None):
|
212 |
-
model_semaphore.release()
|
213 |
-
if fn is not None:
|
214 |
-
fn()
|
215 |
-
|
216 |
-
@app.post("/worker_generate_stream")
|
217 |
-
async def generate_stream(request: Request):
|
218 |
-
global model_semaphore, global_counter
|
219 |
-
global_counter += 1
|
220 |
-
params = await request.json()
|
221 |
-
|
222 |
-
if model_semaphore is None:
|
223 |
-
model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
|
224 |
-
await model_semaphore.acquire()
|
225 |
-
worker.send_heart_beat()
|
226 |
-
generator = worker.generate_stream_gate(params)
|
227 |
-
background_tasks = BackgroundTasks()
|
228 |
-
background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
|
229 |
-
return StreamingResponse(generator, background=background_tasks)
|
230 |
-
|
231 |
-
@app.post("/worker_get_status")
|
232 |
-
async def get_status(request: Request):
|
233 |
-
return worker.get_status()
|
234 |
-
|
235 |
-
if __name__ == "__main__":
|
236 |
-
parser = argparse.ArgumentParser()
|
237 |
-
parser.add_argument("--host", type=str, default="localhost")
|
238 |
-
parser.add_argument("--port", type=int, default=21002)
|
239 |
-
parser.add_argument("--worker-address", type=str,
|
240 |
-
default="http://localhost:21002")
|
241 |
-
parser.add_argument("--controller-address", type=str,
|
242 |
-
default="http://localhost:21001")
|
243 |
-
parser.add_argument("--model-path", type=str, default="joanrodai/starvector-1.4b")
|
244 |
-
parser.add_argument("--model-base", type=str, default=None)
|
245 |
-
parser.add_argument("--model-name", type=str)
|
246 |
-
parser.add_argument("--device", type=str, default="cuda")
|
247 |
-
parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `starvector` is included in the model path.")
|
248 |
-
parser.add_argument("--limit-model-concurrency", type=int, default=5)
|
249 |
-
parser.add_argument("--stream-interval", type=int, default=1)
|
250 |
-
parser.add_argument("--no-register", action="store_true")
|
251 |
-
parser.add_argument("--load-8bit", action="store_true")
|
252 |
-
parser.add_argument("--load-4bit", action="store_true")
|
253 |
-
args = parser.parse_args()
|
254 |
-
logger.info(f"args: {args}")
|
255 |
-
|
256 |
-
if args.multi_modal:
|
257 |
-
logger.warning("Multimodal mode is automatically detected with model name, please make sure `starvector` is included in the model path.")
|
258 |
-
|
259 |
-
worker = ModelWorker(args.controller_address,
|
260 |
-
args.worker_address,
|
261 |
-
worker_id,
|
262 |
-
args.no_register,
|
263 |
-
args.model_path,
|
264 |
-
args.model_base,
|
265 |
-
args.model_name,
|
266 |
-
args.load_8bit,
|
267 |
-
args.load_4bit,
|
268 |
-
args.device)
|
269 |
-
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
starvector/serve/vllm_api_gradio/gradio_web_server.py
CHANGED
@@ -231,7 +231,6 @@ def http_bot(state, task_selector, text_caption, model_selector, num_beams, temp
|
|
231 |
|
232 |
yield (state, st[-1][1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, disable_btn)
|
233 |
return
|
234 |
-
time.sleep(0.01)
|
235 |
except requests.exceptions.RequestException as e:
|
236 |
state.messages[-1][-1] = server_error_msg
|
237 |
yield (state, None, None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, disable_btn)
|
|
|
231 |
|
232 |
yield (state, st[-1][1], state.to_gradio_svg_render()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, disable_btn)
|
233 |
return
|
|
|
234 |
except requests.exceptions.RequestException as e:
|
235 |
state.messages[-1][-1] = server_error_msg
|
236 |
yield (state, None, None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn, disable_btn, disable_btn)
|