Upload wkv.py
Browse files- Trained_20G/wkv.py +7 -4
Trained_20G/wkv.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import torch
|
2 |
from einops import rearrange
|
3 |
|
@@ -41,6 +42,9 @@ else:
|
|
41 |
def compile_decorator(func):
|
42 |
return func
|
43 |
|
|
|
|
|
|
|
44 |
|
45 |
class Rwkv_Tmix_x070(nn.Module):
|
46 |
def __init__(self, args: RwkvHybridConfig, layer_id, **kwargs):
|
@@ -205,23 +209,22 @@ class Rwkv_Tmix_x070(nn.Module):
|
|
205 |
output_final_state,
|
206 |
cu_seqlens
|
207 |
):
|
208 |
-
if
|
209 |
r, w, k, v, a, b = map(lambda x: rearrange(
|
210 |
x, 'b l (h d) -> b h l d', h=self.n_head), (r, w, k, v, a, b))
|
211 |
o, state = native_recurrent_rwkv7(
|
212 |
r=r, k=k, v=v, w=w,
|
213 |
a=a, b=b,
|
214 |
scale=1.0,
|
215 |
-
initial_state=s
|
216 |
output_final_state=True,
|
217 |
head_first=True,
|
218 |
)
|
219 |
-
state = state.transpose(-1, -2)
|
220 |
x = rearrange(o, "b h l d -> b l (h d)")
|
221 |
else:
|
222 |
r, w, k, v, a, b = map(lambda x: rearrange(
|
223 |
x, 'b l (h d) -> b l h d', h=self.n_head), (r, w, k, v, a, b))
|
224 |
-
wkv7_func = chunk_rwkv7 if
|
225 |
o, state = wkv7_func(
|
226 |
r=r, k=k, v=v, w=w,
|
227 |
a=a, b=b,
|
|
|
1 |
+
import os
|
2 |
import torch
|
3 |
from einops import rearrange
|
4 |
|
|
|
42 |
def compile_decorator(func):
|
43 |
return func
|
44 |
|
45 |
+
wkv_mode = os.environ.get("WKV_MODE", "fused")
|
46 |
+
wkv_mode = wkv_mode.lower()
|
47 |
+
assert wkv_mode in ['fused', 'chunk', 'pytorch']
|
48 |
|
49 |
class Rwkv_Tmix_x070(nn.Module):
|
50 |
def __init__(self, args: RwkvHybridConfig, layer_id, **kwargs):
|
|
|
209 |
output_final_state,
|
210 |
cu_seqlens
|
211 |
):
|
212 |
+
if wkv_mode == 'pytorch':
|
213 |
r, w, k, v, a, b = map(lambda x: rearrange(
|
214 |
x, 'b l (h d) -> b h l d', h=self.n_head), (r, w, k, v, a, b))
|
215 |
o, state = native_recurrent_rwkv7(
|
216 |
r=r, k=k, v=v, w=w,
|
217 |
a=a, b=b,
|
218 |
scale=1.0,
|
219 |
+
initial_state=s,
|
220 |
output_final_state=True,
|
221 |
head_first=True,
|
222 |
)
|
|
|
223 |
x = rearrange(o, "b h l d -> b l (h d)")
|
224 |
else:
|
225 |
r, w, k, v, a, b = map(lambda x: rearrange(
|
226 |
x, 'b l (h d) -> b l h d', h=self.n_head), (r, w, k, v, a, b))
|
227 |
+
wkv7_func = chunk_rwkv7 if wkv_mode == 'chunk' else fused_recurrent_rwkv7
|
228 |
o, state = wkv7_func(
|
229 |
r=r, k=k, v=v, w=w,
|
230 |
a=a, b=b,
|