Update app.py
Browse files
app.py
CHANGED
@@ -1,21 +1,32 @@
|
|
|
|
1 |
import gradio as gr
|
2 |
from huggingface_hub import InferenceClient
|
3 |
from transformers import AutoTokenizer # Import the tokenizer
|
4 |
from langchain.memory import ConversationBufferMemory
|
5 |
from langchain.schema import HumanMessage, AIMessage
|
6 |
|
|
|
|
|
|
|
|
|
|
|
7 |
# Use the appropriate tokenizer for your model.
|
8 |
-
tokenizer = AutoTokenizer.from_pretrained("
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
-
# Define a maximum context length (tokens).
|
12 |
-
MAX_CONTEXT_LENGTH = 4096
|
13 |
|
14 |
-
# Read the default prompt from a file
|
15 |
with open("prompt.txt", "r") as file:
|
16 |
nvc_prompt_template = file.read()
|
17 |
|
18 |
-
# Initialize LangChain Conversation Memory
|
19 |
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
|
20 |
|
21 |
def count_tokens(text: str) -> int:
|
@@ -38,7 +49,7 @@ def truncate_memory(memory, system_message: str, max_length: int):
|
|
38 |
system_tokens = count_tokens(system_message)
|
39 |
current_length = system_tokens
|
40 |
|
41 |
-
# Iterate backwards through the memory (newest to oldest)
|
42 |
for msg in reversed(memory.chat_memory.messages):
|
43 |
tokens = count_tokens(msg.content)
|
44 |
if current_length + tokens <= max_length:
|
@@ -52,7 +63,7 @@ def truncate_memory(memory, system_message: str, max_length: int):
|
|
52 |
|
53 |
def respond(
|
54 |
message,
|
55 |
-
history: list[tuple[str, str]], # Required by Gradio but we now use LangChain memory
|
56 |
system_message,
|
57 |
max_tokens,
|
58 |
temperature,
|
@@ -83,13 +94,15 @@ def respond(
|
|
83 |
|
84 |
response = ""
|
85 |
try:
|
86 |
-
|
87 |
-
|
|
|
88 |
max_tokens=max_tokens,
|
89 |
stream=True,
|
90 |
temperature=temperature,
|
91 |
top_p=top_p,
|
92 |
-
)
|
|
|
93 |
token = chunk.choices[0].delta.content
|
94 |
response += token
|
95 |
yield response
|
|
|
1 |
+
import os
|
2 |
import gradio as gr
|
3 |
from huggingface_hub import InferenceClient
|
4 |
from transformers import AutoTokenizer # Import the tokenizer
|
5 |
from langchain.memory import ConversationBufferMemory
|
6 |
from langchain.schema import HumanMessage, AIMessage
|
7 |
|
8 |
+
# Load HF token from environment variables.
|
9 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
10 |
+
if not HF_TOKEN:
|
11 |
+
raise ValueError("HF_TOKEN environment variable not set")
|
12 |
+
|
13 |
# Use the appropriate tokenizer for your model.
|
14 |
+
tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-32B")
|
15 |
+
|
16 |
+
# Instantiate the client with the new inference mechanism.
|
17 |
+
client = InferenceClient(
|
18 |
+
provider="novita",
|
19 |
+
api_key=HF_TOKEN
|
20 |
+
)
|
21 |
|
22 |
+
# Define a maximum context length (tokens). Adjust this based on your model's requirements.
|
23 |
+
MAX_CONTEXT_LENGTH = 4096
|
24 |
|
25 |
+
# Read the default prompt from a file.
|
26 |
with open("prompt.txt", "r") as file:
|
27 |
nvc_prompt_template = file.read()
|
28 |
|
29 |
+
# Initialize LangChain Conversation Memory.
|
30 |
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
|
31 |
|
32 |
def count_tokens(text: str) -> int:
|
|
|
49 |
system_tokens = count_tokens(system_message)
|
50 |
current_length = system_tokens
|
51 |
|
52 |
+
# Iterate backwards through the memory (newest to oldest).
|
53 |
for msg in reversed(memory.chat_memory.messages):
|
54 |
tokens = count_tokens(msg.content)
|
55 |
if current_length + tokens <= max_length:
|
|
|
63 |
|
64 |
def respond(
|
65 |
message,
|
66 |
+
history: list[tuple[str, str]], # Required by Gradio but we now use LangChain memory.
|
67 |
system_message,
|
68 |
max_tokens,
|
69 |
temperature,
|
|
|
94 |
|
95 |
response = ""
|
96 |
try:
|
97 |
+
stream = client.chat.completions.create(
|
98 |
+
model="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
|
99 |
+
messages=messages,
|
100 |
max_tokens=max_tokens,
|
101 |
stream=True,
|
102 |
temperature=temperature,
|
103 |
top_p=top_p,
|
104 |
+
)
|
105 |
+
for chunk in stream:
|
106 |
token = chunk.choices[0].delta.content
|
107 |
response += token
|
108 |
yield response
|