Chanjeans commited on
Commit
0a5824a
ยท
verified ยท
1 Parent(s): 7e3cc4b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -35
app.py CHANGED
@@ -15,7 +15,7 @@ import uvicorn
15
  import gradio as gr
16
  from threading import Thread
17
  from fastapi.middleware.cors import CORSMiddleware
18
- from transformers import AutoTokenizer, AutoModelForCausalLM
19
  #####################################
20
  # 1) ์•ฑ ๋ฐ ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
21
  #####################################
@@ -1114,36 +1114,50 @@ def recommend_content_based(user_profile: dict, top_n=5):
1114
  #####################################
1115
  # 5) ์ฑ—๋ด‡ ๋กœ์ง
1116
  #####################################
1117
- tokenizer = AutoTokenizer.from_pretrained("Chanjeans/tfchatbot_2")
1118
- model = AutoModelForCausalLM.from_pretrained("Chanjeans/tfchatbot_2")
1119
- model.eval()
1120
- print("Model loaded successfully.")
1121
 
1122
- def chat_response(user_input, mode="emotion"):
1123
  if mode not in ["emotion", "rational"]:
1124
  raise HTTPException(status_code=400, detail="mode๋Š” 'emotion' ๋˜๋Š” 'rational'์ด์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.")
1125
 
1126
  prompt = f"<{mode}><usr>{user_input}</usr><sys>"
 
 
 
 
 
 
 
 
 
 
 
 
1127
 
1128
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
1129
-
1130
- with torch.no_grad():
1131
- outputs = model.generate(
1132
- **inputs,
1133
- max_new_tokens=128,
1134
- temperature=0.7,
1135
- top_p=0.9,
1136
- top_k=50,
1137
- repetition_penalty=1.2,
1138
- do_sample=True
1139
- )
1140
-
1141
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
1142
- # prompt ๋ถ€๋ถ„ ์ œ๊ฑฐ (๋ถˆํ•„์š”ํ•œ ํ”„๋กฌํ”„ํŠธ๊นŒ์ง€ ๋ฐ˜ํ™˜๋˜์ง€ ์•Š๋„๋ก)
1143
- response_text = generated_text.replace(prompt, "").strip()
1144
-
1145
- return response_text
1146
-
 
 
 
1147
 
1148
 
1149
  #์šฐ์šธ๋ถ„๋ฅ˜ ๋ชจ๋ธ ์ถ”๊ฐ€
@@ -1357,13 +1371,4 @@ def chat_or_recommend(req: ChatOrRecommendRequest):
1357
  if recommendation_msg:
1358
  response_dict["recommendations"] = recommendations_list
1359
 
1360
- return response_dict
1361
-
1362
-
1363
- #def run_fastapi():
1364
- # uvicorn.run(app, host="0.0.0.0", port=7860)
1365
-
1366
-
1367
- #if __name__ == "__main__":
1368
- # Thread(target=run_fastapi).start()
1369
- # iface.launch(server_name="0.0.0.0", server_port=7861)
 
15
  import gradio as gr
16
  from threading import Thread
17
  from fastapi.middleware.cors import CORSMiddleware
18
+
19
  #####################################
20
  # 1) ์•ฑ ๋ฐ ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
21
  #####################################
 
1114
  #####################################
1115
  # 5) ์ฑ—๋ด‡ ๋กœ์ง
1116
  #####################################
1117
+ HF_API_KEY = os.environ.get("HF_API_KEY", "YOUR_HF_API_KEY")
1118
+ API_URL = "https://api-inference.huggingface.co/models/Chanjeans/tfchatbot_2"
1119
+ HEADERS = {"Authorization": f"Bearer {HF_API_KEY}"}
 
1120
 
1121
+ def chat_response(user_input, mode="emotion", max_retries=5):
1122
  if mode not in ["emotion", "rational"]:
1123
  raise HTTPException(status_code=400, detail="mode๋Š” 'emotion' ๋˜๋Š” 'rational'์ด์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.")
1124
 
1125
  prompt = f"<{mode}><usr>{user_input}</usr><sys>"
1126
+ payload = {
1127
+ "inputs": prompt,
1128
+ "parameters": {
1129
+ "max_new_tokens": 128,
1130
+ "temperature": 0.7,
1131
+ "top_p": 0.9,
1132
+ "top_k": 50,
1133
+ "repetition_penalty": 1.2,
1134
+ "do_sample": True
1135
+ },
1136
+ "options": {"wait_for_model": True}
1137
+ }
1138
 
1139
+ for attempt in range(max_retries):
1140
+ response = requests.post(API_URL, headers=HEADERS, json=payload)
1141
+ if response.status_code == 200:
1142
+ try:
1143
+ result = response.json()
1144
+ if isinstance(result, list) and "generated_text" in result[0]:
1145
+ generated_text = result[0]["generated_text"]
1146
+ return generated_text.replace(prompt, "").strip()
1147
+ else:
1148
+ return "์‘๋‹ต ํ˜•์‹์ด ์˜ˆ์ƒ๊ณผ ๋‹ค๋ฆ…๋‹ˆ๋‹ค."
1149
+ except Exception as e:
1150
+ return f"JSON ํŒŒ์‹ฑ ์˜ค๋ฅ˜: {e}"
1151
+
1152
+ elif response.status_code == 503:
1153
+ # ๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘
1154
+ error_info = response.json()
1155
+ estimated_time = error_info.get("estimated_time", 15)
1156
+ time.sleep(min(estimated_time, 15))
1157
+ else:
1158
+ return f"API Error: {response.status_code}, {response.text}"
1159
+
1160
+ return "๐Ÿšจ ๋ชจ๋ธ ๋กœ๋”ฉ์ด ๋„ˆ๋ฌด ์˜ค๋ž˜ ๊ฑธ๋ฆฝ๋‹ˆ๋‹ค. ์ž ์‹œ ํ›„ ๋‹ค์‹œ ์‹œ๋„ํ•˜์„ธ์š”."
1161
 
1162
 
1163
  #์šฐ์šธ๋ถ„๋ฅ˜ ๋ชจ๋ธ ์ถ”๊ฐ€
 
1371
  if recommendation_msg:
1372
  response_dict["recommendations"] = recommendations_list
1373
 
1374
+ return response_dict