Structured State SpaceS (S4)State Space Models (SSM) mechanicsStructured State SpaceS (S4) for long-term dependenciesS4: a special state space modelS4 variants and simplificationsModeling signals with S4Selective State Space ModelMambaLinks
Structured State SpaceS (S4)
State Space Models (SSM) mechanics
![notion image](https://www.notion.so/image/https%3A%2F%2Fprod-files-secure.s3.us-west-2.amazonaws.com%2F67ce5aa4-f53a-4833-a23e-ad62e6562f0a%2Faa5e92a8-239d-4e1e-a218-47b0ad817d62%2FUntitled.png?table=block&id=d5283cf4-195d-4e0d-b24d-4d4f5ea790cf&t=d5283cf4-195d-4e0d-b24d-4d4f5ea790cf&width=528&cache=v2)
![notion image](https://www.notion.so/image/https%3A%2F%2Fprod-files-secure.s3.us-west-2.amazonaws.com%2F67ce5aa4-f53a-4833-a23e-ad62e6562f0a%2F910804eb-71a9-4f4e-8760-baf88b6df7ab%2FUntitled.png?table=block&id=e2cbc1d5-8a3a-492e-878a-7dd67ca86d10&t=e2cbc1d5-8a3a-492e-878a-7dd67ca86d10&width=384&cache=v2)
相较于Linear RNN的到下一个state,SSM得到的导数更具有连续性(可以设置来控制步长)
transformer可以在训练并行,但推理时还是要串行;RNN训练时要将序列逐个加入,推理时可以并行(看头一个);SSM将RNN+Transformer
![notion image](https://www.notion.so/image/https%3A%2F%2Fprod-files-secure.s3.us-west-2.amazonaws.com%2F67ce5aa4-f53a-4833-a23e-ad62e6562f0a%2Fee2c9179-61d7-47dd-85f6-a8233ba70d87%2FUntitled.png?table=block&id=844af280-30dd-4954-8612-4a0d982fe4bb&t=844af280-30dd-4954-8612-4a0d982fe4bb&width=708&cache=v2)
![notion image](https://www.notion.so/image/https%3A%2F%2Fprod-files-secure.s3.us-west-2.amazonaws.com%2F67ce5aa4-f53a-4833-a23e-ad62e6562f0a%2F6db3a86f-0378-4c95-9158-fd1cf7f855a4%2FUntitled.png?table=block&id=0b472481-ce47-497d-a521-e87361701b2a&t=0b472481-ce47-497d-a521-e87361701b2a&width=1920&cache=v2)
主要在于continuous(由定义决定的)、recurrent(可以化成递归的形式进行计算)、convolutional(可以使用卷积的方式并行加速计算)
![notion image](https://www.notion.so/image/https%3A%2F%2Fprod-files-secure.s3.us-west-2.amazonaws.com%2F67ce5aa4-f53a-4833-a23e-ad62e6562f0a%2F24b1e64e-61fd-42c5-a002-20a4bc8424eb%2FUntitled.png?table=block&id=37f23ebb-c15b-4cab-81f8-3f37bdf43f0d&t=37f23ebb-c15b-4cab-81f8-3f37bdf43f0d&width=1621&cache=v2)
![notion image](https://www.notion.so/image/https%3A%2F%2Fprod-files-secure.s3.us-west-2.amazonaws.com%2F67ce5aa4-f53a-4833-a23e-ad62e6562f0a%2Fe9443de0-15f6-4a02-b9cd-f59775036005%2FUntitled.png?table=block&id=5925d117-9171-467d-9f9c-8d9b51f69258&t=5925d117-9171-467d-9f9c-8d9b51f69258&width=314.5&cache=v2)
![notion image](https://www.notion.so/image/https%3A%2F%2Fprod-files-secure.s3.us-west-2.amazonaws.com%2F67ce5aa4-f53a-4833-a23e-ad62e6562f0a%2Ffb960fb8-f848-4dbf-870a-6392699a821e%2FUntitled.png?table=block&id=e51916cc-6190-473c-8d23-e73b523402fc&t=e51916cc-6190-473c-8d23-e73b523402fc&width=314.5&cache=v2)
可以构造一个卷积核K(对于L固定来说K也是definite的),使得u不经过x便可以算出y,相较于RNN的递归式,如此可以通过parallelizable + near-linear computation的方式算出y(可以通过FFN来计算卷积,计算快速)
Structured State SpaceS (S4) for long-term dependencies
S4: a special state space model
SSM + HiPPO + Structured Matrices = S4
SSM在计算时slow:![notion image](https://www.notion.so/image/https%3A%2F%2Fprod-files-secure.s3.us-west-2.amazonaws.com%2F67ce5aa4-f53a-4833-a23e-ad62e6562f0a%2Fdcb873c4-788e-4943-8206-94849e9469d3%2FUntitled.png?table=block&id=e410eafa-0e1c-4166-a222-13e6986b5fd7&t=e410eafa-0e1c-4166-a222-13e6986b5fd7&width=432&cache=v2)
![notion image](https://www.notion.so/image/https%3A%2F%2Fprod-files-secure.s3.us-west-2.amazonaws.com%2F67ce5aa4-f53a-4833-a23e-ad62e6562f0a%2Fdcb873c4-788e-4943-8206-94849e9469d3%2FUntitled.png?table=block&id=e410eafa-0e1c-4166-a222-13e6986b5fd7&t=e410eafa-0e1c-4166-a222-13e6986b5fd7&width=432&cache=v2)
- hidden state x的维度是input u的k倍,计算时用x做运算(占用更多空间,k倍)
- 计算卷积核K更加昂贵
这时,使用structed state space可以有效地解决此问题
![notion image](https://www.notion.so/image/https%3A%2F%2Fprod-files-secure.s3.us-west-2.amazonaws.com%2F67ce5aa4-f53a-4833-a23e-ad62e6562f0a%2F4b708845-acb8-43e5-8cfa-f71ccee0d078%2FUntitled.png?table=block&id=679bd413-d41b-4165-a6f3-0c66788026a8&t=679bd413-d41b-4165-a6f3-0c66788026a8&width=432&cache=v2)
S4 variants and simplifications
HIPPO可以用于online function reconstruction(实时根据x生成对应的u ),同时HIPPO的设置也能避免vanilla RNN的梯度爆炸问题,可以使用S4D核技巧,将A换成近似对角矩阵来解决
![notion image](https://www.notion.so/image/https%3A%2F%2Fprod-files-secure.s3.us-west-2.amazonaws.com%2F67ce5aa4-f53a-4833-a23e-ad62e6562f0a%2F6704ef2e-b8a9-47c8-9f30-7f15262b492a%2FUntitled.png?table=block&id=f556fc2c-91ed-4574-94dd-579b3c09f328&t=f556fc2c-91ed-4574-94dd-579b3c09f328&width=432&cache=v2)
Modeling signals with S4
![notion image](https://www.notion.so/image/https%3A%2F%2Fprod-files-secure.s3.us-west-2.amazonaws.com%2F67ce5aa4-f53a-4833-a23e-ad62e6562f0a%2Fcb348f67-6cb1-420b-88c1-469c724227d2%2FUntitled.png?table=block&id=60a2f504-582e-40b8-8096-99015e7559ae&t=60a2f504-582e-40b8-8096-99015e7559ae&width=528&cache=v2)
S4比WaveNet对语音建模更好的地方在于其inductive bias,可以对于continuous signal有更好的捕捉,同时可以捕获大量的上下文信息
![notion image](https://www.notion.so/image/https%3A%2F%2Fprod-files-secure.s3.us-west-2.amazonaws.com%2F67ce5aa4-f53a-4833-a23e-ad62e6562f0a%2F1d3344cf-a733-4819-a565-f85cb6b2ccc4%2FUntitled.png?table=block&id=a6163f42-4c1a-4f1b-99e5-4fcafdeea9ec&t=a6163f42-4c1a-4f1b-99e5-4fcafdeea9ec&width=672&cache=v2)
Selective State Space Model
![notion image](https://www.notion.so/image/https%3A%2F%2Fprod-files-secure.s3.us-west-2.amazonaws.com%2F67ce5aa4-f53a-4833-a23e-ad62e6562f0a%2F4b2fd7ec-ce06-4dd6-8508-415737226c75%2FUntitled.png?table=block&id=9ee2a439-063b-4d10-af09-0c4cb6feb588&t=9ee2a439-063b-4d10-af09-0c4cb6feb588&width=576&cache=v2)
相较于SSMs的对于每一个输入(x)都是一样的处理,Selective SSM可以考虑到每一个输入的重要程度(即增加input-dependent control,类似于attention)
![notion image](https://www.notion.so/image/https%3A%2F%2Fprod-files-secure.s3.us-west-2.amazonaws.com%2F67ce5aa4-f53a-4833-a23e-ad62e6562f0a%2F40fef86d-cfae-46e5-9947-8f7ea041592d%2FUntitled.png?table=block&id=6151361e-27ed-4b03-83c1-62d8988cd105&t=6151361e-27ed-4b03-83c1-62d8988cd105&width=480&cache=v2)
![notion image](https://www.notion.so/image/https%3A%2F%2Fprod-files-secure.s3.us-west-2.amazonaws.com%2F67ce5aa4-f53a-4833-a23e-ad62e6562f0a%2F017ff8e2-aeda-47a5-81cc-aba94a1a69d8%2FUntitled.png?table=block&id=abbfcea2-13ba-4422-81d8-c149125f3155&t=abbfcea2-13ba-4422-81d8-c149125f3155&width=1920&cache=v2)
使用SCAN算法(all-prefix-sums)来计算对于input-dependent control的处理,使得计算更加正确
同时,使用纯pytorch的SCAN,计算速度还是很慢,故提出了下面的方法(mamba考虑到了硬件上的加速)
![notion image](https://www.notion.so/image/https%3A%2F%2Fprod-files-secure.s3.us-west-2.amazonaws.com%2F67ce5aa4-f53a-4833-a23e-ad62e6562f0a%2Fb8821f6c-f886-4dad-b845-d6aa7d8dc5c8%2FUntitled.png?table=block&id=ebcbbefe-0b5d-419b-9567-edd92f966033&t=ebcbbefe-0b5d-419b-9567-edd92f966033&width=1920&cache=v2)
Mamba
![notion image](https://www.notion.so/image/https%3A%2F%2Fprod-files-secure.s3.us-west-2.amazonaws.com%2F67ce5aa4-f53a-4833-a23e-ad62e6562f0a%2F28081c95-fe80-4836-af4c-1dab644b515c%2FUntitled.png?table=block&id=b7734ec2-ea5b-4395-b058-36258caab055&t=b7734ec2-ea5b-4395-b058-36258caab055&width=708&cache=v2)
![notion image](https://www.notion.so/image/https%3A%2F%2Fprod-files-secure.s3.us-west-2.amazonaws.com%2F67ce5aa4-f53a-4833-a23e-ad62e6562f0a%2F88714bfd-1567-4105-bb04-56e4a7201052%2FUntitled.png?table=block&id=023a8ccf-26c3-432e-9ea4-95d27b16ab29&t=023a8ccf-26c3-432e-9ea4-95d27b16ab29&width=432&cache=v2)
输出的x是由SSMs的输出(得到了之前的token融合到的)以及当前的embedding的积
Links
MambaStock
zshicode • Updated May 8, 2024
transformer在处理discrete、short data时相较于S4处理得更好;作为continuous model,其在continuous, need lot of context时处理得更好
![u是input,x是hidden,y是output](https://www.notion.so/image/https%3A%2F%2Fprod-files-secure.s3.us-west-2.amazonaws.com%2F67ce5aa4-f53a-4833-a23e-ad62e6562f0a%2F65d3c046-ead1-43da-a393-dd46ab5b6b88%2FUntitled.png?table=block&id=85072a01-5c09-4e73-8192-18bca4e5bb97&t=85072a01-5c09-4e73-8192-18bca4e5bb97&width=336&cache=v2)
还是这个讲的更清楚一点