Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
@@ -15,8 +15,8 @@ WEBSOCKET_URL = "wss://4d24-196-75-8-109.ngrok-free.app/ws"
|
|
15 |
|
16 |
class Item(BaseModel):
|
17 |
prompt: str
|
18 |
-
history: list
|
19 |
-
system_prompt: str
|
20 |
temperature: float = 0.0
|
21 |
max_new_tokens: int = 1048
|
22 |
top_p: float = 0.15
|
@@ -50,19 +50,29 @@ async def generate_stream(item: Item):
|
|
50 |
yield response.token.text
|
51 |
|
52 |
async def websocket_client():
|
|
|
53 |
while True:
|
54 |
try:
|
55 |
async with websockets.connect(WEBSOCKET_URL) as websocket:
|
56 |
logger.info("WebSocket connection established")
|
57 |
while True:
|
|
|
|
|
|
|
58 |
message = await websocket.recv()
|
59 |
data = json.loads(message)
|
60 |
logger.info(f"Received: {data}")
|
61 |
|
62 |
if data.get('type') == 'prompt':
|
|
|
63 |
prompt_id = data.get('id')
|
64 |
-
|
65 |
|
|
|
|
|
|
|
|
|
|
|
66 |
async for chunk in generate_stream(item):
|
67 |
await websocket.send(json.dumps({
|
68 |
"type": "chunk",
|
@@ -70,12 +80,18 @@ async def websocket_client():
|
|
70 |
"chunk": chunk
|
71 |
}))
|
72 |
logger.info(f"Sent chunk: {chunk}")
|
73 |
-
|
74 |
await websocket.send(json.dumps({
|
75 |
"type": "completed",
|
76 |
"id": prompt_id
|
77 |
}))
|
78 |
logger.info("Generation completed")
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
except websockets.exceptions.ConnectionClosed:
|
81 |
logger.error("WebSocket connection closed. Retrying in 5 seconds...")
|
|
|
15 |
|
16 |
class Item(BaseModel):
|
17 |
prompt: str
|
18 |
+
history: list = []
|
19 |
+
system_prompt: str = ""
|
20 |
temperature: float = 0.0
|
21 |
max_new_tokens: int = 1048
|
22 |
top_p: float = 0.15
|
|
|
50 |
yield response.token.text
|
51 |
|
52 |
async def websocket_client():
|
53 |
+
call_count = 0
|
54 |
while True:
|
55 |
try:
|
56 |
async with websockets.connect(WEBSOCKET_URL) as websocket:
|
57 |
logger.info("WebSocket connection established")
|
58 |
while True:
|
59 |
+
# Request a prompt
|
60 |
+
await websocket.send(json.dumps({"type": "getPrompt"}))
|
61 |
+
|
62 |
message = await websocket.recv()
|
63 |
data = json.loads(message)
|
64 |
logger.info(f"Received: {data}")
|
65 |
|
66 |
if data.get('type') == 'prompt':
|
67 |
+
call_count += 1
|
68 |
prompt_id = data.get('id')
|
69 |
+
prompt_text = data.get('prompt')
|
70 |
|
71 |
+
logger.info(f"Processing prompt: {prompt_text}")
|
72 |
+
logger.info(f"Call count: {call_count}")
|
73 |
+
|
74 |
+
item = Item(prompt=prompt_text)
|
75 |
+
|
76 |
async for chunk in generate_stream(item):
|
77 |
await websocket.send(json.dumps({
|
78 |
"type": "chunk",
|
|
|
80 |
"chunk": chunk
|
81 |
}))
|
82 |
logger.info(f"Sent chunk: {chunk}")
|
83 |
+
|
84 |
await websocket.send(json.dumps({
|
85 |
"type": "completed",
|
86 |
"id": prompt_id
|
87 |
}))
|
88 |
logger.info("Generation completed")
|
89 |
+
elif data.get('type') == 'error':
|
90 |
+
logger.info(f"Received error: {data.get('message')}")
|
91 |
+
# Wait a bit before requesting a new prompt
|
92 |
+
await asyncio.sleep(5)
|
93 |
+
else:
|
94 |
+
logger.info(f"Received unexpected message type: {data.get('type')}")
|
95 |
|
96 |
except websockets.exceptions.ConnectionClosed:
|
97 |
logger.error("WebSocket connection closed. Retrying in 5 seconds...")
|