zhiyuan8 commited on
Commit
678c4e9
·
verified ·
1 Parent(s): e05e922

Upload wkv.py

Browse files
Files changed (1) hide show
  1. Trained_20G/wkv.py +7 -4
Trained_20G/wkv.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import torch
2
  from einops import rearrange
3
 
@@ -41,6 +42,9 @@ else:
41
  def compile_decorator(func):
42
  return func
43
 
 
 
 
44
 
45
  class Rwkv_Tmix_x070(nn.Module):
46
  def __init__(self, args: RwkvHybridConfig, layer_id, **kwargs):
@@ -205,23 +209,22 @@ class Rwkv_Tmix_x070(nn.Module):
205
  output_final_state,
206
  cu_seqlens
207
  ):
208
- if r.device.type == "cpu":
209
  r, w, k, v, a, b = map(lambda x: rearrange(
210
  x, 'b l (h d) -> b h l d', h=self.n_head), (r, w, k, v, a, b))
211
  o, state = native_recurrent_rwkv7(
212
  r=r, k=k, v=v, w=w,
213
  a=a, b=b,
214
  scale=1.0,
215
- initial_state=s.transpose(-1, -2),
216
  output_final_state=True,
217
  head_first=True,
218
  )
219
- state = state.transpose(-1, -2)
220
  x = rearrange(o, "b h l d -> b l (h d)")
221
  else:
222
  r, w, k, v, a, b = map(lambda x: rearrange(
223
  x, 'b l (h d) -> b l h d', h=self.n_head), (r, w, k, v, a, b))
224
- wkv7_func = chunk_rwkv7 if r.shape[1] != 1 else fused_recurrent_rwkv7
225
  o, state = wkv7_func(
226
  r=r, k=k, v=v, w=w,
227
  a=a, b=b,
 
1
+ import os
2
  import torch
3
  from einops import rearrange
4
 
 
42
  def compile_decorator(func):
43
  return func
44
 
45
+ wkv_mode = os.environ.get("WKV_MODE", "fused")
46
+ wkv_mode = wkv_mode.lower()
47
+ assert wkv_mode in ['fused', 'chunk', 'pytorch']
48
 
49
  class Rwkv_Tmix_x070(nn.Module):
50
  def __init__(self, args: RwkvHybridConfig, layer_id, **kwargs):
 
209
  output_final_state,
210
  cu_seqlens
211
  ):
212
+ if wkv_mode == 'pytorch':
213
  r, w, k, v, a, b = map(lambda x: rearrange(
214
  x, 'b l (h d) -> b h l d', h=self.n_head), (r, w, k, v, a, b))
215
  o, state = native_recurrent_rwkv7(
216
  r=r, k=k, v=v, w=w,
217
  a=a, b=b,
218
  scale=1.0,
219
+ initial_state=s,
220
  output_final_state=True,
221
  head_first=True,
222
  )
 
223
  x = rearrange(o, "b h l d -> b l (h d)")
224
  else:
225
  r, w, k, v, a, b = map(lambda x: rearrange(
226
  x, 'b l (h d) -> b l h d', h=self.n_head), (r, w, k, v, a, b))
227
+ wkv7_func = chunk_rwkv7 if wkv_mode == 'chunk' else fused_recurrent_rwkv7
228
  o, state = wkv7_func(
229
  r=r, k=k, v=v, w=w,
230
  a=a, b=b,