Jamiiwej2903 commited on
Commit
2c26da0
·
verified ·
1 Parent(s): ef47b5e

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +20 -4
main.py CHANGED
@@ -15,8 +15,8 @@ WEBSOCKET_URL = "wss://4d24-196-75-8-109.ngrok-free.app/ws"
15
 
16
  class Item(BaseModel):
17
  prompt: str
18
- history: list
19
- system_prompt: str
20
  temperature: float = 0.0
21
  max_new_tokens: int = 1048
22
  top_p: float = 0.15
@@ -50,19 +50,29 @@ async def generate_stream(item: Item):
50
  yield response.token.text
51
 
52
  async def websocket_client():
 
53
  while True:
54
  try:
55
  async with websockets.connect(WEBSOCKET_URL) as websocket:
56
  logger.info("WebSocket connection established")
57
  while True:
 
 
 
58
  message = await websocket.recv()
59
  data = json.loads(message)
60
  logger.info(f"Received: {data}")
61
 
62
  if data.get('type') == 'prompt':
 
63
  prompt_id = data.get('id')
64
- item = Item(**data['prompt'])
65
 
 
 
 
 
 
66
  async for chunk in generate_stream(item):
67
  await websocket.send(json.dumps({
68
  "type": "chunk",
@@ -70,12 +80,18 @@ async def websocket_client():
70
  "chunk": chunk
71
  }))
72
  logger.info(f"Sent chunk: {chunk}")
73
-
74
  await websocket.send(json.dumps({
75
  "type": "completed",
76
  "id": prompt_id
77
  }))
78
  logger.info("Generation completed")
 
 
 
 
 
 
79
 
80
  except websockets.exceptions.ConnectionClosed:
81
  logger.error("WebSocket connection closed. Retrying in 5 seconds...")
 
15
 
16
  class Item(BaseModel):
17
  prompt: str
18
+ history: list = []
19
+ system_prompt: str = ""
20
  temperature: float = 0.0
21
  max_new_tokens: int = 1048
22
  top_p: float = 0.15
 
50
  yield response.token.text
51
 
52
  async def websocket_client():
53
+ call_count = 0
54
  while True:
55
  try:
56
  async with websockets.connect(WEBSOCKET_URL) as websocket:
57
  logger.info("WebSocket connection established")
58
  while True:
59
+ # Request a prompt
60
+ await websocket.send(json.dumps({"type": "getPrompt"}))
61
+
62
  message = await websocket.recv()
63
  data = json.loads(message)
64
  logger.info(f"Received: {data}")
65
 
66
  if data.get('type') == 'prompt':
67
+ call_count += 1
68
  prompt_id = data.get('id')
69
+ prompt_text = data.get('prompt')
70
 
71
+ logger.info(f"Processing prompt: {prompt_text}")
72
+ logger.info(f"Call count: {call_count}")
73
+
74
+ item = Item(prompt=prompt_text)
75
+
76
  async for chunk in generate_stream(item):
77
  await websocket.send(json.dumps({
78
  "type": "chunk",
 
80
  "chunk": chunk
81
  }))
82
  logger.info(f"Sent chunk: {chunk}")
83
+
84
  await websocket.send(json.dumps({
85
  "type": "completed",
86
  "id": prompt_id
87
  }))
88
  logger.info("Generation completed")
89
+ elif data.get('type') == 'error':
90
+ logger.info(f"Received error: {data.get('message')}")
91
+ # Wait a bit before requesting a new prompt
92
+ await asyncio.sleep(5)
93
+ else:
94
+ logger.info(f"Received unexpected message type: {data.get('type')}")
95
 
96
  except websockets.exceptions.ConnectionClosed:
97
  logger.error("WebSocket connection closed. Retrying in 5 seconds...")