<?xml version="1.0" encoding="utf-8"?>
<feed xmlns="http://www.w3.org/2005/Atom"><author><name>Isak Falk</name></author>
<title>Notes of Isak Falk</title>
<generator>Emacs webfeeder.el</generator>
<link href="https://isakfalk.com/"/>
<link href="https://isakfalk.com/atom.xml" rel="self"/>
<id>https://isakfalk.com/atom.xml</id>
<updated>2025-12-28T21:51:24-05:00</updated>
<entry>
  <title>NeurIPS 2023 Retrospect</title>
  <author><name>Isak Falk</name></author>
  <content type="html"><![CDATA[<main id="content" class="content"> <header> <h1 class="title">NeurIPS 2023 Retrospect</h1>
</header> <nav id="table-of-contents" role="doc-toc"> <details> <summary>Table of Contents</summary> <div id="text-table-of-contents" role="doc-toc">
 <ul> <li> <a href="#introduction">Introduction</a></li>
 <li> <a href="#mekrr">MEKRR</a></li>
 <li> <a href="#expo-hall-and-companies">Expo hall and companies</a>
 <ul> <li> <a href="#what-i-would-have-done-differently">What I would have done differently</a></li>
</ul></li>
 <li> <a href="#meetings-and-parties">Meetings and parties</a></li>
 <li> <a href="#the-actual-research-conference">The actual research conference</a>
 <ul> <li> <a href="#mekrr-poster">MEKRR Poster</a></li>
</ul></li>
</ul></div>
</details></nav> <div id="outline-container-introduction" class="outline-2">
 <h2 id="introduction"> <a href="#introduction">Introduction</a></h2>
 <div class="outline-text-2" id="text-introduction">
 <p>
Thus another NeurIPS has ended. This is the third one for me, and for every one
I've attended I feel it has increased in size. I reflect on my experience this
time around. Some context: I have just finished my PhD and I'm on the job market
looking for a job, ideally ML research in industry. While the main reason I went
is to present my paper, I also attended in order to see what the job market is
currently like and make some connections with companies and researchers in
industry and academia.
</p>

 <p>
I presented  <a href="#mekrr">our approach</a> (which we call MEKRR, "maker") on how to successfully
transfer pretrained GNNs for learning to predict energies of atomistic system.
This combines GNN feature representations with kernel mean embeddings and ridge
regression. Do have a look at the  <a href="https://arxiv.org/abs/2306.01589">arxiv paper</a> (will update this link with the
conference paper once it's been made available in the official NeurIPS
proceedings) or the  <a href="https://github.com/IsakFalk/atomistic_transfer_mekrr">code base</a> :)
</p>
</div>
</div>
 <div id="outline-container-mekrr" class="outline-2">
 <h2 id="mekrr"> <a href="#mekrr">MEKRR</a></h2>
 <div class="outline-text-2" id="text-mekrr">
 <p>
I will write more about MEKRR in a separate note and will link it here later.
The gist of it is that using learned features from GNNs on trained an upstream
dataset and using KRR on these features, together with some kernel tricks for
dealing with sets / point clouds, works  <strong>really</strong> well. Will kernels make a
come-back?  <em>Probably not</em>. Can kernels improve performance on small to medium
size datasets when given a strong feature map when compared to fine-tuning?  <em>Probably yes</em>!
</p>
</div>
</div>
 <div id="outline-container-expo-hall-and-companies" class="outline-2">
 <h2 id="expo-hall-and-companies"> <a href="#expo-hall-and-companies">Expo hall and companies</a></h2>
 <div class="outline-text-2" id="text-expo-hall-and-companies">
 <p>
The expo hall was bustling. My sense is that the companies can be categorized
into
</p>
 <ol class="org-ol"> <li>big Tech such as MAMAA (previously known as FAANG) and some older tech
companies such as IBM,</li>
 <li>trading companies such as Jane Street or DE Shaw,</li>
 <li>lots of smaller companies serving LLMs and other models as a service, or
speeding up inference using quantization and other postprocessing techniques,</li>
 <li>the rest which includes peripheral companies using ML (Sony, Disney), biology
/ drugs / medicine and publishing houses.</li>
</ol> <p>
In general it was a pretty good place to get in touch with companies. I made
some really good genuine contacts which I will cherish whether it'll lead to job
or not. I disliked talking to recruiters which in the end just forward me to the
general recruitment page of their companies, feels like a waste of time on both
their and my part.
</p>
</div>
 <div id="outline-container-what-i-would-have-done-differently" class="outline-3">
 <h3 id="what-i-would-have-done-differently"> <a href="#what-i-would-have-done-differently">What I would have done differently</a></h3>
 <div class="outline-text-3" id="text-what-i-would-have-done-differently">
 <ul class="org-ul"> <li>Go through all booths on the expo day instead of spacing it out over the week</li>
 <li>Get over fear of talking and engage with people instead of circling around
wasting time
 <ul class="org-ul"> <li>Tip: Use the booths which you are not really interested in or have low
engagement to warm up. It's fun to see what people are up to and they
probably enjoy people actually talking to them.</li>
</ul></li>
</ul> <p>
I spent a couple of hours each day going through the hall in a systematic manner
(my wife remarked that I had so much good swag, mostly really high-quality
socks, which can be explained simply by going through  <strong>almost every</strong> booth…).
I think this is worthwhile but drains your energy and takes up a lot of time.
Looking back I wish I just had done this during the Sunday when I arrived so
that I could have focused on the posters and talks on the other day. This time I
didn't spend much time at all on the research part of the conference which I
slightly regret. But I met with the companies and engaged in some way or another
with about 80% of the booths which I count as a success.
</p>

 <p>
Another part is getting over yourself and just talk to people. At the start of
the conference I was quite shy and wondering why people would want to engage
with me, but honestly, the whole reason for these companies to be present at
NeurIPS is to talk to the attendants (even if the chat does not actually lead
anywhere) at least out of courtesy. I should work on overcome this fear and just
throw myself out there. Talking to low-stakes companies first help to get over
this barrier I felt.
</p>
</div>
</div>
</div>
 <div id="outline-container-meetings-and-parties" class="outline-2">
 <h2 id="meetings-and-parties"> <a href="#meetings-and-parties">Meetings and parties</a></h2>
 <div class="outline-text-2" id="text-meetings-and-parties">
 <p>
I went to a couple of parties which was great. Good time to reconnect with
people I haven't seen in a while and connect with new people. One party was
thrown by one of the UK initiatives for AI Safety and it was interesting
chatting about the state of things. Seemed like the onus was on first making
people aware of the problem and in what ways it can be approached and
potentially solved, working similarly to a think-tank. After this party ended we
went to the Cohere party which was thrown in this amazing multi-layered
building, really cool party, would go again.
</p>

 <p>
While the above party was more general and open, I also went to an open bar
hosted by Imbue. This was super cozy and intimate. I made some great contacts
there and spoke to many of the people in the team. The party was hosted outside
of the usual district which probably made it so less people attended, but on the
other hand, the people there really were there for a reason. I thoroughly
enjoyed myself, one of the highlights of the conference for me.
</p>
</div>
</div>
 <div id="outline-container-the-actual-research-conference" class="outline-2">
 <h2 id="the-actual-research-conference"> <a href="#the-actual-research-conference">The actual research conference</a></h2>
 <div class="outline-text-2" id="text-the-actual-research-conference">
 <p>
LLMs where everywhere. I don't think this should have come as a surprise to
anyone. In one sense I feel like ChatGPT and it's kind has been the first to
actually deliver on the promise of AI to the consumer (at least during my
lifetime, would be interested to here contrasting viewpoints) so it's only
natural that research will tail this development as we have a pretty poor
understanding on what actually goes on inside of these models. On the other hand
there's the question of what the role of academia and publishing should be, to
always cater to industry or to do high-risk research which enables the Next Big
Thing (TM)? Pretty hard to do this kind of research when a lot of it comes down
to compute.
</p>
</div>
 <div id="outline-container-mekrr-poster" class="outline-3">
 <h3 id="mekrr-poster"> <a href="#mekrr-poster">MEKRR Poster</a></h3>
 <div class="outline-text-3" id="text-mekrr-poster">
 <p>
Presenting my poster was a blast. We had a lot of activity and I had great
feedback. Someone even came and wanted a selfie with me and the poster! This is
the highest flattery I've ever received in my academic career by far. There were
some senior researchers that found this work interesting and really engaged with
me. This cemented my confidence in this line of work and I hope that others will
continue investigating how kernels can fit into a neural network world.
</p>
</div>
</div>
</div>
</main>]]></content>
  <link href="https://isakfalk.com/notes/NeurIPS2023retrospect.html"/>
  <id>https://isakfalk.com/notes/NeurIPS2023retrospect.html</id>
  <updated>2025-12-28T21:51:00-05:00</updated>
</entry>
<entry>
  <title>RC Half-batch Retrospect</title>
  <author><name>Isak Falk</name></author>
  <content type="html"><![CDATA[<main id="content" class="content"> <header> <h1 class="title">RC Half-batch Retrospect</h1>
</header> <nav id="table-of-contents" role="doc-toc"> <details> <summary>Table of Contents</summary> <div id="text-table-of-contents" role="doc-toc">
 <ul> <li> <a href="#morning-me-is-very-ambitious">Morning me is very ambitious</a></li>
 <li> <a href="#reflecting-on-inperson-reflection">Reflecting on in-person reflection</a></li>
 <li> <a href="#less-social-meetings-more-coding">Less social meetings, more coding</a></li>
 <li> <a href="#finished-setting-up-my-workstation">Finished setting up my workstation</a></li>
</ul></div>
</details></nav> <p>
Since I've been lazy and not written update for RC Week 5 and 6, I decided to
ease the burden and roll them both into a half-batch retrospect instead so that
I can bring myself to actually write it and reduce friction-to-write.
</p>
 <div id="outline-container-morning-me-is-very-ambitious" class="outline-2">
 <h2 id="morning-me-is-very-ambitious"> <a href="#morning-me-is-very-ambitious">Morning me is very ambitious</a></h2>
 <div class="outline-text-2" id="text-morning-me-is-very-ambitious">
 <p>
When I plan my days I have a lot of tasks that I would like to get finished,
alas, at the end of the day, most of these tasks remain unfinished. I need to be
more realistic. Additionally, I should do some meta-planning where I use certain
days for certain related tasks instead of having many unrelated tasks (e.g.
mixing admin with writing code and reading papers). This is something that I
continue to work on.
</p>
</div>
</div>
 <div id="outline-container-reflecting-on-inperson-reflection" class="outline-2">
 <h2 id="reflecting-on-inperson-reflection"> <a href="#reflecting-on-inperson-reflection">Reflecting on in-person reflection</a></h2>
 <div class="outline-text-2" id="text-reflecting-on-inperson-reflection">
 <p>
I had a chat with some of the RC admin about how it has gone so far. It was
super chill but very insightful. While some things have not gone as I wanted
(does anything ever go according to plan? Don't think so!) it was reassuring to
get some outside point of views which also highlighted what I did accomplish,
like setting up this website and the more meta-skills of building volitional
muscles etc.
</p>
</div>
</div>
 <div id="outline-container-less-social-meetings-more-coding" class="outline-2">
 <h2 id="less-social-meetings-more-coding"> <a href="#less-social-meetings-more-coding">Less social meetings, more coding</a></h2>
 <div class="outline-text-2" id="text-less-social-meetings-more-coding">
 <p>
The first sentence of my Week 6, Day 5 reads
</p>
 <p class="verse">
Removed a lot of coffee chat slots, limiting them to just Fridays from now on. Think I need to go more into deep work mode. <br></br></p>
 <p>
I think this sets the tone for the later part of the batch for me. The social
aspect of RC has been great fun and I've made loads of friends and connections,
but it takes energy and janks you out of your flow doing all of these scheduled
coffee chats. Once per week will be enough from here on.
</p>

 <p>
I will focus less on the social aspects of RC and more on getting
better at coding. I want to finish my  <a href="RC-week-4.html#machine-learning-project">project</a> before the end of batch, and if I
do it before that I want to maybe try to create a Jax package. Here's to the next year and the next RC batch!
</p>
</div>
</div>
 <div id="outline-container-finished-setting-up-my-workstation" class="outline-2">
 <h2 id="finished-setting-up-my-workstation"> <a href="#finished-setting-up-my-workstation">Finished setting up my workstation</a></h2>
 <div class="outline-text-2" id="text-finished-setting-up-my-workstation">
 <p>
I managed to get my workstation set up properly. Now I have a beefy gpu which I
can use on-demand from my laptop. I also went a bit crazy and set up a VPN so
that I can use jupyter together with emacs. For this I relied a lot on  <a href="https://martibosch.github.io/jupyter-emacs-universe/">this blog
post on emacs and jupyter</a> which saved me many hours. I found the  <a href="https://github.com/astoff/code-cells.el">code-cells</a>
emacs package to be very well-engineered and striking a good balance between
light-weight and feature complete.
</p>
</div>
</div>
</main>]]></content>
  <link href="https://isakfalk.com/notes/RC-halfbatch.html"/>
  <id>https://isakfalk.com/notes/RC-halfbatch.html</id>
  <updated>2025-12-28T21:51:00-05:00</updated>
</entry>
<entry>
  <title>RC Retrospect</title>
  <author><name>Isak Falk</name></author>
  <content type="html"><![CDATA[<main id="content" class="content"> <header> <h1 class="title">RC Retrospect</h1>
</header> <nav id="table-of-contents" role="doc-toc"> <details> <summary>Table of Contents</summary> <div id="text-table-of-contents" role="doc-toc">
 <ul> <li> <a href="#rc-is-a-place-and-a-context">RC is a place and a context</a></li>
 <li> <a href="#thinking-back-on-what-i-did">Thinking back on what I did</a></li>
 <li> <a href="#so-long-see-you-around">So long, see you around</a></li>
</ul></div>
</details></nav> <p>
This is a long overdue retrospect of my time at RC. I've been super busy with
job searching and preparing, but the other day I looked at the niceties sent out
to me by email, and it filled me with absolute joy (thanks to all who wrote
them, it really made me a little bit teary-eyed and emotional 🥲).
</p>
 <div id="outline-container-rc-is-a-place-and-a-context" class="outline-2">
 <h2 id="rc-is-a-place-and-a-context"> <a href="#rc-is-a-place-and-a-context">RC is a place and a context</a></h2>
 <div class="outline-text-2" id="text-rc-is-a-place-and-a-context">
 <p>
We have the hub in Brooklyn. This is the physical place of RC. But RC is so much
more than that, and it's also not static, but a changing living entity. Certain
things change more slowly, like the hub and the tools we use together with the
spectacular and kind admins. Other things do change rapidly (almost, like, every
6 weeks… <sup> <a id="fnr.1" class="footref" href="#fn.1" role="doc-backlink">1</a></sup>). Being out of batch and still going to the hub it feels..
different. Not in a bad way, but being out of batch and seeing so many new
people is great, but also weird. Life moves on, people never-graduate and may
not be around in the same way as before. The context change.
</p>

 <p>
RC is a context. Each batch is its own context with different people coming from
different backgrounds with different personalities and tastes. Each batch is big
enough that I would say the average stays reasonably similar from batch to
batch, but to be clear, RC is not about averages, it's about the individual
connections you make and the context you bring to your batch. From my own point
of view, RC was a perfect thing to do while I wait for my employment
authorization to come through. But really, it has shown to be so much more than
that. RC is and was a social context for me in New York. Somewhere to go and
hang out, interact with others and learn new things in an open and warm
environment.
</p>

 <p>
RC is remarkable. I think it attracts a certain kind of people wanting to become
better programmers and learn new things for the joy of it, rather than pure
career progression (although one may lead to the other, the other way around
maybe not so much). The directives are pretty clear: work at the edge of your
abilities, build your volitional muscles and learn generously. But as much as RC
asks you to apply these self-directives, I think there is a feedback loop the
other way around, where the context allow you to apply these successfully.
</p>

 <p>
RC as a place and context has allowed me to learn many things, make many friends
and understand better what I want from a social technology context.
</p>
</div>
</div>
 <div id="outline-container-thinking-back-on-what-i-did" class="outline-2">
 <h2 id="thinking-back-on-what-i-did"> <a href="#thinking-back-on-what-i-did">Thinking back on what I did</a></h2>
 <div class="outline-text-2" id="text-thinking-back-on-what-i-did">
 <p>
My focus shifted wildly during my batch. For my application I wrote that I
wanted to implement a machine learning library using scheme <sup> <a id="fnr.2" class="footref" href="#fn.2" role="doc-backlink">2</a></sup>. I quickly let go
of this idea and proceeded to socialize and get in the groove of the RC spirit.
For a large part of the first half of my batch I did a lot of coffee talks,
learned about HTML and CSS and how to use  <a href="building-this-website.html">org-mode and emacs to create this
blog</a>. A personal reflection is that I can be pretty harsh on myself. Looking
back at my previous RC notes ( <a href="RC-week-1.html">RC-week-1.html</a>,  <a href="RC-week-2.html">RC-week-2.html</a>,
 <a href="RC-week-3.html">RC-week-3.html</a>,  <a href="RC-week-4.html">RC-week-4.html</a> and  <a href="RC-halfbatch.html">RC-halfbatch.html</a>) I summarize
them below.
</p>

 <p>
First off, I thought that my web-development learning would be swift and I would
go on to do "bigger" things. In reality, this part took longer, on the other
hand, I actually built this website by hand (a labour of love!) and in itself
that is something beautiful. My  <a href="RC-week-2.html#getting-focused-on-project">main project</a> (which did not materialize at the
end during the batch, but I have some vague sense that this will be finished in
the future) started forming in my mind early and I still have the data around. I
did not finish it. I did do several presentations and non-presentations, but
could maybe have done better. However, I did still present! Looking at the
presentations have really been inspiring to me, enforcing the rule of learning
in the open and sharing what we've learned.
</p>

 <p>
For impossible day I learned to how to use flask to set up a web-server which
was great. Got into the weeds on how to use databases, minimal html and elements
used for functionality through  <a href="https://flask.palletsprojects.com/en/3.0.x/">flask</a>. One more step to becoming a full-stack ML
developer (I jest.. or?). This was a step forward, but I still  <a href="RC-week-4.html#zigzags-in-my-road">felt a bit
unfocused</a>. As a side-project I learned some algorithms and data structures, and
I can say that this has actually been a success and will hopefully help me land
a job soon.
</p>

 <p>
A thing I did not go into much was that I levelled up my toolsets and personal
workflow. I will probably make a note about this later but in short
</p>
 <ul class="org-ul"> <li>Fixed a satisfactory org-mode + jupyter kernel workflow which  <strong>actually works</strong>.</li>
 <li>Set up a VPN so I can connect to devices on my home network, which allowed for
the above computational notebook to actually be possible in the first place.</li>
 <li>Bought a GPU and installed it.</li>
</ul> <p>
All in all a great outcome, and possibly in the spirit of RC?
</p>

 <p>
I also learned Jax and Diffusion models, together with some mechanistic
interpretability, mostly for the family of LLMs and transformers. This was lead
by the great  <a href="https://www.changlinli.com/">Changlin Li</a>. Dipping into ML in this way has been great and
re-ignited a passion again.
</p>

 <p>
So where does this leave us (me)? Three months seem long, but it's short.
Chipping away day-by-day is the only way to keep going though and does lead to
returns, essentially yielding cumulative interest on what you have learned and
you knowledge. I feel like a more apt and capable technologist and understand
how computers work better and how to use the web to share more openly. This was
one of my goals coming into RC so this has been a great success. And this by
itself made RC worth it.
</p>
</div>
</div>
 <div id="outline-container-so-long-see-you-around" class="outline-2">
 <h2 id="so-long-see-you-around"> <a href="#so-long-see-you-around">So long, see you around</a></h2>
 <div class="outline-text-2" id="text-so-long-see-you-around">
 <p>
I will stick around, at least for the foreseeable future. I will also hang
around on the Zulip, so if you want to reach me, either email me (you can find
my contact in the footer) or reach out on Zulip. I promise I won't bite 😄.
</p>
</div>
</div>
 <div id="footnotes" class="Footnotes">
 <div id="text-footnotes">

 <div class="footdef"> <sup> <a id="fn.1" class="footnum" href="#fnr.1" role="doc-backlink">1</a></sup> <div class="footpara" role="doc-footnote"> <p class="footpara">
This is a boring joke because a batch start and end every 6 weeks.
</p></div></div>

 <div class="footdef"> <sup> <a id="fn.2" class="footnum" href="#fnr.2" role="doc-backlink">2</a></sup> <div class="footpara" role="doc-footnote"> <p class="footpara">
I have an
interesting relationship with lisp through emacs and wanted to learn more about
scheme which seemed to capture the essence of lisps.
</p></div></div>


</div>
</div></main>]]></content>
  <link href="https://isakfalk.com/notes/RC-retrospect.html"/>
  <id>https://isakfalk.com/notes/RC-retrospect.html</id>
  <updated>2025-12-28T21:51:00-05:00</updated>
</entry>
<entry>
  <title>RC: Week 1</title>
  <author><name>Isak Falk</name></author>
  <content type="html"><![CDATA[<main id="content" class="content"> <header> <h1 class="title">RC: Week 1</h1>
</header> <nav id="table-of-contents" role="doc-toc"> <details> <summary>Table of Contents</summary> <div id="text-table-of-contents" role="doc-toc">
 <ul> <li> <a href="#hello-rc">Hello RC!</a></li>
 <li> <a href="#a-truly-hybrid-week">A truly hybrid week</a></li>
 <li> <a href="#socializing-is-fun-but-how-am-i-supposed-to-remember-all-of-these-people">Socializing is fun, but how am I supposed to remember all of these people?</a></li>
 <li> <a href="#the-manual-and-online-resources">The Manual and online resources</a></li>
 <li> <a href="#going-forward">Going forward</a></li>
</ul></div>
</details></nav> <div id="outline-container-hello-rc" class="outline-2">
 <h2 id="hello-rc"> <a href="#hello-rc">Hello RC!</a></h2>
 <div class="outline-text-2" id="text-hello-rc">
 <p>
I just started the winter 1 batch of 2023 at the  <a href="https://www.recurse.com/">Recurse Center</a> (RC)! I'm very
happy, I wasn't sure that I was going to make it as I applied very late and
worried that I would miss the deadline for this batch. However, the process went
super smooth and here I am at then end of week 1 (actually start of week 3, but
I just now got my website up and running).
</p>

 <p>
I will be keeping notes on RC here, mostly weekly retrospects where I will
reflect and potentially plan for the coming week together with short updates
about what I'm doing.
</p>
</div>
</div>
 <div id="outline-container-a-truly-hybrid-week" class="outline-2">
 <h2 id="a-truly-hybrid-week"> <a href="#a-truly-hybrid-week">A truly hybrid week</a></h2>
 <div class="outline-text-2" id="text-a-truly-hybrid-week">
 <p>
RC operates a truly hybrid retreat where participants come from all over the
world (and timezones). As of writing this (day 2 of week 3, I'm a bit out of
sync but aim to catch up to my current day soon!) I am in Sweden as I am sorting
out my visa. Week 1 was split into two sections as the hub wasn't opened for the
Winter 1 2023 (or just W1'23 batch) until the Wednesday of this week. So Monday
and Tuesday were remote while the rest of the week I was present in the hub.
</p>

 <p>
I think this format really works well, since it is designed with being hybrid in
mind and not tacked onto an already existing model. I think the tools really
help here as I find Zulip a joy to use once you get the hang of it compared to
for example Slack and the virtual RC hub is fun and interactive. However, being
physically in the hub is also great, the space is very cool and inspiring (it
really has a nice, comfy hacker vibe).
</p>
</div>
</div>
 <div id="outline-container-socializing-is-fun-but-how-am-i-supposed-to-remember-all-of-these-people" class="outline-2">
 <h2 id="socializing-is-fun-but-how-am-i-supposed-to-remember-all-of-these-people"> <a href="#socializing-is-fun-but-how-am-i-supposed-to-remember-all-of-these-people">Socializing is fun, but how am I supposed to remember all of these people?</a></h2>
 <div class="outline-text-2" id="text-socializing-is-fun-but-how-am-i-supposed-to-remember-all-of-these-people">
 <p>
This week was spent mainly getting to know people of the batch, who they are,
what they do and what their plans were. According to the directory of
participants of my batch we are  <strong>45 people</strong> and additionally there are people
from the previous batch who are also around, making it  <strong>93 people in total</strong>. The
directory really helps keeping track since I can study who I met and do some
quick mental notes, but I have several times re-introduced myself to others at
the hub.
</p>

 <p>
I think taking an organic approach to this is best
</p>
 <ul class="org-ul"> <li>Try to meet as many people as you can.
 <ul class="org-ul"> <li>Go to events and be ready to try new things out</li>
 <li>Introduce yourself even if you may not have the time to chat properly as
this will set a context for further interaction and make it less awkward to
talk</li>
</ul></li>
 <li>Don't worry about meeting everyone, just be nice and things will sort itself
out in time</li>
 <li>Naturally you will gravitate towards some people, either due to vibing or just
because you share similar interests. However, don't let this make you insular,
form inclusive contexts rather than insular cliques</li>
</ul> <p>
Typing out the above it also strikes me that everyone present has a
responsibility to make this retreat inclusive and enjoyable. This is basically
some additional aspirations in addition to the  <a href="https://www.recurse.com/manual#sub-sec-social-rules">RC social rules</a>.
</p>
</div>
</div>
 <div id="outline-container-the-manual-and-online-resources" class="outline-2">
 <h2 id="the-manual-and-online-resources"> <a href="#the-manual-and-online-resources">The Manual and online resources</a></h2>
 <div class="outline-text-2" id="text-the-manual-and-online-resources">
 <p>
RC has an  <a href="https://www.recurse.com/manual">online manual</a> which lays out what
RC is, the environment and how to make the most out of your batch including
logistics, planning your stay and the philosophy of this
 <a href="https://www.recurse.com/manual#sec-welcome">unusual experience</a>. The manual
contains a lot of material, but I found it really comprehensive and great for
setting the tone for the retreat. In any case, it can be used as a reference as
you go through your batch.
</p>

 <p>
There are also other resources available internally which I've found really
helpful such as the directory of nevergraduated RC alumni. The wiki is great and
contains a lot of project that previous RC people have created.
</p>

 <p>
Finally, I think it's so refreshing to have things written down. Often we don't
find the time to actually document processes, which leads to repetition which
can become tiring, or worse, we may be uncertain what processes are or they may
be implicit.
</p>
</div>
</div>
 <div id="outline-container-going-forward" class="outline-2">
 <h2 id="going-forward"> <a href="#going-forward">Going forward</a></h2>
 <div class="outline-text-2" id="text-going-forward">
 <p>
At the end of the first week I felt excited, a bit overwhelmed, and slightly
anxious about what my project would be. The
 <a href="https://www.recurse.com/self-directives">self-directives</a> were mentioned a
lot and we also did some exercises about building our volitional muscles which I
appreciate since I have a tendency to become a bit paralyzed when trying to
scope out new projects and I should just get to it rather than just doing
background reading perpetually!
</p>

 <p>
It was a joy to meet everyone during the first week and looking forward to
meeting more people in the coming week, working together on different projects
and pair programme!
</p>
</div>
</div>
</main>]]></content>
  <link href="https://isakfalk.com/notes/RC-week-1.html"/>
  <id>https://isakfalk.com/notes/RC-week-1.html</id>
  <updated>2025-12-28T21:51:00-05:00</updated>
</entry>
<entry>
  <title>RC: Week 2</title>
  <author><name>Isak Falk</name></author>
  <content type="html"><![CDATA[<main id="content" class="content"> <header> <h1 class="title">RC: Week 2</h1>
</header> <nav id="table-of-contents" role="doc-toc"> <details> <summary>Table of Contents</summary> <div id="text-table-of-contents" role="doc-toc">
 <ul> <li> <a href="#building-a-social-connection">Building a social connection</a></li>
 <li> <a href="#building-this-website">Building this website</a></li>
 <li> <a href="#getting-focused-on-project">Getting focused on project</a></li>
 <li> <a href="#setting-intention-on-events-to-attend">Setting intention on events to attend</a></li>
</ul></div>
</details></nav> <div id="outline-container-building-a-social-connection" class="outline-2">
 <h2 id="building-a-social-connection"> <a href="#building-a-social-connection">Building a social connection</a></h2>
 <div class="outline-text-2" id="text-building-a-social-connection">
 <p>
This week I kept talking and meeting with people, a combination of scheduled
talks and impromptu meetings in the hub and especially the kitchen (especially
especially the coffee machine). I found the coffee bot to be a good way of
having an opt-in coffee chat with people that you haven't met yet, or to get to
know others better that I've already introduced myself to.
</p>

 <p>
I'm continuing to be impressed and happy about the pure breadth and diversity of
people in terms of background and talents. It's really great to get to know what
people are working on and realizing further fields which you may not even know
exist and / or know about but have to real clear picture of. I hope I'll be able
to meet everyone in my batch by the end of the retreat.
</p>
</div>
</div>
 <div id="outline-container-building-this-website" class="outline-2">
 <h2 id="building-this-website"> <a href="#building-this-website">Building this website</a></h2>
 <div class="outline-text-2" id="text-building-this-website">
 <p>
During this week I started learning about web development, HTML and CSS, in
order to build this website and have somewhere to publish my notes. I've started
this several times in the past but never gotten to the finishing line, so I see
this as a success in itself, a small win for me to get started.
</p>

 <p>
I will not go into the technicalities of this website and how I've decided to
structure my notes and pages. For a note on this, see my note on  <a href="building-this-website.html">how I build
this website</a>!
</p>
</div>
</div>
 <div id="outline-container-getting-focused-on-project" class="outline-2">
 <h2 id="getting-focused-on-project"> <a href="#getting-focused-on-project">Getting focused on project</a></h2>
 <div class="outline-text-2" id="text-getting-focused-on-project">
 <p>
During this week I had an idea which came to me, where I will build a "this RC
does not exist" <sup> <a id="fnr.1" class="footref" href="#fn.1" role="doc-backlink">1</a></sup> by taking photos of the hub and then fine-tune an image
generative model on this dataset. I already know a lot of machine learning, but
generative modelling has never been one of my fields of focus.
</p>

 <p>
I am excited about doing this project as it would allow me to do a full
ML-pipeline from data collection all the way to a user facing website serving
images. Let us see how this goes. I have already started collecting some images,
so it's a start!
</p>
</div>
</div>
 <div id="outline-container-setting-intention-on-events-to-attend" class="outline-2">
 <h2 id="setting-intention-on-events-to-attend"> <a href="#setting-intention-on-events-to-attend">Setting intention on events to attend</a></h2>
 <div class="outline-text-2" id="text-setting-intention-on-events-to-attend">
 <p>
There are so much you could potentially do at RC. It's easy to form study groups
or hold events around almost any topic of choice. It's a real smorgasbord of
opportunity, to dive deep in topics you already know or pick up something new
and go from beginner to intermediate in 3 months. Of course, the effect of
focusing on too many things is that you will not get too far into any one topic
in addition to taking energy and time from actually working on projects.
</p>

 <p>
During this week I realized that I cannot attend everything and that I should be
a bit more mindful where I spend my time and energy in order to finish a project
before the end of the batch. Some events I want to attend from time to time even
though it may not be related to my area of expertise (like the creative coding
session where we code something in 90 minutes using a small prompt to work the
creative muscles), but others I will have to leave for now as otherwise I will
spread myself too thin.
</p>
</div>
</div>
 <div id="footnotes" class="Footnotes">
 <div id="text-footnotes">

 <div class="footdef"> <sup> <a id="fn.1" class="footnum" href="#fnr.1" role="doc-backlink">1</a></sup> <div class="footpara" role="doc-footnote"> <p class="footpara">
Inspired by the family of "This X does not exist", like  <a href="https://thisrentaldoesnotexist.com/about/">this rental does not
exist</a>.
</p></div></div>


</div>
</div></main>]]></content>
  <link href="https://isakfalk.com/notes/RC-week-2.html"/>
  <id>https://isakfalk.com/notes/RC-week-2.html</id>
  <updated>2025-12-28T21:51:00-05:00</updated>
</entry>
<entry>
  <title>RC: Week 3</title>
  <author><name>Isak Falk</name></author>
  <content type="html"><![CDATA[<main id="content" class="content"> <header> <h1 class="title">RC: Week 3</h1>
</header> <nav id="table-of-contents" role="doc-toc"> <details> <summary>Table of Contents</summary> <div id="text-table-of-contents" role="doc-toc">
 <ul> <li> <a href="#this-is-now-a-remote-rc-for-me">This is now a remote RC for me</a></li>
 <li> <a href="#sharing-and-pairing">Sharing and pairing</a>
 <ul> <li> <a href="#presentations">Presentations</a></li>
 <li> <a href="#pairing">Pairing</a></li>
</ul></li>
 <li> <a href="#impossible-day">Impossible day</a></li>
</ul></div>
</details></nav> <div id="outline-container-this-is-now-a-remote-rc-for-me" class="outline-2">
 <h2 id="this-is-now-a-remote-rc-for-me"> <a href="#this-is-now-a-remote-rc-for-me">This is now a remote RC for me</a></h2>
 <div class="outline-text-2" id="text-this-is-now-a-remote-rc-for-me">
 <p>
I moved to NY just a couple of months before RC started. In order to get settled
I need to fix my J2 visa as my partner has a J1 visa already. To do this I had
to return to Sweden (where I am currently in the area of Malmo and Lund where
some of my siblings live). While it's been great seeing family it has been
harder to stay focused on the RC batch mostly due to having to find a new
routine being fully remote.
</p>

 <p>
Some of the things I think is different
</p>
 <ol class="org-ol"> <li>It's much harder to meet new people since you actively have to reach out or
use the RC virtual space rather than randomly meeting people in the kitchen
of the hub.</li>
 <li>With people around I can use their energy to energize myself, and I didn't
realize how much this was a thing until I went to Sweden.</li>
 <li>Generally, contacting people to pair programme and other activities just have
a bit higher friction which tires me out more.</li>
 <li>I'm on vampire-time where I stay up until 05:00 and get up around 12:00.</li>
</ol> <p>
Next week I will see if I can come up with ways to mitigate these points.
</p>
</div>
</div>
 <div id="outline-container-sharing-and-pairing" class="outline-2">
 <h2 id="sharing-and-pairing"> <a href="#sharing-and-pairing">Sharing and pairing</a></h2>
 <div class="outline-text-2" id="text-sharing-and-pairing">
 <p>
Previously I've often coded in isolation, and one thing I want to get out of RC is the social aspect of coding and learning together with others.
</p>
</div>
 <div id="outline-container-presentations" class="outline-3">
 <h3 id="presentations"> <a href="#presentations">Presentations</a></h3>
 <div class="outline-text-3" id="text-presentations">
 <p>
I don't think I'm a very good presenter, but I think I could be. One of the aims
that I have with RC is to overcome my (not super big) anxiety of speaking and
giving talks and presentations. My goal for the rest of the batch is to give a
presentation at least once a week. This weeks presentation was a small intro to
how meditation is often categorized in a Buddhist setting, I use  <a href="https://revealjs.com/">reveal.js</a> which
outputs to html so I will try to put the slides I generate throughout RC on this
webpage somehow!
</p>
</div>
</div>
 <div id="outline-container-pairing" class="outline-3">
 <h3 id="pairing"> <a href="#pairing">Pairing</a></h3>
 <div class="outline-text-3" id="text-pairing">
 <p>
I've paired with several people at this point and I really enjoy it. It's a bit
of a double edged sword as it can really drain you since your pairing partner
usually keeps you focused on the task and subjectively it feels like you engage
your system 2 more than your mindless system 1. I think that on the whole this
reduces the work needed in the end and also add a nice social aspect to coding!
</p>
</div>
</div>
</div>
 <div id="outline-container-impossible-day" class="outline-2">
 <h2 id="impossible-day"> <a href="#impossible-day">Impossible day</a></h2>
 <div class="outline-text-2" id="text-impossible-day">
 <p>
This week we did the so called "impossible day" where we set our goal on doing
something impossible, that is, outside of what we expected to be able to do.
I've slowly getting to grip with web-development, and wanted to make a
side-project to my generative ML project where I want to  <a href="RC-week-2.html#getting-focused-on-project">generate novel images
of the RC hub</a> in the form of an online portal where people could upload images
of the hub which I store in a database. Just days before we had gotten $100
credits at  <a href="https://render.com/">render.com</a> which I wanted to use to improve my understanding on how
to actually build a functioning web-app and deploy it and make it available
using a Recurse sub-domain. Here is the  <a href="https://rcnexists.recurse.com/">resulting website</a>!
</p>

 <p>
With some help from some good web-dev people through pair programming I actually
managed to get something together and deploy it successfully (although extremely
bare-bones and barely working). All in all, I managed to
</p>
 <ul class="org-ul"> <li>Learn how to use the  <a href="https://flask.palletsprojects.com/en/3.0.x/">flask framework</a>.</li>
 <li>Set up a postgresql server to store the images through render.com and integrate it into the code.</li>
 <li>Accept user input using some  <code>input</code> HTML tags.</li>
 <li>Set up a Recurse sub-domain to point to the deployed render server.</li>
</ul> <p>
This was a great experience and I am thankful to those that helped me out!
</p>
</div>
</div>
</main>]]></content>
  <link href="https://isakfalk.com/notes/RC-week-3.html"/>
  <id>https://isakfalk.com/notes/RC-week-3.html</id>
  <updated>2025-12-28T21:51:00-05:00</updated>
</entry>
<entry>
  <title>RC: Week 4</title>
  <author><name>Isak Falk</name></author>
  <content type="html"><![CDATA[<main id="content" class="content"> <header> <h1 class="title">RC: Week 4</h1>
</header> <nav id="table-of-contents" role="doc-toc"> <details> <summary>Table of Contents</summary> <div id="text-table-of-contents" role="doc-toc">
 <ul> <li> <a href="#zigzags-in-my-road">Zig-zags in my road</a></li>
 <li> <a href="#machine-learning-project">Machine learning project</a></li>
 <li> <a href="#learning-algorithms">Learning algorithms</a></li>
</ul></div>
</details></nav> <div id="outline-container-zigzags-in-my-road" class="outline-2">
 <h2 id="zigzags-in-my-road"> <a href="#zigzags-in-my-road">Zig-zags in my road</a></h2>
 <div class="outline-text-2" id="text-zigzags-in-my-road">
 <p>
This week was a bit all over the place. I felt like I got some things done, but
on the other hand I also felt like I lacked some focus with respect to the
project. Often I feel that I overestimate what I can get done in a day; I see
this in my daily checkins as I often have several points in my todo list for
that day which are not ticked off due to unforeseen issues, social meetings or
events. At the very least I usually tick some off which is still a win, but it
makes me wonder if I should just internalize this and make my estimations more
accurate.
</p>

 <p>
I also wonder if I'm maybe being too hard on myself. Yes, there are some things
which if I worked harder would have completed, but also I did other things (such
as the above social meetings and events) which also feel important. Finding the
balance between things is just inherently hard I believe!
</p>
</div>
</div>
 <div id="outline-container-machine-learning-project" class="outline-2">
 <h2 id="machine-learning-project"> <a href="#machine-learning-project">Machine learning project</a></h2>
 <div class="outline-text-2" id="text-machine-learning-project">
 <p>
Slowly  <a href="RC-week-2.html#getting-focused-on-project">this project</a> is ramping up and this week was interesting as I got to
learn some new data science tools I haven't worked with before. The dataset I
have collected comes without labels as I collected it myself, so I thought I
should also label it for whatever downstream task I decide to use it for. I
settled on the  <a href="https://labelstud.io/">Label Studio</a> python library which seems very powerful and lets
you define your own annotation UI through XML (I've been learning HTML / CSS for
this website, so using XML was not as much of a pain as I thought it would be).
</p>

 <p>
The dataset now consists of about 350 images from around the hub, many of them
images of signs or text which I will try to fit a generative image model to.
Here are some example images
</p>


 <figure id="org58ca305"> <img src="../assets/images/rc_nexists/5.webp" alt="5.webp"></img> <figcaption> <span class="figure-number">Figure 1: </span>A sign of 5, looks a bit industrial</figcaption></figure> <figure id="org5356f6b"> <img src="../assets/images/rc_nexists/next.webp" alt="next.webp"></img> <figcaption> <span class="figure-number">Figure 2: </span>Some really old computers from the hub</figcaption></figure></div>
</div>
 <div id="outline-container-learning-algorithms" class="outline-2">
 <h2 id="learning-algorithms"> <a href="#learning-algorithms">Learning algorithms</a></h2>
 <div class="outline-text-2" id="text-learning-algorithms">
 <p>
I've been following the  <a href="https://ocw.mit.edu/courses/6-006-introduction-to-algorithms-spring-2020/">MIT 6.006 course</a> as I never had a proper introduction to
algorithms before, which I felt has been holding me back when programming as I
did not have a good grasp of more fundamental algorithms (sorting, shortest
path, etc.) and data structures (arrays, dictionaries, sets, etc.). I'm really
enjoying myself so far, a bit more than halfway through the course.
</p>

 <p>
In addition to the above I've been slowly doing LeetCode problems together with
some other RC attendants and it's been great fun. Hope to get some more problems under my belt and tackle medium and hard in due time!
</p>
</div>
</div>
</main>]]></content>
  <link href="https://isakfalk.com/notes/RC-week-4.html"/>
  <id>https://isakfalk.com/notes/RC-week-4.html</id>
  <updated>2025-12-28T21:51:00-05:00</updated>
</entry>
<entry>
  <title>Building this website</title>
  <author><name>Isak Falk</name></author>
  <content type="html"><![CDATA[<main id="content" class="content"> <header> <h1 class="title">Building this website</h1>
</header> <nav id="table-of-contents" role="doc-toc"> <details> <summary>Table of Contents</summary> <div id="text-table-of-contents" role="doc-toc">
 <ul> <li> <a href="#starting-out">Starting out</a></li>
 <li> <a href="#how-to-structure-posts">How to structure posts</a></li>
</ul></div>
</details></nav> <div id="outline-container-starting-out" class="outline-2">
 <h2 id="starting-out"> <a href="#starting-out">Starting out</a></h2>
 <div class="outline-text-2" id="text-starting-out">
 <p>
What spurred me to redo my website and blog is the fact that I got accepted to
RC and they encourage us to learn openly. This is something that I've wanted to
do for a long time anyway, so I felt that now is the time to get this website up
and running.
</p>

 <p>
Basically, I know nothing (or  <strong>very little</strong>) about web development, and would
like to get up to speed where at least I am  <em>comfortable</em> adjusting my website.
I plan on using emacs to the extent possible, and I'll put the source of how I
build the website together with the source of the actual notes and static
pages + assets online at  <a href="https://sr.ht/">src.hut</a> and eventually building it there.
</p>
</div>
</div>
 <div id="outline-container-how-to-structure-posts" class="outline-2">
 <h2 id="how-to-structure-posts"> <a href="#how-to-structure-posts">How to structure posts</a></h2>
 <div class="outline-text-2" id="text-how-to-structure-posts">
 <p>
Each org mode file will have some options relating to information which will be
taken care of when the site is published. In practice, this means that the
resulting  <code>HTML</code> will take into account this options somehow.
</p>

 <p>
Posts should have the following functionality
</p>
 <ul class="org-ul"> <li>Time created</li>
 <li>Latest time changed</li>
 <li>Tags</li>
</ul> <p>
and rely as much as possible on  <a href="https://orgmode.org/manual/Export-Settings.html">already defined options of org mode</a>.
</p>
</div>
</div>
</main>]]></content>
  <link href="https://isakfalk.com/notes/building-this-website.html"/>
  <id>https://isakfalk.com/notes/building-this-website.html</id>
  <updated>2025-12-28T21:51:00-05:00</updated>
</entry>
<entry>
  <title>Diffusion and score-based generative modeling</title>
  <author><name>Isak Falk</name></author>
  <content type="html"><![CDATA[<main id="content" class="content"> <header> <h1 class="title">Diffusion and score-based generative modeling</h1>
</header> <nav id="table-of-contents" role="doc-toc"> <details> <summary>Table of Contents</summary> <div id="text-table-of-contents" role="doc-toc">
 <ul> <li> <a href="#diffusion-and-scorebased-generative-modeling">Diffusion and score-based generative modeling</a>
 <ul> <li> <a href="#setup">Setup</a></li>
 <li> <a href="#score-matching">Score matching</a></li>
 <li> <a href="#generating-samples">Generating samples</a></li>
</ul></li>
 <li> <a href="#implementation">Implementation</a>
 <ul> <li> <a href="#imports">Imports</a></li>
 <li> <a href="#tractable-mixture-models">Tractable mixture models</a></li>
 <li> <a href="#learning-the-score-function">Learning the score function</a></li>
</ul></li>
 <li> <a href="#conclusion">Conclusion</a></li>
 <li> <a href="#reference">Reference</a></li>
</ul></div>
</details></nav> <div id="outline-container-diffusion-and-scorebased-generative-modeling" class="outline-2">
 <h2 id="diffusion-and-scorebased-generative-modeling"> <a href="#diffusion-and-scorebased-generative-modeling">Diffusion and score-based generative modeling</a></h2>
 <div class="outline-text-2" id="text-diffusion-and-scorebased-generative-modeling">
 <p>
There are great blog posts of what diffusion and score matching is elsewhere, in
particular, see  <a href="https://lilianweng.github.io/posts/2021-07-11-diffusion-models/">Lilian Weng's literature review</a> and the great exposition of  <a href="https://yang-song.net/blog/2021/score/">Yang
Song on learning score functions for generative modeling</a>. Here I will mainly
lean on the blog post of Yang Song and his and his collaborators paper
Generative Modeling by Estimating Gradients of the Data Distribution
( <a href="#citeproc_bib_item_1">Song and Ermon 2020</a>) as I find it very comprehensive and
well-written.
</p>

 <p>
Some of the sections are pretty technical, for the actual implementation you
only need to
</p>
 <ul class="org-ul"> <li>Read the  <a href="#setup">Setup section</a>,</li>
 <li>Understand how the loss \(\ell(\theta; \sigma)\) defined in
\ref{eq:simplified-score-matching-objective} is used to build the optimization
objective \(\hat{\mathcal{E}}(\theta; (\sigma_{l})_{l=1}^{L})\) defined in
\ref{eq:aggregated-final-empirical-risk} which we train on to produce the
estimator \(\hat{\theta}_{n}\) which are the learned parameters of the score
network \(s_{\theta}\),</li>
 <li>Read the  <a href="#generating-samples-1">Generating samples section</a> to understand how to generate samples using \(s_{\hat{\theta}_{n}}\),</li>
</ul> <p>
and you should then be able to follow along with the  <a href="#implementation">implementation section</a>.
</p>
</div>
 <div id="outline-container-setup" class="outline-3">
 <h3 id="setup"> <a href="#setup">Setup</a></h3>
 <div class="outline-text-3" id="text-setup">
 <p>
To start with, we assume that we have a dataset of iid samples
\((x_{i})_{i=1}^{n}\) sampled from some unknown data distribution \(p^{\ast}\),
where the datapoints live in some space \(\mathcal{X}\) which we will take to be
some Euclidean vector space (for example, \(\mathbb{R}^{D}\) for a vector or
\(\mathbb{R}^{H \times W \times C}\) for an image with width \(W\), height \(H\)
and \(C\) color channels). Everything is nice so we assume that \(p^{\ast}\) has
a pdf and identify the distribution with this pdf (so \(p^{\ast}(x)\) is the
density at \(x\)). The goal is to learn a model which would allow us to sample
from \(p^{\ast}\). One way to do this would be to model \(p^{\ast}\) directly,
but as fortune has it, it is enough to learn a model of the  <strong>score function</strong>
\(s^{\ast}(x) = \nabla_{x} \log p^{\ast}(x)\) to accomplish this.
</p>

 <p>
Learning the score function using score matching allows for much easier training
and modelling than trying to learn a model of \(p^{\ast}\) directly. This is due
to not having to learn a properly normalized distribution but only up to a
constant. If rewrite \(p^{\ast}(x) = \exp(-f^{\ast}(x))/Z^{\ast}\), the score function takes
the form \(-\nabla_{x} f^{\ast}(x)\) since \[ \nabla_{x} \log p^{\ast}(x) =
-\nabla_{x} f^{\ast}(x) - \nabla_{x}\log Z^{\ast} = -\nabla_{x} f^{\ast}(x) \] as \(Z^{\ast}\)
is independent of \(x\).
</p>
</div>
</div>
 <div id="outline-container-score-matching" class="outline-3">
 <h3 id="score-matching"> <a href="#score-matching">Score matching</a></h3>
 <div class="outline-text-3" id="text-score-matching">
 <p>
Score matching aim to minimize the least-squares objective
</p>
\begin{equation}
\label{eq:lsq-score-matching-objective}
\frac{1}{2}\mathbb{E}_{X \sim p^{\ast}}\|s_{\theta}(X) - s^{\ast}(X)\|^{2}
\end{equation}
 <p>
where \(s_{\theta}: \mathcal{X} \to \mathcal{X}\) is a model of the score
function, for example a neural network. Of course, we don't know \(s^{\ast}\) so
this objective is not very good, but it can be shown to be proportional to
</p>
\begin{equation}
\label{eq:tr-score-matching-objective}
\mathbb{E}_{X \sim p^{\ast}}\left[\mathrm{tr}\left(\nabla_{x} s_{\theta}(X))\right) + \frac{1}{2}\|s_{\theta}(X)\|_{2}^{2}\right].
\end{equation}
 <p>
In practice, we replace the distribution \(p^{\ast}\) by the empirical version
\(\hat{p}_{n}\) using the train dataset \((x_{i})_{i=1}^{n}\). When the input
dimension is large the trace computation becomes too computational burdensome so
we rely on other approximation. We will use denoising score matching, but there
are other ways, in ( <a href="#citeproc_bib_item_1">Song and Ermon 2020</a>) they also
mention sliced score matching as an alternative.
</p>

 <p>
To get to the point, denoising score matching replaces the distribution
\(p^{\ast}\) with a smoothed version \(q_{\sigma}(x) = \mathbb{E}_{X' \sim
p^{\ast}}\left[q_{\sigma}(x|X')\right]\) where \(q_{\sigma}\) is a some
symmetric bell-curved distribution, for example a gaussian with standard
deviation \(\sigma\) and mean \(X'\). Intuitively the scale parameter \(\sigma\)
allow us to trade off some bias for variance by interpolating between the true
(empirical) distribution as \(\sigma \to 0\) and a uniform distribution as
\(\sigma \to \infty\) <sup> <a id="fnr.1" class="footref" href="#fn.1" role="doc-backlink">1</a></sup>, in addition to making training possible as it makes
the resulting smoothed empirical distribution have full support on
\(\mathcal{X}\) (so, it is never zero anywhere). Without this smoothing,
\(\hat{p}_{n}\) will always be zero on points outside of the train set which
comes with all kinds of problems. Choosing \(q_{\sigma}(x | x')\) to be an
isotropic Gaussian pdf / distribution with covariance matrix \(\sigma I\) and
mean \(x'\) simplifies objective \ref{eq:tr-score-matching-objective} to
</p>
\begin{equation}
\label{eq:simplified-score-matching-objective}
\mathcal{L}(\theta; \sigma) = \frac{1}{2}\mathbb{E}_{X \sim p^{\ast}}\mathbb{E}_{X' \sim q_{\sigma}(\cdot | X)}\left[\|s_{\theta}(X', \sigma) - (- (X' - X)/\sigma^{2})\|_{2}^{2}\right]
\end{equation}
 <p>
where both the risk \(\mathcal{L}\) and the score model \(s_{\theta}\) are now
indexed by \(\sigma\). We may think of this as parameterizing a family of score
models by \(\sigma\) for some fixed \(\theta\). Let's call the empirical risk
\(\ell(\theta; \sigma)\) where we replace \(p^{\ast}\) with the empirical
distribution \(\hat{p}_{n}\). The final objective defining the Noise
Conditional Score Network above average losses over a geometrically spaced grid
of scales \(\sigma\). For such a grid \((\sigma_{l})_{l=1}^{L}\) we have
</p>
\begin{equation}
\label{eq:aggregated-final-empirical-risk}
\hat{\mathcal{E}}(\theta; (\sigma_{l})_{l=1}^{L}) = \frac{1}{L}\sum_{l=1}^{L}\lambda(\sigma_{l})\ell(\theta; \sigma_{l})
\end{equation}
 <p>
where \(\lambda\) is some weighing function which we will fix to be \(\lambda(\sigma) = \sigma^{2}\)
according to the heuristic in ( <a href="#citeproc_bib_item_1">Song and Ermon 2020</a>). Let us call the learned parameters \(\hat{\theta}_{n}\).
</p>
</div>
</div>
 <div id="outline-container-generating-samples" class="outline-3">
 <h3 id="generating-samples"> <a href="#generating-samples">Generating samples</a></h3>
 <div class="outline-text-3" id="text-generating-samples">
 <p>
We can use Langevin dynamics to produce samples from the learned score model \(s_{\hat{\theta}_{n}}\).
Usually, Langevin dynamics allow us to sample from some distribution \(p\) as
long as we can evaluate the score function \(\nabla_{x} \log p(x)\). Fixing a
step size (or more generally, a schedule) \(\eta\) and some prior distribution
\(\pi\) we can sample an initial value \(x_{0}\) and iterate using
</p>
\begin{equation}
\label{eq:langevin-dynamics}
x_{t+1} = x_{t} + \frac{\eta}{2}\nabla_{x}\log p(x_{t}) + \sqrt{\eta}Z_{t}
\end{equation}
 <p>
where \(Z_{t}\)'s are sampled iid from a unit Gaussian. Replacing
\(\nabla_{x}\log p(x)\) with \(s_{\hat{\theta}_{n}}(x)\) we can generate samples
hopefully resembling those from \(p^{\ast}\).
</p>

 <p>
More generally, for any procedure which produces samples from a distribution
\(p\) using only the score function, we can plug-in \(s_{\hat{\theta}_{n}}\) which
we've learned and produce samples, using the  <a href="http://www.math.chalmers.se/Stat/Grundutb/GU/MSA100/A11/lecture6.pdf">plugin-estimator method</a>.
This is pretty nice, we can tap into all the work which has been done in the
field of MCMC <sup> <a id="fnr.2" class="footref" href="#fn.2" role="doc-backlink">2</a></sup>, for example Hamiltonian Monte-Carlo or NUTS. The decoupling of
training and inference leads to many benefits, as we can repurpose
\(s_{\hat{\theta}_{n}}\) for other downstream tasks.
</p>
</div>
</div>
</div>
 <div id="outline-container-implementation" class="outline-2">
 <h2 id="implementation"> <a href="#implementation">Implementation</a></h2>
 <div class="outline-text-2" id="text-implementation">
</div>
 <div id="outline-container-imports" class="outline-3">
 <h3 id="imports"> <a href="#imports">Imports</a></h3>
 <div class="outline-text-3" id="text-imports">
 <div class="org-src-container">
 <pre class="src src-python"> <span class="org-keyword">import</span> functools
 <span class="org-keyword">import</span> math

 <span class="org-keyword">import</span> matplotlib.pyplot  <span class="org-keyword">as</span> plt
 <span class="org-keyword">import</span> numpy  <span class="org-keyword">as</span> np
 <span class="org-keyword">import</span> seaborn  <span class="org-keyword">as</span> sns
sns.set_style( <span class="org-string">"white"</span>)

 <span class="org-keyword">import</span> jax.numpy  <span class="org-keyword">as</span> jnp
 <span class="org-keyword">from</span> jax  <span class="org-keyword">import</span> grad, jax, vmap, lax
 <span class="org-keyword">from</span> jax  <span class="org-keyword">import</span> random
 <span class="org-keyword">from</span> jax  <span class="org-keyword">import</span> value_and_grad
 <span class="org-keyword">import</span> jax.tree_util  <span class="org-keyword">as</span> jtu
 <span class="org-keyword">import</span> jax

 <span class="org-keyword">import</span> equinox  <span class="org-keyword">as</span> eqx
 <span class="org-keyword">import</span> optax
 <span class="org-keyword">from</span> jaxtyping  <span class="org-keyword">import</span> Array, Float, Int, PyTree

 <span class="org-keyword">import</span> tensorflow_datasets  <span class="org-keyword">as</span> tfds
 <span class="org-keyword">from</span> tensorflow_probability.substrates  <span class="org-keyword">import</span> jax  <span class="org-keyword">as</span> tfp
 <span class="org-variable-name">tfd</span>  <span class="org-operator">=</span> tfp.distributions
 <span class="org-variable-name">tfb</span>  <span class="org-operator">=</span> tfp.bijectors
 <span class="org-variable-name">tfpk</span>  <span class="org-operator">=</span> tfp.math.psd_kernels
</pre>
</div>

 <pre class="example">
2024-02-05 13:35:57.262064: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-05 13:35:57.262106: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-05 13:35:57.263175: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-02-05 13:35:58.023617: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
</pre>
</div>
</div>
 <div id="outline-container-tractable-mixture-models" class="outline-3">
 <h3 id="tractable-mixture-models"> <a href="#tractable-mixture-models">Tractable mixture models</a></h3>
 <div class="outline-text-3" id="text-tractable-mixture-models">
 <p>
For some very simple models such as mixtures of tractable base models or  <a href="https://www.tensorflow.org/probability/api_docs/python/tfp/bijectors/Bijector">using
bijectors</a> we don't even need to learn the score function since it's available
to us in closed form. The simplest way to get some intuition for this is to
visualize the the log-probability function \(\log p(x)\) using for
example level-sets and the vector field corresponding to \(s(x)\) in 2 dimensions.
</p>

 <div class="org-src-container">
 <pre class="src src-python" id="orgb317737"> <span class="org-keyword">def</span>  <span class="org-function-name">plot_logdistribution</span>(fig, ax, distribution, xlim <span class="org-operator">=</span>( <span class="org-operator">-</span>1.0, 1.0), ylim <span class="org-operator">=</span>( <span class="org-operator">-</span>1.0, 1.0), n_contour <span class="org-operator">=</span>100, n_quiver <span class="org-operator">=</span>10):
     <span class="org-comment-delimiter"># </span> <span class="org-comment">Define the grid for contour
</span>     <span class="org-variable-name">x</span>  <span class="org-operator">=</span> np.linspace( <span class="org-operator">*</span>xlim, n_contour)
     <span class="org-variable-name">y</span>  <span class="org-operator">=</span> np.linspace( <span class="org-operator">*</span>ylim, n_contour)
     <span class="org-variable-name">X</span>,  <span class="org-variable-name">Y</span>  <span class="org-operator">=</span> np.meshgrid(x, y)
     <span class="org-variable-name">XY</span>  <span class="org-operator">=</span> np.stack([X.ravel(), Y.ravel()], axis <span class="org-operator">=-</span>1)

     <span class="org-comment-delimiter"># </span> <span class="org-comment">Compute the log-distribution
</span>     <span class="org-variable-name">Z</span>  <span class="org-operator">=</span> distribution.log_prob(XY).reshape(n_contour, n_contour)
     <span class="org-variable-name">cont</span>  <span class="org-operator">=</span> ax.contour(X, Y, Z)
    plt.colorbar(cont, ax <span class="org-operator">=</span>ax)

     <span class="org-comment-delimiter"># </span> <span class="org-comment">Compute the gradients
</span>     <span class="org-variable-name">x</span>  <span class="org-operator">=</span> np.linspace( <span class="org-operator">*</span>xlim, n_quiver)
     <span class="org-variable-name">y</span>  <span class="org-operator">=</span> np.linspace( <span class="org-operator">*</span>ylim, n_quiver)
     <span class="org-variable-name">X</span>,  <span class="org-variable-name">Y</span>  <span class="org-operator">=</span> np.meshgrid(x, y)
     <span class="org-variable-name">XY</span>  <span class="org-operator">=</span> np.stack([X.ravel(), Y.ravel()], axis <span class="org-operator">=-</span>1)
     <span class="org-variable-name">grads</span>  <span class="org-operator">=</span> vmap(grad(distribution.log_prob))(XY)
     <span class="org-variable-name">grad_X</span>  <span class="org-operator">=</span> grads[:, 0].reshape(n_quiver, n_quiver)
     <span class="org-variable-name">grad_Y</span>  <span class="org-operator">=</span> grads[:, 1].reshape(n_quiver, n_quiver)
    ax.quiver(X, Y, grad_X, grad_Y)
     <span class="org-keyword">return</span> fig, ax
</pre>
</div>

 <p>
We simply show the level sets and quiver plot (the vector field) of the log-distribution and the score function
</p>
 <div class="org-src-container">
 <pre class="src src-python"> <span class="org-variable-name">key</span>  <span class="org-operator">=</span> random.PRNGKey(0)   <span class="org-comment-delimiter"># </span> <span class="org-comment">Use a different key for different runs
</span>
 <span class="org-comment-delimiter"># </span> <span class="org-comment">Define a 2-component Gaussian Mixture model
</span> <span class="org-variable-name">num_components</span>  <span class="org-operator">=</span> 2
 <span class="org-variable-name">component_means</span>  <span class="org-operator">=</span> [(0.5, 0.5), ( <span class="org-operator">-</span>0.5,  <span class="org-operator">-</span>0.5)]
 <span class="org-variable-name">sd</span>  <span class="org-operator">=</span> 0.4
 <span class="org-variable-name">component_sds</span>  <span class="org-operator">=</span> [(sd, sd), (sd, sd)]
 <span class="org-variable-name">p1</span>  <span class="org-operator">=</span> 0.5
 <span class="org-variable-name">component_probs</span>  <span class="org-operator">=</span> [p1, 1  <span class="org-operator">-</span> p1]

 <span class="org-variable-name">mixture_dist</span>  <span class="org-operator">=</span> tfd.Categorical(probs <span class="org-operator">=</span>component_probs)
 <span class="org-variable-name">component_dist</span>  <span class="org-operator">=</span> tfd.MultivariateNormalDiag(loc <span class="org-operator">=</span>component_means, scale_diag <span class="org-operator">=</span>component_sds)
 <span class="org-variable-name">mixture_model</span>  <span class="org-operator">=</span> tfd.MixtureSameFamily(
    mixture_distribution <span class="org-operator">=</span>mixture_dist,
    components_distribution <span class="org-operator">=</span>component_dist,
    name <span class="org-operator">=</span> <span class="org-string">"MoG"</span>
)

 <span class="org-variable-name">fig</span>,  <span class="org-variable-name">ax</span>  <span class="org-operator">=</span> plt.subplots(figsize <span class="org-operator">=</span>(8, 6))
 <span class="org-variable-name">fig</span>,  <span class="org-variable-name">ax</span>  <span class="org-operator">=</span> plot_logdistribution(fig, ax, mixture_model)
</pre>
</div>


 <figure id="org3fb9668"> <img src="../assets/images/diffusion_models/mog-quiver-plot.webp" alt="Quiver plot of the score function of a mixture of Gaussians" width="800" loading="lazy"></img> <figcaption> <span class="figure-number">Figure 1: </span>The score function points towards the means of the component means (the peaks)</figcaption></figure></div>
 <div id="outline-container-generating-samples-1" class="outline-4">
 <h4 id="generating-samples-1"> <a href="#generating-samples-1">Generating samples</a></h4>
 <div class="outline-text-4" id="text-generating-samples-1">
 <p>
We already know from  <a href="#generating-samples-1">the previous section on generating samples</a> how to do this,
and the implementation is straightforward.
</p>

 <p>
Let's quickly enable plotting the distribution. We will use this as a background
for the evolving particle systems according to the Langevin dynamics.
</p>
 <div class="org-src-container">
 <pre class="src src-python"> <span class="org-keyword">def</span>  <span class="org-function-name">plot_distribution</span>(fig, ax, distribution, xlim <span class="org-operator">=</span>( <span class="org-operator">-</span>1.0, 1.0), ylim <span class="org-operator">=</span>(1.0, 1.0), n_contour <span class="org-operator">=</span>100):
     <span class="org-comment-delimiter"># </span> <span class="org-comment">Define the grid for contour
</span>     <span class="org-variable-name">x</span>  <span class="org-operator">=</span> np.linspace( <span class="org-operator">*</span>xlim, n_contour)
     <span class="org-variable-name">y</span>  <span class="org-operator">=</span> np.linspace( <span class="org-operator">*</span>ylim, n_contour)
     <span class="org-variable-name">X</span>,  <span class="org-variable-name">Y</span>  <span class="org-operator">=</span> np.meshgrid(x, y)
     <span class="org-variable-name">XY</span>  <span class="org-operator">=</span> np.stack([X.ravel(), Y.ravel()], axis <span class="org-operator">=-</span>1)

     <span class="org-comment-delimiter"># </span> <span class="org-comment">Compute the distribution
</span>     <span class="org-variable-name">Z</span>  <span class="org-operator">=</span> distribution.prob(XY).reshape(n_contour, n_contour)
     <span class="org-variable-name">cont</span>  <span class="org-operator">=</span> ax.contour(X, Y, Z)
     <span class="org-keyword">return</span> fig, ax
</pre>
</div>

 <p>
We define the update step (return tuple due to using  <kbd>lax.scan</kbd> later) and evolve a particle over many steps,  <kbd>lax.scan</kbd> simply makes this efficient.
</p>
 <div class="org-src-container">
 <pre class="src src-python"> <span class="org-keyword">def</span>  <span class="org-function-name">update_x</span>(x, z, distribution, step_size):
     <span class="org-variable-name">g</span>  <span class="org-operator">=</span> grad(distribution.log_prob)(x)
     <span class="org-variable-name">xp1</span>  <span class="org-operator">=</span> x  <span class="org-operator">+</span> (step_size  <span class="org-operator">/</span> 2)  <span class="org-operator">*</span> g  <span class="org-operator">+</span> jnp.sqrt(step_size)  <span class="org-operator">*</span> z
     <span class="org-keyword">return</span> xp1, xp1

 <span class="org-variable-name">step_size</span>  <span class="org-operator">=</span> 0.01
 <span class="org-variable-name">num_steps</span>  <span class="org-operator">=</span> 200
 <span class="org-variable-name">key</span>  <span class="org-operator">=</span> random.PRNGKey(0)
 <span class="org-variable-name">z_key</span>,  <span class="org-variable-name">x0_key</span>,  <span class="org-variable-name">key</span>  <span class="org-operator">=</span> random.split(key, 3)
 <span class="org-variable-name">z</span>  <span class="org-operator">=</span> random.normal(z_key, shape <span class="org-operator">=</span>(num_steps, 2))
 <span class="org-variable-name">x0</span>  <span class="org-operator">=</span> random.normal(x0_key, shape <span class="org-operator">=</span>(2,))  <span class="org-operator">*</span> 0.5
 <span class="org-variable-name">update_fun</span>  <span class="org-operator">=</span> functools.partial(update_x, distribution <span class="org-operator">=</span>mixture_model, step_size <span class="org-operator">=</span>step_size)
 <span class="org-variable-name">final</span>,  <span class="org-variable-name">result</span>  <span class="org-operator">=</span> lax.scan(update_fun, x0, z)
</pre>
</div>

 <p>
Let's look at the result in this case. To see the path of the particle more
clearly I'll outline the path by simply drawing a line between each point in
 <kbd>result</kbd>.
</p>
 <div class="org-src-container">
 <pre class="src src-python"> <span class="org-variable-name">fig</span>,  <span class="org-variable-name">ax</span>  <span class="org-operator">=</span> plt.subplots(figsize <span class="org-operator">=</span>(6, 6))
 <span class="org-variable-name">fig</span>,  <span class="org-variable-name">ax</span>  <span class="org-operator">=</span> plot_distribution(fig, ax, mixture_model, xlim <span class="org-operator">=</span>( <span class="org-operator">-</span>1.5, 1.5), ylim <span class="org-operator">=</span>( <span class="org-operator">-</span>1.5, 1.5))
ax.plot(result[:, 0], result[:, 1], marker <span class="org-operator">=</span> <span class="org-string">"."</span>, linewidth <span class="org-operator">=</span>1.0)
</pre>
</div>


 <figure id="org2842830"> <img src="../assets/images/diffusion_models/mog-langevin-dynamics.webp" alt="Mixture of Gaussians Langevin dynamics of a particle" width="800" loading="lazy"></img> <figcaption> <span class="figure-number">Figure 2: </span>Path of a particle following the Langevin dynamics of the mixture model</figcaption></figure> <p>
Finally, let's make a video for goodies. The video is the above plot shown in
time, following the particle according to the Langevin dynamics.
</p>
 <div class="org-src-container">
 <pre class="src src-python"> <span class="org-comment-delimiter"># </span> <span class="org-comment">We will animate this using the FuncAnimation class from matplotlib
</span> <span class="org-keyword">from</span> matplotlib.animation  <span class="org-keyword">import</span> FuncAnimation
 <span class="org-variable-name">fig</span>,  <span class="org-variable-name">ax</span>  <span class="org-operator">=</span> plt.subplots(figsize <span class="org-operator">=</span>(6, 6))
 <span class="org-variable-name">fig</span>,  <span class="org-variable-name">ax</span>  <span class="org-operator">=</span> plot_distribution(fig, ax, mixture_model, xlim <span class="org-operator">=</span>( <span class="org-operator">-</span>1.5, 1.5), ylim <span class="org-operator">=</span>( <span class="org-operator">-</span>1.5, 1.5))

 <span class="org-comment-delimiter"># </span> <span class="org-comment">Initialize the line plot
</span> <span class="org-variable-name">line</span>,  <span class="org-operator">=</span> ax.plot([], [], marker <span class="org-operator">=</span> <span class="org-string">'.'</span>, linewidth <span class="org-operator">=</span>1.0)
 <span class="org-comment-delimiter"># </span> <span class="org-comment">Initialize the particle positions
</span> <span class="org-variable-name">positions</span>  <span class="org-operator">=</span> result

 <span class="org-comment-delimiter"># </span> <span class="org-comment">Function to update the line plot
</span> <span class="org-keyword">def</span>  <span class="org-function-name">update</span>(frame):
     <span class="org-comment-delimiter"># </span> <span class="org-comment">Update the line plot data
</span>    line.set_data(positions[:frame <span class="org-operator">+</span>1, 0], positions[:frame <span class="org-operator">+</span>1, 1])
     <span class="org-keyword">return</span> line,

 <span class="org-comment-delimiter"># </span> <span class="org-comment">Create the FuncAnimation
</span> <span class="org-variable-name">animation</span>  <span class="org-operator">=</span> FuncAnimation(fig, update, frames <span class="org-operator">=</span> <span class="org-builtin">len</span>(positions),
                          interval <span class="org-operator">=</span>100, repeat <span class="org-operator">=</span> <span class="org-constant">False</span>)

 <span class="org-comment-delimiter"># </span> <span class="org-comment">Save using ffmpeg
</span>animation.save( <span class="org-string">"mog-langevin-dynamics.mp4"</span>, writer <span class="org-operator">=</span> <span class="org-string">"ffmpeg"</span>, dpi <span class="org-operator">=</span>200)
</pre>
</div>

 <video controls="controls" id="org5e12ef8"> <source src="../assets/videos/diffusion_models/mog-langevin-dynamics.webm" type="video/webm"></source> <p>
Your browser does not support the video tag.
</p>
</video></div>
</div>
</div>
 <div id="outline-container-learning-the-score-function" class="outline-3">
 <h3 id="learning-the-score-function"> <a href="#learning-the-score-function">Learning the score function</a></h3>
 <div class="outline-text-3" id="text-learning-the-score-function">
 <p>
The reason we didn't have to learn the score function in  <a href="#tractable-mixture-models">Tractable mixture
models</a> was because we restricted ourselves to a distribution with a tractable
score function. In reality this is seldom the case, and even if we could do it
in theory, it may be too computationally expensive to do it directly.
Additionally, If we have a set of points which we interpret as an empirical
distribution then the score function is not even well-defined as there is no
density. We have to resort to learning it in some way.
</p>

 <p>
First, we will use the MNIST dataset where we view an image as a discrete distribution
by normalizing the pixel intensities over the total intensity of all the pixels
in the image. Since each pixel is a value between 0 and 1, we can view this as a
distribution over pixel coordinates. To make this point clear, we assume that
the underlying space is a 2d cartesian square, \(\mathcal{X} = [0, 1]^{2}\) <sup> <a id="fnr.3" class="footref" href="#fn.3" role="doc-backlink">3</a></sup>,
with each pixel coordinate being normalized to be between 0 and 1. So, an image
is a collection of coordinate pairs and pixel intensity values, in the case of
MNIST which is \(28 \times 28\) we have pixel coordinates \((i, j)\) where \(i,
j \in \{(l + 1/2) / 28\}_{l=0}^{27}\) and the corresponding pixel intensities
\(I(i, j) \in [0, 1]\). With this we have an empirical distribution where
\(\hat{p}(i, j) = I(i, j) / \sum_{i', j'}I(i', j') \propto I(i, j)\).
</p>

 <p>
Behold, the first image to the MNIST training dataset!
</p>
 <div class="org-src-container">
 <pre class="src src-python"> <span class="org-comment-delimiter"># </span> <span class="org-comment">Get an image from mnist
</span> <span class="org-keyword">import</span> torchvision

 <span class="org-variable-name">mnist</span>  <span class="org-operator">=</span> torchvision.datasets.MNIST( <span class="org-string">"~/data"</span>, download <span class="org-operator">=</span> <span class="org-constant">True</span>)
 <span class="org-variable-name">mnist_images</span>  <span class="org-operator">=</span> mnist.data.numpy()
 <span class="org-variable-name">image</span>  <span class="org-operator">=</span> mnist_images[0]
 <span class="org-variable-name">image</span>  <span class="org-operator">=</span> image.astype( <span class="org-builtin">float</span>)  <span class="org-operator">/</span> 255.0

 <span class="org-keyword">def</span>  <span class="org-function-name">create_sample_fn</span>(image):
     <span class="org-doc">"""Generate a function that samples from the image distribution"""</span>
     <span class="org-keyword">def</span>  <span class="org-function-name">sample</span>(num_samples, key):
         <span class="org-variable-name">h</span>,  <span class="org-variable-name">w</span>  <span class="org-operator">=</span> image.shape
         <span class="org-comment-delimiter"># </span> <span class="org-comment">Note that random.categorical takes as inputs logits which is why we do not have to normalize
</span>         <span class="org-keyword">return</span> jnp.array(
            [ <span class="org-builtin">divmod</span>(x.item(), w)  <span class="org-keyword">for</span> x  <span class="org-keyword">in</span> random.categorical(logits <span class="org-operator">=</span>jnp.log(image.ravel()), key <span class="org-operator">=</span>key, shape <span class="org-operator">=</span>(num_samples,))]
        )
     <span class="org-keyword">return</span> sample

 <span class="org-variable-name">fig</span>,  <span class="org-variable-name">ax</span>  <span class="org-operator">=</span> plt.subplots()
 <span class="org-variable-name">im</span>  <span class="org-operator">=</span> ax.imshow(image, cmap <span class="org-operator">=</span> <span class="org-string">"gray"</span>)
plt.colorbar(im)
ax.axis( <span class="org-string">"off"</span>)
</pre>
</div>


 <figure id="org5959300"> <img src="../assets/images/diffusion_models/5-mnist-imshow.webp" alt="5-mnist-imshow.webp"></img></figure> <p>
Let's check the histogram when we sample many times according to the
distribution defined by the image, we should get something similar as the sample
size becomes large. The histogram function will rotate the image though.
</p>
 <div class="org-src-container">
 <pre class="src src-python"> <span class="org-variable-name">key</span>  <span class="org-operator">=</span> jax.random.PRNGKey( <span class="org-builtin">sum</span>( <span class="org-builtin">ord</span>(c)  <span class="org-keyword">for</span> c  <span class="org-keyword">in</span>  <span class="org-string">"five"</span>))
 <span class="org-variable-name">sample</span>  <span class="org-operator">=</span> create_sample_fn(image)
 <span class="org-variable-name">x</span>  <span class="org-operator">=</span> sample(10000, key)
 <span class="org-variable-name">fig</span>,  <span class="org-variable-name">ax</span>  <span class="org-operator">=</span> plt.subplots(figsize <span class="org-operator">=</span>(4, 4))
 <span class="org-variable-name">h</span>  <span class="org-operator">=</span> ax.hist2d(x[:, 0], x[:, 1], cmap <span class="org-operator">=</span> <span class="org-string">"gray"</span>)
ax.axis( <span class="org-string">"off"</span>)
</pre>
</div>


 <figure id="org437c1ea"> <img src="../assets/images/diffusion_models/5-mnist-hist2d.webp" alt="5-mnist-hist2d.webp"></img></figure> <p>
Now we set up the training. The architecture here is a combination of things
</p>
 <ol class="org-ol"> <li>We use the insights from ( <a href="#citeproc_bib_item_2">Tancik et al. 2020</a>) which roughly
says that using a pre-processing fourier feature map before the MLP is
helpful for learning high-frequency mappings for coordinate based inputs. We
add a residual connection here, so that the input to the MLP is
 <kbd>jnp.concatenate(f_layer(x), x)</kbd>.</li>
 <li>The  <kbd>RFLayer</kbd> has noise-level specific parameters  <kbd>alpha, beta</kbd> which
linearly transform the random features and we learn one such transformation
for each noise level (the rest of the architecture is shared, like the MLP
and the original random feature mappings).</li>
 <li>We freeze the random feature parameters  <kbd>B_cos, B_sin</kbd> by following the  <a href="https://docs.kidger.site/equinox/examples/frozen_layer/">guide
on how to freeze layer in the equinox documents</a>.</li>
 <li>The rest of the training is done using the objective \(\hat{\mathcal{E}}(\theta; (\sigma_{l})_{l=1}^{L})\).</li>
</ol> <div class="org-src-container">
 <pre class="src src-python"> <span class="org-keyword">class</span>  <span class="org-type">RFLayer</span>(eqx.Module):
     <span class="org-doc">"""Random Feature layer with learnable linear output transformations alpha, beta"""</span>
    B_cos: jax.Array
    B_sin: jax.Array
    alpha: jax.Array
    beta: jax.Array
    num_noise_levels:  <span class="org-builtin">int</span>
    sigma:  <span class="org-builtin">float</span>

     <span class="org-keyword">def</span>  <span class="org-function-name">__init__</span>( <span class="org-keyword">self</span>, in_size:  <span class="org-builtin">int</span>, num_rf:  <span class="org-builtin">int</span>, num_noise_levels:  <span class="org-builtin">int</span>, key, sigma:  <span class="org-builtin">float</span>  <span class="org-operator">=</span> 1.0):
         <span class="org-variable-name">cos_key</span>,  <span class="org-variable-name">sin_key</span>  <span class="org-operator">=</span> random.split(key, 2)
         <span class="org-keyword">self</span>. <span class="org-variable-name">B_cos</span>  <span class="org-operator">=</span> random.normal(cos_key, (num_rf, in_size))  <span class="org-operator">*</span> sigma
         <span class="org-keyword">self</span>. <span class="org-variable-name">B_sin</span>  <span class="org-operator">=</span> random.normal(sin_key, (num_rf, in_size))  <span class="org-operator">*</span> sigma
         <span class="org-keyword">self</span>. <span class="org-variable-name">sigma</span>  <span class="org-operator">=</span> sigma
         <span class="org-keyword">self</span>. <span class="org-variable-name">num_noise_levels</span>  <span class="org-operator">=</span> num_noise_levels
         <span class="org-keyword">self</span>. <span class="org-variable-name">alpha</span>  <span class="org-operator">=</span> jnp.ones(num_noise_levels)
         <span class="org-keyword">self</span>. <span class="org-variable-name">beta</span>  <span class="org-operator">=</span> jnp.zeros(num_noise_levels)

     <span class="org-keyword">def</span>  <span class="org-function-name">__call__</span>( <span class="org-keyword">self</span>, x: jax.Array, noise_level_idx:  <span class="org-builtin">int</span>)  <span class="org-operator">-></span> jax.Array:
         <span class="org-variable-name">rf_features</span>  <span class="org-operator">=</span> jnp.concatenate(
            [jnp.cos(2  <span class="org-operator">*</span> math.pi  <span class="org-operator">*</span>  <span class="org-keyword">self</span>.B_cos @ x), jnp.sin(2  <span class="org-operator">*</span> math.pi  <span class="org-operator">*</span>  <span class="org-keyword">self</span>.B_sin @ x)], axis <span class="org-operator">=-</span>1
        )
         <span class="org-keyword">return</span>  <span class="org-keyword">self</span>.alpha[noise_level_idx]  <span class="org-operator">*</span> rf_features  <span class="org-operator">+</span>  <span class="org-keyword">self</span>.beta[noise_level_idx]

 <span class="org-keyword">class</span>  <span class="org-type">Model</span>(eqx.Module):
    rf_layer: RFLayer
    mlp: eqx.nn.MLP

     <span class="org-keyword">def</span>  <span class="org-function-name">__init__</span>( <span class="org-keyword">self</span>, in_size:  <span class="org-builtin">int</span>, num_rf:  <span class="org-builtin">int</span>, width_size:  <span class="org-builtin">int</span>, depth:  <span class="org-builtin">int</span>, out_size:  <span class="org-builtin">int</span>, num_noise_levels:  <span class="org-builtin">int</span>, key):
         <span class="org-keyword">self</span>. <span class="org-variable-name">rf_layer</span>  <span class="org-operator">=</span> RFLayer(in_size, num_rf, num_noise_levels, key)
         <span class="org-keyword">self</span>. <span class="org-variable-name">mlp</span>  <span class="org-operator">=</span> eqx.nn.MLP(in_size <span class="org-operator">=</span>num_rf  <span class="org-operator">*</span> 2  <span class="org-operator">+</span> 2,
                              width_size <span class="org-operator">=</span>width_size,
                              depth <span class="org-operator">=</span>depth,
                              out_size <span class="org-operator">=</span>out_size,
                              activation <span class="org-operator">=</span>jax.nn.softplus,
                              key <span class="org-operator">=</span>key)

     <span class="org-keyword">def</span>  <span class="org-function-name">__call__</span>( <span class="org-keyword">self</span>, x: jax.Array, noise_level_idx:  <span class="org-builtin">int</span>)  <span class="org-operator">-></span> jax.Array:
         <span class="org-variable-name">x</span>  <span class="org-operator">-=</span> 0.5
         <span class="org-variable-name">x</span>  <span class="org-operator">=</span> jnp.concatenate(( <span class="org-keyword">self</span>.rf_layer(x, noise_level_idx), x))  <span class="org-comment-delimiter">#</span> <span class="org-comment">Residual connection
</span>         <span class="org-keyword">return</span>  <span class="org-keyword">self</span>.mlp(x)

 <span class="org-comment-delimiter"># </span> <span class="org-comment">Define the objective function
</span> <span class="org-keyword">def</span>  <span class="org-function-name">one_sample_loss</span>(model, x, sigmas, key):
     <span class="org-comment-delimiter"># </span> <span class="org-comment">Sample one gaussian for each noise level
</span>     <span class="org-variable-name">perturbations</span>  <span class="org-operator">=</span> random.normal(key, (sigmas.shape[0], x.shape[0]))  <span class="org-operator">*</span> jnp.expand_dims(sigmas, 1)
     <span class="org-variable-name">x_bars</span>  <span class="org-operator">=</span> x  <span class="org-operator">+</span> perturbations
     <span class="org-comment-delimiter"># </span> <span class="org-comment">Predict over all noise levels
</span>     <span class="org-variable-name">scores_pred</span>  <span class="org-operator">=</span> vmap(model)(x_bars, jnp.arange(sigmas.shape[0]))
     <span class="org-variable-name">scores</span>  <span class="org-operator">=</span>  <span class="org-operator">-</span>(x_bars  <span class="org-operator">-</span> x)  <span class="org-operator">/</span> jnp.expand_dims(sigmas  <span class="org-operator">**</span> 2, 1)
     <span class="org-comment-delimiter"># </span> <span class="org-comment">Vectorized version of (x_bar[i] - x) / sigma[i] ** 2
</span>     <span class="org-comment-delimiter"># </span> <span class="org-comment">mean(sigmas[i]**2 * mse(score_pred[i], scores[i]) for i in range(len(sigmas))))
</span>     <span class="org-variable-name">result</span>  <span class="org-operator">=</span> jnp.mean(jnp.square(scores_pred  <span class="org-operator">-</span> scores).mean( <span class="org-operator">-</span>1)  <span class="org-operator">*</span> sigmas  <span class="org-operator">**</span> 2)
     <span class="org-keyword">return</span> result

 <span class="org-keyword">def</span>  <span class="org-function-name">loss</span>(diff_model, static_model, xs, sigmas, keys):
     <span class="org-doc">"""Objective function, we separeate the parameters into active and frozen parameters"""</span>
     <span class="org-variable-name">model</span>  <span class="org-operator">=</span> eqx.combine(static_model, diff_model)
     <span class="org-variable-name">batch_loss</span>  <span class="org-operator">=</span> vmap(one_sample_loss, ( <span class="org-constant">None</span>, 0,  <span class="org-constant">None</span>, 0))
     <span class="org-keyword">return</span> jnp.mean(batch_loss(model, xs, sigmas, keys))

 <span class="org-keyword">def</span>  <span class="org-function-name">train</span>(
        model: eqx.Module,
        filter_spec: PyTree,
        sample,
        optim: optax.GradientTransformation,
        steps:  <span class="org-builtin">int</span>,
        batch_size:  <span class="org-builtin">int</span>,
        print_every:  <span class="org-builtin">int</span>,
        sigmas: Float[Array,  <span class="org-string">"..."</span>],
        key
)  <span class="org-operator">-></span> eqx.Module:
     <span class="org-type">@eqx.filter_jit</span>
     <span class="org-keyword">def</span>  <span class="org-function-name">make_step</span>(
            model: eqx.Module,
            xs: Float[Array,  <span class="org-string">"batch_size 2"</span>],
            opt_state: PyTree,
            keys: Float[Array,  <span class="org-string">"batch_size"</span>],
    ):
         <span class="org-variable-name">diff_model</span>,  <span class="org-variable-name">static_model</span>  <span class="org-operator">=</span> eqx.partition(model, filter_spec)
         <span class="org-variable-name">loss_value</span>,  <span class="org-variable-name">grads</span>  <span class="org-operator">=</span> eqx.filter_value_and_grad(loss)(diff_model, static_model, xs, sigmas, keys)
         <span class="org-variable-name">updates</span>,  <span class="org-variable-name">opt_state</span>  <span class="org-operator">=</span> optim.update(grads, opt_state)
         <span class="org-variable-name">model</span>  <span class="org-operator">=</span> eqx.apply_updates(model, updates)
         <span class="org-keyword">return</span> model, opt_state, loss_value

     <span class="org-variable-name">original_model</span>  <span class="org-operator">=</span> model
     <span class="org-variable-name">opt_state</span>  <span class="org-operator">=</span> optim.init(eqx. <span class="org-builtin">filter</span>(model, eqx.is_inexact_array))
     <span class="org-keyword">for</span> step  <span class="org-keyword">in</span>  <span class="org-builtin">range</span>(steps):
         <span class="org-operator">*</span> <span class="org-variable-name">loss_keys</span>,  <span class="org-variable-name">sample_key</span>,  <span class="org-variable-name">key</span>  <span class="org-operator">=</span> random.split(key, batch_size  <span class="org-operator">+</span> 2)
         <span class="org-variable-name">loss_keys</span>  <span class="org-operator">=</span> jnp.stack(loss_keys)
         <span class="org-variable-name">xs</span>  <span class="org-operator">=</span> sample(batch_size, sample_key)  <span class="org-operator">/</span> 27
         <span class="org-variable-name">model</span>,  <span class="org-variable-name">opt_state</span>,  <span class="org-variable-name">loss_value</span>  <span class="org-operator">=</span> make_step(model, xs, opt_state, loss_keys)
         <span class="org-keyword">if</span> step  <span class="org-operator">%</span> print_every  <span class="org-operator">==</span> 0:
             <span class="org-builtin">print</span>(f <span class="org-string">"Step </span>{step} <span class="org-string">, Loss </span>{loss_value} <span class="org-string">"</span>)

     <span class="org-keyword">return</span> model
</pre>
</div>

 <p>
Now let's train it. We squint and choose some good hyperparameters and pray to the ML-gods for an auspicious training run (actually I did some hand-tuning).
</p>
 <div class="org-src-container">
 <pre class="src src-python"> <span class="org-variable-name">sigmas</span>  <span class="org-operator">=</span> jnp.geomspace(0.0001, 1, 30, endpoint <span class="org-operator">=</span> <span class="org-constant">True</span>)
 <span class="org-variable-name">DEPTH</span>  <span class="org-operator">=</span> 3
 <span class="org-variable-name">WIDTH_SIZE</span>  <span class="org-operator">=</span> 128
 <span class="org-variable-name">NUM_RF</span>  <span class="org-operator">=</span> 256
 <span class="org-variable-name">BATCH_SIZE</span>  <span class="org-operator">=</span> 128
 <span class="org-variable-name">STEPS</span>  <span class="org-operator">=</span> 5  <span class="org-operator">*</span> 10  <span class="org-operator">**</span> 4
 <span class="org-variable-name">PRINT_EVERY</span>  <span class="org-operator">=</span> 5000

 <span class="org-variable-name">model</span>  <span class="org-operator">=</span> Model(in_size <span class="org-operator">=</span>2,
              num_rf <span class="org-operator">=</span>NUM_RF,
              width_size <span class="org-operator">=</span>WIDTH_SIZE,
              depth <span class="org-operator">=</span>DEPTH,
              out_size <span class="org-operator">=</span>2,
              num_noise_levels <span class="org-operator">=</span> <span class="org-builtin">len</span>(sigmas),
              key <span class="org-operator">=</span>random.PRNGKey(0))

 <span class="org-variable-name">LEARNING_RATE</span>  <span class="org-operator">=</span> 1e <span class="org-operator">-</span>3
 <span class="org-variable-name">optim</span>  <span class="org-operator">=</span> optax.adam(LEARNING_RATE)

 <span class="org-comment-delimiter"># </span> <span class="org-comment">The filter spec is a pytree of the same shape as the parameters
</span> <span class="org-comment-delimiter"># </span> <span class="org-comment">True and False represent whether this part of the pytree will be updated
</span> <span class="org-comment-delimiter"># </span> <span class="org-comment">using the optimizer by splitting the parameters into diff_model and static_model
</span> <span class="org-variable-name">filter_spec</span>  <span class="org-operator">=</span> jtu.tree_map( <span class="org-keyword">lambda</span> x:  <span class="org-constant">True</span>  <span class="org-keyword">if</span>  <span class="org-builtin">isinstance</span>(x, jax.Array)  <span class="org-keyword">else</span>  <span class="org-constant">False</span>, model)
 <span class="org-variable-name">filter_spec</span>  <span class="org-operator">=</span> eqx.tree_at(
     <span class="org-keyword">lambda</span> tree: (tree.rf_layer.B_cos, tree.rf_layer.B_sin),
    filter_spec,
    replace <span class="org-operator">=</span>( <span class="org-constant">False</span>,  <span class="org-constant">False</span>),
)
 <span class="org-variable-name">model</span>  <span class="org-operator">=</span> train(model, filter_spec, sample, optim, STEPS, BATCH_SIZE, PRINT_EVERY, sigmas, key)
</pre>
</div>

 <pre class="example" id="orge2599fb">
Step 0, Loss 1.0177688598632812
Step 5000, Loss 0.7470076084136963
Step 10000, Loss 0.6700457334518433
Step 15000, Loss 0.6010410785675049
Step 20000, Loss 0.5470178127288818
Step 25000, Loss 0.5063308477401733
Step 30000, Loss 0.47549256682395935
Step 35000, Loss 0.4591177701950073
Step 40000, Loss 0.4523712992668152
Step 45000, Loss 0.43943890929222107
</pre>

 <p>
Let's visualize the vector field for this new model by repurposing the  <a href="#orgb317737"> <kbd>plot_logdistribution</kbd> function</a> to just plot the vector field. Since we don't have an actual density we will not plot the level
</p>
 <div class="org-src-container">
 <pre class="src src-python"> <span class="org-keyword">def</span>  <span class="org-function-name">plot_vector_field</span>(fig, ax, score_fun, xlim <span class="org-operator">=</span>(0.0, 1.0), ylim <span class="org-operator">=</span>(0.0, 1.0), n_quiver <span class="org-operator">=</span>10):
     <span class="org-comment-delimiter"># </span> <span class="org-comment">Compute the gradients
</span>     <span class="org-variable-name">x</span>  <span class="org-operator">=</span> np.linspace( <span class="org-operator">*</span>xlim, n_quiver)
     <span class="org-variable-name">y</span>  <span class="org-operator">=</span> np.linspace( <span class="org-operator">*</span>ylim, n_quiver)
     <span class="org-variable-name">X</span>,  <span class="org-variable-name">Y</span>  <span class="org-operator">=</span> np.meshgrid(x, y)
     <span class="org-variable-name">XY</span>  <span class="org-operator">=</span> np.stack([X.ravel(), Y.ravel()], axis <span class="org-operator">=-</span>1)
     <span class="org-variable-name">grads</span>  <span class="org-operator">=</span> vmap(score_fun)(XY)
     <span class="org-variable-name">grad_X</span>  <span class="org-operator">=</span> grads[:, 0].reshape(n_quiver, n_quiver)
     <span class="org-variable-name">grad_Y</span>  <span class="org-operator">=</span> grads[:, 1].reshape(n_quiver, n_quiver)
    ax.quiver(X, Y, grad_X, grad_Y)
     <span class="org-keyword">return</span> fig, ax

 <span class="org-variable-name">fig</span>,  <span class="org-variable-name">ax</span>  <span class="org-operator">=</span> plt.subplots(3, 3, figsize <span class="org-operator">=</span>(3  <span class="org-operator">*</span> 2, 3  <span class="org-operator">*</span> 2))
 <span class="org-keyword">for</span> axis, i  <span class="org-keyword">in</span>  <span class="org-builtin">zip</span>(ax.ravel(),  <span class="org-builtin">range</span>(0, 30, 3)):
    axis.axis( <span class="org-string">'off'</span>)
    axis.set_aspect( <span class="org-string">'equal'</span>)
    axis.set_title(f <span class="org-string">"noise level </span>{i} <span class="org-string">: </span>{sigmas[i]:.2f} <span class="org-string">"</span>)
    plot_vector_field(fig, axis, functools.partial(model, noise_level_idx <span class="org-operator">=</span>i), n_quiver <span class="org-operator">=</span>15)

plt.tight_layout()
</pre>
</div>


 <figure id="org964b1f9"> <img src="../assets/images/diffusion_models/5-mnist-vector-fields.webp" alt="5-mnist-vector-fields.webp"></img> <figcaption> <span class="figure-number">Figure 3: </span>The score function for all noise levels used to train the model. It seems like the sweetspot is around 18 (I choose 17 after inspecting all noise levels manually).</figcaption></figure> <p>
Let's see what we have learned. We define the update step (return tuple due to using  <kbd>lax.scan</kbd> later)
</p>
 <div class="org-src-container">
 <pre class="src src-python"> <span class="org-type">@eqx.filter_jit</span>
 <span class="org-keyword">def</span>  <span class="org-function-name">update_x</span>(x, z, model, step_size):
     <span class="org-variable-name">g</span>  <span class="org-operator">=</span> model(x)
     <span class="org-variable-name">xp1</span>  <span class="org-operator">=</span> x  <span class="org-operator">+</span> (step_size  <span class="org-operator">/</span> 2)  <span class="org-operator">*</span> g  <span class="org-operator">+</span> jnp.sqrt(step_size)  <span class="org-operator">*</span> z
     <span class="org-keyword">return</span> xp1, xp1
</pre>
</div>

 <p>
and evolve a particle over many steps,  <kbd>lax.scan</kbd> simply makes this efficient
</p>
 <div class="org-src-container">
 <pre class="src src-python"> <span class="org-variable-name">step_size</span>  <span class="org-operator">=</span> 0.001
 <span class="org-variable-name">num_steps</span>  <span class="org-operator">=</span> 400_000
 <span class="org-variable-name">key</span>  <span class="org-operator">=</span> random.PRNGKey(0)
 <span class="org-variable-name">z_key</span>,  <span class="org-variable-name">x0_key</span>,  <span class="org-variable-name">key</span>  <span class="org-operator">=</span> random.split(key, 3)
 <span class="org-variable-name">z</span>  <span class="org-operator">=</span> random.normal(z_key, shape <span class="org-operator">=</span>(num_steps, 2))
 <span class="org-variable-name">x0</span>  <span class="org-operator">=</span> jnp.ones(2,)  <span class="org-operator">*</span> 0.5
 <span class="org-variable-name">score_model</span>  <span class="org-operator">=</span> functools.partial(model, noise_level_idx <span class="org-operator">=</span>17)
 <span class="org-variable-name">update_fun</span>  <span class="org-operator">=</span> functools.partial(update_x, model <span class="org-operator">=</span>score_model, step_size <span class="org-operator">=</span>step_size)
 <span class="org-variable-name">final</span>,  <span class="org-variable-name">result</span>  <span class="org-operator">=</span> lax.scan(update_fun, x0, z)
</pre>
</div>

 <p>
Let's look at this. Since we sample so many particles, let's just plot a 2d histogram of this (choosing  <kbd>noise_level_idx</kbd> being 17 but other indices in the
vicinity should work too). Note that this has a finer resolution than the original mnist images which are \(27 \times 27\)
</p>
 <div class="org-src-container">
 <pre class="src src-python"> <span class="org-variable-name">fig</span>,  <span class="org-variable-name">ax</span>  <span class="org-operator">=</span> plt.subplots(figsize <span class="org-operator">=</span>(6, 6))
 <span class="org-variable-name">h</span>  <span class="org-operator">=</span> ax.hist2d(result[:, 0], result[:, 1], cmap <span class="org-operator">=</span> <span class="org-string">"gray"</span>, bins <span class="org-operator">=</span>(50, 50))
ax.axis( <span class="org-string">"off"</span>)
</pre>
</div>


 <figure id="orgc072c4e"> <img src="../assets/images/diffusion_models/5-mnist-from-samples.webp" alt="5-mnist-from-samples.webp"></img></figure></div>
</div>
</div>
 <div id="outline-container-conclusion" class="outline-2">
 <h2 id="conclusion"> <a href="#conclusion">Conclusion</a></h2>
 <div class="outline-text-2" id="text-conclusion">
 <p>
This was a great way to learn jax and how diffusion works. Looking back I think
it may be overkill to do this on images as distributions as I did above,
learning the distribution directly may be better in this case and faster in this
case. I like the fact that the score model generalizes from the grid points
\((i/27, j/27)_{i, j}^{27}\) to any tuple of points \((i, j)_{i, j \in [0, 1]}\)
which is pretty cool and makes me wonder if you can use this to create a way to
combine images of different resolutions as long as the aspect ratio is the same.
</p>
</div>
</div>
 <div id="outline-container-reference" class="outline-2">
 <h2 id="reference"> <a href="#reference">Reference</a></h2>
 <div class="outline-text-2" id="text-reference">
 <style>.csl-entry{text-indent: -1.5em; margin-left: 1.5em;}</style> <div class="csl-bib-body">
   <div class="csl-entry"> <a id="citeproc_bib_item_1"></a>Song, Yang, and Stefano Ermon. 2020. “Generative Modeling by Estimating Gradients of the Data Distribution.” arXiv.  <a href="http://arxiv.org/abs/1907.05600">http://arxiv.org/abs/1907.05600</a>.</div>
   <div class="csl-entry"> <a id="citeproc_bib_item_2"></a>Tancik, Matthew, Pratul P. Srinivasan, Ben Mildenhall, Sara Fridovich-Keil, Nithin Raghavan, Utkarsh Singhal, Ravi Ramamoorthi, Jonathan T. Barron, and Ren Ng. 2020. “Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains.” arXiv.  <a href="http://arxiv.org/abs/2006.10739">http://arxiv.org/abs/2006.10739</a>.</div>
</div>
</div>
</div>
 <div id="footnotes" class="Footnotes">
 <div id="text-footnotes">

 <div class="footdef"> <sup> <a id="fn.1" class="footnum" href="#fnr.1" role="doc-backlink">1</a></sup> <div class="footpara" role="doc-footnote"> <p class="footpara">
Feels like there should be some way of looking at this through a
regularization lense where \(\sigma\) takes the role as the regularization
strength in traditional supervised learning such as Ridge Regression.
</p></div></div>

 <div class="footdef"> <sup> <a id="fn.2" class="footnum" href="#fnr.2" role="doc-backlink">2</a></sup> <div class="footpara" role="doc-footnote"> <p class="footpara">
Of which I know  <em>very little</em>.
</p></div></div>

 <div class="footdef"> <sup> <a id="fn.3" class="footnum" href="#fnr.3" role="doc-backlink">3</a></sup> <div class="footpara" role="doc-footnote"> <p class="footpara">
Although when we convolve the inputs with Gaussians we will have that any point in \(\mathbb{R}^{2}\) will have positive probability, albeit maybe very small.
</p></div></div>


</div>
</div></main>]]></content>
  <link href="https://isakfalk.com/notes/diffusion-models.html"/>
  <id>https://isakfalk.com/notes/diffusion-models.html</id>
  <updated>2025-12-28T21:51:00-05:00</updated>
</entry>
<entry>
  <title>Notes</title>
  <author><name>Isak Falk</name></author>
  <content type="html"><![CDATA[<main id="content" class="content"> <header> <h1 class="title">Notes</h1>
</header> <ul class="org-ul"> <li> <a href="RC-retrospect.html">RC Retrospect</a>  <span class="timestamp-wrapper"> <span class="timestamp">[2024-02-22 Thu]</span></span></li>
 <li> <a href="diffusion-models.html">Diffusion and score-based generative modeling</a>  <span class="timestamp-wrapper"> <span class="timestamp">[2024-02-05 Mon]</span></span></li>
 <li> <a href="intro-to-jax.html">Intro to Jax</a>  <span class="timestamp-wrapper"> <span class="timestamp">[2024-01-22 Mon]</span></span></li>
 <li> <a href="RC-halfbatch.html">RC Half-batch Retrospect</a>  <span class="timestamp-wrapper"> <span class="timestamp">[2024-01-01 Mon]</span></span></li>
 <li> <a href="NeurIPS2023retrospect.html">NeurIPS 2023 Retrospect</a>  <span class="timestamp-wrapper"> <span class="timestamp">[2023-12-20 Wed]</span></span></li>
 <li> <a href="RC-week-4.html">RC: Week 4</a>  <span class="timestamp-wrapper"> <span class="timestamp">[2023-11-27 Mon]</span></span></li>
 <li> <a href="RC-week-3.html">RC: Week 3</a>  <span class="timestamp-wrapper"> <span class="timestamp">[2023-11-20 Mon]</span></span></li>
 <li> <a href="RC-week-2.html">RC: Week 2</a>  <span class="timestamp-wrapper"> <span class="timestamp">[2023-11-15 Wed]</span></span></li>
 <li> <a href="RC-week-1.html">RC: Week 1</a>  <span class="timestamp-wrapper"> <span class="timestamp">[2023-11-14 Tue]</span></span></li>
 <li> <a href="note-for-showcasing-design.html">Notes for showcasing design</a>  <span class="timestamp-wrapper"> <span class="timestamp">[2023-11-10 Fri]</span></span></li>
 <li> <a href="building-this-website.html">Building this website</a>  <span class="timestamp-wrapper"> <span class="timestamp">[2023-11-09 Thu]</span></span></li>
</ul></main>]]></content>
  <link href="https://isakfalk.com/notes/index.html"/>
  <id>https://isakfalk.com/notes/index.html</id>
  <updated>2025-12-28T21:51:00-05:00</updated>
</entry>
<entry>
  <title>Intro to Jax</title>
  <author><name>Isak Falk</name></author>
  <content type="html"><![CDATA[<main id="content" class="content"> <header> <h1 class="title">Intro to Jax</h1>
</header> <nav id="table-of-contents" role="doc-toc"> <details> <summary>Table of Contents</summary> <div id="text-table-of-contents" role="doc-toc">
 <ul> <li> <a href="#imports">Imports</a></li>
 <li> <a href="#introduction">Introduction</a></li>
 <li> <a href="#backend-xla">Backend: XLA</a></li>
 <li> <a href="#what-is-jax">What is jax?</a></li>
 <li> <a href="#jax-primitives-jit-grad-vmap">Jax primitives:  <kbd>jit</kbd>,  <kbd>grad</kbd>,  <kbd>vmap</kbd></a>
 <ul> <li> <a href="#jit"> <kbd>jit</kbd></a></li>
 <li> <a href="#grad"> <kbd>grad</kbd></a></li>
 <li> <a href="#vmap"> <kbd>vmap</kbd></a></li>
 <li> <a href="#composition">Composition</a></li>
</ul></li>
 <li> <a href="#building-a-neural-network-from-scratch">Building a neural network from scratch</a>
 <ul> <li> <a href="#creating-the-neural-network">Creating the neural network</a></li>
 <li> <a href="#training">Training</a></li>
</ul></li>
</ul></div>
</details></nav> <div id="outline-container-imports" class="outline-2">
 <h2 id="imports"> <a href="#imports">Imports</a></h2>
 <div class="outline-text-2" id="text-imports">
 <p>
Some of the libraries we will use throughout this post are imported below.
</p>

 <div class="org-src-container">
 <pre class="src src-python"> <span class="org-keyword">import</span> time

 <span class="org-keyword">import</span> numpy  <span class="org-keyword">as</span> np
 <span class="org-keyword">import</span> matplotlib.pyplot  <span class="org-keyword">as</span> plt
 <span class="org-keyword">import</span> matplotlib  <span class="org-keyword">as</span> mpl
 <span class="org-keyword">import</span> seaborn  <span class="org-keyword">as</span> sns
</pre>
</div>
</div>
</div>
 <div id="outline-container-introduction" class="outline-2">
 <h2 id="introduction"> <a href="#introduction">Introduction</a></h2>
 <div class="outline-text-2" id="text-introduction">
 <p>
The  <a href="https://jax.readthedocs.io/en/latest/notebooks/quickstart.html">Jax Quickstart</a> tutorial states
</p>
 <blockquote>
 <p>
JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research.
</p>
</blockquote>
 <p>
What does this mean? And how does this differ from other deep learning libraries such as torch and tensorflow?
</p>

 <p>
As is standard we will import some jax libraries and functions
</p>
 <div class="org-src-container">
 <pre class="src src-python"> <span class="org-keyword">import</span> jax
 <span class="org-keyword">from</span> jax  <span class="org-keyword">import</span> jit, grad, vmap
 <span class="org-keyword">from</span> jax  <span class="org-keyword">import</span> random
 <span class="org-keyword">import</span> jax.numpy  <span class="org-keyword">as</span> jnp
 <span class="org-keyword">import</span> jax.scipy  <span class="org-keyword">as</span> jscp
</pre>
</div>
</div>
</div>
 <div id="outline-container-backend-xla" class="outline-2">
 <h2 id="backend-xla"> <a href="#backend-xla">Backend: XLA</a></h2>
 <div class="outline-text-2" id="text-backend-xla">
 <p>
Jax is basically a compiler for turning python
code and vector operations using the  <a href="https://www.tensorflow.org/xla/architecture">XLA compiler</a> to machine instructions for
different computer architectures. The standard computer architecture we use is the GPU, but there are others, for example
</p>
 <ul class="org-ul"> <li>CPUs</li>
 <li> <a href="https://en.wikipedia.org/wiki/Tensor_Processing_Unit">TPUs</a></li>
 <li> <a href="https://www.graphcore.ai/products/ipu">IPUs</a></li>
</ul> <p>
or other specially created hardware which accelerates operations or make them
more efficient in some way. The  <strong>point is that python is slow and XLA makes this
very fast using techniques such as fusing operations and removing redundant code
and operations</strong>. Personally, this feels like a pretty future-proof way of
decoupling how we specify what we want using e.g. python+jax vs how it is made
to run on hardware, here using XLA. It reminds me of how LSP has solved the
decoupling problem for code editing for editors <sup> <a id="fnr.1" class="footref" href="#fn.1" role="doc-backlink">1</a></sup>. There seem to be even more
specialized hardware being created for e.g. inference of LLMs ( <a href="https://www.positron.ai/">like this</a> which
is one of several LLM inference hardware companies I saw at  <a href="notes/NeurIPS2023retrospect.html">NeurIPS 2023</a>) so who
knows what funky architectures will become available in the future.
</p>
</div>
</div>
 <div id="outline-container-what-is-jax" class="outline-2">
 <h2 id="what-is-jax"> <a href="#what-is-jax">What is jax?</a></h2>
 <div class="outline-text-2" id="text-what-is-jax">
 <p>
Jax is a reimplementation of the older linear algebra and science
stack for python including  <kbd>numpy</kbd> and  <kbd>scipy</kbd>, with a just-in-time compiler and
ways to perform automatic differentiation. To really hammer this home, jax has
reimplemented a subset of both of these packages which seem pretty
feature-complete. The current state of this API can be found in  <a href="https://jax.readthedocs.io/en/latest/jax.html">the docs</a>.
</p>
</div>
</div>
 <div id="outline-container-jax-primitives-jit-grad-vmap" class="outline-2">
 <h2 id="jax-primitives-jit-grad-vmap"> <a href="#jax-primitives-jit-grad-vmap">Jax primitives:  <kbd>jit</kbd>,  <kbd>grad</kbd>,  <kbd>vmap</kbd></a></h2>
 <div class="outline-text-2" id="text-jax-primitives-jit-grad-vmap">
 <p>
There are 3 functions which are integral to almost any jax program.
</p>
</div>
 <div id="outline-container-jit" class="outline-3">
 <h3 id="jit"> <a href="#jit"> <kbd>jit</kbd></a></h3>
 <div class="outline-text-3" id="text-jit">
 <p>
The  <kbd>jit</kbd> function takes a large subset of python together with jax functions
and compile it down to XLA-kernels which are very fast. Below I've done a very
quick benchmark of how  <kbd>jit</kbd> speeds up matrix-matrix multiplication.
</p>

 <div class="org-src-container">
 <pre class="src src-python"> <span class="org-keyword">def</span>  <span class="org-function-name">jax_matmul</span>(A, B):
    A @ B

 <span class="org-variable-name">jit_jax_matmul</span>  <span class="org-operator">=</span> jit(jax_matmul)

 <span class="org-keyword">import</span> timeit
 <span class="org-variable-name">n</span>,  <span class="org-variable-name">p</span>,  <span class="org-variable-name">k</span>  <span class="org-operator">=</span> 10 <span class="org-operator">**</span>4, 10 <span class="org-operator">**</span>4, 10 <span class="org-operator">**</span>4
 <span class="org-variable-name">A</span>  <span class="org-operator">=</span> jnp.ones((n, p))
 <span class="org-variable-name">B</span>  <span class="org-operator">=</span> jnp.ones((p, k))
jit_jax_matmul(A, B)  <span class="org-comment-delimiter"># </span> <span class="org-comment">Trace the jit function once
</span> <span class="org-builtin">print</span>(f <span class="org-string">"jax: </span>{timeit.timeit( <span class="org-keyword">lambda</span>: jax_matmul(A, B).block_until_ready(), number <span class="org-operator">=</span>10)} <span class="org-string">"</span>)
 <span class="org-builtin">print</span>(f <span class="org-string">"jax (JIT): </span>{timeit.timeit( <span class="org-keyword">lambda</span>: jit_jax_matmul(A, B).block_until_ready(), number <span class="org-operator">=</span>10)} <span class="org-string">"</span>)
</pre>
</div>

 <pre class="example">
jax: 0.37372643800335936
jax (JIT): 0.0003170749987475574
</pre>


 <p>
which is about double the speed. The gains are much greater when we jit things
which does not have an already efficient implementation (such as a matmul).
Additionally, this allows us to speed things up which cannot be done without
considerable vectorization effort in numpy or may be outright impossible.
</p>
</div>
</div>
 <div id="outline-container-grad" class="outline-3">
 <h3 id="grad"> <a href="#grad"> <kbd>grad</kbd></a></h3>
 <div class="outline-text-3" id="text-grad">
 <p>
The  <kbd>grad</kbd> function takes as input a function \(f\) mapping to \(\mathbb{R}\)
and spits out the gradient of that function \(\nabla f\). This can be a very
natural way of working with gradients if you are used to the math.
</p>

 <div class="org-src-container">
 <pre class="src src-python"> <span class="org-keyword">def</span>  <span class="org-function-name">sum_of_squares</span>(x):
     <span class="org-keyword">return</span> jnp. <span class="org-builtin">sum</span>(x <span class="org-operator">**</span>2)

 <span class="org-variable-name">sum_of_squares_dx</span>  <span class="org-operator">=</span> grad(sum_of_squares)
</pre>
</div>

 <p>
The function  <kbd>sum_of_squares_dx</kbd> is the mathematical gradient of
 <kbd>sum_of_squares</kbd>. The randomness is handled explicitly by splitting the state
(key), read about it  <a href="https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html">here</a>.
</p>

 <div class="org-src-container">
 <pre class="src src-python"> <span class="org-variable-name">key</span>  <span class="org-operator">=</span> jax.random.PRNGKey(0)
 <span class="org-variable-name">key</span>,  <span class="org-variable-name">subkey</span>  <span class="org-operator">=</span> jax.random.split(key)
 <span class="org-variable-name">in_x</span>  <span class="org-operator">=</span> jax.random.normal(key, (3, 3))
 <span class="org-variable-name">dx</span>  <span class="org-operator">=</span> sum_of_squares_dx(in_x)
 <span class="org-builtin">print</span>(dx)
 <span class="org-builtin">print</span>(dx.shape)
</pre>
</div>

 <pre class="example">
[[-5.2211165   0.06770565  2.1726665 ]
 [-2.960598    3.0806496   2.125032  ]
 [ 1.0834967   0.0340456   0.544537  ]]
(3, 3)
</pre>
</div>
</div>
 <div id="outline-container-vmap" class="outline-3">
 <h3 id="vmap"> <a href="#vmap"> <kbd>vmap</kbd></a></h3>
 <div class="outline-text-3" id="text-vmap">
 <p>
The function  <kbd>vmap</kbd> allows you to lift a function to a batched function,
 <strong>without having to go through vectorization</strong>. For example, if we wanted to batch
the  <kbd>sum_of_squares</kbd> function we can do this by simply applying  <kbd>vmap</kbd>
</p>
 <div class="org-src-container">
 <pre class="src src-python"> <span class="org-variable-name">batched_sum_of_squares</span>  <span class="org-operator">=</span> vmap(sum_of_squares)
 <span class="org-variable-name">x</span>  <span class="org-operator">=</span> jax.random.normal(key, (5, 3, 3))
 <span class="org-builtin">print</span>(batched_sum_of_squares(x))
 <span class="org-builtin">print</span>(batched_sum_of_squares(x).shape)
</pre>
</div>

 <pre class="example">
[ 7.109205   7.1214614 21.167786   6.137778   4.915494 ]
(5,)
</pre>


 <p>
This is pretty powerful: often it's easy to specify the function for a sample
\(x\) but harder to vectorize. For a standard neural network it may be pretty
simple, but imagine something like LLMs, GANs or working with inputs which are
not points, e.g. sets. Additionally, we can use the  <kbd>in_axes</kbd> argument to batch
in according to different input arguments and ignore others.
</p>

 <div class="org-src-container">
 <pre class="src src-python"> <span class="org-keyword">def</span>  <span class="org-function-name">multi_matmul</span>(A, B, C):
     <span class="org-keyword">return</span> A @ B @ C

 <span class="org-comment-delimiter"># </span> <span class="org-comment">Batch according to first and third input argument, not second
</span> <span class="org-variable-name">vmap_multi_matmul</span>  <span class="org-operator">=</span> vmap(multi_matmul, in_axes <span class="org-operator">=</span>(0,  <span class="org-constant">None</span>, 0))

 <span class="org-variable-name">l</span>,  <span class="org-variable-name">n</span>,  <span class="org-variable-name">p</span>,  <span class="org-variable-name">d</span>,  <span class="org-variable-name">m</span>  <span class="org-operator">=</span> 3, 5, 7, 9, 11
 <span class="org-variable-name">A</span>  <span class="org-operator">=</span> jnp.ones((l, n, p))
 <span class="org-variable-name">B</span>  <span class="org-operator">=</span> jnp.ones((p, d))
 <span class="org-variable-name">C</span>  <span class="org-operator">=</span> jnp.ones((l, d, m))

 <span class="org-builtin">print</span>(vmap_multi_matmul(A, B, C).shape)  <span class="org-comment-delimiter"># </span> <span class="org-comment">l batches of (n, m) -> (l, n, m)</span>
</pre>
</div>

 <pre class="example">
(3, 5, 11)
</pre>
</div>
</div>
 <div id="outline-container-composition" class="outline-3">
 <h3 id="composition"> <a href="#composition">Composition</a></h3>
 <div class="outline-text-3" id="text-composition">
 <p>
You can compose all of these functions as you see fit
</p>
 <div class="org-src-container">
 <pre class="src src-python"> <span class="org-variable-name">jit_batched_sum_of_squares_dx</span>  <span class="org-operator">=</span> jit(vmap(grad(sum_of_squares)))
 <span class="org-builtin">print</span>(jit_batched_sum_of_squares_dx(x).shape)
</pre>
</div>

 <pre class="example">
(5, 3, 3)
</pre>


 <p>
This allows for utlizing the autodiff framework fully.
</p>
</div>
</div>
</div>
 <div id="outline-container-building-a-neural-network-from-scratch" class="outline-2">
 <h2 id="building-a-neural-network-from-scratch"> <a href="#building-a-neural-network-from-scratch">Building a neural network from scratch</a></h2>
 <div class="outline-text-2" id="text-building-a-neural-network-from-scratch">
 <p>
We'll build an MLP using nothing but jax. We will train this on MNIST. To
load the data I'm using the  <a href="https://pypi.org/project/jax-dataloader/">jax-dataloader</a> library.
</p>

 <div class="org-src-container">
 <pre class="src src-python"> <span class="org-keyword">import</span> jax_dataloader  <span class="org-keyword">as</span> jdl
 <span class="org-keyword">from</span> torchvision.datasets  <span class="org-keyword">import</span> MNIST

 <span class="org-variable-name">pt_ds</span>  <span class="org-operator">=</span> MNIST( <span class="org-string">"/tmp/mnist"</span>, download <span class="org-operator">=</span> <span class="org-constant">True</span>, transform <span class="org-operator">=</span> <span class="org-keyword">lambda</span> x: np.array(x, np.float32), train <span class="org-operator">=</span> <span class="org-constant">True</span>)
 <span class="org-variable-name">train_dataloader</span>  <span class="org-operator">=</span> jdl.DataLoader(pt_ds, backend <span class="org-operator">=</span> <span class="org-string">"pytorch"</span>, batch_size <span class="org-operator">=</span>128, shuffle <span class="org-operator">=</span> <span class="org-constant">True</span>)
 <span class="org-variable-name">pt_ds</span>  <span class="org-operator">=</span> MNIST( <span class="org-string">"/tmp/mnist"</span>, download <span class="org-operator">=</span> <span class="org-constant">True</span>, transform <span class="org-operator">=</span> <span class="org-keyword">lambda</span> x: np.array(x, np.float32), train <span class="org-operator">=</span> <span class="org-constant">False</span>)
 <span class="org-variable-name">test_dataloader</span>  <span class="org-operator">=</span> jdl.DataLoader(pt_ds, backend <span class="org-operator">=</span> <span class="org-string">"pytorch"</span>, batch_size <span class="org-operator">=</span>128, shuffle <span class="org-operator">=</span> <span class="org-constant">True</span>)
</pre>
</div>

 <p>
The jax library have some helpful functions for building neural networks. Here
we create parameters and define a prediction function which given a  <a href="https://jax.readthedocs.io/en/latest/pytrees.html">pytree</a> of
parameters and an input outputs the predicted logits. Pytrees is a great thing
about jax where it allow us to intuitively and effectively use not only raw
arrays but also tree-like structures of by composing lists, tuples and
dictionaries with each other and arrays as leaves and map over these as if they
were arrays.
</p>
</div>
 <div id="outline-container-creating-the-neural-network" class="outline-3">
 <h3 id="creating-the-neural-network"> <a href="#creating-the-neural-network">Creating the neural network</a></h3>
 <div class="outline-text-3" id="text-creating-the-neural-network">
 <div class="org-src-container">
 <pre class="src src-python"> <span class="org-keyword">from</span> jax.nn  <span class="org-keyword">import</span> relu
 <span class="org-keyword">from</span> jax.nn.initializers  <span class="org-keyword">import</span> glorot_normal
 <span class="org-keyword">from</span> jax.scipy.special  <span class="org-keyword">import</span> logsumexp

 <span class="org-keyword">def</span>  <span class="org-function-name">create_mlp_weights</span>(num_layers:  <span class="org-builtin">int</span>, in_dim:  <span class="org-builtin">int</span>, out_dim:  <span class="org-builtin">int</span>, hidden_dim:  <span class="org-builtin">int</span>, key):
     <span class="org-comment-delimiter"># </span> <span class="org-comment">Create helper function for generating weights and biases in each layer
</span>     <span class="org-keyword">def</span>  <span class="org-function-name">create_layer_weights</span>(in_dim, out_dim, key):
         <span class="org-keyword">return</span> {
             <span class="org-string">"W"</span>: glorot_normal()(key, (in_dim, out_dim)),
             <span class="org-string">"b"</span>: np.zeros(out_dim)
        }
     <span class="org-variable-name">params</span>  <span class="org-operator">=</span> []
     <span class="org-variable-name">key</span>,  <span class="org-variable-name">subkey</span>  <span class="org-operator">=</span> jax.random.split(key)
     <span class="org-comment-delimiter"># </span> <span class="org-comment">Fill out parameter list with dictionary of layer-weights and biases
</span>    params.append(create_layer_weights(in_dim, hidden_dim, subkey))
     <span class="org-keyword">for</span> _  <span class="org-keyword">in</span>  <span class="org-builtin">range</span>(1, num_layers):
         <span class="org-variable-name">key</span>,  <span class="org-variable-name">subkey</span>  <span class="org-operator">=</span> jax.random.split(key)
        params.append(create_layer_weights(hidden_dim, hidden_dim, key))
     <span class="org-variable-name">key</span>,  <span class="org-variable-name">subkey</span>  <span class="org-operator">=</span> jax.random.split(key)
    params.append(create_layer_weights(hidden_dim, out_dim, subkey))
     <span class="org-keyword">return</span> params

 <span class="org-keyword">def</span>  <span class="org-function-name">predict</span>(params, x):
     <span class="org-keyword">for</span> layer  <span class="org-keyword">in</span> params[: <span class="org-operator">-</span>1]:
         <span class="org-variable-name">x</span>  <span class="org-operator">=</span> relu(x @ layer[ <span class="org-string">"W"</span>]  <span class="org-operator">+</span> layer[ <span class="org-string">"b"</span>])
     <span class="org-variable-name">logits</span>  <span class="org-operator">=</span> x @ params[ <span class="org-operator">-</span>1][ <span class="org-string">"W"</span>]  <span class="org-operator">+</span> params[ <span class="org-operator">-</span>1][ <span class="org-string">"b"</span>]
     <span class="org-keyword">return</span> logits  <span class="org-operator">-</span> logsumexp(logits)
</pre>
</div>

 <p>
Let's pick some reasonable defaults. We see that all shapes are correct and we have batched the  <kbd>predict</kbd> function.
</p>
 <div class="org-src-container">
 <pre class="src src-python"> <span class="org-variable-name">num_layers</span>  <span class="org-operator">=</span> 3
 <span class="org-variable-name">in_dim</span>  <span class="org-operator">=</span> 28  <span class="org-operator">*</span> 28
 <span class="org-variable-name">out_dim</span>  <span class="org-operator">=</span> 10
 <span class="org-variable-name">hidden_dim</span>  <span class="org-operator">=</span> 128
 <span class="org-variable-name">key</span>  <span class="org-operator">=</span> jax.random.PRNGKey(2023)
 <span class="org-variable-name">params</span>  <span class="org-operator">=</span> create_mlp_weights(num_layers, in_dim, out_dim, hidden_dim, key)
 <span class="org-builtin">print</span>(predict(params, jnp.ones(28  <span class="org-operator">*</span> 28)))

 <span class="org-variable-name">batched_predict</span>  <span class="org-operator">=</span> vmap(predict, in_axes <span class="org-operator">=</span>( <span class="org-constant">None</span>, 0))
 <span class="org-builtin">print</span>(batched_predict(params, jnp.ones((4, 28  <span class="org-operator">*</span> 28))).shape)
 <span class="org-builtin">print</span>( <span class="org-builtin">len</span>(params))
 <span class="org-builtin">print</span>( <span class="org-builtin">type</span>(params[0][ <span class="org-string">"W"</span>]))
</pre>
</div>

 <pre class="example">
[-3.3419425 -1.4851335 -2.5466485 -3.1445212 -1.8924606 -2.5047162
 -2.622343  -2.6072748 -1.5674857 -3.5270252]
(4, 10)
4
<class 'jaxlib.xla_extension.ArrayImpl'>
</pre>
</div>
</div>
 <div id="outline-container-training" class="outline-3">
 <h3 id="training"> <a href="#training">Training</a></h3>
 <div class="outline-text-3" id="text-training">
 <p>
Now we write the helper functions to train this network. In particular we use
the  <a href="https://jax.readthedocs.io/en/latest/pytrees.html">pytree</a> functionality of jax to update the parameters which is a pytree since
it's a list of dictionaries of arrays.
</p>
 <div class="org-src-container">
 <pre class="src src-python"> <span class="org-keyword">import</span> jax.tree_util  <span class="org-keyword">as</span> tree_util

 <span class="org-keyword">def</span>  <span class="org-function-name">one_hot</span>(x, k, dtype <span class="org-operator">=</span>jnp.float32):
   <span class="org-doc">"""Create a one-hot encoding of x of size k."""</span>
   <span class="org-keyword">return</span> jnp.array(x[:,  <span class="org-constant">None</span>]  <span class="org-operator">==</span> jnp.arange(k), dtype)

 <span class="org-type">@jit</span>
 <span class="org-keyword">def</span>  <span class="org-function-name">accuracy</span>(params, images, targets):
   <span class="org-variable-name">target_class</span>  <span class="org-operator">=</span> jnp.argmax(targets, axis <span class="org-operator">=</span>1)
   <span class="org-variable-name">predicted_class</span>  <span class="org-operator">=</span> jnp.argmax(batched_predict(params, images), axis <span class="org-operator">=</span>1)
   <span class="org-keyword">return</span> jnp.mean(predicted_class  <span class="org-operator">==</span> target_class)

 <span class="org-keyword">def</span>  <span class="org-function-name">loss</span>(params, images, targets):
   <span class="org-variable-name">preds</span>  <span class="org-operator">=</span> batched_predict(params, images)
   <span class="org-keyword">return</span>  <span class="org-operator">-</span>jnp.mean(preds  <span class="org-operator">*</span> targets)

 <span class="org-type">@jit</span>
 <span class="org-keyword">def</span>  <span class="org-function-name">update</span>(params, x, y, step_size):
   <span class="org-variable-name">grads</span>  <span class="org-operator">=</span> grad(loss)(params, x, y)
   <span class="org-keyword">return</span> tree_util.tree_map( <span class="org-keyword">lambda</span> w, g: w  <span class="org-operator">-</span> step_size  <span class="org-operator">*</span> g, params, grads)

 <span class="org-variable-name">EPOCHS</span>  <span class="org-operator">=</span> 10
 <span class="org-variable-name">STEP_SIZE</span>  <span class="org-operator">=</span> 10  <span class="org-operator">**</span>  <span class="org-operator">-</span>2
 <span class="org-variable-name">train_acc</span>  <span class="org-operator">=</span> []
 <span class="org-variable-name">train_loss</span>  <span class="org-operator">=</span> []
 <span class="org-variable-name">test_acc</span>  <span class="org-operator">=</span> []
 <span class="org-variable-name">test_loss</span>  <span class="org-operator">=</span> []

 <span class="org-keyword">for</span> epoch  <span class="org-keyword">in</span>  <span class="org-builtin">range</span>(EPOCHS):
   <span class="org-builtin">print</span>( <span class="org-string">'Epoch'</span>, epoch)
   <span class="org-keyword">for</span> image, output  <span class="org-keyword">in</span> train_dataloader:
     <span class="org-variable-name">image</span>,  <span class="org-variable-name">output</span>  <span class="org-operator">=</span> jnp.array(image).reshape( <span class="org-operator">-</span>1, 28  <span class="org-operator">*</span> 28), one_hot(jnp.array(output), 10)
    train_acc.append(accuracy(params, image, output).item())
    train_loss.append(loss(params, image, output).item())
     <span class="org-variable-name">params</span>  <span class="org-operator">=</span> update(params, image, output, STEP_SIZE)
   <span class="org-builtin">print</span>(f <span class="org-string">'Train accuracy: </span>{np.mean(train_acc):.3f} <span class="org-string">'</span>)
   <span class="org-builtin">print</span>(f <span class="org-string">'Train loss: </span>{np.mean(train_loss):.3f} <span class="org-string">'</span>)
   <span class="org-variable-name">_test_acc</span>  <span class="org-operator">=</span> []
   <span class="org-variable-name">_test_loss</span>  <span class="org-operator">=</span> []
   <span class="org-keyword">for</span> image, output  <span class="org-keyword">in</span> test_dataloader:
     <span class="org-variable-name">image</span>,  <span class="org-variable-name">output</span>  <span class="org-operator">=</span> jnp.array(image).reshape( <span class="org-operator">-</span>1, 28  <span class="org-operator">*</span> 28), one_hot(jnp.array(output), 10)
    _test_acc.append(accuracy(params, image, output).item())
    _test_loss.append(loss(params, image, output).item())
  test_acc.append(_test_acc)
  test_loss.append(_test_loss)
   <span class="org-builtin">print</span>(f <span class="org-string">'Test accuracy: </span>{np.mean(test_acc):.3f} <span class="org-string">'</span>)
   <span class="org-builtin">print</span>(f <span class="org-string">'Test loss: </span>{np.mean(test_loss):.3f} <span class="org-string">'</span>)
</pre>
</div>

 <pre class="example" id="org30abb30">
Epoch 0
Train accuracy: 0.788
Train loss: 0.213
Test accuracy: 0.856
Test loss: 0.073
Epoch 1
Train accuracy: 0.832
Train loss: 0.135
Test accuracy: 0.872
Test loss: 0.062
Epoch 2
Train accuracy: 0.856
Train loss: 0.103
Test accuracy: 0.882
Test loss: 0.055
Epoch 3
Train accuracy: 0.872
Train loss: 0.085
Test accuracy: 0.889
Test loss: 0.051
Epoch 4
Train accuracy: 0.883
Train loss: 0.074
Test accuracy: 0.894
Test loss: 0.048
Epoch 5
Train accuracy: 0.892
Train loss: 0.065
Test accuracy: 0.898
Test loss: 0.045
Epoch 6
Train accuracy: 0.899
Train loss: 0.059
Test accuracy: 0.902
Test loss: 0.043
Epoch 7
Train accuracy: 0.905
Train loss: 0.054
Test accuracy: 0.904
Test loss: 0.042
Epoch 8
Train accuracy: 0.910
Train loss: 0.050
Test accuracy: 0.907
Test loss: 0.040
Epoch 9
Train accuracy: 0.914
Train loss: 0.046
Test accuracy: 0.909
Test loss: 0.039
</pre>

 <p>
Finally we plot the learning curves
</p>
 <div class="org-src-container">
 <pre class="src src-python">sns.set_theme( <span class="org-string">"notebook"</span>)
sns.set_style( <span class="org-string">"ticks"</span>)

 <span class="org-variable-name">iterations_per_epoch</span>  <span class="org-operator">=</span>  <span class="org-builtin">len</span>(train_dataloader)

 <span class="org-variable-name">fig</span>,  <span class="org-variable-name">ax</span>  <span class="org-operator">=</span> plt.subplots(2, 1)
ax[0].plot(np.array(train_loss), label <span class="org-operator">=</span> <span class="org-string">"train_loss"</span>)
ax[0].plot((np.arange( <span class="org-builtin">len</span>(test_loss))  <span class="org-operator">+</span> 1)  <span class="org-operator">*</span> iterations_per_epoch, np.array(test_loss).mean( <span class="org-operator">-</span>1), label <span class="org-operator">=</span> <span class="org-string">"test_loss"</span>)
ax[0].set_ylim([0.0, 0.1])
ax[0].legend()

ax[1].plot(np.array(train_acc), label <span class="org-operator">=</span> <span class="org-string">"train_acc"</span>)
ax[1].plot((np.arange( <span class="org-builtin">len</span>(test_acc))  <span class="org-operator">+</span> 1)  <span class="org-operator">*</span> iterations_per_epoch, np.array(test_acc).mean( <span class="org-operator">-</span>1), label <span class="org-operator">=</span> <span class="org-string">"test_acc"</span>)
ax[1].set_ylim([0.8, 1.0])
ax[1].legend()

plt.tight_layout()
</pre>
</div>


 <figure id="org6c60a32"> <img src="../assets/images/intro_to_jax/learning_curves.webp" alt="Loss and accuracy learning curves on train and test set of an MLP on mnist, with the curves doing well" width="800" loading="lazy"></img> <figcaption> <span class="figure-number">Figure 1: </span>Both test and train loss goes down and accuracy goes up as we train for longer</figcaption></figure></div>
</div>
</div>
 <div id="footnotes" class="Footnotes">
 <div id="text-footnotes">

 <div class="footdef"> <sup> <a id="fn.1" class="footnum" href="#fnr.1" role="doc-backlink">1</a></sup> <div class="footpara" role="doc-footnote"> <p class="footpara">
LSP decoupled the implementation of code editing features by allowing the
implementation of a server which editors then used through a frontend. In this
way the frontend implementation relies on a consistent API but does not actually
have to reimplement the server for every editor.
</p></div></div>


</div>
</div></main>]]></content>
  <link href="https://isakfalk.com/notes/intro-to-jax.html"/>
  <id>https://isakfalk.com/notes/intro-to-jax.html</id>
  <updated>2025-12-28T21:51:00-05:00</updated>
</entry>
<entry>
  <title>Notes for showcasing design</title>
  <author><name>Isak Falk</name></author>
  <content type="html"><![CDATA[<main id="content" class="content"> <header> <h1 class="title">Notes for showcasing design</h1>
</header> <nav id="table-of-contents" role="doc-toc"> <details> <summary>Table of Contents</summary> <div id="text-table-of-contents" role="doc-toc">
 <ul> <li> <a href="#what-this-is">What this is</a></li>
 <li> <a href="#markup">Markup</a></li>
 <li> <a href="#headlines">Headlines</a>
 <ul> <li> <a href="#headline-1-level-down">Headline: 1 level down</a></li>
</ul></li>
 <li> <a href="#lists">Lists</a></li>
 <li> <a href="#tables">Tables</a></li>
 <li> <a href="#links">Links</a></li>
 <li> <a href="#tags">Tags</a></li>
 <li> <a href="#dates">Dates</a></li>
 <li> <a href="#blocks">Blocks</a></li>
</ul></div>
</details></nav> <div id="outline-container-what-this-is" class="outline-2">
 <h2 id="what-this-is"> <a href="#what-this-is">What this is</a></h2>
 <div class="outline-text-2" id="text-what-this-is">
 <p>
This note is strictly for showcasing how the things that org-mode has
functionality for translates into the actual html output itself. If you don't
know what org-mode is, you can  <a href="https://orgmode.org/">read about it here</a>. I'll basically copy/paste the
source they used there which show most of the functionality and put it here. The
note you are reading right now is the actual output of the build system using
this org-mode file as a source.
</p>
</div>
</div>
 <div id="outline-container-markup" class="outline-2">
 <h2 id="markup"> <a href="#markup">Markup</a></h2>
 <div class="outline-text-2" id="text-markup">
 <p>
Org is a markup language and can be used for all kinds of things, such as  <em>italics</em>,  <strong>bold</strong>,  <del>strikethrough</del> and  <span class="underline">underline</span>. It can also combine these styles, such as  <strong> <em> <del> <span class="underline">here</span></del></em></strong>. It also has the  <kbd>verbatim</kbd> and  <code>code</code> styles. Additionally we  <strong>may</strong>, if we choose to use the right publishing option, use sub and superscripts, like <sub>this</sub> or like <sup>that</sup>. To make a true underline we can use_for this. We should also have access to special symbols: π and also embed latex \(x^2 = \int \sin(y + z)\) or even do equations
</p>

\begin{equation}
\label{org48ba78a}
x^2 = \int \sin(y + z)
\end{equation}
 <p>
which may even be referenced \eqref{org48ba78a}.
</p>
</div>
</div>
 <div id="outline-container-headlines" class="outline-2">
 <h2 id="headlines"> <a href="#headlines">Headlines</a></h2>
 <div class="outline-text-2" id="text-headlines">
 <p>
Org-mode has headlines, we can descend down.
</p>
</div>
 <div id="outline-container-headline-1-level-down" class="outline-3">
 <h3 id="headline-1-level-down"> <a href="#headline-1-level-down">Headline: 1 level down</a></h3>
 <div class="outline-text-3" id="text-headline-1-level-down">
</div>
 <div id="outline-container-headline-2-levels-down" class="outline-4">
 <h4 id="headline-2-levels-down"> <a href="#headline-2-levels-down">Headline: 2 levels down</a></h4>
 <div class="outline-text-4" id="text-headline-2-levels-down">
</div>
 <ul class="org-ul"> <li> <a id="headline-3-levels-down"></a> <a href="#headline-3-levels-down">Headline: 3 levels down</a> <br></br> <div class="outline-text-5" id="text-headline-3-levels-down">
</div>
 <ul class="org-ul"> <li> <a id="headline-4-levels-down"></a> <a href="#headline-4-levels-down">Headline: 4 levels down</a> <br></br> <div class="outline-text-6" id="text-headline-4-levels-down">
 <p>
Note that due to chosen options during the org-publishing export, levels below some point will be made into lists or just ignored.
</p>
</div>
</li>
</ul></li>
</ul></div>
</div>
</div>
 <div id="outline-container-lists" class="outline-2">
 <h2 id="lists"> <a href="#lists">Lists</a></h2>
 <div class="outline-text-2" id="text-lists">
 <p>
We can create unordered lists like:
</p>
 <ul class="org-ul"> <li>The first list item</li>
 <li>The second</li>
 <li>And so forth</li>
</ul> <p>
If we want, we can also create ordered items as
</p>
 <ol class="org-ol"> <li>First item
 <ul class="org-ul"> <li>We can also nest them</li>
 <li>Oh yeah</li>
</ul></li>
 <li>Second item</li>
</ol> <p>
and so forth.
</p>

 <p>
Finally, we can also make lists with descriptions
</p>
 <dl class="org-dl"> <dt>First element</dt> <dd>This is the first element</dd>
 <dt>Second element</dt> <dd>This is the second element</dd>
</dl></div>
</div>
 <div id="outline-container-tables" class="outline-2">
 <h2 id="tables"> <a href="#tables">Tables</a></h2>
 <div class="outline-text-2" id="text-tables">
 <p>
We can create tables, with captions
</p>

 <table> <caption class="t-above"> <span class="table-number">Table 1:</span> Skills that I have acquired over the years</caption>

 <colgroup> <col class="org-left"></col> <col class="org-right"></col> <col class="org-right"></col></colgroup> <thead> <tr> <th scope="col" class="org-left">Skill</th>
 <th scope="col" class="org-right">Years</th>
 <th scope="col" class="org-right">Level (out of 10)</th>
</tr></thead> <tbody> <tr> <td class="org-left">Webdev</td>
 <td class="org-right">0</td>
 <td class="org-right">2</td>
</tr> <tr> <td class="org-left">ML</td>
 <td class="org-right">9</td>
 <td class="org-right">9</td>
</tr></tbody></table></div>
</div>
 <div id="outline-container-links" class="outline-2">
 <h2 id="links"> <a href="#links">Links</a></h2>
 <div class="outline-text-2" id="text-links">
 <p>
We can link to many, many things in different ways. Internally we can  <a href="#lists">link to other headings</a> and also to other  <a href="index.html">files completely</a> (but the links need to be relative for this to work when exporting to html). Finally, we have links to [BROKEN LINK: yt:SzA2YODtgK4] and  <a href="https://orgmode.org/">more</a>, each one handled by their own way internally by org-mode. However, how these are exported vary and all may not be supported.
</p>

 <p>
We can also link to images and style them for example here setting the width to be 300 pixels
</p>

 <figure id="org570676b"> <img src="../assets/images/me_profile.webp" alt="me_profile.webp" width="300"></img> <figcaption> <span class="figure-number">Figure 1: </span>This is me</figcaption></figure> <p>
We can also link internal targets globally like this  <a href="#org570676b">picture of myself</a>, with a true internal link being this  <a id="org5c6d3b2"></a>. This also works for lists
</p>
 <ol class="org-ol"> <li>One</li>
 <li> <a id="org4331cb2"></a> And two</li>
</ol> <p>
Look how we can link  <a href="#org4331cb2">2</a>!
</p>
</div>
</div>
 <div id="outline-container-tags" class="outline-2">
 <h2 id="tags"> <a href="#tags">Tags    <span class="tag"> <span class="tag1">tag1</span>  <span class="tag2">tag2</span></span></a></h2>
 <div class="outline-text-2" id="text-tags">
 <p>
We can also add tags. Maybe for categorising notes or headings.
</p>
</div>
</div>
 <div id="outline-container-dates" class="outline-2">
 <h2 id="dates"> <a href="#dates">Dates</a></h2>
 <div class="outline-text-2" id="text-dates">
 <p>
We can add dates inline  <span class="timestamp-wrapper"> <span class="timestamp"><2023-11-10 Fri></span></span>.
</p>
</div>
</div>
 <div id="outline-container-blocks" class="outline-2">
 <h2 id="blocks"> <a href="#blocks">Blocks</a></h2>
 <div class="outline-text-2" id="text-blocks">
 <p>
There are a great number of predefined blocks
</p>
 <blockquote>
 <p>
To be, or not to be, that is the question
</p>
</blockquote>

 <p>
We can make notes
</p>
 <div class="notes" id="orgf66c0e0">
 <p>
A quick note
</p>

</div>

 <p>
Or center text
</p>
 <div class="org-center">
 <p>
Let's center some text
</p>
</div>

 <p>
Finally, here is an example
</p>
 <pre class="example" id="org34ee860">
Here is a quick example
</pre>

 <p class="verse">
This is a verse. <br></br></p>

 <p>
But probably source blocks are the most useful ones (?)
</p>
 <div class="org-src-container">
 <pre class="src src-python"> <span class="org-keyword">while</span>  <span class="org-constant">True</span>:
     <span class="org-builtin">print</span>( <span class="org-string">"Emacs ❤️ Org"</span>)
</pre>
</div>
</div>
</div>
 <div id="footnotes" class="Footnotes">
 <div id="text-footnotes">

 <div class="footdef"> <sup> <a id="fn.1" class="footnum" href="#fnr.1" role="doc-backlink">1</a></sup> <div class="footpara" role="doc-footnote"> <p class="footpara">
Here is the footnote. Rendering of course depends on the publish / export.
</p></div></div>

 <div class="footdef"> <sup> <a id="fn.name-of-note" class="footnum" href="#fnr.name-of-note" role="doc-backlink">2</a></sup> <div class="footpara" role="doc-footnote"> <p class="footpara">Here is an inline footnote, wonder what to put in here.</p></div></div>


</div>
</div></main>]]></content>
  <link href="https://isakfalk.com/notes/note-for-showcasing-design.html"/>
  <id>https://isakfalk.com/notes/note-for-showcasing-design.html</id>
  <updated>2025-12-28T21:51:00-05:00</updated>
</entry>
</feed>
